From cb20a446ae5f21c2cc32442ccacde0be0b3060b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0tampar?= Date: Sun, 28 Jun 2026 14:28:42 +0200 Subject: [PATCH] Update of unit tests --- .github/workflows/tests.yml | 8 + .gitignore | 1 + data/txt/sha256sums.txt | 55 +- lib/core/settings.py | 2 +- lib/parse/cmdline.py | 2 +- tests/test_agent_dialects.py | 274 ++++++++++ tests/test_api.py | 619 ++++++++++++++++++++++ tests/test_checks.py | 504 ++++++++++++++++++ tests/test_common_parsers.py | 466 +++++++++++++++++ tests/test_common_utils.py | 340 +++++++++++++ tests/test_compat.py | 290 +++++++++++ tests/test_core_extra.py | 676 +++++++++++++++++++++++++ tests/test_core_final.py | 605 ++++++++++++++++++++++ tests/test_core_more.py | 706 ++++++++++++++++++++++++++ tests/test_databases_enum.py | 511 +++++++++++++++++++ tests/test_dbms_enum.py | 98 ++++ tests/test_dbms_enum_a.py | 215 ++++++++ tests/test_dbms_enum_b.py | 469 +++++++++++++++++ tests/test_deps.py | 113 +++++ tests/test_dialectdbms.py | 11 +- tests/test_dns_engine.py | 119 ++++- tests/test_dns_server.py | 163 +++--- tests/test_dump_format.py | 410 +++++++++++++++ tests/test_filesystem.py | 736 +++++++++++++++++++++++++++ tests/test_fingerprint.py | 203 ++++++++ tests/test_generic_enum_more.py | 865 +++++++++++++++++++++++++++++++ tests/test_generic_more.py | 873 ++++++++++++++++++++++++++++++++ tests/test_graphql.py | 59 ++- tests/test_gui_helpers.py | 118 +++++ tests/test_har.py | 171 +++++++ tests/test_hash_crack.py | 218 ++++++++ tests/test_hashdb.py | 13 +- tests/test_inference.py | 293 +++++++++++ tests/test_ldap.py | 119 +++-- tests/test_option_more.py | 663 ++++++++++++++++++++++++ tests/test_option_setup.py | 739 +++++++++++++++++++++++++++ tests/test_parse_modules.py | 175 +++++++ tests/test_payload_marking.py | 163 ++++-- tests/test_progress.py | 78 +++ tests/test_purge.py | 124 +++++ tests/test_search_enum.py | 475 +++++++++++++++++ tests/test_sgmllib.py | 267 ++++++++++ tests/test_tamper.py | 40 +- tests/test_target_parsing.py | 521 +++++++++++++++++++ tests/test_techniques.py | 769 ++++++++++++++++++++++++++++ tests/test_techniques_more.py | 540 ++++++++++++++++++++ tests/test_threads.py | 171 +++++++ tests/test_users_enum.py | 256 ++++++++++ 48 files changed, 15116 insertions(+), 190 deletions(-) create mode 100644 tests/test_agent_dialects.py create mode 100644 tests/test_api.py create mode 100644 tests/test_checks.py create mode 100644 tests/test_common_parsers.py create mode 100644 tests/test_common_utils.py create mode 100644 tests/test_compat.py create mode 100644 tests/test_core_extra.py create mode 100644 tests/test_core_final.py create mode 100644 tests/test_core_more.py create mode 100644 tests/test_databases_enum.py create mode 100644 tests/test_dbms_enum.py create mode 100644 tests/test_dbms_enum_a.py create mode 100644 tests/test_dbms_enum_b.py create mode 100644 tests/test_deps.py create mode 100644 tests/test_dump_format.py create mode 100644 tests/test_filesystem.py create mode 100644 tests/test_fingerprint.py create mode 100644 tests/test_generic_enum_more.py create mode 100644 tests/test_generic_more.py create mode 100644 tests/test_gui_helpers.py create mode 100644 tests/test_har.py create mode 100644 tests/test_hash_crack.py create mode 100644 tests/test_inference.py create mode 100644 tests/test_option_more.py create mode 100644 tests/test_option_setup.py create mode 100644 tests/test_parse_modules.py create mode 100644 tests/test_progress.py create mode 100644 tests/test_purge.py create mode 100644 tests/test_search_enum.py create mode 100644 tests/test_sgmllib.py create mode 100644 tests/test_target_parsing.py create mode 100644 tests/test_techniques.py create mode 100644 tests/test_techniques_more.py create mode 100644 tests/test_threads.py create mode 100644 tests/test_users_enum.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e5629645b..7f3268e69 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -99,6 +99,14 @@ jobs: # 'binary' instead of 'text'. Keeping this step byte-compile-free leaves --smoke clean. run: python -B -m unittest discover -s tests -p "test_*.py" + - name: Coverage + if: matrix.python-version != 'pypy-2.7' + run: | + python -m pip install coverage + python -m coverage run --source=lib,plugins,tamper -m unittest discover -s tests -p "test_*.py" + python -m coverage run -a --source=lib,plugins,tamper sqlmap.py --doc-test + python -m coverage report --fail-under=50 + - name: Smoke test run: python sqlmap.py --smoke-test diff --git a/.gitignore b/.gitignore index 78c5d1d9b..07ca46e6e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ lib/.DS_Store plugins/.DS_Store thirdparty/.DS_Store CLAUDE.md +.coverage diff --git a/data/txt/sha256sums.txt b/data/txt/sha256sums.txt index 8767f793a..d8d8b6631 100644 --- a/data/txt/sha256sums.txt +++ b/data/txt/sha256sums.txt @@ -189,7 +189,7 @@ e033b20a0f7821797a10f4bf4235723f38c7db551c611fbb713faa621b123c4a lib/core/optio 9bf174058f15d14e24e94f9aaf42df045119d3617c6c54bd2f3af79b462f331d lib/core/replication.py 0b8c38a01bb01f843d94a6c5f2075ee47520d0c4aa799cecea9c3e2c5a4a23a6 lib/core/revision.py 888daba83fd4a34e9503fe21f01fef4cc730e5cde871b1d40e15d4cbc847d56c lib/core/session.py -1a569b5bcd33ae45d95c140fd3bae2f12ad54640d938172de3cb99f73a549b47 lib/core/settings.py +5cbf5f4bc21f21873df79babd91da8f7fea5ec3c1999f108f005ca6fb4d453b6 lib/core/settings.py c7804223319e18eb0b8e2cbf0a8b6896d1cefb7b0b1a2e9f1cf826a8a3b56750 lib/core/shell.py a2e98a94b231432736d6b304fc75525c8b5fdb4768c418387c5b4c1a610dad64 lib/core/subprocessng.py 19f1e3c5e3ba703d28d510cd7a9ab8284d5fbe9df5ce7e77c86e5931571364b7 lib/core/target.py @@ -200,7 +200,7 @@ b9aacb840310173202f79c2ba125b0243003ee6b44c92eca50424f2bdfc83c02 lib/core/unesc 2400e465fa4d13e4c32795910878c71ff212e4361b46428d57ce43983f5e997c lib/core/wordlist.py 1966ca704961fb987ab757f0a4afddbf841d1a880631b701487c75cef63d60c3 lib/__init__.py 54bfd31ebded3ffa5848df1c644f196eb704116517c7a3d860b5d081e984d821 lib/parse/banner.py -a6440d24f8d6b772221fc78a655d3df07a000ba23e7924bd51cf5068097ee1fb lib/parse/cmdline.py +8351588876a7579fa96b3ab860ef2254487de34ea624c0a7696f2428c24ceb98 lib/parse/cmdline.py 02d82e4069bd98c52755417f8b8e306d79945672656ac24f1a45e7a6eff4b158 lib/parse/configfile.py c5b258be7485089fac9d9cd179960e774fbd85e62836dc67cce76cc028bb6aeb lib/parse/handler.py 5c9a9caee948843d5537745640cc7b98d70a0412cc0949f59d4ebe8b2907c06c lib/parse/headers.py @@ -579,50 +579,85 @@ dcdeed9ee285e63cf06baf8347e3db7f210ef25a63869bab78ce1ec6898ae191 tamper/unional 0694e721b07b8242245688be5c7951a3a22f512ed73776a998885e4b1bc82bc7 tamper/versionedmorekeywords.py ce1b6bf8f296de27014d6f21aa8b3df9469d418740cd31c93d1f5e36d6c509cf tamper/xforwardedfor.py 44401cad3e39ae9fb899ed5d0e2fdd0879561de05c3117f17f3b0db54f4e3724 tests/__init__.py +d2c27dff782dbe119a4cb5041f374d87b67e3da523ee3a7ad584d34721b6c564 tests/test_agent_dialects.py bfb553602eb5d20b4ab5928dbcf8e6a3e7e5ff69f7d30d1f53ef6d323c237f6c tests/test_agent.py +138381e05a860272fedab780e6c38ab74c59c879048b11b909d23f8df654352a tests/test_api.py feb763ddcbf4f32822372ca53f8c71c754af7b72510ef06e1e9c77927fc90b10 tests/test_bigarray.py 27ad87c0ea377e0657bd6f6a4eaa0e9756aa9d28ec0483bdadeb3f66dcc4660d tests/test_charset.py +c99b77cc5d85334f147a1a6d4b2867af396f70e9f2609f8587344e084910e893 tests/test_checks.py 9e678a56e16211c49ab4995b6c658d3f122bfa3b357d9e17ff38f5a489ace6ad tests/test_cloak.py 2ec894f49ca9bd750a23ead16dae176bcbc57d18ec5847fa4a5eeb886d75c1bd tests/test_common_helpers.py +c6338f74230b758cb41adacf4f04593e70b4b11e054ea0b35712607a781e0d55 tests/test_common_parsers.py +b1540c5f2be80ee3d870d7c373adfca23f33adb06724db00335adbd79bea4272 tests/test_common_utils.py 899bc085e96d68f8a8cbe0d7e55863e98ef37b73ab0e4234f7d969e31ea2d23a tests/test_comparison_json.py 7b72d4f850bbd059b8e95fceb45a58470354cb7270c99b0e9981aaa189af20d1 tests/test_comparison.py +a0a29231acbbe6bec11400e28b39b76eaf812c03bf79d5f0dbdd68cd54a052f8 tests/test_compat.py 75357efd92f3f57cc05244a0f40985108077479fd192caaaa81e14f61c13783d tests/test_convert.py +d2c52b1c9b0f31e2d30e1fc3942986692a815e76fa8e39903c3824d6d6d0ee71 tests/test_core_extra.py +7c6d542bf96e8962ecdf8607f93e84babe4820045533bded170955e95727d630 tests/test_core_final.py +e42f6dd46fa7f2d1e666116e2244fa02e7b9d930a005e2bbeea89cfe3f2215b6 tests/test_core_more.py +951822c0d6ea62dc91cc4a7614059788b256cac06167f4767721f2ad5d54a78b tests/test_databases_enum.py c17544be5e945dc8c4fbb5c3b922da8eceec30b0fb239c32fb5f40e1660a197f tests/test_datafiles.py 9c240d4f796e56376374d4ce46f358ceb7d48cc6a7427760c5bfb89ff01cb545 tests/test_datatypes.py +c9f7c5219e379b0242914f79f1e5d3b8b7d1a4c5e9f77cd05d0ec382d4fbed88 tests/test_dbms_enum_a.py +866978b7d5d0270a54465897932fe645c7e0360d73b0e4086540558c107e680d tests/test_dbms_enum_b.py +a3628b7f22dcc0ff4cd9ed8a1e70519a340f40fa4d73e9220c7d11f5088d9c01 tests/test_dbms_enum.py 3804eb2d730220360f9dc07d5994eb64e9f65acf3b0d8648df8df2a2177ba8fd tests/test_decodepage.py -b6d8a4bc9c46a332a2dc7b3cf862ea67e38b5c5701cfd8eb3556021f6b611416 tests/test_dialectdbms.py +8e469e4e29319bcb718803a9e109e742965875c985fa8e8d3bb5b18c922ec597 tests/test_deps.py +b01343eb8aa42ea5c2c483ec028a24f6451aa6f668fdc0c289d5ff9554c277d7 tests/test_dialectdbms.py e40a49cfa73c45b3c3c6d1d1d00738861e270cb7a07b28f5a5356f9c7c800cf2 tests/test_dialect.py 993a2d4d87c4fbaf261663b069629acc95ee4405aa0c42cf5a8f39649fdb0fff tests/test_dicts.py -ed5a0e453b811dc3dcc5ca28e14a9d7552aacaa7e316e1bca1b042dc5939e204 tests/test_dns_engine.py -703faac01f38224ba85bd0fc398d939ea034f1d7fd641cdc15da4f77ec049443 tests/test_dns_server.py +7f9180a53dbf0bb3e52801fdbfffd31f365a0bff77bf90e58d2ef63a0c23026f tests/test_dns_engine.py +ec58ba0849d90d2bb7580fe2b8b96cd8299ddfc25f14dc27d9de9d41f152c78a tests/test_dns_server.py +4556bb0bfa6fcd5b98552426c57c99942ee8274eaefec7c316fd64247e4fcd6a tests/test_dump_format.py 9cd5841349bc4db818658d12184929a96f7f279eff1f53ad18a54dbefbd6b276 tests/test_dump_jsonl.py 2bbe4b01f79992cfa8884651fc0a28dbd0e3abb0cbea9eb7eadf1f98ca3c3420 tests/test_encoding.py bb6991260a994fcbe79e05febaa34affd5631d02299fbc626820addd5f6ea4f4 tests/test_error_engine.py -4a5f9392b7fec7b40c4d865b83306b58b76f3423cebc2876e6e75fb91b037202 tests/test_graphql.py -8105de9978fe286a29f6b635a58db1e9998d86e8dded54d7efdfb9d52a121094 tests/test_hashdb.py +31354d3cff0d26ecf3b42e949a2780ae3d286cdf206b59404e18a96e7a2cddd2 tests/test_filesystem.py +6a9d95f64c7892957742534a14e8f094c6ed9ebc91b7059f4f1665049228a5a6 tests/test_fingerprint.py +4f3cfb830b323a3423b0f80985b9a0bbbe4ef77350b762f103dcd8936cca67c6 tests/test_generic_enum_more.py +9874920d18fc30736630df6b14a70b230504d2e4d0c035971a9aa285ee623839 tests/test_generic_more.py +bde97a4781c4ee84e0fe86f7a33206f114167eb14b704013ecf1c26b838193d7 tests/test_graphql.py +50b71422ee91b9a4864f4d5ce6c9bdf169dc5f57ed1db05c152eb010c282136b tests/test_gui_helpers.py +92648f2fe81e22c5726b198bbbda14961cd4d3294a0d9139dcea808b324142ac tests/test_har.py +da2efd1b7457ff619d98a2ae5045f072fdd34be2aa1c18f17d74d7518eeb6707 tests/test_hash_crack.py +0336c875dd2b6554bff6eafd746229e38c69ca8070cd933d45cf27c82ef3e05f tests/test_hashdb.py c04e8358fb6df45f69f2f26435c971acde280535bf304e84d30cf2681158c6a7 tests/test_hash.py d539d0ae758b5bb91e314ab82ab4fe03d6fb2f8b377d16aefa6d7d1d77a7d5a9 tests/test_identifiers_output.py 5372270b7ed82b62f273c2e9bd1f7ecd8605371e66cd0ad70663762cb08d42f1 tests/test_inference_engine.py -13d0369f3fea7262f7944999f559da38e5284cbc76660fd7aeffedad78e65f5f tests/test_ldap.py +280afe64cabac3a737d2574f4e2873760c3883eaac1b7ba0f8fed4b82b91c9c2 tests/test_inference.py +0fc7bd9bae4fbd09f51027780b7a8e72eab73810dccdfdf87ed9e489e6e671c9 tests/test_ldap.py caa06fed7323b2bb6d0f2443ce343de94f75bf8ad012c055d5e07741d908ebad tests/test_misc.py 790b78c600b61eb0bdd6e07e14b1db3eb2ddd5fc5d4edb9e975f85ced38558c7 tests/test_nosql.py 88a8c7ce0ba0ca721dffbcf9351cd07f7e471ad2fe667a10608c18952b09868d tests/test_openapi_drift.py +647d782395fe88dcda775808b9988a0809b208d1df9412d89dc8b6809bd15de6 tests/test_option_more.py +a5743989442de51b3689b30c27118249502bb462788abeeb1ddb27cd176cd363 tests/test_option_setup.py cde0bea1263ae857561f91ed2bd515e972b716743f017d31b1718a8546c72759 tests/test_pagecontent.py -4bac34af2abddce003756d6776e89b2fda220bb7603ef3761f4f37ee29f9c369 tests/test_payload_marking.py +7554a918309cf0f2cd8a63a3bb7659708f13beffbcd5ce498ece9f9167d55c97 tests/test_parse_modules.py +064617c6a3d28ecd75136318b4f515ab1adefbf830da17667f105337b419c184 tests/test_payload_marking.py 6bfc8201724078bd9d6d559916ef73c9ff97e19b0f2948f37e588a49b027795f tests/test_payloads_structure.py +d6ffa83bd56ae98e7f55307b72dd7ea4802bccea9a85bb8f062619fb0a88913e tests/test_progress.py a6d013104601c0414628aff3d8b5b69bee3e6733781d8f8da880457d8b44bd3a tests/test_property.py +c4c6f500bb71c3e430da343a49e8c8b8b3c919f438b6e6130597ce68dd856487 tests/test_purge.py 2dfefb4bfaee3868152835502ec43da317c4f274b1d55cd2ef21e4f7390c9bea tests/test_replication.py 67a5241aeebc20eb1c20cfc490422a59af5179040824e5731bd785db2e6bf750 tests/test_report.py cec98d72992c0799229a780fa7f0d7f3fb01ec2d708187ce0e4a05c8612f291b tests/test_safe2bin.py +d4f6e60c23db67430cf68dc2d90317d69391a19feff0f842c08ae2443b481857 tests/test_search_enum.py a1c6cda1e5b483f61e6a4f8ddd0b06a15ddaa3fd2119bfb9dbd9cc970d7a751d tests/test_settings_regex.py +d6bcba7232fff834737c094679c92e7a69cab5721bc87cb10bcab868c6a8115f tests/test_sgmllib.py d3d991331096e16e5019de3d652e9fff92c09bd9f97c50b1c2c3ceb0ed49b17e tests/test_sqlparse.py 8bcbf1091134dd0a62f6201f8b3645ed87b5ff2f7ba40a87231a29dac412591f tests/test_strings.py -f3a628db8a3e05baee580c02132e95b164695e4b3ee1785707e3ea148702449a tests/test_tamper.py +8f1c5f0f337ecd26d35c5551060034e0aa33a62cce5385fc1227fdc485f6383e tests/test_tamper.py +44954b916f1e4a4bb217516a65cf330fca922600d484f732525e0e4a2a553167 tests/test_target_parsing.py b3e13febe9e0ff6f97334f2868655bfdbaa18755e464a6dc4c6d424f513bad02 tests/test_targeturl.py +d070a72ae9529182d6dfc0884f7720d42a5f0cd8cd865dd4c2d209389c3ade85 tests/test_techniques_more.py +f2e8b5b9799f4e591462f53a97bb643c6399acf703f33e119c03d991971274ab tests/test_techniques.py 639851dc68f62b559b200b09c308e64e453f414969940005bac75dc0ab07a6b6 tests/test_texthelpers.py +f49bcce1df533ffa1acfd02af43faf6687b21eebda9362ceb1e5871b8cb37fd4 tests/test_threads.py 708b3c040f8b677a84020dd6f7c4242f77260b3c6d2697fe8189e1881b0e1365 tests/test_union_engine.py 48b0ae4abe0fdde8ce4975c5cbf4c3514a2815021cb2e3a490a189bea5edfe78 tests/test_unpickle_security.py 4b646f513c6da1e33200184ed6eabe0aa345eb2e2a19598dc123e191168591bf tests/test_urls.py +e7793907ce4dad9034d61f2a3cdfec8af33b96f8e6f67138b09daf81a825c13f tests/test_users_enum.py 23ffd75b5aec33066e6d6aad01ab2c9c1b12ee20c1a0990f8f1be81f1ad16161 tests/_testutils.py 2364db35025a53ea4e5a0a80c034997642785f7e6d1566d0d0f1db959fe3c82e tests/test_utils.py 93ef9944effc62d4f744c57bd643137c90fd92205c6a6cbe891e0e99efb80a7f tests/test_wafbypass.py diff --git a/lib/core/settings.py b/lib/core/settings.py index bb3a3ada1..79a4e7ea0 100644 --- a/lib/core/settings.py +++ b/lib/core/settings.py @@ -20,7 +20,7 @@ from lib.core.enums import OS from thirdparty import six # sqlmap version (...) -VERSION = "1.10.6.185" +VERSION = "1.10.6.186" TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable" TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34} VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE) diff --git a/lib/parse/cmdline.py b/lib/parse/cmdline.py index ea79f3115..72e43e1e6 100644 --- a/lib/parse/cmdline.py +++ b/lib/parse/cmdline.py @@ -898,7 +898,7 @@ def cmdLineParser(argv=None): parser.add_argument("--non-interactive", dest="nonInteractive", action="store_true", help=SUPPRESS) - parser.add_argument("--smoke-test", dest="smokeTest", action="store_true", + parser.add_argument("--smoke-test", "--doc-test", dest="smokeTest", action="store_true", help=SUPPRESS) parser.add_argument("--vuln-test", dest="vulnTest", action="store_true", diff --git a/tests/test_agent_dialects.py b/tests/test_agent_dialects.py new file mode 100644 index 000000000..72b9007a5 --- /dev/null +++ b/tests/test_agent_dialects.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Cross-dialect exercise of lib/core/agent.py payload-assembly helpers. + +agent.py builds SQL payloads from per-DBMS dialect templates (queries.xml). +The helpers are pure given the identified back-end DBMS, so driving each one +across EVERY supported dialect walks the dialect-specific branches (CAST forms, +concatenation operators, LIMIT/TOP/ROWNUM shapes, ...) without a live target. + +These are smoke-level assertions (right type, dialect tokens present) rather than +golden strings: the goal is to traverse the dialect branches the single-DBMS +tests in test_agent.py do not reach. +""" + +import os +import re +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.agent import agent +from lib.core.data import queries + +DIALECTS = sorted(queries.keys()) + +# --------------------------------------------------------------------------- # +# Per-dialect expectation maps (keyed by the DBMS display name == queries key). +# +# These were derived by inspecting the actual agent.py output for every dialect +# (the queries.xml templates drive the branches). They pin the *distinctive* +# dialect token so an assertion fails if the dialect branch collapses to the +# wrong form (e.g. concat operator swapped, null-wrapper dropped). +# --------------------------------------------------------------------------- # + +# concatQuery / simpleConcatenate join operator per dialect. +CONCAT_OPERATOR = { + "ClickHouse": "CONCAT(", + "Informix": "CONCAT(", + "MySQL": "CONCAT(", + "SAP MaxDB": "CONCAT(", + "Microsoft SQL Server": "+", + "Sybase": "+", + "Microsoft Access": "&", +} +# everything not listed above uses the SQL standard "||" +CONCAT_OPERATOR_DEFAULT = "||" + +# nullAndCastField / nullCastConcatFields NULL-wrapper function per dialect. +NULL_WRAPPER = { + "Altibase": "NVL", + "Apache Derby": "COALESCE", + "ClickHouse": "ifNull", + "CrateDB": "COALESCE", + "Cubrid": "IFNULL", + "Firebird": "COALESCE", + "FrontBase": "COALESCE", + "H2": "IFNULL", + "HSQLDB": "IFNULL", + "IBM DB2": "COALESCE", + "Informix": "NVL", + "InterSystems Cache": "COALESCE", + "Mckoi": "IF(", + "Microsoft Access": "IIF", + "Microsoft SQL Server": "ISNULL", + "MimerSQL": "COALESCE", + "MonetDB": "COALESCE", + "MySQL": "IFNULL", + "Oracle": "NVL", + "PostgreSQL": "COALESCE", + "Presto": "COALESCE", + "Raima Database Manager": "IFNULL", + "SAP MaxDB": "VALUE", + "SQLite": "COALESCE", + "Snowflake": "NVL", + "Spanner": "IFNULL", + "Sybase": "ISNULL", + "Vertica": "COALESCE", + "Virtuoso": "__MAX_NOTNULL", + "eXtremeDB": "IFNULL", +} + +# hexConvertField: dialects that DO have a hex function, mapped to its token. +HEX_FUNCTION = { + "Altibase": "HEX_ENCODE(", + "Cubrid": "HEX(", + "H2": "RAWTOHEX(", + "IBM DB2": "HEX(", + "Microsoft SQL Server": "fn_varbintohexstr", + "MySQL": "HEX(", + "Oracle": "RAWTOHEX(", + "PostgreSQL": "ENCODE(", + "Presto": "TO_HEX(", + "SAP MaxDB": "HEX(", + "SQLite": "HEX(", + "Spanner": "TO_HEX(", + "Sybase": "BINTOSTR", + "Vertica": "TO_HEX(", +} +# dialects that intentionally do NOT support hex conversion and return the +# field unchanged (a no-op the old "colname in out" check silently masked). +HEX_NOOP = set(DIALECTS) - set(HEX_FUNCTION) + +# limitQuery: dialects whose limit template is empty so the call legitimately +# raises (no .limit.query). These are skipped by name in the limit-token test. +LIMIT_RAISES = {"Mckoi", "Raima Database Manager"} +# dialects with no special limitQuery branch: the query is returned unchanged +# (no limit token is emitted). +LIMIT_PASSTHROUGH = {"Informix", "Microsoft Access", "SAP MaxDB"} +# broad set of dialect limit tokens; every running, non-passthrough dialect +# emits at least one of these. +LIMIT_TOKENS = ("LIMIT", "TOP", "ROWNUM", "FETCH", "ROWS", "OFFSET", "ROW_NUMBER") + + +class TestNullCastConcatFields(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.nullCastConcatFields("user,password") + self.assertIsInstance(out, str, msg=dbms) + # both column names survive the null/cast/concat rewrite + self.assertIn("user", out, msg=dbms) + self.assertIn("password", out, msg=dbms) + # the dialect-specific NULL-wrapper must be present (the column-name + # check above is always satisfied and so cannot catch a broken + # branch); this fails if the wrapper collapses to the wrong form. + self.assertIn(NULL_WRAPPER[dbms], out, msg="%s: %s" % (dbms, out)) + + def test_literal_passthrough(self): + for dbms in DIALECTS: + set_dbms(dbms) + # a bare quoted literal is returned untouched + self.assertEqual(agent.nullCastConcatFields("'abc'"), "'abc'", msg=dbms) + + +class TestNullAndCastField(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.nullAndCastField("colname") + self.assertIsInstance(out, str, msg=dbms) + self.assertIn("colname", out, msg=dbms) + # dialect-specific NULL wrapper (IFNULL/COALESCE/NVL/ISNULL/IIF/...) + self.assertIn(NULL_WRAPPER[dbms], out, msg="%s: %s" % (dbms, out)) + + +class TestHexConvertField(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.hexConvertField("colname") + self.assertIsInstance(out, str, msg=dbms) + self.assertIn("colname", out, msg=dbms) + if dbms in HEX_FUNCTION: + # the dialect's hex function wraps the field + self.assertIn(HEX_FUNCTION[dbms], out, msg="%s: %s" % (dbms, out)) + else: + # intentional no-op: the field is returned verbatim. The old + # "colname in out" check masked this; pin the exact identity. + self.assertEqual(out, "colname", msg="%s expected no-op: %s" % (dbms, out)) + + +class TestConcatQuery(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.concatQuery("SELECT user FROM users") + self.assertIsInstance(out, str, msg=dbms) + # concatQuery output is dialect-specific: MySQL/ClickHouse/Informix/ + # SAP MaxDB use CONCAT(...), MSSQL/Sybase use +, Access uses &, and + # the rest use the SQL-standard ||. Assert the right operator so the + # test fails if the dialect collapses to the wrong concatenation. + expected = CONCAT_OPERATOR.get(dbms, CONCAT_OPERATOR_DEFAULT) + self.assertIn(expected, out, msg="%s: %s" % (dbms, out)) + + +class TestSimpleConcatenate(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.simpleConcatenate("a", "b") + self.assertIsInstance(out, str, msg=dbms) + self.assertIn("a", out, msg=dbms) + self.assertIn("b", out, msg=dbms) + + +class TestForgeUnionQuery(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + count = 3 + out = agent.forgeUnionQuery("SELECT user FROM users", -1, count, None, + None, None, "NULL", None) + self.assertIsInstance(out, str, msg=dbms) + self.assertIn("UNION", out.upper(), msg=dbms) + # position -1 with char NULL fills every one of the `count` columns + # with the char, so the NULL char must appear exactly `count` times. + # (a hardcoded "UNION in out" check could not catch a wrong column + # count.) Match NULL as a whole token to avoid matching substrings. + self.assertEqual(re.findall(r"\bNULL\b", out).__len__(), count, + msg="%s expected %d NULLs: %s" % (dbms, count, out)) + + +class TestLimitQuery(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + + # Only Mckoi/Raima have an empty limit template and legitimately + # raise; skip exactly those by name rather than swallowing *any* + # exception (which would hide a real regression in another dialect). + if dbms in LIMIT_RAISES: + with self.assertRaises(Exception, msg=dbms): + agent.limitQuery(0, "SELECT user FROM users", "user") + continue + + out = agent.limitQuery(0, "SELECT user FROM users", "user") + self.assertIsInstance(out, str, msg=dbms) + + if dbms in LIMIT_PASSTHROUGH: + # these dialects have no dedicated limitQuery branch and return + # the query unchanged (documented no-op). + self.assertEqual(out, "SELECT user FROM users", msg=dbms) + else: + # every other running dialect emits a real limit construct + self.assertTrue(any(tok in out.upper() for tok in LIMIT_TOKENS), + msg="%s missing limit token: %s" % (dbms, out)) + + +class TestForgeCaseStatement(unittest.TestCase): + def test_all_dialects(self): + for dbms in DIALECTS: + set_dbms(dbms) + out = agent.forgeCaseStatement("1=1") + self.assertIsInstance(out, str, msg=dbms) + # dialects vary on the conditional form (CASE / IIF / IF); the + # condition itself is always embedded + self.assertIn("1=1", out, msg=dbms) + # ...but the conditional construct itself must also be present, + # otherwise the "1=1" check alone could pass on a degenerate output. + self.assertTrue("CASE" in out or "IIF" in out or "IF(" in out, + msg="%s missing conditional construct: %s" % (dbms, out)) + + +class TestPrefixSuffixAcrossDialects(unittest.TestCase): + def test_prefix_suffix(self): + for dbms in DIALECTS: + set_dbms(dbms) + prefix = agent.prefixQuery("1=1") + suffix = agent.suffixQuery("1=1") + self.assertIsInstance(prefix, str, msg=dbms) + self.assertIsInstance(suffix, str, msg=dbms) + # prefixQuery pads a leading space ahead of the expression by default + self.assertEqual(prefix, " 1=1", msg="%s prefix: %r" % (dbms, prefix)) + # suffixQuery returns the expression itself (no extra clause/comment) + self.assertEqual(suffix, "1=1", msg="%s suffix: %r" % (dbms, suffix)) + + +class TestRunAsDBMSUserAndWhere(unittest.TestCase): + def test_run_as_user_noop_without_conf(self): + for dbms in DIALECTS: + set_dbms(dbms) + # without conf.dbmsCred the query is returned unchanged + self.assertEqual(agent.runAsDBMSUser("SELECT 1"), "SELECT 1", msg=dbms) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 000000000..a76d814d6 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Unit tests for the sqlmap REST API (lib/utils/api.py). + +Two complementary angles: + 1. Pure helpers / objects called directly (is_admin, validate_task_options, + the Database and Task classes, the StdDbOut/LogRecorder IPC writers). + 2. The bottle HTTP routes driven through the WSGI app via a minimal in-process + test client (no sockets, no network, no scan subprocess) - task lifecycle, + option get/set, scan status/data/log, admin list/flush, version, auth. + +The scan-data assembler/collector helpers (_storeData / _assembleData / +_sanitizeScanData / _cleanIdentifier / writeReportJson) are pinned separately in +test_report.py; here we focus on what that file does not exercise. +""" + +import io +import json +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +import lib.utils.api as api +from lib.core.data import conf +from lib.core.convert import encodeBase64 +from lib.core.enums import CONTENT_STATUS, CONTENT_TYPE +from thirdparty.bottle.bottle import default_app + + +def _wsgi_call(method, path, body=None, headers=None, remote_addr="127.0.0.1"): + """ + Drive the module's bottle routes through the WSGI interface in-process. + Returns (status_code_int, parsed_json_or_None, raw_text). + """ + + app = default_app() + environ = { + "REQUEST_METHOD": method, + "PATH_INFO": path, + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "REMOTE_ADDR": remote_addr, + "wsgi.input": io.BytesIO(), + "wsgi.errors": sys.stderr, + "wsgi.url_scheme": "http", + } + + if body is not None: + data = json.dumps(body).encode("utf-8") + environ["CONTENT_TYPE"] = "application/json" + environ["CONTENT_LENGTH"] = str(len(data)) + environ["wsgi.input"] = io.BytesIO(data) + + for key, value in (headers or {}).items(): + environ["HTTP_%s" % key.upper().replace("-", "_")] = value + + captured = {} + + def start_response(status, response_headers, exc_info=None): + captured["status"] = status + + chunks = app(environ, start_response) + raw = b"".join(chunks).decode("utf-8", "replace") + code = int(captured["status"].split(" ", 1)[0]) + + try: + parsed = json.loads(raw) + except ValueError: + parsed = None + + return code, parsed, raw + + +class _ApiServerCase(unittest.TestCase): + """ + Stands up just enough of the API server state (IPC database + DataStore globals) + to drive the routes, snapshotting and restoring every global it touches. + """ + + def setUp(self): + conf.batch = True + + # snapshot mutated globals + self._saved = { + "current_db": api.DataStore.current_db, + "tasks": api.DataStore.tasks, + "admin_token": api.DataStore.admin_token, + "username": api.DataStore.username, + "password": api.DataStore.password, + "filepath": api.Database.filepath, + } + + # fresh in-memory IPC database (same init the server() function performs) + self.db = api.Database(":memory:") + self.db.connect() + self.db.init() + + api.DataStore.current_db = self.db + api.DataStore.tasks = {} + api.DataStore.admin_token = "a" * 32 + api.DataStore.username = None + api.DataStore.password = None + api.Database.filepath = ":memory:" + + def tearDown(self): + try: + self.db.disconnect() + except Exception: + pass + + api.DataStore.current_db = self._saved["current_db"] + api.DataStore.tasks = self._saved["tasks"] + api.DataStore.admin_token = self._saved["admin_token"] + api.DataStore.username = self._saved["username"] + api.DataStore.password = self._saved["password"] + api.Database.filepath = self._saved["filepath"] + + def _new_task(self): + code, parsed, _ = _wsgi_call("GET", "/task/new") + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + return parsed["taskid"] + + +# --------------------------------------------------------------------------- +# Pure helpers / objects +# --------------------------------------------------------------------------- + +class TestGenericHelpers(unittest.TestCase): + def setUp(self): + self._saved_token = api.DataStore.admin_token + + def tearDown(self): + api.DataStore.admin_token = self._saved_token + + def test_is_admin_constant_time_compare(self): + api.DataStore.admin_token = "deadbeef" + self.assertTrue(api.is_admin("deadbeef")) + self.assertFalse(api.is_admin("deadbeer")) + self.assertFalse(api.is_admin(None)) + self.assertFalse(api.is_admin("")) + + +class TestValidateTaskOptions(unittest.TestCase): + def setUp(self): + self._saved_tasks = api.DataStore.tasks + api.DataStore.tasks = {"t1": api.Task("t1", "127.0.0.1")} + + def tearDown(self): + api.DataStore.tasks = self._saved_tasks + + def test_non_dict_rejected(self): + msg = api.validate_task_options("t1", ["level"], "scan_start") + self.assertEqual(msg, "Invalid JSON options") + + def test_unsupported_option_rejected(self): + # reportJson is in RESTAPI_UNSUPPORTED_OPTIONS + msg = api.validate_task_options("t1", {"reportJson": "x.json"}, "scan_start") + self.assertIn("Unsupported option", msg) + self.assertIn("reportJson", msg) + + def test_readonly_option_rejected(self): + # taskid is in RESTAPI_READONLY_OPTIONS + msg = api.validate_task_options("t1", {"taskid": "haxx"}, "option_set") + self.assertIn("Unsupported option", msg) + self.assertIn("taskid", msg) + + def test_unknown_option_rejected(self): + msg = api.validate_task_options("t1", {"nosuchoption": 1}, "option_set") + self.assertIn("Unknown option", msg) + self.assertIn("nosuchoption", msg) + + def test_valid_options_accepted(self): + # a real, supported option returns None (no error message) + self.assertIsNone(api.validate_task_options("t1", {"level": 3, "risk": 2}, "scan_start")) + + +class TestDatabase(unittest.TestCase): + """The IPC Database wrapper: connect/init schema, execute SELECT vs DML, disconnect.""" + + def setUp(self): + self.db = api.Database(":memory:") + self.db.connect("test") + self.db.init() + + def tearDown(self): + self.db.disconnect() + + def test_init_creates_expected_schema(self): + names = set(row[0] for row in self.db.execute("SELECT name FROM sqlite_master WHERE type='table'")) + self.assertTrue({"logs", "data", "errors"}.issubset(names)) + + def test_init_is_idempotent(self): + # "CREATE TABLE IF NOT EXISTS" - running init twice must not raise + self.db.init() + + def test_execute_select_returns_rows_dml_returns_none(self): + self.assertIsNone(self.db.execute("INSERT INTO errors VALUES(NULL, ?, ?)", ("t1", "boom"))) + rows = self.db.execute("SELECT taskid, error FROM errors") + self.assertEqual(rows, [("t1", "boom")]) + + def test_disconnect_is_safe_without_connection(self): + fresh = api.Database(":memory:") # never connected + fresh.disconnect() # must not raise + + +class TestTask(unittest.TestCase): + """The Task object: option defaults, set/get/reset, and the no-process engine paths.""" + + def test_initialize_options_sets_api_markers(self): + t = api.Task("abc123", "10.0.0.1") + self.assertEqual(t.remote_addr, "10.0.0.1") + self.assertIs(t.options.api, True) + self.assertEqual(t.options.taskid, "abc123") + self.assertIs(t.options.batch, True) + self.assertIs(t.options.disableColoring, True) + self.assertIs(t.options.eta, False) + + def test_set_get_reset_options(self): + t = api.Task("abc123", "10.0.0.1") + original_level = t.get_option("level") + t.set_option("level", original_level + 4) + self.assertEqual(t.get_option("level"), original_level + 4) + t.reset_options() + self.assertEqual(t.get_option("level"), original_level) + + def test_get_options_returns_attribdict(self): + t = api.Task("abc123", "10.0.0.1") + opts = t.get_options() + self.assertIs(opts, t.options) + self.assertIn("level", opts) + + def test_engine_paths_without_process(self): + t = api.Task("abc123", "10.0.0.1") + self.assertIsNone(t.engine_process()) + self.assertIsNone(t.engine_get_id()) + self.assertIsNone(t.engine_get_returncode()) + self.assertFalse(t.engine_has_terminated()) + self.assertIsNone(t.engine_stop()) + self.assertIsNone(t.engine_kill()) + + +class TestStdDbOutAndLogRecorder(unittest.TestCase): + """ + StdDbOut and LogRecorder write engine output/logs into the IPC database + (conf.databaseCursor). Verify both write paths land the expected rows. + """ + + def setUp(self): + self.db = api.Database(":memory:") + self.db.connect("client") + self.db.init() + self._saved = { + "stdout": sys.stdout, + "stderr": sys.stderr, + "databaseCursor": conf.get("databaseCursor"), + "taskid": conf.get("taskid"), + "partRun": getattr(__import__("lib.core.data", fromlist=["kb"]).kb, "partRun", None), + } + conf.databaseCursor = self.db + conf.taskid = "t1" + + def tearDown(self): + sys.stdout = self._saved["stdout"] + sys.stderr = self._saved["stderr"] + conf.databaseCursor = self._saved["databaseCursor"] + conf.taskid = self._saved["taskid"] + self.db.disconnect() + + def test_stdout_write_stores_typed_data(self): + # StdDbOut hijacks sys.stdout in __init__; restore it immediately and call write() directly + std = api.StdDbOut("t1", messagetype="stdout") + sys.stdout = self._saved["stdout"] + std.write("MySQL >= 5.0", status=CONTENT_STATUS.COMPLETE, content_type=CONTENT_TYPE.DBMS_FINGERPRINT) + rows = self.db.execute("SELECT taskid, status, content_type FROM data") + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], "t1") + self.assertEqual(rows[0][2], CONTENT_TYPE.DBMS_FINGERPRINT) + # the helpers are noops but must not raise + std.flush(); std.close(); std.seek() + + def test_stderr_write_stores_error(self): + std = api.StdDbOut("t1", messagetype="stderr") + sys.stderr = self._saved["stderr"] + std.write("something failed") + rows = self.db.execute("SELECT taskid, error FROM errors") + self.assertEqual(rows, [("t1", "something failed")]) + + def test_logrecorder_emit_stores_log(self): + import logging + rec = api.LogRecorder() + record = logging.LogRecord("sqlmap", logging.INFO, __file__, 1, "hello %s", ("world",), None) + rec.emit(record) + rows = self.db.execute("SELECT taskid, level, message FROM logs") + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], "t1") + self.assertEqual(rows[0][1], "INFO") + self.assertEqual(rows[0][2], "hello world") + + +# --------------------------------------------------------------------------- +# HTTP routes (WSGI test client) +# --------------------------------------------------------------------------- + +class TestVersionRoute(_ApiServerCase): + def test_version(self): + code, parsed, _ = _wsgi_call("GET", "/version") + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + self.assertIn("version", parsed) + self.assertEqual(parsed["api_version"], 2) # MAJOR of RESTAPI_VERSION "2.0.0" + + def test_security_headers_applied(self): + app = default_app() + environ = { + "REQUEST_METHOD": "GET", "PATH_INFO": "/version", + "SERVER_NAME": "localhost", "SERVER_PORT": "80", "REMOTE_ADDR": "127.0.0.1", + "wsgi.input": io.BytesIO(), "wsgi.errors": sys.stderr, "wsgi.url_scheme": "http", + } + captured = {} + + def start_response(status, response_headers, exc_info=None): + captured["headers"] = dict(response_headers) + + b"".join(app(environ, start_response)) + headers = captured["headers"] + self.assertEqual(headers.get("X-Frame-Options"), "DENY") + self.assertEqual(headers.get("X-Content-Type-Options"), "nosniff") + self.assertIn("application/json", headers.get("Content-Type", "")) + + +class TestTaskLifecycle(_ApiServerCase): + def test_new_and_delete(self): + taskid = self._new_task() + self.assertIn(taskid, api.DataStore.tasks) + + code, parsed, _ = _wsgi_call("GET", "/task/%s/delete" % taskid) + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + self.assertNotIn(taskid, api.DataStore.tasks) + + def test_delete_unknown_task_404(self): + code, parsed, _ = _wsgi_call("GET", "/task/deadbeef/delete") + self.assertEqual(code, 404) + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Non-existing task ID") + + +class TestOptionRoutes(_ApiServerCase): + def test_option_list(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/option/%s/list" % taskid) + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + self.assertIn("level", parsed["options"]) + + def test_option_list_invalid_task(self): + code, parsed, _ = _wsgi_call("GET", "/option/nope/list") + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + def test_option_set_then_get(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("POST", "/option/%s/set" % taskid, {"level": 4, "risk": 3}) + self.assertTrue(parsed["success"]) + + code, parsed, _ = _wsgi_call("POST", "/option/%s/get" % taskid, ["level", "risk"]) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["options"], {"level": 4, "risk": 3}) + + def test_option_set_invalid_task(self): + code, parsed, _ = _wsgi_call("POST", "/option/nope/set", {"level": 1}) + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + def test_option_set_rejects_unsupported(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("POST", "/option/%s/set" % taskid, {"reportJson": "x"}) + self.assertFalse(parsed["success"]) + self.assertIn("Unsupported option", parsed["message"]) + + def test_option_get_unknown_option(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("POST", "/option/%s/get" % taskid, ["nosuchoption"]) + self.assertFalse(parsed["success"]) + self.assertIn("Unknown option", parsed["message"]) + + def test_option_get_invalid_task(self): + code, parsed, _ = _wsgi_call("POST", "/option/nope/get", ["level"]) + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + +class TestScanQueryRoutes(_ApiServerCase): + """status/data/log on a task that has never launched a subprocess (no scan started).""" + + def test_status_not_running(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/status" % taskid) + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["status"], "not running") + self.assertIsNone(parsed["returncode"]) + + def test_status_invalid_task(self): + code, parsed, _ = _wsgi_call("GET", "/scan/nope/status") + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + def test_data_empty(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/data" % taskid) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["data"], []) + self.assertEqual(parsed["error"], []) + + def test_data_returns_stored_rows(self): + taskid = self._new_task() + # store a result row directly into the shared IPC db, then read it back via the route + api._storeData(self.db, taskid, "MySQL >= 5.0", CONTENT_STATUS.COMPLETE, CONTENT_TYPE.DBMS_FINGERPRINT) + code, parsed, _ = _wsgi_call("GET", "/scan/%s/data" % taskid) + self.assertTrue(parsed["success"]) + self.assertEqual(len(parsed["data"]), 1) + self.assertEqual(parsed["data"][0]["type_name"], "DBMS_FINGERPRINT") + self.assertEqual(parsed["data"][0]["value"], "MySQL >= 5.0") + + def test_data_invalid_task(self): + code, parsed, _ = _wsgi_call("GET", "/scan/nope/data") + self.assertFalse(parsed["success"]) + + def test_log_empty(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/log" % taskid) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["log"], []) + + def test_log_returns_stored_rows(self): + taskid = self._new_task() + self.db.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?, ?)", (taskid, "00:00:00", "INFO", "started")) + code, parsed, _ = _wsgi_call("GET", "/scan/%s/log" % taskid) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["log"], [{"time": "00:00:00", "level": "INFO", "message": "started"}]) + + def test_log_invalid_task(self): + code, parsed, _ = _wsgi_call("GET", "/scan/nope/log") + self.assertFalse(parsed["success"]) + + def test_log_limited_subset(self): + taskid = self._new_task() + for i in range(1, 4): + self.db.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?, ?)", (taskid, "00:00:0%d" % i, "INFO", "m%d" % i)) + code, parsed, _ = _wsgi_call("GET", "/scan/%s/log/1/2" % taskid) + self.assertTrue(parsed["success"]) + self.assertEqual([m["message"] for m in parsed["log"]], ["m1", "m2"]) + + def test_log_limited_bad_range(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/log/5/2" % taskid) + self.assertFalse(parsed["success"]) + self.assertIn("must be digits", parsed["message"]) + + def test_log_limited_invalid_task(self): + code, parsed, _ = _wsgi_call("GET", "/scan/nope/log/1/2") + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + def test_scan_stop_invalid_when_not_running(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/stop" % taskid) + self.assertFalse(parsed["success"]) + + def test_scan_kill_invalid_when_not_running(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/scan/%s/kill" % taskid) + self.assertFalse(parsed["success"]) + + +class TestScanStart(_ApiServerCase): + """scan_start, with the subprocess-spawning seam (engine_start) monkeypatched.""" + + def test_scan_start_invalid_task(self): + code, parsed, _ = _wsgi_call("POST", "/scan/nope/start", {}) + self.assertFalse(parsed["success"]) + self.assertEqual(parsed["message"], "Invalid task ID") + + def test_scan_start_rejects_unsupported_option(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("POST", "/scan/%s/start" % taskid, {"wizard": True}) + self.assertFalse(parsed["success"]) + self.assertIn("Unsupported option", parsed["message"]) + + def test_scan_start_launches_engine(self): + taskid = self._new_task() + task = api.DataStore.tasks[taskid] + + calls = {"started": False} + + class _FakeProc(object): + pid = 4242 + returncode = None + + def poll(self): + return None + + def terminate(self): + pass + + def kill(self): + pass + + def wait(self): + return 0 + + def fake_engine_start(): + calls["started"] = True + task.process = _FakeProc() + + original = task.engine_start + task.engine_start = fake_engine_start + try: + code, parsed, _ = _wsgi_call("POST", "/scan/%s/start" % taskid, {"url": "http://t/?id=1"}) + finally: + task.engine_start = original + + self.assertTrue(calls["started"]) + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["engineid"], 4242) + # the provided option was applied to the task + self.assertEqual(task.get_option("url"), "http://t/?id=1") + + +class TestAdminRoutes(_ApiServerCase): + def test_admin_list_with_token(self): + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/admin/%s/list" % api.DataStore.admin_token) + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + self.assertIn(taskid, parsed["tasks"]) + self.assertEqual(parsed["tasks_num"], 1) + + def test_admin_list_same_remote_addr_without_token(self): + # /admin/list (no token) sees only tasks from the requesting remote_addr + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/admin/list", remote_addr="127.0.0.1") + self.assertTrue(parsed["success"]) + self.assertIn(taskid, parsed["tasks"]) + + def test_admin_list_other_remote_addr_excluded(self): + self._new_task() # created from 127.0.0.1 + code, parsed, _ = _wsgi_call("GET", "/admin/list", remote_addr="10.9.9.9") + self.assertTrue(parsed["success"]) + self.assertEqual(parsed["tasks_num"], 0) + + def test_admin_flush_with_token(self): + self._new_task() + self._new_task() + self.assertEqual(len(api.DataStore.tasks), 2) + code, parsed, _ = _wsgi_call("GET", "/admin/%s/flush" % api.DataStore.admin_token) + self.assertTrue(parsed["success"]) + self.assertEqual(len(api.DataStore.tasks), 0) + + def test_admin_flush_only_own_remote_addr(self): + # task from .1, flush requested by .2 (no token) -> task survives + taskid = self._new_task() + code, parsed, _ = _wsgi_call("GET", "/admin/flush", remote_addr="10.0.0.2") + self.assertTrue(parsed["success"]) + self.assertIn(taskid, api.DataStore.tasks) + + +class TestAuthentication(_ApiServerCase): + """check_authentication before_request hook (HTTP Basic) when credentials are configured.""" + + def test_no_credentials_allows_access(self): + api.DataStore.username = None + api.DataStore.password = None + code, parsed, _ = _wsgi_call("GET", "/version") + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + + def test_missing_auth_header_denied(self): + api.DataStore.username = "user" + api.DataStore.password = "pass" + code, _, raw = _wsgi_call("GET", "/version") + self.assertEqual(code, 401) + + def test_wrong_credentials_denied(self): + api.DataStore.username = "user" + api.DataStore.password = "pass" + token = encodeBase64("user:wrong", binary=False) + code, _, raw = _wsgi_call("GET", "/version", headers={"Authorization": "Basic %s" % token}) + self.assertEqual(code, 401) + + def test_correct_credentials_allowed(self): + api.DataStore.username = "user" + api.DataStore.password = "pass" + token = encodeBase64("user:pass", binary=False) + code, parsed, _ = _wsgi_call("GET", "/version", headers={"Authorization": "Basic %s" % token}) + self.assertEqual(code, 200) + self.assertTrue(parsed["success"]) + + def test_malformed_basic_credentials_denied(self): + # base64 of a string without ':' separator -> denied + api.DataStore.username = "user" + api.DataStore.password = "pass" + token = encodeBase64("nocolon", binary=False) + code, _, _ = _wsgi_call("GET", "/version", headers={"Authorization": "Basic %s" % token}) + self.assertEqual(code, 401) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_checks.py b/tests/test_checks.py new file mode 100644 index 000000000..d0fe284c9 --- /dev/null +++ b/tests/test_checks.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Unit tests for lib/controller/checks.py driven with a MOCKED HTTP layer. + +checks.py is the injection-detection controller; almost everything in it goes +through the network seam (lib.request.connect.Connect, imported into the module +as `Request`). By monkeypatching `Request.queryPage` / `Request.getPage` to +return canned (page, headers/ratio, code) tuples - and stubbing `agent.payload` +where the real payload machinery would require a fully-built target - the +decision logic of each check (the kb.*/conf.*/return-value verdict) can be +exercised offline, without a live target, DBMS, or DNS. + +Every test snapshots and restores the conf/kb fields it touches AND every +module attribute it monkeypatches, so ordering between tests (and with the rest +of the suite) is irrelevant. conf.batch is forced on to avoid interactive +prompts, and readInput is stubbed per-test where a branch would prompt. +""" + +import os +import re +import sys +import time +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +import lib.controller.checks as checks +from lib.core.data import conf, kb +from lib.core.datatype import AttribDict, InjectionDict +from lib.core.dicts import FROM_DUMMY_TABLE +from lib.core.enums import DBMS +from lib.core.enums import HEURISTIC_TEST +from lib.core.enums import HTTP_HEADER +from lib.core.enums import HTTPMETHOD +from lib.core.enums import NULLCONNECTION +from lib.core.enums import PLACE +from lib.core.settings import SINGLE_QUOTE_MARKER +from lib.core.common import getCurrentThreadData +from lib.parse.html import htmlParser + + +# conf/kb fields any of the checks read or write; snapshotted wholesale so a +# test never leaks state into another test or the rest of the suite. +_CONF_KEYS = ( + "paramDict", "parameters", "url", "hostname", "method", "skipHeuristics", + "prefix", "suffix", "nosql", "graphql", "ldap", "beep", "string", + "notString", "regexp", "regex", "dummy", "offline", "skipWaf", "data", + "hashDB", "cj", "cookie", "dropSetCookie", "httpHeaders", "proxy", "tor", + "tamper", "timeout", "retries", "textOnly", "ignoreCode", "disablePrecon", + "ipv6", "multipleTargets", "level", "base64Parameter", "batch", +) +_KB_KEYS = ( + "heavilyDynamic", "dynamicParameter", "originalPage", "originalPageTime", + "originalCode", "ignoreCasted", "heuristicMode", "disableHtmlDecoding", + "heuristicTest", "heuristicPage", "heuristicCode", "pageStable", + "nullConnection", "pageCompress", "matchRatio", "skipSeqMatcher", + "choices", "injection", "errorIsNone", "serverHeader", "identifiedWafs", + "tamperFunctions", "resendPostOnRedirect", "checkWafMode", "wafBypass", + "heuristicExtendedDbms", "resumeValues", "mergeCookies", "httpErrorCodes", +) + + +def _snapshot(): + return ( + dict((k, conf.get(k)) for k in _CONF_KEYS), + dict((k, kb.get(k)) for k in _KB_KEYS), + ) + + +def _restore(snap): + confSnap, kbSnap = snap + for k, v in confSnap.items(): + conf[k] = v + for k, v in kbSnap.items(): + kb[k] = v + + +class _ChecksTestBase(unittest.TestCase): + """Snapshots conf/kb and the patchable seams; restores them in tearDown.""" + + def setUp(self): + self._snap = _snapshot() + # remember the real seams so monkeypatches can't leak. agent.payload / + # addPayloadDelimiters are class methods on a shared singleton: patching + # sets an *instance* attribute, so it's restored by deleting that + # attribute (reassigning would leave a stale bound method behind). + self._origQueryPage = checks.Request.queryPage + self._origGetPage = checks.Request.getPage + self._agentHadPayload = "payload" in checks.agent.__dict__ + self._agentHadAddDelims = "addPayloadDelimiters" in checks.agent.__dict__ + self._origReadInput = checks.readInput + self._origDbmsErr = checks.wasLastResponseDBMSError + self._origHttpErr = checks.wasLastResponseHTTPError + self._origCBE = checks.checkBooleanExpression + + # sane offline baseline shared by most checks + conf.batch = True + conf.skipHeuristics = False + conf.prefix = conf.suffix = None + conf.hashDB = None + conf.dummy = conf.offline = conf.proxy = conf.tor = None + kb.choices = AttribDict(keycheck=False) + + def tearDown(self): + checks.Request.queryPage = self._origQueryPage + checks.Request.getPage = self._origGetPage + if not self._agentHadPayload and "payload" in checks.agent.__dict__: + del checks.agent.payload + if not self._agentHadAddDelims and "addPayloadDelimiters" in checks.agent.__dict__: + del checks.agent.addPayloadDelimiters + checks.readInput = self._origReadInput + checks.wasLastResponseDBMSError = self._origDbmsErr + checks.wasLastResponseHTTPError = self._origHttpErr + checks.checkBooleanExpression = self._origCBE + _restore(self._snap) + + # --- helpers --- + + def _patchQueryPage(self, fn): + checks.Request.queryPage = staticmethod(fn) + + def _patchGetPage(self, fn): + checks.Request.getPage = staticmethod(fn) + + @staticmethod + def _contentQuery(page, code=200, headers=None): + """A queryPage that returns (page, headers/ratio, code) when content is + requested and a plain truthiness otherwise.""" + def _fn(*args, **kwargs): + if kwargs.get("content"): + return (page, headers, code) + return bool(page) + return _fn + + @staticmethod + def _detectingContentQuery(page, code=200, headers=None): + """Like _contentQuery, but mirrors the real connection layer's + error-detection seam: it advances the request UID and runs the REAL + htmlParser() over the page (exactly as Connect.getPage() does), so the + page is classified by sqlmap's genuine error regexes. The unstubbed + wasLastResponseDBMSError() then reads the threadData.lastErrorPage this + leaves behind - the heuristic verdict is the detector's, not the stub's.""" + def _fn(*args, **kwargs): + threadData = getCurrentThreadData() + kb.requestCounter = (kb.get("requestCounter") or 0) + 1 + threadData.lastRequestUID = kb.requestCounter + htmlParser(page or "") + if kwargs.get("content"): + return (page, headers, code) + return bool(page) + return _fn + + @staticmethod + def _comparingQuery(page, code=200, headers=None): + """A queryPage that, for a non-content request, runs the REAL + comparison() engine of the injected page against kb.pageTemplate (the + same call Connect.queryPage makes for its True/False verdict). The + matchRatio/seqMatcher dynamicity logic therefore actually executes - + the verdict is computed, not hard-coded.""" + def _fn(*args, **kwargs): + if kwargs.get("content"): + return (page, headers, code) + return checks.comparison(page, headers, code, getRatioValue=False) + return _fn + + +class TestHeuristicCheckSqlInjection(_ChecksTestBase): + def setUp(self): + super(TestHeuristicCheckSqlInjection, self).setUp() + conf.paramDict = {PLACE.GET: {"id": "1"}} + conf.parameters = {PLACE.GET: "id=1"} + conf.url = "http://test.invalid/index.php?id=1" + conf.method = None + conf.nosql = conf.graphql = conf.ldap = False + conf.beep = False + kb.heavilyDynamic = False + kb.dynamicParameter = False + kb.originalPage = "" + kb.ignoreCasted = False + # clear any error-page marker left by an earlier request so the real + # wasLastResponseDBMSError() starts from a clean slate + td = getCurrentThreadData() + td.lastErrorPage = tuple() + td.lastRequestUID = 0 + # bypass the full payload-building machinery (needs a built target) + checks.agent.payload = lambda *a, **kw: "PAYLOAD" + + def test_skip_heuristics_returns_none(self): + conf.skipHeuristics = True + self.assertIsNone(checks.heuristicCheckSqlInjection(PLACE.GET, "id")) + + def test_positive_on_dbms_error(self): + # Feed a GENUINE MySQL error page (matches sqlmap's real error regex in + # data/xml/errors.xml) through the detecting stub and let the UNSTUBBED + # wasLastResponseDBMSError() classify it. The POSITIVE verdict is then + # the real detector's, not a hard-coded True. + page = ("You have an error in your SQL syntax; check the " + "manual that corresponds to your MySQL server version") + self._patchQueryPage(self._detectingContentQuery(page)) + result = checks.heuristicCheckSqlInjection(PLACE.GET, "id") + self.assertEqual(result, HEURISTIC_TEST.POSITIVE) + self.assertEqual(kb.heuristicTest, HEURISTIC_TEST.POSITIVE) + + def test_negative_on_clean_page(self): + # A clean page matches none of sqlmap's error regexes, so the unstubbed + # wasLastResponseDBMSError() returns false -> NEGATIVE verdict. + self._patchQueryPage(self._detectingContentQuery("a perfectly ordinary page")) + result = checks.heuristicCheckSqlInjection(PLACE.GET, "id") + self.assertEqual(result, HEURISTIC_TEST.NEGATIVE) + self.assertEqual(kb.heuristicTest, HEURISTIC_TEST.NEGATIVE) + + def test_records_page_and_resets_mode(self): + self._patchQueryPage(self._detectingContentQuery("nothing special here")) + checks.heuristicCheckSqlInjection(PLACE.GET, "id") + # mode flags must be flipped back off after the check + self.assertFalse(kb.heuristicMode) + self.assertFalse(kb.disableHtmlDecoding) + + +class TestHeuristicCheckDbms(_ChecksTestBase): + def setUp(self): + super(TestHeuristicCheckDbms, self).setUp() + kb.injection = InjectionDict() + + def test_skip_heuristics_returns_false(self): + conf.skipHeuristics = True + self.assertFalse(checks.heuristicCheckDbms(InjectionDict())) + + def test_no_match_when_all_expressions_false(self): + checks.checkBooleanExpression = lambda expr: False + self.assertFalse(checks.heuristicCheckDbms(InjectionDict())) + + def test_identifies_dbms_on_distinguishing_pair(self): + # An expr-AWARE oracle that recognises ONLY the predicate + # heuristicCheckDbms() builds for one CHOSEN target DBMS. The function + # iterates every DBMS, forging for each the pair + # positive: (SELECT '')= -> must be True + # negative: (SELECT '')= -> must be False + # ( == SINGLE_QUOTE_MARKER, r1 != r2). The DBMS is reported only when + # the positive holds AND the negative fails. The oracle below returns + # True exactly for that shape - it keys off the chosen DBMS's UNIQUE + # FROM clause (so no other DBMS's predicate matches) and off the two + # quoted literals being equal (so the "must differ" negative is False). + # Firebird is chosen because its FROM clause (' FROM RDB$DATABASE') is + # unique in FROM_DUMMY_TABLE and it is not a HEURISTIC_NULL_EVAL DBMS, + # so heuristicCheckDbms() takes the SELECT-literal predicate path for it. + target = DBMS.FIREBIRD + targetFrom = FROM_DUMMY_TABLE[target] + predicate = re.compile( + r"\(SELECT '([^']*)'( FROM [^)]*)?\)=" + + re.escape(SINGLE_QUOTE_MARKER) + r"(.*?)" + re.escape(SINGLE_QUOTE_MARKER) + ) + + def oracle(expr): + match = predicate.search(expr) + if not match: + return False + selected, fromClause, compared = match.group(1), match.group(2) or "", match.group(3) + # True only for the target DBMS's FROM clause with matching literals + return fromClause == targetFrom and selected == compared + + checks.checkBooleanExpression = oracle + result = checks.heuristicCheckDbms(InjectionDict()) + # real predicate matching must single out the chosen DBMS, not whatever + # getPublicTypeMembers() happens to yield first + self.assertEqual(result, target) + self.assertEqual(kb.heuristicExtendedDbms, target) + + +class TestCheckDynParam(_ChecksTestBase): + # A stable baseline page that checkDynParam's injected response is compared + # against by the REAL comparison() engine. Long enough that difflib's + # quick_ratio is meaningful rather than degenerate. + _BASELINE = ("Welcome" + + "the quick brown fox jumps over the lazy dog. " * 20 + + "") + + def setUp(self): + super(TestCheckDynParam, self).setUp() + conf.method = None + checks.agent.payload = lambda *a, **kw: "PAYLOAD" + # state the real comparison() engine reads + conf.string = conf.notString = conf.regexp = conf.code = None + conf.titles = conf.textOnly = False + kb.nullConnection = False + kb.heavilyDynamic = False + kb.skipSeqMatcher = False + kb.errorIsNone = False + kb.negativeLogic = False + kb.pageCompress = False + kb.matchRatio = None + kb.pageTemplate = self._BASELINE + + def test_redirect_short_circuits(self): + kb.choices.redirect = "yes" + self.assertIsNone(checks.checkDynParam(PLACE.GET, "id", "1")) + + def test_dynamic_when_page_differs(self): + # A response wildly different from the baseline drives the real + # comparison() ratio below LOWER_RATIO_BOUND -> queryPage returns False + # (page differs) -> parameter is dynamic. + self._patchQueryPage(self._comparingQuery("totally unrelated content " + "Z" * 200)) + result = checks.checkDynParam(PLACE.GET, "id", "1") + self.assertTrue(result) + self.assertTrue(kb.dynamicParameter) + + def test_not_dynamic_when_page_same(self): + # An identical response yields ratio 1.0 (> UPPER_RATIO_BOUND) from the + # real comparison() -> queryPage returns True (page same) -> not dynamic. + self._patchQueryPage(self._comparingQuery(self._BASELINE)) + result = checks.checkDynParam(PLACE.GET, "id", "1") + self.assertFalse(result) + self.assertFalse(kb.dynamicParameter) + + +class TestCheckDynamicContent(_ChecksTestBase): + def setUp(self): + super(TestCheckDynamicContent, self).setUp() + kb.nullConnection = False + + def test_null_connection_skips(self): + kb.nullConnection = NULLCONNECTION.HEAD + self.assertIsNone(checks.checkDynamicContent("a", "b")) + + def test_missing_page_aborts(self): + self.assertIsNone(checks.checkDynamicContent(None, "x")) + + def test_identical_pages_no_dynamicity(self): + # high ratio -> no dynamic-content engine, no further requests + self._patchQueryPage(lambda *a, **kw: self.fail("should not request")) + self.assertIsNone(checks.checkDynamicContent("identical content", "identical content")) + + +class TestCheckStability(_ChecksTestBase): + def setUp(self): + super(TestCheckStability, self).setUp() + kb.originalPageTime = time.time() + kb.nullConnection = False + + def test_stable_when_pages_match(self): + kb.originalPage = "SAME PAGE" + self._patchQueryPage(self._contentQuery("SAME PAGE")) + self.assertTrue(checks.checkStability()) + self.assertTrue(kb.pageStable) + + def test_redirect_returns_none(self): + kb.originalPage = "SAME PAGE" + self._patchQueryPage(self._contentQuery("SAME PAGE")) + kb.choices.redirect = "yes" + self.assertIsNone(checks.checkStability()) + + def test_unstable_continue_choice(self): + kb.originalPage = "FIRST PAGE CONTENT" + conf.retries = 0 + kb.heavilyDynamic = False + checks.readInput = lambda *a, **kw: "C" + + def _q(*a, **kw): + if kw.get("content"): + return ("SECOND DIFFERENT PAGE", None, 200) + return True # keeps checkDynamicContent's retry loop from firing + self._patchQueryPage(_q) + + result = checks.checkStability() + self.assertFalse(result) + self.assertFalse(kb.pageStable) + + def test_unstable_string_choice_sets_conf_string(self): + kb.originalPage = "FIRST" + self._patchQueryPage(self._contentQuery("SECOND")) + replies = iter(["S", "MATCHME"]) + checks.readInput = lambda *a, **kw: next(replies) + checks.checkStability() + self.assertEqual(conf.string, "MATCHME") + + +class TestCheckNullConnection(_ChecksTestBase): + def setUp(self): + super(TestCheckNullConnection, self).setUp() + conf.data = None + kb.pageCompress = False + kb.nullConnection = None + + def test_post_data_disables_null_connection(self): + conf.data = "a=b" + self.assertFalse(checks.checkNullConnection()) + + def test_head_content_length(self): + def _getPage(*a, **kw): + if kw.get("method") == HTTPMETHOD.HEAD: + return ("", {HTTP_HEADER.CONTENT_LENGTH: "1234"}, 200) + return ("x", {}, 200) + self._patchGetPage(_getPage) + self.assertTrue(checks.checkNullConnection()) + self.assertEqual(kb.nullConnection, NULLCONNECTION.HEAD) + + def test_range_content_range(self): + def _getPage(*a, **kw): + if kw.get("method") == HTTPMETHOD.HEAD: + return ("", {}, 200) # no Content-Length on HEAD + if kw.get("auxHeaders"): + return ("A", {HTTP_HEADER.CONTENT_RANGE: "bytes 0-0/100"}, 206) + return ("x", {}, 200) + self._patchGetPage(_getPage) + self.assertTrue(checks.checkNullConnection()) + self.assertEqual(kb.nullConnection, NULLCONNECTION.RANGE) + + def test_not_supported(self): + # nothing usable on any method -> nullConnection ends up False + self._patchGetPage(lambda *a, **kw: ("xx", {}, 200)) + self.assertFalse(checks.checkNullConnection()) + self.assertFalse(kb.nullConnection) + + +class TestCheckConnection(_ChecksTestBase): + def setUp(self): + super(TestCheckConnection, self).setUp() + conf.hostname = "1.2.3.4" # dotted-quad -> no DNS resolution + conf.string = conf.regexp = None + conf.cj = None + conf.ignoreCode = None + kb.httpErrorCodes = {} + checks.wasLastResponseHTTPError = lambda: False + checks.wasLastResponseDBMSError = lambda: False + td = getCurrentThreadData() + td.lastPage = "PAGE CONTENT" + td.lastCode = 200 + + class _Headers(object): + headers = "Server: test\r\n" + + def test_success_sets_error_is_none(self): + self._patchQueryPage(lambda *a, **kw: ("PAGE CONTENT", self._Headers(), 200)) + self.assertTrue(checks.checkConnection()) + self.assertTrue(kb.errorIsNone) + self.assertEqual(kb.originalPage, "PAGE CONTENT") + + def test_dbms_error_clears_error_is_none(self): + self._patchQueryPage(lambda *a, **kw: ("oops SQL error", self._Headers(), 200)) + checks.wasLastResponseDBMSError = lambda: True + self.assertTrue(checks.checkConnection()) + self.assertFalse(kb.errorIsNone) + + def test_string_not_in_response_still_continues(self): + conf.string = "NEEDLE-NOT-PRESENT" + self._patchQueryPage(lambda *a, **kw: ("haystack only", self._Headers(), 200)) + # warns but carries on (returns True) + self.assertTrue(checks.checkConnection()) + + +class TestCheckWaf(_ChecksTestBase): + def setUp(self): + super(TestCheckWaf, self).setUp() + conf.string = conf.notString = conf.regexp = None + conf.dummy = conf.offline = conf.skipWaf = None + kb.originalCode = 200 + kb.originalPage = "page" + conf.parameters = {PLACE.GET: "id=1"} + kb.resendPostOnRedirect = False + conf.timeout = 30 + kb.identifiedWafs = [] + conf.tamper = None + kb.tamperFunctions = [] + checks.agent.addPayloadDelimiters = lambda v: v + + def test_skips_when_string_set(self): + conf.string = "x" + self.assertIsNone(checks.checkWaf()) + + def test_not_detected_on_high_ratio(self): + # queryPage()[1] is the ratio; high ratio -> not blocked + self._patchQueryPage(lambda *a, **kw: ("ok", 0.9, 200)) + self.assertFalse(checks.checkWaf()) + + def test_detected_on_low_ratio(self): + self._patchQueryPage(lambda *a, **kw: ("blocked", 0.1, 403)) + checks.readInput = lambda *a, **kw: True # continue + accept bypass + import lib.utils.wafbypass as wafbypass + orig = wafbypass.neutralizeFingerprint + wafbypass.neutralizeFingerprint = lambda: None + try: + self.assertTrue(checks.checkWaf()) + finally: + wafbypass.neutralizeFingerprint = orig + + +class TestCheckInternet(_ChecksTestBase): + def test_internet_available(self): + self._patchGetPage(lambda *a, **kw: ("ok", None, checks.CHECK_INTERNET_CODE)) + self.assertTrue(checks.checkInternet()) + + def test_internet_unavailable(self): + self._patchGetPage(lambda *a, **kw: ("captive portal", None, 500)) + self.assertFalse(checks.checkInternet()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_common_parsers.py b/tests/test_common_parsers.py new file mode 100644 index 000000000..4c2882990 --- /dev/null +++ b/tests/test_common_parsers.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Pure / near-pure parsers and state helpers in lib/core/common.py that are NOT +already exercised by tests/test_common_utils.py. + +Covered here: + * proxy-log parsers reached through parseRequestFile() + (_parseBurpLog plain log, _parseBurpLog Burp XML history, _parseWebScarabLog) + * parseTargetDirect() non-smoke branch (driver resolution for SQLite) + * removeReflectiveValues() reflected-payload masking + * findPageForms() HTML
and inline JS POST discovery + * saveConfig() .ini serialization + * getSQLSnippet() proc-file loading + variable substitution + * checkSystemEncoding() (no-op on a normal default encoding) + * Format.getOs() fingerprint humanizer + * Backend setters/getters (setOs/getOs, setOsVersion, setOsServicePack, + setVersion/getVersion/setVersionList) + * urlencode() extra branches (LIKE percent-encoding, convall, limit, direct) + * safeStringFormat() extra branches (PAYLOAD_DELIMITER region, scalar percent) + +Everything is run in isolation (no network, no DBMS). Any function that +reads/writes global conf/kb/Backend state has that state saved and restored +around the call so test ordering stays irrelevant. Temp files go to the +session scratchpad and are removed. +""" + +import os +import sys +import base64 +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.core.common import ( + parseRequestFile, + parseTargetDirect, + removeReflectiveValues, + findPageForms, + saveConfig, + getSQLSnippet, + checkSystemEncoding, + urlencode, + safeStringFormat, + Format, + Backend, +) +from lib.core.data import kb, conf +from lib.core.enums import DBMS, HTTPMETHOD +from lib.core.settings import REFLECTED_VALUE_MARKER, PAYLOAD_DELIMITER + +SCRATCH = "/tmp/claude-1000/-tmp-tmp-oUnlQJzlQN/fcd55d25-6313-49ed-817e-dcbe7fc2bf22/scratchpad" + + +def _write_temp(content, suffix): + """Write `content` (str) to a scratchpad temp file, return its path.""" + if not os.path.isdir(SCRATCH): + os.makedirs(SCRATCH) + handle, path = tempfile.mkstemp(suffix=suffix, dir=SCRATCH) + os.write(handle, content.encode("utf-8") if isinstance(content, str) else content) + os.close(handle) + return path + + +class TestParseRequestFileBurp(unittest.TestCase): + """_parseBurpLog via parseRequestFile (plain '=====' log + Burp XML history).""" + + def setUp(self): + self._scope = conf.scope + self._method = conf.method + self._headers = conf.headers + conf.scope = None + + def tearDown(self): + conf.scope = self._scope + conf.method = self._method + conf.headers = self._headers + + def test_plain_burp_log_get(self): + content = ( + "======================================================\n" + "GET http://www.target.com:80/vuln.php?id=1 HTTP/1.1\n" + "Host: www.target.com\n" + "Cookie: PHPSESSID=abc\n" + "======================================================\n" + ) + path = _write_temp(content, ".log") + try: + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + + self.assertEqual(len(targets), 1) + url, method, data, cookie, headers = targets[0] + self.assertEqual(url, "http://www.target.com:80/vuln.php?id=1") + self.assertEqual(method, HTTPMETHOD.GET) + self.assertIsNone(data) + self.assertEqual(cookie, "PHPSESSID=abc") + self.assertIn(("Host", "www.target.com"), headers) + + def test_burp_xml_history_base64_request(self): + req = "GET /vuln.php?id=1 HTTP/1.1\r\nHost: www.target.com\r\nCookie: SID=xyz\r\n\r\n" + b64 = base64.b64encode(req.encode()).decode() + xml = ('80' + '' + '' % b64) + path = _write_temp(xml, ".xml") + try: + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + + self.assertEqual(len(targets), 1) + url, method, data, cookie, headers = targets[0] + self.assertEqual(url, "http://www.target.com:80/vuln.php?id=1") + self.assertEqual(method, HTTPMETHOD.GET) + self.assertEqual(cookie, "SID=xyz") + + def test_post_body_captured(self): + content = ( + "======================================================\n" + "POST http://www.target.com:80/login HTTP/1.1\n" + "Host: www.target.com\n" + "Content-Length: 17\n" + "\n" + "user=admin&pw=1\n" + "======================================================\n" + ) + path = _write_temp(content, ".log") + try: + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + + self.assertEqual(len(targets), 1) + url, method, data, cookie, headers = targets[0] + self.assertEqual(method, HTTPMETHOD.POST) + self.assertEqual(data, "user=admin&pw=1") + + def test_scope_filters_out_nonmatching(self): + content = ( + "======================================================\n" + "GET http://www.target.com:80/vuln.php?id=1 HTTP/1.1\n" + "Host: www.target.com\n" + "======================================================\n" + ) + path = _write_temp(content, ".log") + try: + conf.scope = r"example\.org" # does not match target.com + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + self.assertEqual(targets, []) + + +class TestParseRequestFileWebScarab(unittest.TestCase): + """_parseWebScarabLog via parseRequestFile.""" + + def setUp(self): + self._scope = conf.scope + conf.scope = None + + def tearDown(self): + conf.scope = self._scope + + def test_get_conversation(self): + content = ( + "### Conversation : 1\n" + "URL: http://www.target.com/vuln.php?id=1\n" + "METHOD: GET\n" + "COOKIE: SID=abc\n" + ) + path = _write_temp(content, ".log") + try: + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + + self.assertEqual(len(targets), 1) + url, method, data, cookie, headers = targets[0] + self.assertEqual(url, "http://www.target.com/vuln.php?id=1") + self.assertEqual(method, "GET") + self.assertIsNone(data) + self.assertEqual(cookie, "SID=abc") + self.assertEqual(headers, tuple()) + + def test_post_conversation_skipped(self): + # POST bodies live in separate files -> WebScarab POSTs are skipped + content = ( + "### Conversation : 1\n" + "URL: http://www.target.com/login\n" + "METHOD: POST\n" + ) + path = _write_temp(content, ".log") + try: + targets = list(parseRequestFile(path)) + finally: + os.unlink(path) + self.assertEqual(targets, []) + + +class TestParseTargetDirectNonSmoke(unittest.TestCase): + """parseTargetDirect() non-smoke branch: resolves the canonical DBMS name. + + Uses SQLite because its driver (stdlib sqlite3) is always importable. + """ + + _KEYS = ("direct", "dbms", "dbmsUser", "dbmsPass", "dbmsDb", "hostname", "port") + + def setUp(self): + self._saved = {k: conf.get(k) for k in self._KEYS} + self._smoke = kb.smokeMode + self._params_none = conf.parameters.get(None) + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + kb.smokeMode = self._smoke + if self._params_none is None: + conf.parameters.pop(None, None) + else: + conf.parameters[None] = self._params_none + + def test_sqlite_local_dsn(self): + kb.smokeMode = False + conf.direct = "sqlite://%s" % os.path.join(SCRATCH, "test.db") + parseTargetDirect() + # non-smoke path canonicalizes the DBMS name via DBMS_DICT + self.assertEqual(conf.dbms, DBMS.SQLITE) + # local file DBMS: hostname forced to localhost, port 0 + self.assertEqual(conf.hostname, "localhost") + self.assertEqual(conf.port, 0) + self.assertEqual(conf.parameters[None], "direct connection") + + +class TestRemoveReflectiveValues(unittest.TestCase): + def setUp(self): + self._mech = kb.reflectiveMechanism + self._heur = kb.heuristicMode + kb.reflectiveMechanism = True + kb.heuristicMode = False + + def tearDown(self): + kb.reflectiveMechanism = self._mech + kb.heuristicMode = self._heur + + def test_reflected_payload_masked(self): + content = u"You searched for 1 AND 1=2 here" + out = removeReflectiveValues(content, "1 AND 1=2") + self.assertIn(REFLECTED_VALUE_MARKER, out) + self.assertNotIn("AND 1=2", out) + + def test_no_reflection_returns_content_unchanged(self): + content = u"nothing interesting" + out = removeReflectiveValues(content, "1 AND 1=2") + self.assertEqual(out, content) + + def test_none_payload_returns_content(self): + content = u"x" + self.assertEqual(removeReflectiveValues(content, None), content) + + def test_bytes_content_returned_as_is(self): + # non-text content short-circuits (isinstance text_type check) + content = b"1 AND 1=2" + self.assertEqual(removeReflectiveValues(content, "1 AND 1=2"), content) + + +class TestFindPageForms(unittest.TestCase): + def setUp(self): + self._scope = conf.scope + self._crawlExclude = conf.crawlExclude + self._cookie = conf.cookie + conf.scope = None + conf.crawlExclude = None + conf.cookie = None + + def tearDown(self): + conf.scope = self._scope + conf.crawlExclude = self._crawlExclude + conf.cookie = self._cookie + + def test_post_form_discovered(self): + html = ('' + '' + '
') + forms = findPageForms(html, "http://www.site.com") + self.assertEqual(forms, set([("http://www.site.com/input.php", "POST", "id=1", None, None)])) + + def test_get_form_discovered(self): + html = ('
' + '' + '
') + forms = findPageForms(html, "http://www.site.com") + self.assertEqual(len(forms), 1) + url, method, data, _cookie, _ = list(forms)[0] + self.assertEqual(method, "GET") + self.assertIn("q=x", url) + + def test_inline_js_post_discovered(self): + # the `.post('url', {k: v})` regex branch (independent of HTML form parsing) + html = "" + forms = findPageForms(html, "http://www.site.com") + self.assertTrue(any(m == HTTPMETHOD.POST and u.endswith("/api/save") for (u, m, d, c, e) in forms)) + + def test_blank_content_returns_empty_set(self): + self.assertEqual(findPageForms("", "http://www.site.com"), set()) + + +class TestSaveConfig(unittest.TestCase): + def test_writes_ini_with_sections(self): + path = _write_temp("", ".ini") + try: + saveConfig(conf, path) + with open(path) as f: + data = f.read() + finally: + os.unlink(path) + + # optDict families become [Section] headers + self.assertIn("[Target]", data) + self.assertIn("[Request]", data) + self.assertIn("[Enumeration]", data) + self.assertTrue(len(data) > 0) + + +class TestGetSQLSnippet(unittest.TestCase): + def test_mssql_proc_loaded(self): + snippet = getSQLSnippet(DBMS.MSSQL, "activate_sp_oacreate") + self.assertIn("RECONFIGURE", snippet) + + def test_variable_substitution(self): + # %VAR% placeholders are substituted from kwargs (here %ENABLE%); + # supplying it avoids the interactive "provide substitution values" prompt. + snippet = getSQLSnippet(DBMS.MSSQL, "configure_xp_cmdshell", ENABLE="1") + self.assertIn("xp_cmdshell", snippet) + self.assertIn("RECONFIGURE", snippet) + # comments (#...) are stripped and the placeholder is fully resolved + self.assertNotIn("#", snippet) + self.assertNotIn("%ENABLE%", snippet) + + +class TestCheckSystemEncoding(unittest.TestCase): + def test_noop_on_normal_encoding(self): + # On a normal default encoding this is a no-op and must not raise. + self.assertIsNone(checkSystemEncoding()) + + +class TestFormatGetOs(unittest.TestCase): + def setUp(self): + self._api = conf.api + conf.api = False + + def tearDown(self): + conf.api = self._api + + def test_humanizes_type_and_technology(self): + info = { + "type": set(["Linux"]), + "distrib": set(["Ubuntu"]), + "release": set(["8.10"]), + "technology": set(["PHP 5.2.6", "Apache 2.2.9"]), + } + out = Format.getOs("back-end DBMS", info) + self.assertTrue(out.startswith("back-end DBMS operating system: Linux")) + self.assertIn("Ubuntu", out) + self.assertIn("8.10", out) + self.assertIn("web application technology:", out) + + def test_api_mode_returns_dict(self): + orig = conf.api + try: + conf.api = True + info = {"type": set(["Windows"]), "technology": set(["IIS"])} + out = Format.getOs("back-end DBMS", info) + self.assertIsInstance(out, dict) + self.assertIn("web application technology", out) + finally: + conf.api = orig + + +class TestBackendSetters(unittest.TestCase): + """Backend OS/version setters write kb state; save and restore it.""" + + _KEYS = ("os", "osVersion", "osSP", "dbmsVersion") + + def setUp(self): + self._saved = {k: kb.get(k) for k in self._KEYS} + + def tearDown(self): + for k, v in self._saved.items(): + kb[k] = v + + def test_set_get_os(self): + kb.os = None + self.assertEqual(Backend.setOs("windows"), "Windows") # capitalized + self.assertEqual(Backend.getOs(), "Windows") + + def test_set_os_none_returns_none(self): + self.assertIsNone(Backend.setOs(None)) + + def test_set_os_version(self): + kb.osVersion = None + Backend.setOsVersion("2008") + self.assertEqual(Backend.getOsVersion(), "2008") + + def test_set_os_service_pack(self): + kb.osSP = None + Backend.setOsServicePack(3) + self.assertEqual(Backend.getOsServicePack(), 3) + + def test_set_get_version(self): + kb.dbmsVersion = [] + self.assertEqual(Backend.setVersion("5.7"), ["5.7"]) + self.assertEqual(Backend.getVersion(), "5.7") + + def test_set_version_list(self): + kb.dbmsVersion = [] + Backend.setVersionList(["8.0", "8.1"]) + self.assertEqual(Backend.getVersionList(), ["8.0", "8.1"]) + + +class TestUrlencodeExtraBranches(unittest.TestCase): + def test_like_percent_encoded(self): + # '%' inside a LIKE '...' literal is encoded to %25 + self.assertEqual(urlencode("AND name LIKE '%DBA%'"), + "AND%20name%20LIKE%20%27%25DBA%25%27") + + def test_convall_drops_safe_set(self): + self.assertEqual(urlencode("a&b", convall=True), "a%26b") + + def test_limit_does_not_crash_on_long_input(self): + out = urlencode("x " * 4000, limit=True) + self.assertTrue(len(out) > 0) + + def test_direct_mode_returns_value_unchanged(self): + orig = conf.direct + try: + conf.direct = "mysql://u:p@h:3306/d" + self.assertEqual(urlencode("a b"), "a b") + finally: + conf.direct = orig + + +class TestSafeStringFormatExtraBranches(unittest.TestCase): + def test_percent_d_in_payload_region_becomes_string(self): + fmt = "SELECT %s" + PAYLOAD_DELIMITER + " AND %d " + PAYLOAD_DELIMITER + self.assertEqual( + safeStringFormat(fmt, ("a", "5")), + "SELECT a" + PAYLOAD_DELIMITER + " AND 5 " + PAYLOAD_DELIMITER) + + def test_scalar_string_percent_preserved(self): + # single-string param path: plain replace, embedded '%' survives + self.assertEqual(safeStringFormat("LIKE %s", "100%done"), "LIKE 100%done") + + def test_two_params_list(self): + self.assertEqual(safeStringFormat("%s/%s", ("a", "b")), "a/b") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_common_utils.py b/tests/test_common_utils.py new file mode 100644 index 000000000..9faa815f7 --- /dev/null +++ b/tests/test_common_utils.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Pure / near-pure helpers in lib/core/common.py. + +These cover the request/parameter parsing, charset construction, limit-range +generation, safe string formatting, URL encoding, UNION page parsing, target +URL/direct-connection parsing and SQL identifier quoting. They are exercised +in isolation (no network, no DBMS, no filesystem mutation); any function that +reads/writes global conf/kb state has that state saved and restored around the +call so test ordering stays irrelevant. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.common import ( + paramToDict, + getCharset, + getLimitRange, + parseUnionPage, + safeStringFormat, + urlencode, + parseTargetUrl, + parseTargetDirect, + safeSQLIdentificatorNaming, + getPartRun, + getText, +) +from lib.core.data import kb, conf +from lib.core.enums import PLACE, CHARSET_TYPE, DBMS + + +class TestParamToDict(unittest.TestCase): + """Parameter string -> OrderedDict for the various injection places.""" + + def test_get_two_params(self): + result = paramToDict(PLACE.GET, "id=1&name=foo") + self.assertEqual(list(result.items()), [("id", "1"), ("name", "foo")]) + + def test_get_preserves_order(self): + result = paramToDict(PLACE.GET, "c=3&a=1&b=2") + self.assertEqual(list(result.keys()), ["c", "a", "b"]) + + def test_post_place(self): + result = paramToDict(PLACE.POST, "user=admin&pass=secret") + self.assertEqual(result["user"], "admin") + self.assertEqual(result["pass"], "secret") + + def test_empty_value(self): + result = paramToDict(PLACE.GET, "id=&name=x") + self.assertEqual(result["id"], "") + self.assertEqual(result["name"], "x") + + def test_value_with_equal_signs(self): + # value is re-joined on '=' so embedded '=' survives + result = paramToDict(PLACE.GET, "token=a=b=c") + self.assertEqual(result["token"], "a=b=c") + + def test_cookie_delimiter(self): + # COOKIE place splits on ';' rather than '&' + result = paramToDict(PLACE.COOKIE, "foo=bar;baz=qux") + self.assertEqual(list(result.items()), [("foo", "bar"), ("baz", "qux")]) + + def test_param_without_equals_ignored(self): + # an element with no '=' has len(parts) < 2 and is skipped + result = paramToDict(PLACE.GET, "lonely&id=1") + self.assertEqual(list(result.items()), [("id", "1")]) + + +class TestGetCharset(unittest.TestCase): + """Inference charsets are fixed integer tables.""" + + def test_binary(self): + self.assertEqual(getCharset(CHARSET_TYPE.BINARY), [0, 1, 47, 48, 49]) + + def test_default_is_full_ascii(self): + self.assertEqual(getCharset(None), list(range(0, 128))) + + def test_digits(self): + result = getCharset(CHARSET_TYPE.DIGITS) + self.assertEqual(result, list(range(0, 10)) + list(range(47, 58))) + + def test_alpha_has_no_digits(self): + result = getCharset(CHARSET_TYPE.ALPHA) + # ASCII codes for '0'..'9' are 48..57; ALPHA must exclude them + self.assertFalse(any(48 <= _ <= 57 for _ in result)) + self.assertIn(ord("A"), result) + self.assertIn(ord("z"), result) + + def test_alphanum_superset_of_alpha(self): + alpha = set(getCharset(CHARSET_TYPE.ALPHA)) + alphanum = set(getCharset(CHARSET_TYPE.ALPHANUM)) + self.assertTrue(alpha.issubset(alphanum)) + self.assertIn(ord("5"), alphanum) + + def test_hexadecimal_contains_hex_letters(self): + result = getCharset(CHARSET_TYPE.HEXADECIMAL) + for ch in "0123456789abcdefABCDEF": + self.assertIn(ord(ch), result, msg="missing %r" % ch) + + +class TestGetLimitRange(unittest.TestCase): + def test_basic(self): + self.assertEqual(list(getLimitRange(10)), list(range(0, 10))) + + def test_plus_one(self): + self.assertEqual(list(getLimitRange(3, plusOne=True)), [1, 2, 3]) + + def test_string_count_coerced(self): + # count is int()-coerced internally + self.assertEqual(list(getLimitRange("4")), [0, 1, 2, 3]) + + def test_length(self): + self.assertEqual(len(getLimitRange(7)), 7) + + +class TestParseUnionPage(unittest.TestCase): + def test_none(self): + self.assertIsNone(parseUnionPage(None)) + + def test_two_entries(self): + page = "%sfoo%s%sbar%s" % (kb.chars.start, kb.chars.stop, kb.chars.start, kb.chars.stop) + # returns a BigArray; compare element-wise + self.assertEqual(list(parseUnionPage(page)), ["foo", "bar"]) + + def test_single_entry_unwrapped(self): + # a lone wrapped string is returned as the bare string, not a 1-element list + page = "%shello%s" % (kb.chars.start, kb.chars.stop) + self.assertEqual(parseUnionPage(page), "hello") + + def test_multi_column_row(self): + # a single row whose values are joined by kb.chars.delimiter becomes one + # nested list entry + page = "%sa%sb%s" % (kb.chars.start, kb.chars.delimiter, kb.chars.stop) + self.assertEqual(list(parseUnionPage(page)), [["a", "b"]]) + + def test_unmarked_page_returned_verbatim(self): + self.assertEqual(parseUnionPage("no markers here"), "no markers here") + + +class TestSafeStringFormat(unittest.TestCase): + def test_basic_tuple(self): + self.assertEqual(safeStringFormat("SELECT foo FROM %s LIMIT %d", ("bar", "1")), + "SELECT foo FROM bar LIMIT 1") + + def test_literal_percent_preserved(self): + self.assertEqual( + safeStringFormat("SELECT foo FROM %s WHERE name LIKE '%susan%' LIMIT %d", ("bar", "1")), + "SELECT foo FROM bar WHERE name LIKE '%susan%' LIMIT 1") + + def test_single_string_param(self): + self.assertEqual(safeStringFormat("a %s b", "X"), "a X b") + + def test_scalar_non_string(self): + self.assertEqual(safeStringFormat("n=%d", 5), "n=5") + + +class TestUrlencode(unittest.TestCase): + def test_basic(self): + self.assertEqual(urlencode("AND 1>(2+3)#"), "AND%201%3E%282%2B3%29%23") + + def test_none(self): + self.assertIsNone(urlencode(None)) + + def test_spaceplus(self): + self.assertEqual(urlencode("a b", spaceplus=True), "a+b") + + def test_convall_encodes_safe_chars(self): + # with convall the explicit 'safe' set is dropped, so '/' gets encoded + self.assertEqual(urlencode("a/b", convall=True), "a%2Fb") + + def test_safe_char_default_kept(self): + # by default '-' and '_' are in the safe set + self.assertEqual(urlencode("a-b_c"), "a-b_c") + + +class TestParseTargetUrl(unittest.TestCase): + """parseTargetUrl mutates conf.* in place; save and restore everything touched.""" + + def _save(self): + return {k: conf.get(k) for k in + ("url", "scheme", "path", "hostname", "port", "ipv6")} + + def _restore(self, saved): + for k, v in saved.items(): + conf[k] = v + + def test_https_url(self): + saved = self._save() + orig_params = conf.parameters.get(PLACE.GET) + try: + conf.url = "https://www.test.com/?id=1" + parseTargetUrl() + self.assertEqual(conf.hostname, "www.test.com") + self.assertEqual(conf.scheme, "https") + self.assertEqual(conf.port, 443) + self.assertEqual(conf.parameters[PLACE.GET], "id=1") + finally: + self._restore(saved) + if orig_params is None: + conf.parameters.pop(PLACE.GET, None) + else: + conf.parameters[PLACE.GET] = orig_params + + def test_scheme_defaulted_and_port(self): + saved = self._save() + try: + conf.url = "example.org:8080/app" + parseTargetUrl() + self.assertEqual(conf.hostname, "example.org") + self.assertEqual(conf.scheme, "http") + self.assertEqual(conf.port, 8080) + finally: + self._restore(saved) + + def test_empty_url_returns_none(self): + saved = self._save() + try: + conf.url = "" + self.assertIsNone(parseTargetUrl()) + finally: + self._restore(saved) + + +class TestParseTargetDirect(unittest.TestCase): + """parseTargetDirect under smokeMode (early-returns before driver imports).""" + + def _save(self): + return {k: conf.get(k) for k in + ("direct", "dbms", "dbmsUser", "dbmsPass", "dbmsDb", "hostname", "port")} + + def _restore(self, saved): + for k, v in saved.items(): + conf[k] = v + + def test_full_mysql_dsn(self): + saved = self._save() + orig_smoke = kb.smokeMode + orig_none = conf.parameters.get(None) + try: + kb.smokeMode = True + conf.direct = "mysql://root:testpass@127.0.0.1:3306/testdb" + parseTargetDirect() + self.assertEqual(conf.dbms, "mysql") + self.assertEqual(conf.dbmsUser, "root") + self.assertEqual(conf.dbmsPass, "testpass") + self.assertEqual(conf.dbmsDb, "testdb") + self.assertEqual(conf.hostname, "127.0.0.1") + self.assertEqual(conf.port, 3306) + finally: + self._restore(saved) + kb.smokeMode = orig_smoke + if orig_none is None: + conf.parameters.pop(None, None) + else: + conf.parameters[None] = orig_none + + def test_quoted_password(self): + saved = self._save() + orig_smoke = kb.smokeMode + orig_none = conf.parameters.get(None) + try: + kb.smokeMode = True + conf.direct = "mysql://user:'P@ssw0rd'@127.0.0.1:3306/test" + parseTargetDirect() + self.assertEqual(conf.dbmsPass, "P@ssw0rd") + self.assertEqual(conf.hostname, "127.0.0.1") + finally: + self._restore(saved) + kb.smokeMode = orig_smoke + if orig_none is None: + conf.parameters.pop(None, None) + else: + conf.parameters[None] = orig_none + + def test_empty_direct_returns_none(self): + saved = self._save() + try: + conf.direct = None + self.assertIsNone(parseTargetDirect()) + finally: + self._restore(saved) + + +class TestSafeSQLIdentificatorNaming(unittest.TestCase): + """Quoting of identifiers is DBMS-specific; drive it via kb.forcedDbms.""" + + def _run(self, dbms, name, **kw): + orig = kb.forcedDbms + try: + kb.forcedDbms = dbms + return getText(safeSQLIdentificatorNaming(name, **kw)) + finally: + kb.forcedDbms = orig + + def test_mssql_keyword_bracketed(self): + self.assertEqual(self._run(DBMS.MSSQL, "begin"), "[begin]") + + def test_plain_name_unquoted(self): + self.assertEqual(self._run(DBMS.MSSQL, "foobar"), "foobar") + + def test_firebird_name_with_space_double_quoted(self): + self.assertEqual(self._run(DBMS.FIREBIRD, "foo bar"), '"foo bar"') + + def test_mysql_keyword_backticked(self): + self.assertEqual(self._run(DBMS.MYSQL, "select"), "`select`") + + def test_oracle_keyword_uppercased(self): + # Oracle quotes AND uppercases reserved words + self.assertEqual(self._run(DBMS.ORACLE, "table"), '"TABLE"') + + def test_unsafe_naming_passthrough(self): + orig = conf.unsafeNaming + try: + conf.unsafeNaming = True + self.assertEqual(self._run(DBMS.MYSQL, "select"), "select") + finally: + conf.unsafeNaming = orig + + +class TestGetPartRun(unittest.TestCase): + def test_no_dbms_handler_in_stack(self): + # called from a test (no conf.dbmsHandler.* on the stack) -> None + self.assertIsNone(getPartRun()) + + def test_non_alias_form_also_none(self): + self.assertIsNone(getPartRun(alias=False)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..69edf2e7a --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Tests for lib/core/compat.py -- cross-version compatibility utilities, +including WichmannHill RNG, patchHeaders, cmp_to_key, LooseVersion, +MixedWriteTextIO, and _codecs_open. +""" + +import io +import os +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.core.compat import (WichmannHill, patchHeaders, cmp, choose_boundary, + round, cmp_to_key, LooseVersion, _is_write_mode, + MixedWriteTextIO, _codecs_open, codecs_open) +from lib.core.compat import xrange + + +class TestWichmannHill(unittest.TestCase): + def test_seed_and_random(self): + r = WichmannHill(42) + self.assertIsInstance(r.random(), float) + self.assertGreaterEqual(r.random(), 0.0) + self.assertLess(r.random(), 1.0) + + def test_deterministic_seed(self): + r1 = WichmannHill(123) + r2 = WichmannHill(123) + # First random numbers should match + self.assertEqual([r1.random() for _ in range(10)], + [r2.random() for _ in range(10)]) + + def test_getstate_setstate(self): + r = WichmannHill(7) + for _ in range(20): + r.random() + state = r.getstate() + saved = [r.random() for _ in range(5)] + r.setstate(state) + self.assertEqual(saved, [r.random() for _ in range(5)]) + + def test_jumpahead(self): + r1 = WichmannHill(99) + r2 = WichmannHill(99) + for _ in range(10): + r1.random() + r2.jumpahead(10) + self.assertEqual(r1.getstate()[1], r2.getstate()[1]) + + def test_jumpahead_negative_raises(self): + r = WichmannHill() + with self.assertRaises(ValueError): + r.jumpahead(-1) + + def test_whseed(self): + # a fixed integer whseed must be deterministic across instances ... + r1 = WichmannHill() + r1.whseed(12345) + r2 = WichmannHill() + r2.whseed(12345) + self.assertEqual([r1.random() for _ in range(10)], + [r2.random() for _ in range(10)]) + # ... and pin the known sequence (hash(int) == int, so stable across processes) + r3 = WichmannHill() + r3.whseed(12345) + self.assertEqual([round(r3.random(), 6) for _ in range(3)], + [0.600031, 0.872148, 0.039151]) + + def test_whseed_none(self): + r = WichmannHill() + r.whseed() # seeds from current time; must not raise + # the time-derived seed must still drive a valid in-range sequence. (Non-determinism is NOT + # asserted here: __whseed() derives its seed from int(time.time()*256) masked to 24 bits, so + # two back-to-back instances legitimately collide - that would be a timing-fragile test. The + # os.urandom-backed seed() None path IS asserted non-deterministic in test_seed_none.) + seq = [r.random() for _ in range(10)] + self.assertTrue(all(isinstance(x, float) and 0.0 <= x < 1.0 for x in seq)) + # the seed must actually advance the generator (not stuck on a constant) + self.assertGreater(len(set(seq)), 1) + + def test_seed_none(self): + r = WichmannHill() + r.seed() # seeds from os.urandom/time; must not raise + seq = [r.random() for _ in range(10)] + self.assertTrue(all(isinstance(x, float) and 0.0 <= x < 1.0 for x in seq)) + other = WichmannHill() + other.seed() + self.assertNotEqual(seq, [other.random() for _ in range(10)]) + + def test_seed_hashable(self): + # a non-int hashable seed goes through hash(a); two instances seeded with the same + # object in the same process must produce the same sequence (determinism). The literal + # values are NOT pinned because hash() of a str is randomized per process. + r1 = WichmannHill("a_string_seed") + r2 = WichmannHill("a_string_seed") + seq = [r1.random() for _ in range(10)] + self.assertEqual(seq, [r2.random() for _ in range(10)]) + self.assertTrue(all(0.0 <= x < 1.0 for x in seq)) + # a different seed must yield a different sequence + r3 = WichmannHill("different_seed") + self.assertNotEqual(seq, [r3.random() for _ in range(10)]) + + def test_setstate_bad_version(self): + r = WichmannHill() + with self.assertRaises(ValueError): + r.setstate((999, (1, 1, 1), None)) + + +class TestPatchHeaders(unittest.TestCase): + def test_patches_dict_to_header_obj(self): + h = patchHeaders({"Host": "example.com", "Content-Type": "text/html"}) + self.assertEqual(h["host"], "example.com") + self.assertEqual(h["content-type"], "text/html") + self.assertEqual(h.get("HOST"), "example.com") + self.assertIsNone(h.get("missing")) + self.assertIsNotNone(h.headers) + self.assertTrue(any("Host: example.com" in _ for _ in h.headers)) + + def test_passthrough_none(self): + self.assertIsNone(patchHeaders(None)) + + def test_passthrough_existing_headers_attr(self): + d = {"A": "1"} + d["headers"] = [] + result = patchHeaders(d) + self.assertEqual(result, d) # unchanged + + +class TestCmp(unittest.TestCase): + def test_less(self): + self.assertEqual(cmp("a", "b"), -1) + + def test_greater(self): + self.assertEqual(cmp(2, 1), 1) + + def test_equal(self): + self.assertEqual(cmp(5, 5), 0) + + +class TestRound(unittest.TestCase): + def test_positive(self): + self.assertEqual(round(2.0), 2.0) + self.assertEqual(round(2.5), 3.0) + self.assertEqual(round(2.499), 2.0) + + def test_negative(self): + self.assertEqual(round(-2.5), -3.0) + self.assertEqual(round(-2.0), -2.0) + + def test_with_decimals(self): + self.assertAlmostEqual(round(2.567, d=2), 2.57) + + +class TestCmpToKey(unittest.TestCase): + def test_sort_with_cmp(self): + items = [3, 1, 4, 1, 5] + key_func = cmp_to_key(lambda a, b: (a > b) - (a < b)) + self.assertEqual(sorted(items, key=key_func), [1, 1, 3, 4, 5]) + + def test_reverse_sort(self): + items = [3, 1, 2] + key_func = cmp_to_key(lambda a, b: (b > a) - (b < a)) + self.assertEqual(sorted(items, key=key_func), [3, 2, 1]) + + def test_hash_raises(self): + k = cmp_to_key(lambda a, b: 0)(5) + with self.assertRaises(TypeError): + hash(k) + + +class TestLooseVersion(unittest.TestCase): + def test_basic(self): + self.assertEqual(LooseVersion("1.0"), (1, 0)) + self.assertEqual(LooseVersion("1.0.1"), (1, 0, 1)) + + def test_comparison(self): + self.assertTrue(LooseVersion("1.0.1") > LooseVersion("1.0")) + self.assertTrue(LooseVersion("8.0.22") > LooseVersion("8.0.2")) + + def test_no_digits(self): + self.assertEqual(LooseVersion("alpha"), ()) + self.assertEqual(LooseVersion(""), ()) + self.assertEqual(LooseVersion(None), ()) + + def test_with_suffix(self): + self.assertEqual(LooseVersion("1.0alpha"), (1, 0)) + self.assertEqual(LooseVersion("10.5.3-beta"), (10, 5, 3)) + + +class TestIsWriteMode(unittest.TestCase): + def test_write_modes(self): + for mode in ("w", "a", "x", "w+", "a+", "x+", "w+b", "ab"): + self.assertTrue(_is_write_mode(mode), msg="mode %r" % mode) + + def test_read_modes(self): + for mode in ("r", "rb", ""): + self.assertFalse(_is_write_mode(mode), msg="mode %r" % mode) + + +class TestMixedWriteTextIO(unittest.TestCase): + def test_text_write(self): + buf = io.StringIO() + w = MixedWriteTextIO(buf, "utf-8", "strict") + w.write(u"hello") + self.assertEqual(buf.getvalue(), "hello") + + def test_bytes_write_decodes(self): + buf = io.StringIO() + w = MixedWriteTextIO(buf, "utf-8", "strict") + w.write(b"world") + self.assertEqual(buf.getvalue(), "world") + + def test_writelines(self): + buf = io.StringIO() + w = MixedWriteTextIO(buf, "utf-8", "strict") + w.writelines([u"a", u"b", u"c"]) + self.assertEqual(buf.getvalue(), "abc") + + def test_iterator(self): + buf = io.StringIO(u"line1\nline2\n") + w = MixedWriteTextIO(buf, "utf-8", "strict") + self.assertEqual(list(w), ["line1\n", "line2\n"]) + + def test_enter_exit(self): + buf = io.StringIO() + w = MixedWriteTextIO(buf, "utf-8", "strict") + with w as f: + f.write(u"test") + self.assertTrue(buf.closed) + + +class TestCodecsOpen(unittest.TestCase): + def test_no_encoding_returns_io_open(self): + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) + tmp.close() + try: + f = _codecs_open(tmp.name, "w", encoding=None) + f.write(u"test") + f.close() + with open(tmp.name) as fh: + self.assertIn("test", fh.read()) + finally: + os.unlink(tmp.name) + + def test_with_encoding(self): + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) + tmp.close() + try: + f = _codecs_open(tmp.name, "w", encoding="utf-8") + f.write(u"caf\xe9") + f.close() + with open(tmp.name, "rb") as fh: + self.assertIn(b"caf\xc3\xa9", fh.read()) + finally: + os.unlink(tmp.name) + + def test_with_encoding_and_bytes(self): + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) + tmp.close() + try: + f = _codecs_open(tmp.name, "w", encoding="utf-8") + # MixedWriteTextIO should accept bytes too + f.write(b"bytes_input") + f.close() + with open(tmp.name) as fh: + self.assertIn("bytes_input", fh.read()) + finally: + os.unlink(tmp.name) + + +class TestChooseBoundary(unittest.TestCase): + def test_length(self): + self.assertEqual(len(choose_boundary()), 32) + + def test_hex_chars(self): + b = choose_boundary() + self.assertTrue(all(c in "0123456789abcdef" for c in b)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_core_extra.py b/tests/test_core_extra.py new file mode 100644 index 000000000..5c1a5a282 --- /dev/null +++ b/tests/test_core_extra.py @@ -0,0 +1,676 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional REAL unit coverage for genuinely-uncovered PURE functions in: + + * lib/core/common.py + * lib/core/option.py + * lib/core/agent.py + * lib/request/basic.py + +Every test asserts a concrete, independently-reasoned known-correct value that +would FAIL if the function under test regressed. No isinstance-only checks, no +tautologies, no swallowed exceptions. + +Functions targeted here are deliberately DIFFERENT from those already exercised +by tests/test_common_utils.py, test_common_parsers.py, test_core_more.py, +test_core_final.py, test_option_setup.py, test_option_more.py, test_agent.py, +test_agent_dialects.py, test_decodepage.py and test_charset.py. + +stdlib unittest only (no pytest / no pip); works on Python 2.7 and 3.x. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from tests._testutils import bootstrap, set_dbms + +bootstrap() + +from lib.core.data import conf, kb +from lib.core.defaults import defaults +from lib.core.common import Backend +from lib.core.enums import DBMS + + +class TestCommonStringHelpers(unittest.TestCase): + """Small pure string/list/regex/encoding helpers in lib/core/common.py.""" + + def test_posix_to_nt_slashes(self): + from lib.core.common import posixToNtSlashes + self.assertEqual(posixToNtSlashes("C:/Windows"), "C:\\Windows") + self.assertEqual(posixToNtSlashes("a/b/c"), "a\\b\\c") + # falsy input returned unchanged + self.assertEqual(posixToNtSlashes(""), "") + self.assertIsNone(posixToNtSlashes(None)) + + def test_nt_to_posix_slashes(self): + from lib.core.common import ntToPosixSlashes + self.assertEqual(ntToPosixSlashes("C:\\Windows"), "C:/Windows") + self.assertEqual(ntToPosixSlashes("a\\b\\c"), "a/b/c") + self.assertEqual(ntToPosixSlashes(""), "") + + def test_is_hex_encoded_string(self): + from lib.core.common import isHexEncodedString + self.assertTrue(isHexEncodedString("DEADBEEF")) + self.assertTrue(isHexEncodedString("0x1234")) # 'x' is allowed by the regex + self.assertFalse(isHexEncodedString("test")) + self.assertFalse(isHexEncodedString("12 34")) # space breaks it + + def test_is_digit(self): + from lib.core.common import isDigit + self.assertTrue(isDigit("123456")) + self.assertFalse(isDigit("3b3")) + self.assertFalse(isDigit(u"\xb2")) # superscript-2: str.isdigit() True, isDigit False + self.assertFalse(isDigit("")) # empty -> no match + self.assertFalse(isDigit(None)) + + def test_sanitize_str(self): + from lib.core.common import sanitizeStr + self.assertEqual(sanitizeStr("foo\n\rbar"), "foo bar") + self.assertEqual(sanitizeStr("a\r\nb"), "a b") + self.assertEqual(sanitizeStr(None), "None") + + def test_filter_control_chars(self): + from lib.core.common import filterControlChars + self.assertEqual(filterControlChars("AND 1>(2+3)\n--"), "AND 1>(2+3) --") + # custom replacement character + self.assertEqual(filterControlChars("a\tb", replacement="_"), "a_b") + + def test_normalize_path(self): + from lib.core.common import normalizePath + self.assertEqual(normalizePath("//var///log/apache.log"), "/var/log/apache.log") + self.assertEqual(normalizePath("/a/b/../c"), "/a/c") + + def test_directory_path(self): + from lib.core.common import directoryPath + self.assertEqual(directoryPath("/var/log/apache.log"), "/var/log") + # no extension -> returned unchanged + self.assertEqual(directoryPath("/var/log"), "/var/log") + + def test_longest_common_prefix(self): + from lib.core.common import longestCommonPrefix + self.assertEqual(longestCommonPrefix("foobar", "fobar"), "fo") + self.assertEqual(longestCommonPrefix("abc", "abd", "abe"), "ab") + # single sequence returned verbatim + self.assertEqual(longestCommonPrefix("only"), "only") + + def test_first_not_none(self): + from lib.core.common import firstNotNone + self.assertEqual(firstNotNone(None, None, 1, 2, 3), 1) + self.assertEqual(firstNotNone(None, 0), 0) # 0 is not None + self.assertIsNone(firstNotNone(None, None)) + + def test_decode_string_escape(self): + from lib.core.common import decodeStringEscape + self.assertEqual(decodeStringEscape("a\\tb"), "a\tb") + self.assertEqual(decodeStringEscape("a\\nb"), "a\nb") + # no backslash -> unchanged + self.assertEqual(decodeStringEscape("plain"), "plain") + + def test_encode_string_escape(self): + from lib.core.common import encodeStringEscape + self.assertEqual(encodeStringEscape("a\tb"), "a\\tb") + self.assertEqual(encodeStringEscape("a\nb"), "a\\nb") + self.assertEqual(encodeStringEscape("plain"), "plain") + + def test_decode_encode_string_escape_roundtrip(self): + from lib.core.common import decodeStringEscape, encodeStringEscape + self.assertEqual(decodeStringEscape(encodeStringEscape("x\ty\nz")), "x\ty\nz") + + def test_escape_json_value(self): + from lib.core.common import escapeJsonValue + # newline gets escaped (literal '\n' becomes the two chars backslash+n) + self.assertNotIn("\n", escapeJsonValue("foo\nbar")) + self.assertIn("\\n", escapeJsonValue("foo\nbar")) + # tab gets escaped to '\t' + self.assertIn("\\t", escapeJsonValue("foo\tbar")) + # quote and backslash escaped + self.assertEqual(escapeJsonValue('a"b'), 'a\\"b') + self.assertEqual(escapeJsonValue("a\\b"), "a\\\\b") + # ordinary characters untouched + self.assertEqual(escapeJsonValue("plain text"), "plain text") + + def test_clean_query(self): + from lib.core.common import cleanQuery + self.assertEqual(cleanQuery("select id from users"), "SELECT id FROM users") + # already-uppercase keywords stay; identifiers untouched + self.assertEqual(cleanQuery("SELECT a FROM t"), "SELECT a FROM t") + + def test_json_minimize_canonical(self): + from lib.core.common import jsonMinimize + # key order / whitespace independence + self.assertEqual(jsonMinimize('{"b": 2, "a": 1}'), jsonMinimize('{"a":1, "b":2}')) + # nested leaf path + self.assertEqual(jsonMinimize('{"a": {"b": 1}}'), ".a.b=1") + # empty object + self.assertEqual(jsonMinimize("{}"), "") + # not parseable -> None (and only None) + self.assertIsNone(jsonMinimize("not json")) + + def test_json_minimize_array_length_registers(self): + from lib.core.common import jsonMinimize + # array length change must perturb the projection + self.assertNotEqual(jsonMinimize('{"a": [1, 2]}'), jsonMinimize('{"a": [1, 2, 3]}')) + + def test_list_to_str_value(self): + from lib.core.common import listToStrValue + self.assertEqual(listToStrValue([1, 2, 3]), "1, 2, 3") + # set/tuple/generator normalized via list first + self.assertEqual(listToStrValue((1, 2)), "1, 2") + # non-list passes through + self.assertEqual(listToStrValue("abc"), "abc") + + def test_intersect(self): + from lib.core.common import intersect + self.assertEqual(intersect([1, 2, 3], set([1, 3])), [1, 3]) + # order follows containerA + self.assertEqual(intersect([3, 2, 1], [1, 2]), [2, 1]) + # case-insensitive option + self.assertEqual(intersect(["FOO", "bar"], ["foo"], lowerCase=True), ["foo"]) + + def test_priority_sort_columns(self): + from lib.core.common import prioritySortColumns + # 'id'-containing columns first, then by ascending length + self.assertEqual( + prioritySortColumns(["password", "userid", "name", "id"]), + ["id", "userid", "name", "password"], + ) + + def test_safe_variable_naming(self): + from lib.core.common import safeVariableNaming + self.assertEqual(safeVariableNaming("class.id"), "EVAL_636c6173732e6964") + # plain identifier left untouched + self.assertEqual(safeVariableNaming("foobar"), "foobar") + + def test_unsafe_variable_naming(self): + from lib.core.common import unsafeVariableNaming + self.assertEqual(unsafeVariableNaming("EVAL_636c6173732e6964"), "class.id") + self.assertEqual(unsafeVariableNaming("foobar"), "foobar") + + def test_variable_naming_roundtrip(self): + from lib.core.common import safeVariableNaming, unsafeVariableNaming + self.assertEqual(unsafeVariableNaming(safeVariableNaming("a-b")), "a-b") + + def test_average(self): + from lib.core.common import average + self.assertAlmostEqual(average([0.9, 0.9, 0.9, 1.0, 0.8, 0.9]), 0.9, places=6) + self.assertEqual(average([2, 4]), 3.0) + self.assertIsNone(average([])) + + def test_stdev(self): + from lib.core.common import stdev + self.assertEqual("%.3f" % stdev([0.9, 0.9, 0.9, 1.0, 0.8, 0.9]), "0.063") + # fewer than 2 values -> None + self.assertIsNone(stdev([1.0])) + self.assertIsNone(stdev([])) + + +class TestCommonSafeCompare(unittest.TestCase): + """Constant-time / checksum helpers.""" + + def test_safe_compare_strings(self): + from lib.core.common import safeCompareStrings + self.assertTrue(safeCompareStrings("test", "test")) + self.assertFalse(safeCompareStrings("test1", "test2")) + self.assertFalse(safeCompareStrings("test", None)) + # both None compares equal (a == b path) + self.assertTrue(safeCompareStrings(None, None)) + + def test_safe_cs_value(self): + from lib.core.common import safeCSValue + # ensure deterministic delimiter + old = conf.get("csvDel") + conf.csvDel = defaults.csvDel + try: + self.assertEqual(safeCSValue("foo, bar"), '"foo, bar"') + self.assertEqual(safeCSValue("foobar"), "foobar") + self.assertEqual(safeCSValue("foo\rbar"), '"foo\rbar"') + self.assertEqual(safeCSValue('foo"bar'), '"foo""bar"') + finally: + conf.csvDel = old + + +class TestCommonSafeExString(unittest.TestCase): + def test_sqlmap_exception_message(self): + from lib.core.common import getSafeExString + from lib.core.exception import SqlmapBaseException + self.assertEqual(getSafeExString(SqlmapBaseException("foobar")), "foobar") + + def test_oserror_prefixed_with_type(self): + from lib.core.common import getSafeExString + self.assertEqual(getSafeExString(OSError(0, "foobar")), "OSError: foobar") + + def test_generic_value_error(self): + from lib.core.common import getSafeExString + self.assertEqual(getSafeExString(ValueError("bad input")), "ValueError: bad input") + + +class TestCommonHostHeader(unittest.TestCase): + def test_plain_host(self): + from lib.core.common import getHostHeader + self.assertEqual(getHostHeader("http://www.target.com/vuln.php?id=1"), "www.target.com") + + def test_default_port_stripped(self): + from lib.core.common import getHostHeader + self.assertEqual(getHostHeader("http://www.target.com:80/x"), "www.target.com") + self.assertEqual(getHostHeader("https://www.target.com:443/x"), "www.target.com") + + def test_nondefault_port_kept(self): + from lib.core.common import getHostHeader + self.assertEqual(getHostHeader("http://www.target.com:8080/x"), "www.target.com:8080") + + def test_ipv6_brackets(self): + from lib.core.common import getHostHeader + self.assertEqual(getHostHeader("http://[::1]:8080/vuln.php?id=1"), "[::1]:8080") + self.assertEqual(getHostHeader("http://[::1]/vuln.php?id=1"), "[::1]") + + +class TestCommonCheckSameHost(unittest.TestCase): + def test_same_host(self): + from lib.core.common import checkSameHost + self.assertTrue(checkSameHost( + "http://www.target.com/page1.php?id=1", + "http://www.target.com/images/page2.php", + )) + + def test_different_host(self): + from lib.core.common import checkSameHost + self.assertFalse(checkSameHost( + "http://www.target.com/page1.php?id=1", + "http://www.target2.com/images/page2.php", + )) + + def test_www_prefix_ignored(self): + from lib.core.common import checkSameHost + # leading 'www.' is stripped before comparison + self.assertTrue(checkSameHost("http://www.target.com/a", "http://target.com/b")) + + def test_single_url_true_and_empty_none(self): + from lib.core.common import checkSameHost + self.assertTrue(checkSameHost("http://only.com/a")) + self.assertIsNone(checkSameHost()) + + +class TestCommonUrldecode(unittest.TestCase): + def test_convall_true(self): + from lib.core.common import urldecode + self.assertEqual(urldecode("AND%201%3E%282%2B3%29%23", convall=True), "AND 1>(2+3)#") + + def test_convall_false_keeps_unsafe(self): + from lib.core.common import urldecode + # %2B (plus) is in the default 'unsafe' set so it stays encoded when convall=False + self.assertEqual(urldecode("AND%201%3E%282%2B3%29%23", convall=False), "AND 1>(2%2B3)#") + + def test_bytes_input(self): + from lib.core.common import urldecode + self.assertEqual(urldecode(b"AND%201%3E%282%2B3%29%23", convall=False), "AND 1>(2%2B3)#") + + def test_spaceplus(self): + from lib.core.common import urldecode + # with spaceplus the '+' becomes a space + self.assertEqual(urldecode("a+b", convall=False, spaceplus=True), "a b") + # without spaceplus the '+' stays + self.assertEqual(urldecode("a+b", convall=False, spaceplus=False), "a+b") + + +class TestCommonChunkSplit(unittest.TestCase): + def test_chunk_split_post_data(self): + import random + from lib.core.common import chunkSplitPostData + from lib.core.patch import unisonRandom + # The pinned docstring value is produced under sqlmap's cross-version PRNG; install it + # (then restore the stdlib functions) so the expectation is deterministic here too. + _saved = (random.choice, random.randint, random.sample, random.seed) + unisonRandom() + try: + random.seed(0) + expected = ('5;4Xe90\r\nSELEC\r\n3;irWlc\r\nT u\r\n1;eT4zO\r\ns\r\n' + '5;YB4hM\r\nernam\r\n9;2pUD8\r\ne,passwor\r\n3;mp07y\r\nd F\r\n' + '5;8RKXi\r\nROM u\r\n4;MvMhO\r\nsers\r\n0\r\n\r\n') + self.assertEqual(chunkSplitPostData("SELECT username,password FROM users"), expected) + finally: + random.choice, random.randint, random.sample, random.seed = _saved + + def test_chunk_split_terminator(self): + import random + from lib.core.common import chunkSplitPostData + random.seed(123) + # regardless of content, the chunked stream must end with the zero-length terminator + self.assertTrue(chunkSplitPostData("abc").endswith("0\r\n\r\n")) + + +class TestCommonDecodeIntToUnicode(unittest.TestCase): + def tearDown(self): + set_dbms(None) + + def test_basic_ascii(self): + from lib.core.common import decodeIntToUnicode + self.assertEqual(decodeIntToUnicode(35), "#") + self.assertEqual(decodeIntToUnicode(64), "@") + self.assertEqual(decodeIntToUnicode(65), "A") + + def test_non_int_passthrough(self): + from lib.core.common import decodeIntToUnicode + # non-int is returned unchanged + self.assertEqual(decodeIntToUnicode("x"), "x") + + def test_pgsql_high_codepoint(self): + from lib.core.common import decodeIntToUnicode + set_dbms(DBMS.PGSQL) + # value > 255 on PGSQL takes the _unichr(value) branch + self.assertEqual(decodeIntToUnicode(0x2122), u"â„¢") + + +class TestCommonDecodeDbmsHex(unittest.TestCase): + def setUp(self): + self._old_binary = kb.binaryField + kb.binaryField = False + + def tearDown(self): + kb.binaryField = self._old_binary + set_dbms(None) + + def test_plain_hex(self): + from lib.core.common import decodeDbmsHexValue + self.assertEqual(decodeDbmsHexValue("3132332031"), u"123 1") + + def test_odd_length_appends_question_mark(self): + from lib.core.common import decodeDbmsHexValue + self.assertEqual(decodeDbmsHexValue("313233203"), u"123 ?") + + def test_list_input(self): + from lib.core.common import decodeDbmsHexValue + self.assertEqual(decodeDbmsHexValue(["0x31", "0x32"]), [u"1", u"2"]) + + def test_non_hex_passthrough(self): + from lib.core.common import decodeDbmsHexValue + self.assertEqual(decodeDbmsHexValue("5.1.41"), u"5.1.41") + + +class TestCommonUnsafeSQLIdentificator(unittest.TestCase): + def tearDown(self): + set_dbms(None) + + def test_mssql_brackets(self): + from lib.core.common import unsafeSQLIdentificatorNaming + from lib.core.common import getText + set_dbms(DBMS.MSSQL) + self.assertEqual(getText(unsafeSQLIdentificatorNaming("[begin]")), "begin") + self.assertEqual(getText(unsafeSQLIdentificatorNaming("foobar")), "foobar") + + def test_mysql_backticks(self): + from lib.core.common import unsafeSQLIdentificatorNaming, getText + set_dbms(DBMS.MYSQL) + self.assertEqual(getText(unsafeSQLIdentificatorNaming("`col`")), "col") + + def test_oracle_uppercases(self): + from lib.core.common import unsafeSQLIdentificatorNaming, getText + set_dbms(DBMS.ORACLE) + # Oracle strips double quotes and uppercases + self.assertEqual(getText(unsafeSQLIdentificatorNaming('"name"')), "NAME") + + +class TestCommonParseSqliteSchema(unittest.TestCase): + def setUp(self): + self._old_cached = kb.data.get("cachedColumns") + self._old_db = conf.db + self._old_tbl = conf.tbl + kb.data.cachedColumns = {} + conf.db = "SQLITE_MASTER" + conf.tbl = "users" + + def tearDown(self): + kb.data.cachedColumns = self._old_cached + conf.db = self._old_db + conf.tbl = self._old_tbl + + def test_simple_schema(self): + from lib.core.common import parseSqliteTableSchema + self.assertTrue(parseSqliteTableSchema( + "CREATE TABLE users(\n\t\tid INTEGER,\n\t\tname TEXT\n);")) + cols = kb.data.cachedColumns[conf.db][conf.tbl] + self.assertEqual(tuple(cols.items()), (("id", "INTEGER"), ("name", "TEXT"))) + + def test_constraints_skipped(self): + from lib.core.common import parseSqliteTableSchema + self.assertTrue(parseSqliteTableSchema( + "CREATE TABLE suppliers(\n\tsupplier_id INTEGER PRIMARY KEY DESC,\n\tname TEXT NOT NULL\n);")) + cols = kb.data.cachedColumns[conf.db][conf.tbl] + self.assertEqual(tuple(cols.items()), (("supplier_id", "INTEGER"), ("name", "TEXT"))) + + +class TestAgentPure(unittest.TestCase): + """Pure agent.py methods independent of full injection state.""" + + @classmethod + def setUpClass(cls): + from lib.core.agent import agent + cls.agent = agent + + def tearDown(self): + set_dbms(None) + + def test_get_comment_present(self): + from lib.core.datatype import AttribDict + request = AttribDict() + request.comment = "-- foo" + self.assertEqual(self.agent.getComment(request), "-- foo") + + def test_get_comment_absent(self): + from lib.core.datatype import AttribDict + request = AttribDict() + self.assertEqual(self.agent.getComment(request), "") + + def test_add_payload_delimiters(self): + from lib.core.settings import PAYLOAD_DELIMITER + value = "1 AND 1=1" + result = self.agent.addPayloadDelimiters(value) + self.assertEqual(result, "%s%s%s" % (PAYLOAD_DELIMITER, value, PAYLOAD_DELIMITER)) + # falsy value returned unchanged + self.assertEqual(self.agent.addPayloadDelimiters(""), "") + + def test_remove_payload_delimiters_roundtrip(self): + self.assertEqual( + self.agent.removePayloadDelimiters(self.agent.addPayloadDelimiters("1 AND 1=1")), + "1 AND 1=1", + ) + + def test_extract_payload(self): + wrapped = "prefix" + self.agent.addPayloadDelimiters("1 AND 1=1") + "suffix" + self.assertEqual(self.agent.extractPayload(wrapped), "1 AND 1=1") + + def test_replace_payload(self): + wrapped = "prefix" + self.agent.addPayloadDelimiters("OLD") + "suffix" + replaced = self.agent.replacePayload(wrapped, "NEW") + self.assertEqual(self.agent.extractPayload(replaced), "NEW") + # surrounding text preserved + self.assertTrue(replaced.startswith("prefix")) + self.assertTrue(replaced.endswith("suffix")) + + def test_simple_concatenate_mysql(self): + set_dbms(DBMS.MYSQL) + # MySQL concatenate query template is 'CONCAT(%s,%s)' + self.assertEqual(self.agent.simpleConcatenate("a", "b"), "CONCAT(a,b)") + + def test_hex_convert_field_mysql(self): + set_dbms(DBMS.MYSQL) + # MySQL hex template is 'HEX(%s)' + self.assertEqual(self.agent.hexConvertField("col"), "HEX(col)") + + def test_get_fields_select_from(self): + set_dbms(DBMS.MYSQL) + result = self.agent.getFields("SELECT a, b FROM users") + fieldsToCastList = result[5] + fieldsToCastStr = result[6] + self.assertEqual(fieldsToCastStr, "a, b") + self.assertEqual(fieldsToCastList, ["a", "b"]) + + def test_get_fields_no_from(self): + set_dbms(DBMS.MYSQL) + # a bare SELECT without FROM -> fieldsSelectFrom is None, casts the whole select list + result = self.agent.getFields("SELECT 1") + fieldsSelectFrom = result[0] + self.assertIsNone(fieldsSelectFrom) + self.assertEqual(result[6], "1") + + +class TestAgentWhereQuery(unittest.TestCase): + @classmethod + def setUpClass(cls): + from lib.core.agent import agent + cls.agent = agent + + def setUp(self): + self._old_dumpWhere = conf.dumpWhere + self._old_tbl = conf.tbl + conf.tbl = None + + def tearDown(self): + conf.dumpWhere = self._old_dumpWhere + conf.tbl = self._old_tbl + set_dbms(None) + + def test_no_dumpwhere_passthrough(self): + conf.dumpWhere = None + query = "SELECT a FROM t" + self.assertEqual(self.agent.whereQuery(query), query) + + def test_appends_where_clause(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>0" + # no existing WHERE -> appends ' WHERE id>0' + self.assertEqual(self.agent.whereQuery("SELECT a FROM t"), "SELECT a FROM t WHERE id>0") + + def test_and_when_where_present(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>0" + # existing WHERE -> appended with AND + self.assertEqual( + self.agent.whereQuery("SELECT a FROM t WHERE x=1"), + "SELECT a FROM t WHERE x=1 AND id>0", + ) + + def test_splices_before_order_by(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>0" + # WHERE must be spliced before the trailing ORDER BY suffix + self.assertEqual( + self.agent.whereQuery("SELECT a FROM t ORDER BY a"), + "SELECT a FROM t WHERE id>0 ORDER BY a", + ) + + +class TestBasicHeuristicCharEncoding(unittest.TestCase): + def test_ascii(self): + from lib.request.basic import getHeuristicCharEncoding + self.assertEqual(getHeuristicCharEncoding(b""), "ascii") + + def test_cache_hit_returns_same(self): + from lib.request.basic import getHeuristicCharEncoding + page = b"hello world" + first = getHeuristicCharEncoding(page) + # second call for identical page must come back identical (and from cache) + self.assertEqual(getHeuristicCharEncoding(page), first) + key = (len(page), hash(page)) + self.assertEqual(kb.cache.encoding.get(key), first) + + +class TestBasicDecodePage(unittest.TestCase): + """decodePage charset + HTML-entity decoding branches.""" + + def setUp(self): + self._old_encoding = conf.encoding + self._old_null = conf.nullConnection + conf.nullConnection = False + + def tearDown(self): + conf.encoding = self._old_encoding + conf.nullConnection = self._old_null + + def test_html_entity_amp(self): + from lib.request.basic import decodePage + from lib.core.common import getText + conf.encoding = None + self.assertEqual( + getText(decodePage(b"foo&bar", None, "text/html; charset=utf-8")), + "foo&bar", + ) + + def test_numeric_hex_entity_tab(self): + from lib.request.basic import decodePage + from lib.core.common import getText + conf.encoding = None + self.assertEqual(getText(decodePage(b" ", None, "text/html; charset=utf-8")), "\t") + + def test_numeric_hex_entity_letter(self): + from lib.request.basic import decodePage + from lib.core.common import getText + conf.encoding = None + self.assertEqual(getText(decodePage(b"J", None, "text/html; charset=utf-8")), "J") + + def test_unicode_entity(self): + from lib.request.basic import decodePage + conf.encoding = None + self.assertEqual(decodePage(b"™", None, "text/html; charset=utf-8"), u"â„¢") + + def test_empty_page(self): + from lib.request.basic import decodePage + from lib.core.common import getText + # empty page short-circuits to getUnicode(page) + self.assertEqual(getText(decodePage(b"", None, "text/html")), "") + + +class TestOptionSetPrefixSuffix(unittest.TestCase): + """_setPrefixSuffix boundary construction (pure conf-mutation, no I/O).""" + + def setUp(self): + self._saved = {k: conf.get(k) for k in ("prefix", "suffix", "boundaries")} + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + + def _run(self, prefix, suffix): + from lib.core.option import _setPrefixSuffix + conf.prefix = prefix + conf.suffix = suffix + conf.boundaries = None + _setPrefixSuffix() + return conf.boundaries + + def test_none_no_boundary(self): + # when either prefix or suffix is None, no boundary is created + self.assertIsNone(self._run(None, None)) + + def test_single_quote_ptype(self): + boundaries = self._run("' AND ", "'") + self.assertEqual(len(boundaries), 1) + b = boundaries[0] + self.assertEqual(b.prefix, "' AND ") + self.assertEqual(b.suffix, "'") + self.assertEqual(b.ptype, 2) # single-quote, no LIKE + self.assertEqual(b.level, 1) + self.assertEqual(b.clause, [0]) + + def test_double_quote_ptype(self): + boundaries = self._run('" AND ', '"') + self.assertEqual(boundaries[0].ptype, 4) # double-quote, no LIKE + + def test_numeric_ptype(self): + boundaries = self._run(" AND ", "") + self.assertEqual(boundaries[0].ptype, 1) # no quoting + + def test_like_single_quote_ptype(self): + boundaries = self._run("' AND ", "' like '%") + self.assertEqual(boundaries[0].ptype, 3) # LIKE with single quote + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_core_final.py b/tests/test_core_final.py new file mode 100644 index 000000000..1e1119a48 --- /dev/null +++ b/tests/test_core_final.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional unit coverage for lib/core/common.py, lib/core/option.py and +lib/core/target.py, targeting *pure* (or near-pure) functions and branches NOT +already exercised by the existing test modules: + + * tests/test_common_utils.py / test_common_parsers.py / test_core_more.py + * tests/test_option_setup.py / test_option_more.py + * tests/test_target_parsing.py + +This file instead covers (common.py): + + boldifyMessage, calculateDeltaSeconds, commonFinderOnly, + enumValueToNameLookup, extractErrorMessage, filePathToSafeString, + isWindowsDriveLetterPath, cleanReplaceUnicode, trimAlphaNum, + removePostHintPrefix, safeExpandUser, safeFilepathEncode, + serializeObject/unserializeObject, applyFunctionRecursively, + extractExpectedValue, getHeader, getRequestHeader, parseJson, + parsePasswordHash, findMultipartPostBoundary, setTechnique/getTechnique, + extractRegexResult, extractTextTagContent, getFilteredPageContent, + checkFile, listToStrValue, intersect, isZipFile, checkOldOptions. + +(option.py): + + _setHTTPAuthentication (basic/ntlm/bearer/pki + error branches), + _setWriteFile, _setHTTPTimeout, _setAuthCred. + +Everything runs in isolation: no network, no DBMS, no persistent filesystem +mutation. All mutated conf/kb/Backend/socket state is snapshotted and restored. +""" + +import os +import socket +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +import lib.core.option as option +from lib.core.data import conf, kb, paths +from lib.core.enums import ( + AUTH_TYPE, + DBMS, + EXPECTED, + HTTP_HEADER, + SORT_ORDER, +) +from lib.core.exception import ( + SqlmapFilePathException, + SqlmapMissingMandatoryOptionException, + SqlmapMissingDependence, + SqlmapSyntaxException, + SqlmapSystemException, +) +from lib.core.settings import NULL +from lib.core.common import ( + applyFunctionRecursively, + boldifyMessage, + calculateDeltaSeconds, + checkFile, + checkOldOptions, + cleanReplaceUnicode, + commonFinderOnly, + enumValueToNameLookup, + extractErrorMessage, + extractExpectedValue, + extractRegexResult, + extractTextTagContent, + filePathToSafeString, + findMultipartPostBoundary, + getFilteredPageContent, + getHeader, + getRequestHeader, + getText, + getTechnique, + intersect, + isWindowsDriveLetterPath, + isZipFile, + listToStrValue, + parseJson, + parsePasswordHash, + removePostHintPrefix, + safeExpandUser, + safeFilepathEncode, + serializeObject, + setTechnique, + trimAlphaNum, + unserializeObject, +) +from thirdparty.six.moves import urllib as _urllib + + +class _FakeRequest(object): + """Minimal stand-in for urllib2.Request used by getRequestHeader().""" + + def __init__(self, headers): + self.headers = headers + + def header_items(self): + return self.headers.items() + + +class TestCommonPureHelpers(unittest.TestCase): + """Pure string/encoding/list/regex helpers from lib/core/common.py.""" + + def test_boldify_message_marks_known_pattern(self): + self.assertEqual( + boldifyMessage("GET parameter id is not injectable", istty=True), + "\x1b[1mGET parameter id is not injectable\x1b[0m", + ) + + def test_boldify_message_leaves_plain_unchanged(self): + self.assertEqual(boldifyMessage("just a plain message", istty=True), "just a plain message") + + def test_calculate_delta_seconds_from_epoch(self): + self.assertGreater(calculateDeltaSeconds(0), 1151721660) + + def test_calculate_delta_seconds_nonnegative(self): + import time as _time + self.assertGreaterEqual(calculateDeltaSeconds(_time.time()), 0.0) + + def test_common_finder_only_returns_longest_common_prefix(self): + self.assertEqual(commonFinderOnly("abcd", ["abcdefg", "foobar", "abcde"]), "abcde") + + def test_enum_value_to_name_lookup_hit(self): + self.assertEqual(enumValueToNameLookup(SORT_ORDER, SORT_ORDER.LAST), "LAST") + + def test_enum_value_to_name_lookup_miss(self): + self.assertIsNone(enumValueToNameLookup(SORT_ORDER, -987654321)) + + def test_file_path_to_safe_string(self): + self.assertEqual(filePathToSafeString("C:/Windows/system32"), "C__Windows_system32") + + def test_file_path_to_safe_string_spaces_backslashes(self): + self.assertEqual(filePathToSafeString("a b\\c:d"), "a_b_c_d") + + def test_is_windows_drive_letter_path_true(self): + self.assertTrue(isWindowsDriveLetterPath("C:\\boot.ini")) + + def test_is_windows_drive_letter_path_false(self): + self.assertFalse(isWindowsDriveLetterPath("/var/log/apache.log")) + + def test_clean_replace_unicode_list(self): + self.assertEqual(cleanReplaceUnicode(["a", "b"]), ["a", "b"]) + + def test_clean_replace_unicode_scalar(self): + self.assertEqual(cleanReplaceUnicode(u"plain"), u"plain") + + def test_trim_alpha_num(self): + self.assertEqual(trimAlphaNum("AND 1>(2+3)-- foobar"), " 1>(2+3)-- ") + + def test_trim_alpha_num_all_alnum(self): + self.assertEqual(trimAlphaNum("abc123"), "") + + def test_trim_alpha_num_empty(self): + self.assertEqual(trimAlphaNum(""), "") + + def test_list_to_str_value_list(self): + self.assertEqual(listToStrValue([1, 2, 3]), "1, 2, 3") + + def test_list_to_str_value_tuple(self): + self.assertEqual(listToStrValue((4, 5)), "4, 5") + + def test_list_to_str_value_scalar(self): + self.assertEqual(listToStrValue("foo"), "foo") + + def test_intersect_lists(self): + self.assertEqual(intersect([1, 2, 3], set([1, 3])), [1, 3]) + + def test_intersect_lowercase(self): + self.assertEqual(intersect(["A", "B"], ["a"], lowerCase=True), ["a"]) + + def test_intersect_empty(self): + self.assertEqual(intersect([], [1, 2]), []) + + def test_apply_function_recursively(self): + self.assertEqual( + applyFunctionRecursively([1, 2, [3, -9]], lambda _: _ > 0), + [True, True, [True, False]], + ) + + def test_apply_function_recursively_scalar(self): + self.assertEqual(applyFunctionRecursively(5, lambda _: _ + 1), 6) + + +class TestCommonRegexAndPage(unittest.TestCase): + """Regex / page-content extraction helpers.""" + + def test_extract_regex_result_hit(self): + self.assertEqual(extractRegexResult(r"a(?P[^g]+)g", "abcdefg"), "bcdef") + + def test_extract_regex_result_no_match(self): + self.assertIsNone(extractRegexResult(r"a(?P[^g]+)g", "xyz")) + + def test_extract_regex_result_no_result_group(self): + self.assertIsNone(extractRegexResult(r"plain", "plain")) + + def test_extract_regex_result_empty_content(self): + self.assertIsNone(extractRegexResult(r"a(?P.)b", "")) + + def test_extract_text_tag_content(self): + self.assertEqual( + extractTextTagContent("Title
foobar
"), + ["Title", "foobar"], + ) + + def test_extract_text_tag_content_empty(self): + self.assertEqual(extractTextTagContent(""), []) + + def test_get_filtered_page_content(self): + self.assertEqual( + getFilteredPageContent(u"foobartest"), + "foobar test", + ) + + def test_get_filtered_page_content_drops_script(self): + page = u"hello" + self.assertNotIn("var x", getFilteredPageContent(page)) + self.assertIn("hello", getFilteredPageContent(page)) + + def test_get_filtered_page_content_nonstring_passthrough(self): + self.assertEqual(getFilteredPageContent(None), None) + + def test_extract_error_message_oracle(self): + page = (u"Test\nWarning: oci_parse() " + u"[function.oci-parse]: ORA-01756: quoted string not properly " + u"terminated

Only a test page

") + self.assertEqual( + getText(extractErrorMessage(page)), + "oci_parse() [function.oci-parse]: ORA-01756: quoted string not properly terminated", + ) + + def test_extract_error_message_none_for_plain(self): + self.assertIsNone(extractErrorMessage("Warning: This is only a dummy foobar test")) + + def test_extract_error_message_non_string(self): + self.assertIsNone(extractErrorMessage(None)) + + def test_find_multipart_post_boundary(self): + post = ("-----------------------------9051914041544843365972754266\n" + "Content-Disposition: form-data; name=text\n\ndefault") + self.assertEqual(findMultipartPostBoundary(post), "9051914041544843365972754266") + + def test_find_multipart_post_boundary_none(self): + self.assertIsNone(findMultipartPostBoundary("")) + + +class TestCommonHeadersAndExpected(unittest.TestCase): + + def test_get_header_case_insensitive(self): + self.assertEqual(getHeader({"Foo": "bar"}, "foo"), "bar") + + def test_get_header_missing(self): + self.assertIsNone(getHeader({"Foo": "bar"}, "x")) + + def test_get_header_empty_dict(self): + self.assertIsNone(getHeader({}, "anything")) + + def test_get_request_header_hit(self): + self.assertEqual(getText(getRequestHeader(_FakeRequest({"FOO": "BAR"}), "foo")), "BAR") + + def test_get_request_header_miss(self): + self.assertIsNone(getRequestHeader(_FakeRequest({"FOO": "BAR"}), "missing")) + + def test_extract_expected_value_bool_true(self): + self.assertIs(extractExpectedValue(["1"], EXPECTED.BOOL), True) + + def test_extract_expected_value_bool_false(self): + self.assertIs(extractExpectedValue(["0"], EXPECTED.BOOL), False) + + def test_extract_expected_value_bool_word(self): + self.assertIs(extractExpectedValue(["true"], EXPECTED.BOOL), True) + self.assertIs(extractExpectedValue(["false"], EXPECTED.BOOL), False) + + def test_extract_expected_value_int(self): + self.assertEqual(extractExpectedValue("5", EXPECTED.INT), 5) + + def test_extract_expected_value_int_invalid(self): + self.assertIsNone(extractExpectedValue(u"7\xb9645", EXPECTED.INT)) + + def test_extract_expected_value_no_expected(self): + self.assertEqual(extractExpectedValue("foo", None), "foo") + + +class TestParseJsonAndHash(unittest.TestCase): + + def test_parse_json_double_quotes(self): + self.assertEqual(parseJson('{"id":1}')["id"], 1) + + def test_parse_json_single_quotes(self): + self.assertEqual(parseJson("{'id':1, 'foo':[2,3,4]}")["id"], 1) + + def test_parse_json_not_json(self): + self.assertIsNone(parseJson("this is not json")) + + def test_parse_password_hash_mssql(self): + saved = kb.forcedDbms + try: + kb.forcedDbms = DBMS.MSSQL + result = parsePasswordHash("0x01004086ceb60c90646a8ab9889fe3ed8e5c150b5460ece8425a") + self.assertIn("salt: 4086ceb6", result) + self.assertIn("header: 0x0100", result) + finally: + kb.forcedDbms = saved + + def test_parse_password_hash_none(self): + self.assertEqual(parsePasswordHash(None), NULL) + + def test_parse_password_hash_blank(self): + self.assertEqual(parsePasswordHash(" "), NULL) + + +class TestSerializeAndTechnique(unittest.TestCase): + + def test_serialize_roundtrip(self): + self.assertEqual(unserializeObject(serializeObject([1, 2, 3])), [1, 2, 3]) + + def test_serialize_object_is_str(self): + self.assertIsInstance(serializeObject([1, 2, ("a", "b")]), str) + + def test_unserialize_none(self): + self.assertIsNone(unserializeObject(None)) + + def test_set_get_technique_thread_local(self): + saved = getTechnique() + try: + setTechnique(5) + self.assertEqual(getTechnique(), 5) + finally: + setTechnique(saved) + + def test_get_technique_falls_back_to_kb(self): + saved_thread = getTechnique() + saved_kb = kb.get("technique") + try: + setTechnique(None) + kb.technique = 7 + self.assertEqual(getTechnique(), 7) + finally: + setTechnique(saved_thread) + kb.technique = saved_kb + + +class TestRemovePostHint(unittest.TestCase): + + def test_removes_known_prefix(self): + self.assertEqual(removePostHintPrefix("JSON id"), "id") + + def test_no_prefix_unchanged(self): + self.assertEqual(removePostHintPrefix("id"), "id") + + +class TestFileHelpers(unittest.TestCase): + + def test_check_file_existing(self): + self.assertTrue(checkFile(__file__)) + + def test_check_file_missing_no_raise(self): + self.assertFalse(checkFile("/no/such/path_xyz_123", raiseOnError=False)) + + def test_check_file_missing_raises(self): + with self.assertRaises(SqlmapSystemException): + checkFile("/no/such/path_xyz_123", raiseOnError=True) + + def test_is_zip_file_wordlist(self): + # paths.WORDLIST is a zip-compressed wordlist shipped with sqlmap + self.assertTrue(isZipFile(paths.WORDLIST)) + + def test_is_zip_file_plain_text(self): + self.assertFalse(isZipFile(paths.SQL_KEYWORDS)) + + def test_safe_filepath_encode_ascii_passthrough(self): + # On Python 3 the function returns the value unchanged for str input + self.assertEqual(safeFilepathEncode("/tmp/x"), "/tmp/x") + + def test_safe_expand_user_basename_preserved(self): + self.assertIn(os.path.basename(__file__), safeExpandUser(__file__)) + + +class TestCheckOldOptions(unittest.TestCase): + + def test_no_old_options_is_noop(self): + # Returns None and does not raise when no deprecated options are present + self.assertIsNone(checkOldOptions(["-u", "http://test.invalid/?id=1", "--banner"])) + + +class TestOptionSetWriteFile(unittest.TestCase): + + def setUp(self): + self._saved = (conf.fileWrite, conf.fileDest, conf.get("fileWriteType")) + + def tearDown(self): + conf.fileWrite, conf.fileDest, conf.fileWriteType = self._saved + + def test_noop_when_no_filewrite(self): + conf.fileWrite = None + self.assertIsNone(option._setWriteFile()) + + def test_raises_on_missing_local_file(self): + conf.fileWrite = "/no/such/local_file_xyz" + conf.fileDest = "/var/www/x" + with self.assertRaises(SqlmapFilePathException): + option._setWriteFile() + + def test_raises_on_missing_dest(self): + fd, path = tempfile.mkstemp() + os.close(fd) + try: + conf.fileWrite = path + conf.fileDest = None + with self.assertRaises(SqlmapMissingMandatoryOptionException): + option._setWriteFile() + finally: + os.unlink(path) + + def test_sets_file_write_type(self): + fd, path = tempfile.mkstemp() + os.close(fd) + try: + conf.fileWrite = path + conf.fileDest = "/var/www/x" + option._setWriteFile() + self.assertIn(conf.fileWriteType, ("text", "binary")) + finally: + os.unlink(path) + + +class TestOptionSetHTTPTimeout(unittest.TestCase): + + def setUp(self): + self._savedTimeout = conf.timeout + self._savedSocket = socket.getdefaulttimeout() + + def tearDown(self): + conf.timeout = self._savedTimeout + socket.setdefaulttimeout(self._savedSocket) + + def test_explicit_timeout(self): + conf.timeout = 10 + option._setHTTPTimeout() + self.assertEqual(conf.timeout, 10.0) + + def test_below_minimum_is_clamped(self): + conf.timeout = 1 + option._setHTTPTimeout() + self.assertEqual(conf.timeout, 3.0) + + def test_default_when_unset(self): + conf.timeout = None + option._setHTTPTimeout() + self.assertEqual(conf.timeout, 30.0) + + +class TestOptionSetHTTPAuthentication(unittest.TestCase): + + def setUp(self): + self._saved = { + "authType": conf.authType, + "authCred": conf.authCred, + "authFile": conf.authFile, + "authUsername": conf.authUsername, + "authPassword": conf.authPassword, + "httpHeaders": list(conf.httpHeaders), + "passwordMgr": kb.passwordMgr, + } + # provide a real password manager so the basic/digest branches work + kb.passwordMgr = _urllib.request.HTTPPasswordMgrWithDefaultRealm() + + def tearDown(self): + conf.authType = self._saved["authType"] + conf.authCred = self._saved["authCred"] + conf.authFile = self._saved["authFile"] + conf.authUsername = self._saved["authUsername"] + conf.authPassword = self._saved["authPassword"] + conf.httpHeaders = self._saved["httpHeaders"] + kb.passwordMgr = self._saved["passwordMgr"] + + def test_noop_when_nothing_set(self): + conf.authType = None + conf.authCred = None + conf.authFile = None + self.assertIsNone(option._setHTTPAuthentication()) + + def test_basic_credentials_parsed(self): + conf.authType = "basic" + conf.authCred = "admin:secret" + conf.authFile = None + option._setHTTPAuthentication() + self.assertEqual(conf.authUsername, "admin") + self.assertEqual(conf.authPassword, "secret") + + def test_ntlm_credentials_parsed(self): + conf.authType = "ntlm" + conf.authCred = "DOMAIN\\user:pa:ss" + conf.authFile = None + conf.authUsername = None + conf.authPassword = None + # The python-ntlm handler module is optional; credential parsing happens + # before the handler import, so the parsed creds are set regardless. + try: + option._setHTTPAuthentication() + except SqlmapMissingDependence: + pass + self.assertEqual(conf.authUsername, "DOMAIN\\user") + self.assertEqual(conf.authPassword, "pa:ss") + + def test_ntlm_bad_format_raises(self): + conf.authType = "ntlm" + conf.authCred = "nobackslash:pass" + conf.authFile = None + with self.assertRaises(SqlmapSyntaxException): + option._setHTTPAuthentication() + + def test_bearer_appends_authorization_header(self): + conf.authType = "bearer" + conf.authCred = "tok123" + conf.authFile = None + conf.httpHeaders = [] + option._setHTTPAuthentication() + self.assertIn((HTTP_HEADER.AUTHORIZATION, "Bearer tok123"), conf.httpHeaders) + + def test_unsupported_type_raises(self): + conf.authType = "wrongtype" + conf.authCred = "a:b" + conf.authFile = None + with self.assertRaises(SqlmapSyntaxException): + option._setHTTPAuthentication() + + def test_type_without_credentials_raises(self): + conf.authType = "basic" + conf.authCred = None + conf.authFile = None + with self.assertRaises(SqlmapSyntaxException): + option._setHTTPAuthentication() + + def test_credentials_without_type_raises(self): + conf.authType = None + conf.authCred = "a:b" + conf.authFile = None + with self.assertRaises(SqlmapSyntaxException): + option._setHTTPAuthentication() + + def test_authfile_without_type_defaults_to_pki(self): + conf.authType = None + conf.authCred = None + conf.authFile = __file__ # exists, so checkFile() inside PKI branch passes + option._setHTTPAuthentication() + self.assertEqual(conf.authType, AUTH_TYPE.PKI) + + def test_pki_type_without_authfile_raises(self): + conf.authType = "pki" + conf.authCred = "x" + conf.authFile = None + with self.assertRaises(SqlmapSyntaxException): + option._setHTTPAuthentication() + + +class TestOptionSetAuthCred(unittest.TestCase): + + def setUp(self): + self._saved = { + "scheme": conf.scheme, + "hostname": conf.hostname, + "port": conf.port, + "authUsername": conf.authUsername, + "authPassword": conf.authPassword, + "passwordMgr": kb.passwordMgr, + } + + def tearDown(self): + conf.scheme = self._saved["scheme"] + conf.hostname = self._saved["hostname"] + conf.port = self._saved["port"] + conf.authUsername = self._saved["authUsername"] + conf.authPassword = self._saved["authPassword"] + kb.passwordMgr = self._saved["passwordMgr"] + + def test_noop_without_password_manager(self): + kb.passwordMgr = None + # Must not raise when there is no password manager configured + self.assertIsNone(option._setAuthCred()) + + def test_adds_credentials_to_manager(self): + kb.passwordMgr = _urllib.request.HTTPPasswordMgrWithDefaultRealm() + conf.scheme = "http" + conf.hostname = "host" + conf.port = 80 + conf.authUsername = "u" + conf.authPassword = "p" + option._setAuthCred() + self.assertEqual( + kb.passwordMgr.find_user_password(None, "http://host:80"), + ("u", "p"), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_core_more.py b/tests/test_core_more.py new file mode 100644 index 000000000..529415a8d --- /dev/null +++ b/tests/test_core_more.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional unit coverage for lib/core/agent.py, lib/core/common.py and +lib/utils/brute.py, targeting functions/branches NOT already exercised by: + + * tests/test_agent.py (payload delimiters, prefix/suffix defaults, + getFields(SELECT a,b), one MySQL concatQuery, + cleanupPayload RANDNUM) + * tests/test_agent_dialects.py (null/cast/concat, hexConvertField, + nullAndCastField, simpleConcatenate, + forgeUnionQuery(-1,3,...), limitQuery(0,...), + forgeCaseStatement, runAsDBMSUser-noop) + * tests/test_common_utils.py (paramToDict, getCharset, getLimitRange, + parseUnionPage, safeStringFormat, urlencode, + parseTargetUrl/Direct, safeSQLIdentificatorNaming) + * tests/test_common_parsers.py (request-file parsers, reflective masking, + findPageForms, saveConfig, getSQLSnippet, + Backend setters, urlencode/safeStringFormat extras) + +This file instead covers: + + agent.py: forgeUnionQuery (limited / multipleUnions / fromTable / collate / + INTO OUTFILE), limitQuery across several DBMS shapes (TOP/ROWNUM/ + OFFSET dialects + the " FROM "-less early return), whereQuery + (dumpWhere splicing), getComment, concatQuery(unpack=False), + cleanupPayload([ORIGVALUE]/[ORIGINAL]/[SPACE_REPLACE]), + adjustLateValues (SLEEPTIME/base64/RANDNUM), getFields on TOP / + DISTINCT / function / no-FROM shapes, prefixQuery/suffixQuery with + explicit prefix/suffix/clause/comment args, nullAndCastField noCast. + + common.py: isNoneValue, isNullValue, isNumPosStrValue, isNumber, isListLike, + filterPairValues, filterListValue, filterNone, filterStringValue, + zeroDepthSearch, splitFields, unArrayizeValue, flattenValue, + arrayizeValue, joinValue, aliasToDbmsEnum, getPageWordSet, + resetCookieJar (clear branch), normalizeUnicode. + + brute.py: tableExists / columnExists driven with conf.direct=True and the + external collaborators (inject.checkBooleanExpression, getFileItems, + runThreads) monkeypatched, plus _addPageTextWords. + +Everything runs in isolation (no network, no DBMS, no filesystem mutation of +the project). Any global conf/kb/Backend state that a call reads or writes is +snapshotted in setUp and restored in tearDown so test ordering is irrelevant. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.agent import agent +from lib.core.data import conf, kb, queries +from lib.core.enums import DBMS +from lib.core.settings import ( + PAYLOAD_DELIMITER, + SLEEP_TIME_MARKER, + BOUNDED_BASE64_MARKER, + NULL, +) +from lib.core.common import ( + Backend, + isNoneValue, + isNullValue, + isNumPosStrValue, + isNumber, + isListLike, + filterPairValues, + filterListValue, + filterNone, + filterStringValue, + zeroDepthSearch, + splitFields, + unArrayizeValue, + flattenValue, + arrayizeValue, + joinValue, + aliasToDbmsEnum, + getPageWordSet, + resetCookieJar, + normalizeUnicode, +) + + +class DbmsStateMixin(object): + """Snapshot/restore the Backend/kb DBMS-forcing state so set_dbms() does not leak.""" + + def setUp(self): + self._forcedDbms = kb.forcedDbms + self._sticky = kb.stickyDBMS + self._batch = conf.batch + conf.batch = True + + def tearDown(self): + kb.forcedDbms = self._forcedDbms + kb.stickyDBMS = self._sticky + conf.batch = self._batch + + +# --------------------------------------------------------------------------- # +# lib/core/agent.py +# --------------------------------------------------------------------------- # + +class TestForgeUnionQuery(DbmsStateMixin, unittest.TestCase): + """forgeUnionQuery arg combinations not reached by the dialect smoke test.""" + + def test_limited_subselect_wraps_query(self): + set_dbms(DBMS.MYSQL) + # limited=True wraps the payload as (SELECT ...) at `position`, fills the + # rest with `char`, and appends the FROM/comment/suffix + out = agent.forgeUnionQuery("SELECT user FROM mysql.user", 1, 3, None, + None, None, "NULL", None, limited=True) + self.assertIn("(SELECT user FROM mysql.user)", out) + self.assertTrue(out.startswith(" UNION ALL SELECT NULL,(SELECT"), msg=out) + # position 1 of 3 => NULL,,NULL + self.assertEqual(out.count("NULL"), 2, msg=out) + + def test_multiple_unions_appends_second_select(self): + set_dbms(DBMS.MYSQL) + out = agent.forgeUnionQuery("SELECT a FROM t", 0, 2, None, None, None, + "NULL", None, multipleUnions="b") + # the multipleUnions payload produces a *second* UNION ALL SELECT + self.assertEqual(out.upper().count("UNION ALL SELECT"), 2, msg=out) + self.assertIn("b", out) + + def test_from_table_override(self): + set_dbms(DBMS.MYSQL) + out = agent.forgeUnionQuery("SELECT 1", 0, 1, None, None, None, "NULL", + None, fromTable=" FROM dummytable") + self.assertIn("FROM dummytable", out, msg=out) + + def test_into_outfile_forces_null_position(self): + set_dbms(DBMS.MYSQL) + # an INTO OUTFILE clause forces position 0 / char NULL and re-appends the file part + out = agent.forgeUnionQuery("SELECT a INTO OUTFILE '/tmp/o.txt' FROM t", + 1, 2, None, None, None, "NULL", None) + self.assertIn("INTO OUTFILE '/tmp/o.txt'", out, msg=out) + + def test_collate_clause_on_mysql(self): + set_dbms(DBMS.MYSQL) + # collate=True on MySQL wraps a non-NULL, non-numeric value in the + # MYSQL_UNION_VALUE_CAST collation wrapper + out = agent.forgeUnionQuery("SELECT user FROM mysql.user", 0, 1, None, + None, None, "NULL", None, collate=True) + self.assertIn("CONVERT", out.upper(), msg=out) + + +class TestLimitQuery(DbmsStateMixin, unittest.TestCase): + """limitQuery dialect shapes beyond the single limitQuery(0,...) smoke test.""" + + def test_no_from_returns_unchanged(self): + set_dbms(DBMS.MYSQL) + self.assertEqual(agent.limitQuery(5, "SELECT 1", "1"), "SELECT 1") + + def test_mysql_appends_limit_offset_one(self): + set_dbms(DBMS.MYSQL) + out = agent.limitQuery(7, "SELECT user FROM mysql.user", "user") + self.assertTrue(out.endswith("LIMIT 7,1"), msg=out) + + def test_pgsql_offset_form(self): + set_dbms(DBMS.PGSQL) + out = agent.limitQuery(4, "SELECT usename FROM pg_shadow", "usename") + self.assertIn("OFFSET 4 LIMIT 1", out, msg=out) + + def test_oracle_rownum_wrap(self): + set_dbms(DBMS.ORACLE) + out = agent.limitQuery(2, "SELECT banner FROM v$version", ["banner"]) + # Oracle wraps in a ROWNUM-bounded subselect ending with = + self.assertIn("ROWNUM", out.upper(), msg=out) + self.assertTrue(out.rstrip().endswith("=3"), msg=out) + + def test_firebird_first_skip(self): + set_dbms(DBMS.FIREBIRD) + out = agent.limitQuery(3, "SELECT foo FROM bar", "foo") + self.assertIsInstance(out, str) + self.assertIn("foo", out) + # Firebird uses ROWS TO (the FIRST/SKIP emulation); pin + # the exact shape so a broken offset arithmetic is caught. + self.assertTrue(out.endswith("ROWS 4 TO 4"), msg=out) + + def test_mssql_top_not_in(self): + set_dbms(DBMS.MSSQL) + out = agent.limitQuery(2, "SELECT name FROM sysobjects", "name", uniqueField="name") + # MSSQL emulates LIMIT via TOP + NOT IN + self.assertIn("TOP", out.upper(), msg=out) + self.assertIn("NOT IN", out.upper(), msg=out) + + +class TestWhereQuery(DbmsStateMixin, unittest.TestCase): + """whereQuery only acts when conf.dumpWhere is set.""" + + def setUp(self): + DbmsStateMixin.setUp(self) + self._dumpWhere = conf.dumpWhere + self._tbl = conf.tbl + + def tearDown(self): + conf.dumpWhere = self._dumpWhere + conf.tbl = self._tbl + DbmsStateMixin.tearDown(self) + + def test_no_dumpwhere_is_identity(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = None + self.assertEqual(agent.whereQuery("SELECT a FROM t"), "SELECT a FROM t") + + def test_appends_where_clause(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>10" + conf.tbl = None + out = agent.whereQuery("SELECT a FROM t") + self.assertIn("WHERE id>10", out, msg=out) + + def test_existing_where_gets_anded(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>10" + conf.tbl = None + out = agent.whereQuery("SELECT a FROM t WHERE b=1") + self.assertIn("AND id>10", out, msg=out) + + def test_order_by_suffix_preserved(self): + set_dbms(DBMS.MYSQL) + conf.dumpWhere = "id>10" + conf.tbl = None + out = agent.whereQuery("SELECT a FROM t ORDER BY a") + # the genuine trailing ORDER BY is kept after the spliced WHERE + self.assertIn("WHERE id>10", out, msg=out) + # the ORDER BY must survive *after* the spliced WHERE clause; the + # substring check alone could pass even if the suffix were dropped. + self.assertTrue(out.rstrip().endswith("ORDER BY a"), msg=out) + + +class TestGetComment(unittest.TestCase): + def test_present(self): + from lib.core.datatype import AttribDict + self.assertEqual(agent.getComment(AttribDict({"comment": "-- x"})), "-- x") + + def test_absent_returns_empty(self): + from lib.core.datatype import AttribDict + self.assertEqual(agent.getComment(AttribDict()), "") + + +class TestConcatQueryUnpack(DbmsStateMixin, unittest.TestCase): + def test_unpack_false_returns_input_unchanged(self): + set_dbms(DBMS.MYSQL) + self.assertEqual(agent.concatQuery("SELECT a FROM t", unpack=False), + "SELECT a FROM t") + + def test_pgsql_unpack_uses_pipe_concat(self): + set_dbms(DBMS.PGSQL) + out = agent.concatQuery("SELECT usename FROM pg_shadow") + self.assertIn("||", out, msg=out) + self.assertIn(kb.chars.start, out, msg=out) + self.assertIn(kb.chars.stop, out, msg=out) + + +class TestCleanupPayloadOrigValue(DbmsStateMixin, unittest.TestCase): + def test_origvalue_digit_inlined(self): + out = agent.cleanupPayload("x=[ORIGVALUE]", origValue="42") + self.assertEqual(out, "x=42") + + def test_origvalue_nondigit_quoted(self): + out = agent.cleanupPayload("x=[ORIGVALUE]", origValue="abc") + self.assertIn("'abc'", out, msg=out) + + def test_original_marker_raw_substitution(self): + out = agent.cleanupPayload("p=[ORIGINAL]", origValue="raw") + self.assertEqual(out, "p=raw") + + def test_space_replace_marker(self): + out = agent.cleanupPayload("a[SPACE_REPLACE]b") + self.assertEqual(out, "a%sb" % kb.chars.space) + + def test_non_string_returns_none(self): + self.assertIsNone(agent.cleanupPayload(None)) + + +class TestAdjustLateValues(DbmsStateMixin, unittest.TestCase): + def test_sleeptime_replaced_with_timesec(self): + out = agent.adjustLateValues("SLEEP(%s)" % SLEEP_TIME_MARKER) + self.assertEqual(out, "SLEEP(%s)" % conf.timeSec) + self.assertNotIn(SLEEP_TIME_MARKER, out) + + def test_randnum_marker_substituted(self): + out = agent.adjustLateValues("v=[RANDNUM]") + self.assertNotIn("[RANDNUM]", out) + self.assertTrue(out.split("=")[1].isdigit(), msg=out) + + def test_bounded_base64_marker_encoded(self): + payload = "%sAB%s" % (BOUNDED_BASE64_MARKER, BOUNDED_BASE64_MARKER) + out = agent.adjustLateValues(payload) + # the marked region is base64-encoded and the markers are consumed + self.assertNotIn(BOUNDED_BASE64_MARKER, out) + self.assertEqual(out, "QUI=") + + def test_empty_payload_passthrough(self): + self.assertEqual(agent.adjustLateValues(""), "") + + +class TestGetFieldsShapes(DbmsStateMixin, unittest.TestCase): + def test_select_top(self): + set_dbms(DBMS.MSSQL) + res = agent.getFields("SELECT TOP 1 name FROM sysobjects") + self.assertIsNotNone(res[3], msg="fieldsSelectTop not matched") + self.assertEqual(res[6], "name") + + def test_distinct(self): + set_dbms(DBMS.MYSQL) + res = agent.getFields("SELECT DISTINCT(name) FROM t") + self.assertEqual(res[6], "name") + + def test_function_is_single_element(self): + set_dbms(DBMS.MYSQL) + res = agent.getFields("SELECT COUNT(*) FROM t") + self.assertEqual(res[5], ["COUNT(*)"]) + + def test_no_from_keeps_whole_select_list(self): + set_dbms(DBMS.MYSQL) + res = agent.getFields("SELECT a,b,c") + self.assertIsNone(res[0], msg="fieldsSelectFrom must be None without FROM") + self.assertEqual(res[5], ["a", "b", "c"]) + + +class TestPrefixSuffixArgs(DbmsStateMixin, unittest.TestCase): + def test_prefix_with_explicit_prefix(self): + set_dbms(DBMS.MYSQL) + out = agent.prefixQuery("1=1", prefix="')") + self.assertIn("')", out, msg=out) + self.assertTrue(out.endswith("1=1"), msg=out) + + def test_prefix_group_by_clause_uses_prefix_verbatim(self): + set_dbms(DBMS.MYSQL) + # clause == [2] (GROUP BY / ORDER BY) => no trailing space added + out = agent.prefixQuery("1=1", prefix="X", clause=[2]) + self.assertEqual(out, "X1=1") + + def test_suffix_appends_comment(self): + set_dbms(DBMS.MYSQL) + out = agent.suffixQuery("1=1", comment="-- -") + self.assertTrue(out.startswith("1=1"), msg=out) + self.assertIn("-", out) + + def test_suffix_appends_suffix_no_comment(self): + set_dbms(DBMS.MYSQL) + out = agent.suffixQuery("1=1", suffix="')") + self.assertIn("')", out, msg=out) + + +class TestNullAndCastFieldNoCast(DbmsStateMixin, unittest.TestCase): + def setUp(self): + DbmsStateMixin.setUp(self) + self._noCast = conf.noCast + + def tearDown(self): + conf.noCast = self._noCast + DbmsStateMixin.tearDown(self) + + def test_nocast_returns_field_unchanged(self): + set_dbms(DBMS.MYSQL) + conf.noCast = True + self.assertEqual(agent.nullAndCastField("colname"), "colname") + + def test_cast_present_when_nocast_off(self): + set_dbms(DBMS.MYSQL) + conf.noCast = False + out = agent.nullAndCastField("colname") + self.assertIn("CAST", out.upper(), msg=out) + self.assertIn("colname", out) + + +# --------------------------------------------------------------------------- # +# lib/core/common.py +# --------------------------------------------------------------------------- # + +class TestSmallPredicates(unittest.TestCase): + def test_is_none_value(self): + self.assertTrue(isNoneValue(None)) + self.assertTrue(isNoneValue("None")) + self.assertTrue(isNoneValue("")) + self.assertTrue(isNoneValue([])) + self.assertTrue(isNoneValue(["None", ""])) + self.assertTrue(isNoneValue({})) + self.assertFalse(isNoneValue([2])) + self.assertFalse(isNoneValue("x")) + + def test_is_null_value(self): + self.assertTrue(isNullValue(u"NULL")) + self.assertTrue(isNullValue(u"null")) + self.assertFalse(isNullValue(u"foobar")) + self.assertFalse(isNullValue(5)) + + def test_is_num_pos_str_value(self): + self.assertTrue(isNumPosStrValue(1)) + self.assertTrue(isNumPosStrValue("1")) + self.assertFalse(isNumPosStrValue(0)) + self.assertFalse(isNumPosStrValue("-2")) + self.assertFalse(isNumPosStrValue("100000000000000000000")) + self.assertFalse(isNumPosStrValue("abc")) + + def test_is_number(self): + self.assertTrue(isNumber(1)) + self.assertTrue(isNumber("0")) + self.assertTrue(isNumber("3.14")) + self.assertFalse(isNumber("foobar")) + self.assertFalse(isNumber(None)) + + def test_is_list_like(self): + self.assertTrue(isListLike([1])) + self.assertTrue(isListLike((1,))) + self.assertTrue(isListLike(set([1]))) + self.assertFalse(isListLike("x")) + self.assertFalse(isListLike(5)) + + +class TestValueShaping(unittest.TestCase): + def test_filter_pair_values(self): + self.assertEqual(filterPairValues([[1, 2], [3], 1, [4, 5]]), [[1, 2], [4, 5]]) + self.assertEqual(filterPairValues(None), []) + + def test_filter_list_value(self): + self.assertEqual(filterListValue(["users", "admins", "logs"], r"(users|admins)"), + ["users", "admins"]) + # non-list input returned unchanged + self.assertEqual(filterListValue("notlist", r"x"), "notlist") + # no regex returns input + self.assertEqual(filterListValue(["a"], None), ["a"]) + + def test_filter_none(self): + self.assertEqual(filterNone([1, 2, "", None, 3, 0]), [1, 2, 3, 0]) + + def test_filter_string_value(self): + self.assertEqual(filterStringValue("wzydeadbeef0123#", r"[0-9a-f]"), "deadbeef0123") + + def test_un_arrayize_value(self): + self.assertEqual(unArrayizeValue(["1"]), "1") + self.assertEqual(unArrayizeValue("1"), "1") + self.assertEqual(unArrayizeValue(["1", "2"]), "1") + self.assertEqual(unArrayizeValue([["a", "b"], "c"]), "a") + self.assertIsNone(unArrayizeValue([])) + + def test_flatten_value(self): + self.assertEqual(list(flattenValue([["1"], [["2"], "3"]])), ["1", "2", "3"]) + + def test_arrayize_value(self): + self.assertEqual(arrayizeValue("1"), ["1"]) + self.assertEqual(arrayizeValue(["1"]), ["1"]) + + def test_join_value(self): + self.assertEqual(joinValue(["1", "2"]), "1,2") + self.assertEqual(joinValue("1"), "1") + self.assertEqual(joinValue(["1", None]), "1,None") + + +class TestZeroDepthAndSplit(unittest.TestCase): + def test_zero_depth_search_skips_parens(self): + expr = "SELECT (SELECT id FROM users WHERE 2>1) AS r FROM DUAL" + idx = zeroDepthSearch(expr, " FROM ") + # only the outer top-level FROM is found, not the one inside the subselect + self.assertEqual(len(idx), 1) + self.assertTrue(expr[idx[0]:].startswith(" FROM DUAL")) + + def test_zero_depth_search_ignores_quoted(self): + expr = "a , 'b , c' , d" + # commas inside the quoted literal are not reported + self.assertEqual(len(zeroDepthSearch(expr, ",")), 2) + + def test_split_fields_basic(self): + self.assertEqual(splitFields("foo, bar, max(foo, bar)"), + ["foo", "bar", "max(foo,bar)"]) + + def test_split_fields_quoted(self): + self.assertEqual(splitFields("a, 'b, c', d"), ["a", "'b, c'", "d"]) + + def test_split_fields_custom_delimiter(self): + self.assertEqual(splitFields("a; b; max(c; d)", delimiter=";"), + ["a", "b", "max(c;d)"]) + + +class TestAliasToDbmsEnum(unittest.TestCase): + def test_known_aliases(self): + self.assertEqual(aliasToDbmsEnum("mssql"), DBMS.MSSQL) + self.assertEqual(aliasToDbmsEnum("mysql"), DBMS.MYSQL) + self.assertEqual(aliasToDbmsEnum("postgres"), DBMS.PGSQL) + + def test_unknown_alias_returns_none(self): + self.assertIsNone(aliasToDbmsEnum("definitely_not_a_dbms")) + + def test_empty_returns_none(self): + self.assertIsNone(aliasToDbmsEnum("")) + + +class TestGetPageWordSet(unittest.TestCase): + def test_word_extraction(self): + words = getPageWordSet(u"foobartest") + self.assertEqual(sorted(words), [u"foobar", u"test"]) + + def test_non_string_returns_empty(self): + self.assertEqual(getPageWordSet(None), set()) + + +class TestNormalizeUnicode(unittest.TestCase): + def test_accents_stripped(self): + # normalizeUnicode collapses accented chars to their ASCII base + self.assertEqual(normalizeUnicode(u"éè"), "ee") + + def test_plain_ascii_unchanged(self): + self.assertEqual(normalizeUnicode(u"abc123"), "abc123") + + def test_none_returns_none(self): + self.assertIsNone(normalizeUnicode(None)) + + +class TestResetCookieJar(unittest.TestCase): + """resetCookieJar's clear branch (conf.loadCookies falsy).""" + + def setUp(self): + self._loadCookies = conf.loadCookies + conf.loadCookies = None + + def tearDown(self): + conf.loadCookies = self._loadCookies + + def test_clear_branch(self): + try: + from http.cookiejar import CookieJar + except ImportError: # Python 2 + from cookielib import CookieJar + + jar = CookieJar() + cleared = {"called": False} + + class _Jar(object): + def clear(self): + cleared["called"] = True + + resetCookieJar(_Jar()) + self.assertTrue(cleared["called"]) + # also accepts a real jar without raising + self.assertIsNone(resetCookieJar(jar)) + + +# --------------------------------------------------------------------------- # +# lib/utils/brute.py +# --------------------------------------------------------------------------- # + +import lib.utils.brute as brute +from lib.request import inject +import lib.core.threads as threads_mod +import lib.core.common as common_mod + + +class TestBrute(DbmsStateMixin, unittest.TestCase): + """Drive tableExists / columnExists with all external collaborators stubbed. + + conf.direct=True skips the time/stacked recommendation prompt. checkBooleanExpression, + getFileItems and runThreads are monkeypatched so the check runs synchronously, + deterministically and offline. getPageWordSet is neutralized so the wordlist is + just what the stub returns. + """ + + def setUp(self): + DbmsStateMixin.setUp(self) + self._saved_conf = {k: conf.get(k) for k in + ("direct", "db", "tbl", "threads", "api", "verbose")} + self._choices = kb.choices + self._cachedTables = kb.data.get("cachedTables") + self._cachedColumns = kb.data.get("cachedColumns") + self._brute = kb.brute + self._origPage = kb.originalPage + + # stub the collaborators + self._orig_cbe = inject.checkBooleanExpression + self._orig_brute_cbe = brute.inject.checkBooleanExpression + self._orig_getFileItems = brute.getFileItems + self._orig_runThreads = brute.runThreads + self._orig_getPageWordSet = brute.getPageWordSet + + from lib.core.datatype import AttribDict + kb.choices = AttribDict(keycheck=False) + kb.choices.tableExists = None + kb.choices.columnExists = None + kb.data.cachedTables = {} + kb.data.cachedColumns = {} + kb.brute = AttribDict({"tables": [], "columns": []}) + kb.originalPage = None + + conf.direct = True + conf.db = None + conf.threads = 1 + conf.api = False + conf.verbose = 0 + + # runThreads -> just call the worker once synchronously + def _fakeRunThreads(numThreads, threadFunction, *args, **kwargs): + kb.threadContinue = True + threadFunction() + brute.runThreads = _fakeRunThreads + # no page words injected into the wordlist + brute.getPageWordSet = lambda page: set() + # wordlist file -> small fixed list + brute.getFileItems = lambda *a, **k: ["users", "logs", "secret_t"] + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + kb.choices = self._choices + if self._cachedTables is None: + kb.data.pop("cachedTables", None) + else: + kb.data.cachedTables = self._cachedTables + if self._cachedColumns is None: + kb.data.pop("cachedColumns", None) + else: + kb.data.cachedColumns = self._cachedColumns + kb.brute = self._brute + kb.originalPage = self._origPage + brute.inject.checkBooleanExpression = self._orig_brute_cbe + brute.getFileItems = self._orig_getFileItems + brute.runThreads = self._orig_runThreads + brute.getPageWordSet = self._orig_getPageWordSet + DbmsStateMixin.tearDown(self) + + def test_table_exists_collects_true_results(self): + set_dbms(DBMS.MYSQL) + + def _cbe(expression, expectingNone=True): + # initial sanity probe (random table) -> must be False, otherwise the + # function raises SqlmapDataException; then only "users" exists. + return "users" in expression + brute.inject.checkBooleanExpression = _cbe + + result = brute.tableExists("/nonexistent/tables.txt") + # cachedTables keyed by conf.db (None here) holds the discovered table + self.assertIn(None, result) + self.assertIn("users", result[None]) + self.assertNotIn("logs", result.get(None, [])) + # also recorded in kb.brute.tables as (db, table) + self.assertIn((None, "users"), kb.brute.tables) + + def test_table_exists_invalid_results_raises(self): + from lib.core.exception import SqlmapDataException + set_dbms(DBMS.MYSQL) + # the initial random-table probe returns True -> "invalid results" guard + brute.inject.checkBooleanExpression = lambda *a, **k: True + with self.assertRaises(SqlmapDataException): + brute.tableExists("/nonexistent/tables.txt") + + def test_column_exists_requires_table(self): + from lib.core.exception import SqlmapMissingMandatoryOptionException + set_dbms(DBMS.MYSQL) + conf.tbl = None + # the sanity probe is False so we reach the missing-table guard + brute.inject.checkBooleanExpression = lambda *a, **k: False + with self.assertRaises(SqlmapMissingMandatoryOptionException): + brute.columnExists("/nonexistent/columns.txt") + + def test_column_exists_collects_and_types(self): + set_dbms(DBMS.MYSQL) + conf.tbl = "users" + brute.getFileItems = lambda *a, **k: ["id", "name"] + + calls = {"n": 0} + + def _cbe(expression, expectingNone=True): + calls["n"] += 1 + # initial sanity probe uses two random strings (no real column name) + if "id" not in expression and "name" not in expression: + return False + # MySQL numeric-type follow-up: `not checkBooleanExpression(... REGEXP '[^0-9]')`. + # 'id' is numeric (no non-digit chars => probe False => numeric); + # 'name' is non-numeric (has non-digit chars => probe True => non-numeric). + if "REGEXP" in expression: + return "name" in expression + # plain existence check (EXISTS(SELECT FROM )) => both columns exist + return True + brute.inject.checkBooleanExpression = _cbe + + result = brute.columnExists("/nonexistent/columns.txt") + self.assertIn(None, result) + cols = result[None]["users"] + # column names are run through safeSQLIdentificatorNaming, so the MySQL + # reserved word "name" comes back backtick-quoted + from lib.core.common import safeSQLIdentificatorNaming, getText + self.assertEqual(cols.get(getText(safeSQLIdentificatorNaming("id"))), "numeric") + self.assertEqual(cols.get(getText(safeSQLIdentificatorNaming("name"))), "non-numeric") + + def test_add_page_text_words_filters(self): + # restore the real getPageWordSet for this one and drive it directly + brute.getPageWordSet = self._orig_getPageWordSet + kb.originalPage = u"admin password 1abc xy verylongword" + words = brute._addPageTextWords() + # words <= 2 chars or starting with a digit are dropped + self.assertIn("admin", words) + self.assertIn("password", words) + self.assertNotIn("xy", words) + self.assertNotIn("1abc", words) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_databases_enum.py b/tests/test_databases_enum.py new file mode 100644 index 000000000..323a4a728 --- /dev/null +++ b/tests/test_databases_enum.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Unit tests for the enumeration methods of plugins/generic/databases.py. + +The injection layer (lib.request.inject.getValue) is mocked so no network or +live DBMS is required; each test drives a single enumeration method down a +specific branch (conf.direct "inband" path or the isInferenceAvailable() blind +path) and asserts on the returned value / kb.data.cached* state. + +CRITICAL: every test restores conf.*, the patched dbmod.inject.getValue, and the +mutated kb.data flags in tearDown so global state does not leak into the rest of +the suite. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms + +bootstrap() + +from lib.core.data import conf, kb +from lib.core.enums import EXPECTED, PAYLOAD +import plugins.generic.databases as dbmod +from plugins.generic.databases import Databases + +# Databases.forceDbmsEnum() is supplied at runtime by the concrete dbms fingerprint +# plugin mixin (plugins/dbms/*/fingerprint.py); a bare Databases() instance lacks it, +# so neutralize it for the duration of these tests. Restored in tearDown via the saved ref. +_NOOP = lambda self: None + + +class _BaseEnumTest(unittest.TestCase): + """Shared setup/teardown that snapshots and restores all touched global state.""" + + # conf keys every test may read/write + _CONF_KEYS = ("direct", "technique", "db", "tbl", "col", "exclude", + "getComments", "excludeSysDbs", "search", "freshQueries") + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._saved_getValue = dbmod.inject.getValue + self._saved_injection_data = kb.injection.data + self._saved_has_is = kb.data.get("has_information_schema") + # the inference paths of getTables/getColumns set kb.hintValue as a side effect; + # snapshot it so we never leak a stale hint into other test files (e.g. the + # inference engine's tryHint(), whose setUp does not reset it). + self._saved_hintValue = kb.get("hintValue") + self._saved_forceDbmsEnum = getattr(Databases, "forceDbmsEnum", None) + Databases.forceDbmsEnum = _NOOP + + # sane defaults shared by most tests + conf.getComments = False + conf.excludeSysDbs = False + conf.exclude = None + conf.search = False + conf.freshQueries = False + conf.col = None + kb.data.has_information_schema = True + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + dbmod.inject.getValue = self._saved_getValue + kb.injection.data = self._saved_injection_data + kb.data.has_information_schema = self._saved_has_is + kb.hintValue = self._saved_hintValue + if self._saved_forceDbmsEnum is not None: + Databases.forceDbmsEnum = self._saved_forceDbmsEnum + else: + try: + del Databases.forceDbmsEnum + except AttributeError: + pass + + # helpers ----------------------------------------------------------------- + + def _fresh(self): + """Return a Databases() instance with every cache reset to empty.""" + d = Databases() + kb.data.currentDb = "" + kb.data.cachedDbs = [] + kb.data.cachedTables = {} + kb.data.cachedColumns = {} + kb.data.cachedCounts = {} + kb.data.cachedStatements = [] + kb.data.cachedProcedures = [] + return d + + def _enable_inference(self): + """Take the blind inference branch: conf.direct off, a BOOLEAN technique present.""" + conf.direct = False + conf.technique = None + kb.injection.data = {PAYLOAD.TECHNIQUE.BOOLEAN: {"title": "AND boolean-based blind"}} + + +class TestGetCurrentDb(_BaseEnumTest): + def test_current_db_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: "testdb" + self.assertEqual(d.getCurrentDb(), "testdb") + self.assertEqual(kb.data.currentDb, "testdb") + + def test_current_db_cached(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + kb.data.currentDb = "already" + + def _boom(*a, **k): + raise AssertionError("inject.getValue must not be called when currentDb is cached") + + dbmod.inject.getValue = _boom + self.assertEqual(d.getCurrentDb(), "already") + + def test_current_db_oracle_schema_warning_branch(self): + # Oracle takes the schema-name warning branch; result still returned. + set_dbms("Oracle") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: "SYSTEM" + self.assertEqual(d.getCurrentDb(), "SYSTEM") + + +class TestGetDbs(_BaseEnumTest): + def test_get_dbs_direct_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: [["information_schema"], ["mysql"], ["testdb"]] + result = d.getDbs() + self.assertEqual(sorted(result), ["information_schema", "mysql", "testdb"]) + self.assertIn("testdb", kb.data.cachedDbs) + + def test_get_dbs_cached_short_circuit(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + kb.data.cachedDbs = ["pre", "cached"] + + def _boom(*a, **k): + raise AssertionError("must not query when cachedDbs is populated") + + dbmod.inject.getValue = _boom + self.assertEqual(d.getDbs(), ["pre", "cached"]) + + def test_get_dbs_direct_pgsql_schema_branch(self): + set_dbms("PostgreSQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: [["public"], ["information_schema"]] + result = d.getDbs() + self.assertEqual(sorted(result), ["information_schema", "public"]) + + def test_get_dbs_mysql_no_information_schema(self): + # MySQL < 5: query2 / count2 branch; still inband under conf.direct. + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + kb.data.has_information_schema = False + dbmod.inject.getValue = lambda query, *a, **k: [["mysql"], ["app"]] + result = d.getDbs() + self.assertEqual(sorted(result), ["app", "mysql"]) + + def test_get_dbs_inference(self): + set_dbms("MySQL") + self._enable_inference() + d = self._fresh() + + names = ["alpha", "beta", "gamma"] + state = {"i": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(names)) + val = names[state["i"]] + state["i"] += 1 + return [val] + + dbmod.inject.getValue = gv + result = d.getDbs() + self.assertEqual(sorted(result), sorted(names)) + + def test_get_dbs_fallback_to_current(self): + # No dbs returned inband -> falls back to current database. + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + state = {"n": 0} + + def gv(query, *a, **k): + state["n"] += 1 + if state["n"] == 1: + return None # getDbs inband: nothing + return "fallbackdb" # getCurrentDb + + dbmod.inject.getValue = gv + result = d.getDbs() + self.assertEqual(result, ["fallbackdb"]) + + +class TestGetTables(_BaseEnumTest): + def test_get_tables_direct_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = None + dbmod.inject.getValue = lambda query, *a, **k: [["testdb", "users"], ["testdb", "posts"]] + result = d.getTables() + self.assertIn("testdb", result) + self.assertEqual(sorted(result["testdb"]), ["posts", "users"]) + + def test_get_tables_cached_short_circuit(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + kb.data.cachedTables = {"db": ["t1"]} + + def _boom(*a, **k): + raise AssertionError("must not query when cachedTables is populated") + + dbmod.inject.getValue = _boom + self.assertEqual(d.getTables(), {"db": ["t1"]}) + + def test_get_tables_direct_pgsql(self): + set_dbms("PostgreSQL") + conf.direct = True + d = self._fresh() + conf.db = "public" + conf.tbl = None + dbmod.inject.getValue = lambda query, *a, **k: [["public", "accounts"]] + result = d.getTables() + self.assertEqual(result.get("public"), ["accounts"]) + + def test_get_tables_inference(self): + set_dbms("MySQL") + self._enable_inference() + d = self._fresh() + conf.db = "testdb" + conf.tbl = None + + tables = ["t_a", "t_b"] + state = {"i": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(tables)) + val = tables[state["i"] % len(tables)] + state["i"] += 1 + return [val] + + dbmod.inject.getValue = gv + result = d.getTables() + self.assertIn("testdb", result) + self.assertEqual(sorted(result["testdb"]), sorted(tables)) + + +class TestGetColumns(_BaseEnumTest): + def _run_direct(self, dbms, db, tbl, rows): + set_dbms(dbms) + conf.direct = True + d = self._fresh() + conf.db = db + conf.tbl = tbl + dbmod.inject.getValue = lambda query, *a, **k: rows + return d.getColumns() + + def test_columns_direct_mysql(self): + result = self._run_direct("MySQL", "testdb", "users", [["id", "int"], ["age", "int"]]) + self.assertIn("testdb", result) + cols = result["testdb"]["users"] + self.assertEqual(cols.get("id"), "int") + self.assertEqual(cols.get("age"), "int") + + def test_columns_direct_pgsql(self): + result = self._run_direct("PostgreSQL", "public", "users", [["id", "integer"]]) + self.assertEqual(result["public"]["users"].get("id"), "integer") + + def test_columns_direct_oracle_uppercase(self): + # Oracle is an UPPER_CASE dbms: conf.db/tbl get upcased internally. + result = self._run_direct("Oracle", "system", "users", [["ID", "NUMBER"]]) + # Oracle quotes the identifier ("SYSTEM"); assert the column landed regardless. + flat = {} + for tables in result.values(): + for cols in tables.values(): + flat.update(cols) + self.assertEqual(flat.get("ID"), "NUMBER") + + def test_columns_direct_mssql(self): + result = self._run_direct("Microsoft SQL Server", "master", "users", [["id", "int"]]) + # MSSQL wraps the db identifier in [brackets]; assert the column landed. + flat = {} + for tables in result.values(): + for cols in tables.values(): + flat.update(cols) + self.assertEqual(flat.get("id"), "int") + + def test_columns_only_names(self): + # onlyColNames is ONLY read in the inference branch (the INBAND path + # ignores it), so drive the blind inference path like + # test_columns_inference_mysql but with onlyColNames=True. The flag must + # SUPPRESS the type lookup: each column's value lands as None instead of + # the real type. Asserting cols.get("id") is None proves the flag took + # effect (otherwise the type query would run and return "int"). + set_dbms("MySQL") + self._enable_inference() + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + + colnames = ["id", "name"] + state = {"i": 0} + type_queries = {"n": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(colnames)) + # With onlyColNames the second-stage type query (blind.query2, which + # selects column_type) must NEVER be issued. + if "column_type" in query.lower(): + type_queries["n"] += 1 + return ["int"] + val = colnames[state["i"] % len(colnames)] + state["i"] += 1 + return [val] + + dbmod.inject.getValue = gv + result = d.getColumns(onlyColNames=True) + cols = result["testdb"]["users"] + # both column names enumerated... + self.assertEqual(len(cols), len(colnames)) + self.assertIn("id", cols) + # ...but their types were suppressed (None), and no type query ran. + self.assertIsNone(cols.get("id")) + self.assertEqual(type_queries["n"], 0) + + def test_columns_inference_mysql(self): + set_dbms("MySQL") + self._enable_inference() + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + + colnames = ["id", "name"] + state = {"i": 0, "names": True} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(colnames)) + # alternate: column name then its type + if state["names"]: + val = colnames[state["i"] % len(colnames)] + state["i"] += 1 + state["names"] = False + return [val] + else: + state["names"] = True + return ["int"] + + dbmod.inject.getValue = gv + result = d.getColumns() + self.assertIn("testdb", result) + cols = result["testdb"]["users"] + # both columns enumerated (reserved words like "name" get quoted, so count, not exact keys) + self.assertEqual(len(cols), len(colnames)) + self.assertEqual(cols.get("id"), "int") + + +class TestGetCount(_BaseEnumTest): + def test_count_single_table_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + dbmod.inject.getValue = lambda query, *a, **k: "42" + result = d.getCount() + self.assertEqual(result, {"testdb": {42: ["users"]}}) + + def test_count_dotted_table_splits_db(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = None + conf.tbl = "shop.orders" + dbmod.inject.getValue = lambda query, *a, **k: "7" + result = d.getCount() + self.assertEqual(result, {"shop": {7: ["orders"]}}) + + def test_count_multiple_tables(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users,posts" + counts = {"users": "3", "posts": "5"} + + def gv(query, *a, **k): + # the table name appears in the FROM clause of the generated query + for t, c in counts.items(): + if t in query: + return c + return "0" + + dbmod.inject.getValue = gv + result = d.getCount() + self.assertIn("testdb", result) + self.assertIn("users", result["testdb"][3]) + self.assertIn("posts", result["testdb"][5]) + + +class TestGetStatements(_BaseEnumTest): + def test_statements_direct_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: [["SELECT 1"], ["SELECT 2"]] + result = d.getStatements() + self.assertEqual(sorted(result), ["SELECT 1", "SELECT 2"]) + + def test_statements_direct_pgsql(self): + set_dbms("PostgreSQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: [["SELECT now()"]] + result = d.getStatements() + self.assertEqual(result, ["SELECT now()"]) + + def test_statements_inference(self): + set_dbms("PostgreSQL") + self._enable_inference() + d = self._fresh() + stmts = ["SELECT a", "SELECT b"] + state = {"i": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(stmts)) + val = stmts[state["i"] % len(stmts)] + state["i"] += 1 + return [val] + + dbmod.inject.getValue = gv + result = d.getStatements() + self.assertEqual(sorted(result), sorted(stmts)) + + +class TestGetSchema(_BaseEnumTest): + def test_schema_mysql(self): + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = None + conf.col = None + state = {"n": 0} + + def gv(query, *a, **k): + state["n"] += 1 + if state["n"] == 1: + # getTables call + return [["testdb", "users"]] + # getColumns call + return [["id", "int"]] + + dbmod.inject.getValue = gv + result = d.getSchema() + self.assertIn("testdb", result) + self.assertIn("users", result["testdb"]) + self.assertEqual(result["testdb"]["users"].get("id"), "int") + + +class TestGetProcedures(_BaseEnumTest): + def test_procedures_direct_pgsql(self): + set_dbms("PostgreSQL") + conf.direct = True + d = self._fresh() + dbmod.inject.getValue = lambda query, *a, **k: [["proc_a"], ["proc_b"]] + result = d.getProcedures() + self.assertEqual(sorted(result), ["proc_a", "proc_b"]) + + def test_procedures_inference_mysql(self): + set_dbms("MySQL") + self._enable_inference() + d = self._fresh() + procs = ["sp_one", "sp_two"] + state = {"i": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(procs)) + val = procs[state["i"] % len(procs)] + state["i"] += 1 + return [val] + + dbmod.inject.getValue = gv + result = d.getProcedures() + self.assertEqual(sorted(result), sorted(procs)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dbms_enum.py b/tests/test_dbms_enum.py new file mode 100644 index 000000000..8188f3c0e --- /dev/null +++ b/tests/test_dbms_enum.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +DBMS-specific enumeration overrides (plugins/dbms//enumeration.py), +driven through each full DBMS handler with the injection layer mocked, so the +dialect-specific table/column discovery paths run without a live target. The +in-band (UNION/error/direct) branch is taken via conf.direct=True and +inject.getValue is stubbed with canned result rows. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.data import conf, kb +from lib.core.enums import EXPECTED + + +class _EnumBase(unittest.TestCase): + """Snapshot/restore the global state these enumerators mutate.""" + module = None # the enumeration module whose inject.getValue we patch + + def setUp(self): + self._direct = conf.direct + self._db = conf.db + self._gv = self.module.inject.getValue + self._cachedTables = kb.data.get("cachedTables") + self._cachedColumns = kb.data.get("cachedColumns") + conf.direct = True + kb.data.cachedTables = {} + kb.data.cachedColumns = {} + + def tearDown(self): + conf.direct = self._direct + conf.db = self._db + self.module.inject.getValue = self._gv + kb.data.cachedTables = self._cachedTables + kb.data.cachedColumns = self._cachedColumns + + +class TestMSSQLServerEnum(_EnumBase): + import plugins.dbms.mssqlserver.enumeration as module + + def _handler(self): + from plugins.dbms.mssqlserver import MSSQLServerMap + set_dbms("Microsoft SQL Server") + return MSSQLServerMap() + + def test_get_tables(self): + # one database (conf.db), single-column rows: getTables keys the cache by + # the db loop variable and stores the rows run through + # arrayize -> unArrayize -> safeSQLIdentificatorNaming -> sorted(). + conf.db = "appdb" + self.module.inject.getValue = lambda q, *a, **k: ( + 3 if k.get("expected") == EXPECTED.INT else [["users"], ["products"], ["customers"]] + ) + self._handler().getTables() + tables = kb.data.cachedTables + self.assertEqual(list(tables.keys()), ["appdb"]) + stored = tables["appdb"] + # value is a real sorted list (the final sort step), not an echo of input + self.assertEqual(stored, sorted(stored)) + # MSSQL qualifies bare names with the dbo schema; assert exact membership + self.assertIn("dbo.users", stored) + self.assertEqual(stored, ["dbo.customers", "dbo.products", "dbo.users"]) + + def test_get_tables_multiple_dbs(self): + # exercise the per-database keying with two DBs (conf.db = "a,b"): each db + # in the loop gets its OWN sorted table list. Rows are single-column; + # unArrayizeValue collapses each 1-tuple row to the scalar table name. + conf.db = "appdb,salesdb" + + def getValue(q, *a, **k): + if k.get("expected") == EXPECTED.INT: + return 3 + # the query carries the db name (%s substituted); route per database + if "appdb" in q: + return [["users"], ["sessions"], ["accounts"]] + return [["orders"], ["invoices"]] + + self.module.inject.getValue = getValue + self._handler().getTables() + tables = kb.data.cachedTables + # exactly the two requested databases, each mapped to its own sorted list + self.assertEqual(sorted(tables.keys()), ["appdb", "salesdb"]) + self.assertEqual(tables["appdb"], ["dbo.accounts", "dbo.sessions", "dbo.users"]) + self.assertEqual(tables["salesdb"], ["dbo.invoices", "dbo.orders"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dbms_enum_a.py b/tests/test_dbms_enum_a.py new file mode 100644 index 000000000..4c9948fd1 --- /dev/null +++ b/tests/test_dbms_enum_a.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +DBMS-specific enumeration overrides for Oracle, PostgreSQL, MySQL and SQLite +(plugins/dbms//enumeration.py), driven through each full DBMS handler with +the injection layer mocked, so the dialect-specific discovery paths run without a +live target. The in-band (UNION/error/direct) branch is taken via conf.direct=True +and inject.getValue is stubbed with canned result rows. + +Companion to tests/test_dbms_enum.py (which covers Microsoft SQL Server). +""" + +import importlib +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.common import Backend +from lib.core.data import conf, kb +from lib.core.enums import EXPECTED +from lib.core.exception import SqlmapUnsupportedFeatureException + + +class _EnumBase(unittest.TestCase): + """Snapshot/restore the global state these enumerators mutate. + + Other tests in the suite depend on clean globals (a leaked kb.hintValue + breaks test_inference_engine; a leaked forced DBMS breaks others), so every + knob touched here is captured in setUp and put back in tearDown. + """ + + # the enumeration module whose inject.getValue we patch (overridden per DBMS) + module = None + + def setUp(self): + # conf knobs + self._direct = conf.direct + self._batch = conf.batch + self._user = conf.user + self._db = conf.get("db") + self._tbl = conf.get("tbl") + self._exclude = conf.get("exclude") + + # injection layer (some override modules - e.g. SQLite/PostgreSQL - do not + # import inject because their overrides return constants without querying) + self._has_inject = hasattr(self.module, "inject") + if self._has_inject: + self._gv = self.module.inject.getValue + + # kb.data cached* containers + self._cachedTables = kb.data.get("cachedTables") + self._cachedColumns = kb.data.get("cachedColumns") + self._cachedDbs = kb.data.get("cachedDbs") + self._cachedUsers = kb.data.get("cachedUsers") + self._cachedUsersRoles = kb.data.get("cachedUsersRoles") + self._cachedUsersPrivileges = kb.data.get("cachedUsersPrivileges") + self._has_information_schema = kb.data.get("has_information_schema") + + # state other tests are sensitive to + self._hintValue = kb.hintValue + self._injectionData = kb.injection.data + self._forcedDbms = Backend.getForcedDbms() + self._stickyDBMS = kb.stickyDBMS + + # avoid readInput EOFError flakiness and interactive prompts + conf.direct = True + conf.batch = True + + def tearDown(self): + conf.direct = self._direct + conf.batch = self._batch + conf.user = self._user + conf.db = self._db + conf.tbl = self._tbl + conf.exclude = self._exclude + + if self._has_inject: + self.module.inject.getValue = self._gv + + kb.data.cachedTables = self._cachedTables + kb.data.cachedColumns = self._cachedColumns + kb.data.cachedDbs = self._cachedDbs + kb.data.cachedUsers = self._cachedUsers + kb.data.cachedUsersRoles = self._cachedUsersRoles + kb.data.cachedUsersPrivileges = self._cachedUsersPrivileges + kb.data.has_information_schema = self._has_information_schema + + kb.hintValue = self._hintValue + kb.injection.data = self._injectionData + kb.stickyDBMS = self._stickyDBMS + if self._forcedDbms is not None: + Backend.forceDbms(self._forcedDbms) + else: + kb.forcedDbms = None + + +class TestOracleEnum(_EnumBase): + module = importlib.import_module("plugins.dbms.oracle.enumeration") + + def _handler(self): + from plugins.dbms.oracle import OracleMap + set_dbms("Oracle") + return OracleMap() + + def test_get_roles(self): + # rows are [GRANTEE, GRANTED_ROLE]; first column is the user, the rest roles + conf.user = None + kb.data.cachedUsersRoles = {} + self.module.inject.getValue = lambda q, *a, **k: [ + ["SYS", "DBA"], ["SYS", "CONNECT"], ["SCOTT", "RESOURCE"] + ] + roles, areAdmins = self._handler().getRoles() + self.assertIn("SYS", roles) + self.assertIn("SCOTT", roles) + self.assertEqual(set(roles["SYS"]), {"DBA", "CONNECT"}) + # DBA implies administrator + self.assertIn("SYS", areAdmins) + + def test_get_roles_filtered_by_user(self): + # conf.user populates a WHERE clause; canned rows still drive the parse + conf.user = "SCOTT" + kb.data.cachedUsersRoles = {} + self.module.inject.getValue = lambda q, *a, **k: [["SCOTT", "RESOURCE"]] + roles, _ = self._handler().getRoles() + self.assertEqual(list(roles.keys()), ["SCOTT"]) + self.assertEqual(roles["SCOTT"], ["RESOURCE"]) + + def test_get_roles_multiple_roles_per_user(self): + # a user appearing across several rows accumulates all granted roles + conf.user = None + kb.data.cachedUsersRoles = {} + self.module.inject.getValue = lambda q, *a, **k: [ + ["APP", "CONNECT"], ["APP", "RESOURCE"], ["APP", "CREATE SESSION"] + ] + roles, _ = self._handler().getRoles() + self.assertEqual( + set(roles["APP"]), {"CONNECT", "RESOURCE", "CREATE SESSION"} + ) + + +class TestPostgreSQLEnum(_EnumBase): + module = importlib.import_module("plugins.dbms.postgresql.enumeration") + + def _handler(self): + from plugins.dbms.postgresql import PostgreSQLMap + set_dbms("PostgreSQL") + return PostgreSQLMap() + + def test_get_hostname_unsupported(self): + # PostgreSQL overrides getHostname purely to warn; it returns None + self.assertIsNone(self._handler().getHostname()) + + +class TestMySQLEnum(_EnumBase): + # MySQL's enumeration.py adds no overrides (it is a bare `pass`); cover the + # generic discovery path through the full MySQL handler instead. + module = importlib.import_module("plugins.generic.enumeration") + + def _handler(self): + from plugins.dbms.mysql import MySQLMap + set_dbms("MySQL") + return MySQLMap() + + def test_get_dbs(self): + conf.db = None + kb.data.cachedDbs = [] + kb.data.has_information_schema = True + self.module.inject.getValue = lambda q, *a, **k: ( + 3 if k.get("expected") == EXPECTED.INT + else [["information_schema"], ["testdb"], ["mysql"]] + ) + dbs = self._handler().getDbs() + self.assertIn("testdb", dbs) + self.assertEqual(set(kb.data.cachedDbs), set(dbs)) + + +class TestSQLiteEnum(_EnumBase): + module = importlib.import_module("plugins.dbms.sqlite.enumeration") + + def _handler(self): + from plugins.dbms.sqlite import SQLiteMap + set_dbms("SQLite") + return SQLiteMap() + + def test_unsupported_simple_overrides(self): + # SQLite overrides these to a warning + an empty/neutral return value + h = self._handler() + self.assertIsNone(h.getCurrentUser()) + self.assertIsNone(h.getCurrentDb()) + self.assertIsNone(h.getHostname()) + self.assertEqual(h.getUsers(), []) + self.assertEqual(h.getDbs(), []) + self.assertEqual(h.searchDb(), []) + self.assertEqual(h.getStatements(), []) + self.assertEqual(h.getPasswordHashes(), {}) + self.assertEqual(h.getPrivileges(), {}) + + def test_is_dba_always_true(self): + # on SQLite the current user is treated as having all privileges + self.assertTrue(self._handler().isDba()) + + def test_search_column_raises(self): + with self.assertRaises(SqlmapUnsupportedFeatureException): + self._handler().searchColumn() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dbms_enum_b.py b/tests/test_dbms_enum_b.py new file mode 100644 index 000000000..b0622366d --- /dev/null +++ b/tests/test_dbms_enum_b.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Second batch of DBMS-specific enumeration override tests (companion to +tests/test_dbms_enum.py, which covers Microsoft SQL Server getTables). + +Each test drives a FULL per-DBMS handler (the *Map class in +plugins/dbms//__init__.py) with the injection layer mocked, so the +dialect-specific table/column/user/privilege discovery paths run without a live +target, network, or DBMS. The in-band (UNION/error/direct) branch is taken via +conf.direct=True; conf.batch=True avoids interactive prompts. + +Covered here: + * Sybase - getUsers, getDbs, getTables, getColumns, getPrivileges, + searchDb/searchTable/searchColumn, getHostname, getStatements + * SAP MaxDB - getDbs, getTables, getColumns, getPrivileges, + getPasswordHashes, getHostname, getStatements + * Microsoft SQL Server - getPrivileges, searchTable, searchColumn + (getTables already covered by test_dbms_enum.py) + * IBM DB2 - getPasswordHashes, getStatements + * Informix - searchDb, searchTable, searchColumn, getStatements + * Firebird - getDbs, getPasswordHashes, searchDb, getHostname, getStatements + * HSQLDB - getBanner, getPrivileges, getHostname, getStatements, + getCurrentDb + +Sybase/MaxDB enumeration goes through lib.utils.pivotdumptable.pivotDumpTable +(imported into the module namespace), so for those we mock that wrapper - it is +part of the same data-retrieval layer - and mock inject.getValue elsewhere. + +stdlib unittest only (no pytest / no pip); works on Python 2.7 and 3.x. +""" + +import importlib +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.data import conf, kb +from lib.core.common import Backend +from lib.core.enums import EXPECTED +from lib.request import inject + + +def _fresh_cached(): + kb.data.cachedDbs = [] + kb.data.cachedTables = {} + kb.data.cachedColumns = {} + kb.data.cachedUsers = [] + kb.data.cachedUsersPrivileges = {} + kb.data.cachedCounts = {} + kb.data.cachedStatements = [] + kb.data.banner = None + + +class _NoOpDumper(object): + """Swallow every dumper call so search methods don't emit/prompt.""" + + def __getattr__(self, name): + return lambda *a, **k: None + + +def _handler(display_name, dirname): + """Instantiate the full *Map handler for the given DBMS.""" + set_dbms(display_name) + main = importlib.import_module("plugins.dbms.%s" % dirname) + cls = [getattr(main, n) for n in dir(main) if n.endswith("Map")][0] + return cls() + + +class _EnumBase(unittest.TestCase): + """Snapshot/restore every global these enumerators mutate.""" + + # subclasses set these + display_name = None + dirname = None + + def setUp(self): + # config snapshot + self._direct = conf.direct + self._batch = conf.batch + self._db = conf.db + self._tbl = conf.tbl + self._col = conf.col + self._user = conf.user + self._exclude = conf.exclude + self._search = conf.search + self._getBanner = conf.getBanner + self._excludeSysDbs = conf.excludeSysDbs + self._dumper = conf.get("dumper") + + # kb snapshot + self._cached = {k: kb.data.get(k) for k in ( + "cachedDbs", "cachedTables", "cachedColumns", "cachedUsers", + "cachedUsersPrivileges", "cachedCounts", "cachedStatements", "banner", + )} + self._hintValue = kb.hintValue + self._injectionData = kb.injection.data + self._currentDb = kb.data.get("currentDb") + self._hasIS = kb.data.get("has_information_schema") + + # injection layer snapshot + self._gv = inject.getValue + self._cbe = getattr(inject, "checkBooleanExpression", None) + + # baseline config the in-band/non-interactive paths need + conf.direct = True + conf.batch = True + kb.data.has_information_schema = True + _fresh_cached() + + # restore the chosen DBMS for every test + self.handler = _handler(self.display_name, self.dirname) + # the enumeration module whose pivotDumpTable some tests stub + self.em = importlib.import_module("plugins.dbms.%s.enumeration" % self.dirname) + + def tearDown(self): + conf.direct = self._direct + conf.batch = self._batch + conf.db = self._db + conf.tbl = self._tbl + conf.col = self._col + conf.user = self._user + conf.exclude = self._exclude + conf.search = self._search + conf.getBanner = self._getBanner + conf.excludeSysDbs = self._excludeSysDbs + conf.dumper = self._dumper + + for k, v in self._cached.items(): + kb.data[k] = v + kb.hintValue = self._hintValue + kb.injection.data = self._injectionData + kb.data.currentDb = self._currentDb + kb.data.has_information_schema = self._hasIS + + inject.getValue = self._gv + if self._cbe is not None: + inject.checkBooleanExpression = self._cbe + if hasattr(self.em, "pivotDumpTable"): + # restore the pristine reference from the wrapper module + import lib.utils.pivotdumptable as _pdt + self.em.pivotDumpTable = _pdt.pivotDumpTable + + +# --------------------------------------------------------------------------- +# Sybase +# --------------------------------------------------------------------------- + +class TestSybaseEnum(_EnumBase): + display_name = "Sybase" + dirname = "sybase" + + def _pivot(self, *value_lists): + """Make em.pivotDumpTable return canned (entries, lengths) per call. + + Each successive call pops the next mapping of {colName: [values]}. + """ + calls = list(value_lists) + + def fake(table, colList, count=None, blind=True, alias=None): + mapping = calls.pop(0) if calls else {} + entries = {} + lengths = {} + for col in colList: + vals = mapping.get(col.split(".")[-1], []) + entries[col] = list(vals) + lengths[col] = 0 + return entries, lengths + + self.em.pivotDumpTable = fake + + def test_get_users(self): + self._pivot({"name": ["sa", "guest"]}) + users = self.handler.getUsers() + self.assertIn("sa", users) + self.assertIn("guest", users) + + def test_get_dbs(self): + self._pivot({"name": ["master", "model"]}) + dbs = self.handler.getDbs() + self.assertEqual(sorted(dbs), ["master", "model"]) + + def test_get_tables(self): + conf.db = "testdb" + self._pivot({"name": ["users", "logs"]}) + tables = self.handler.getTables() + self.assertIn("testdb", tables) + self.assertEqual(sorted(tables["testdb"]), ["logs", "users"]) + + def test_get_columns(self): + conf.db = "testdb" + conf.tbl = "users" + # column pivot returns name + usertype: REAL Sybase numeric type ids that + # getColumns resolves through SYBASE_TYPES (7 -> "int", 2 -> "varchar"). + from lib.core.dicts import SYBASE_TYPES + self._pivot({"name": ["id", "name"], "usertype": ["7", "2"]}) + cols = self.handler.getColumns() + self.assertIn("testdb", cols) + # table key is identifier-normalized (may be schema-qualified) + tbls = cols["testdb"] + self.assertTrue(any("users" in t for t in tbls)) + colset = list(tbls.values())[0] + # the VALUE is the resolved type name, not the raw usertype number: + # proves the SYBASE_TYPES numeric->name mapping actually ran. + self.assertEqual(colset["id"], SYBASE_TYPES[7]) # "int" + self.assertEqual(colset["name"], SYBASE_TYPES[2]) # "varchar" + + def test_get_privileges(self): + # getPrivileges -> getUsers (pivot) then isDba (checkBooleanExpression). + # Drive the admin-set branch BOTH ways via the isDba oracle so the result + # is not forced by a constant-True stub. + conf.user = None + + # oracle True: every user is flagged DBA -> admins == all users + self._pivot({"name": ["sa", "guest"]}) + inject.checkBooleanExpression = lambda *a, **k: True + privs, admins = self.handler.getPrivileges() + self.assertIn("sa", privs) # users still enumerated as privilege keys + self.assertIn("guest", privs) + self.assertEqual(admins, set(["sa", "guest"])) + + # oracle False: nobody is a DBA -> admins is empty, but users still listed + _fresh_cached() + self._pivot({"name": ["sa", "guest"]}) + inject.checkBooleanExpression = lambda *a, **k: False + privs, admins = self.handler.getPrivileges() + self.assertIn("sa", privs) + self.assertEqual(admins, set()) + + def test_search_not_implemented(self): + # these intentionally return [] with a warning on Sybase + self.assertEqual(self.handler.searchDb(), []) + self.assertEqual(self.handler.searchTable(), []) + self.assertEqual(self.handler.searchColumn(), []) + + def test_get_hostname(self): + # not possible on Sybase; just must not raise + self.assertIsNone(self.handler.getHostname()) + + def test_get_statements(self): + self.assertEqual(self.handler.getStatements(), []) + + +# --------------------------------------------------------------------------- +# SAP MaxDB +# --------------------------------------------------------------------------- + +class TestMaxDBEnum(_EnumBase): + display_name = "SAP MaxDB" + dirname = "maxdb" + + def _pivot(self, *value_lists): + calls = list(value_lists) + + def fake(table, colList, count=None, blind=True, alias=None): + mapping = calls.pop(0) if calls else {} + entries = {} + lengths = {} + for col in colList: + vals = mapping.get(col.split(".")[-1], []) + entries[col] = list(vals) + lengths[col] = 0 + return entries, lengths + + self.em.pivotDumpTable = fake + + def test_get_dbs(self): + self._pivot({"schemaname": ["SYSTEM", "DOMAIN"]}) + dbs = self.handler.getDbs() + self.assertEqual(sorted(dbs), ["DOMAIN", "SYSTEM"]) + + def test_get_tables(self): + conf.db = "SYSTEM" + self._pivot({"tablename": ["USERS", "TABLES"]}) + tables = self.handler.getTables() + # db key is identifier-normalized (uppercase names get quoted) + self.assertEqual(len(tables), 1) + tbls = list(tables.values())[0] + self.assertEqual(sorted(tbls), ["TABLES", "USERS"]) + + def test_get_columns(self): + conf.db = "SYSTEM" + conf.tbl = "USERS" + self._pivot({ + "columnname": ["ID", "NAME"], + "datatype": ["INTEGER", "CHAR"], + "len": ["4", "32"], + }) + cols = self.handler.getColumns() + self.assertEqual(len(cols), 1) + tbls = list(cols.values())[0] + self.assertIn("USERS", tbls) + self.assertEqual(tbls["USERS"]["ID"], "INTEGER(4)") + + def test_get_privileges_empty(self): + self.assertEqual(self.handler.getPrivileges(), {}) + + def test_get_password_hashes_empty(self): + self.assertEqual(self.handler.getPasswordHashes(), {}) + + def test_get_hostname(self): + self.assertIsNone(self.handler.getHostname()) + + def test_get_statements(self): + self.assertEqual(self.handler.getStatements(), []) + + +# --------------------------------------------------------------------------- +# Microsoft SQL Server (methods NOT covered by test_dbms_enum.py) +# --------------------------------------------------------------------------- + +class TestMSSQLServerExtraEnum(_EnumBase): + display_name = "Microsoft SQL Server" + dirname = "mssqlserver" + + def test_get_privileges(self): + # getPrivileges -> getUsers (generic, inject.getValue) then isDba. + # Exercise the admin-set branch BOTH ways via the isDba oracle. + conf.user = None + inject.getValue = lambda q, *a, **k: ["sa", "BUILTIN\\Administrators"] + + # oracle True: all users flagged DBA + inject.checkBooleanExpression = lambda *a, **k: True + privs, admins = self.handler.getPrivileges() + self.assertIn("sa", privs) + self.assertEqual(admins, set(["sa", "BUILTIN\\Administrators"])) + + # oracle False: none are DBA -> empty admin set, users still enumerated + _fresh_cached() + inject.getValue = lambda q, *a, **k: ["sa", "BUILTIN\\Administrators"] + inject.checkBooleanExpression = lambda *a, **k: False + privs, admins = self.handler.getPrivileges() + self.assertIn("sa", privs) + self.assertEqual(admins, set()) + + def test_search_table(self): + conf.db = "testdb" + conf.tbl = "users" + # in-band branch: getValue returns matching table name(s) + inject.getValue = lambda q, *a, **k: ["users"] + # capture the discovered tables instead of dumping them + captured = {} + conf.dumper = _NoOpDumper() + self.handler.dumpFoundTables = lambda tables: captured.update(tables) + self.handler.searchTable() + # at least one database mapped to the matched table + flat = set() + for tbls in captured.values(): + flat.update(tbls) + self.assertTrue(any("users" in t for t in flat)) + + def test_search_column(self): + conf.db = "testdb" + conf.tbl = None + conf.col = "password" + # exact match (no wildcard) so no recursive getColumns call; + # getValue returns the tables that contain the column + inject.getValue = lambda q, *a, **k: ["users"] + captured = {} + conf.dumper = _NoOpDumper() + self.handler.dumpFoundColumn = lambda dbs, foundCols, colConsider: captured.update(dbs) + self.handler.searchColumn() + # the searched column was located in at least one table + flat = set() + for tbls in captured.values(): + flat.update(tbls) + self.assertTrue(any("users" in t for t in flat)) + + +# --------------------------------------------------------------------------- +# IBM DB2 +# --------------------------------------------------------------------------- + +class TestDB2Enum(_EnumBase): + display_name = "IBM DB2" + dirname = "db2" + + def test_get_password_hashes_empty(self): + self.assertEqual(self.handler.getPasswordHashes(), {}) + + def test_get_statements_empty(self): + self.assertEqual(self.handler.getStatements(), []) + + +# --------------------------------------------------------------------------- +# Informix +# --------------------------------------------------------------------------- + +class TestInformixEnum(_EnumBase): + display_name = "Informix" + dirname = "informix" + + def test_search_db(self): + self.assertEqual(self.handler.searchDb(), []) + + def test_search_table(self): + self.assertEqual(self.handler.searchTable(), []) + + def test_search_column(self): + self.assertEqual(self.handler.searchColumn(), []) + + def test_get_statements(self): + self.assertEqual(self.handler.getStatements(), []) + + +# --------------------------------------------------------------------------- +# Firebird +# --------------------------------------------------------------------------- + +class TestFirebirdEnum(_EnumBase): + display_name = "Firebird" + dirname = "firebird" + + def test_get_dbs_empty(self): + self.assertEqual(self.handler.getDbs(), []) + + def test_get_password_hashes_empty(self): + self.assertEqual(self.handler.getPasswordHashes(), {}) + + def test_search_db_empty(self): + self.assertEqual(self.handler.searchDb(), []) + + def test_get_hostname(self): + self.assertIsNone(self.handler.getHostname()) + + def test_get_statements_empty(self): + self.assertEqual(self.handler.getStatements(), []) + + +# --------------------------------------------------------------------------- +# HSQLDB +# --------------------------------------------------------------------------- + +class TestHSQLDBEnum(_EnumBase): + display_name = "HSQLDB" + dirname = "hsqldb" + + def test_get_banner(self): + conf.getBanner = True + kb.data.banner = None + # getValue returns a single-element LIST; getBanner pipes it through + # unArrayizeValue, which must unwrap it to the scalar banner string. + inject.getValue = lambda q, *a, **k: ["HSQLDB 2.5.1"] + banner = self.handler.getBanner() + self.assertEqual(banner, "HSQLDB 2.5.1") + + def test_get_privileges_empty(self): + self.assertEqual(self.handler.getPrivileges(), {}) + + def test_get_hostname(self): + self.assertIsNone(self.handler.getHostname()) + + def test_get_statements_empty(self): + self.assertEqual(self.handler.getStatements(), []) + + def test_get_current_db_default_schema(self): + from lib.core.settings import HSQLDB_DEFAULT_SCHEMA + self.assertEqual(self.handler.getCurrentDb(), HSQLDB_DEFAULT_SCHEMA) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deps.py b/tests/test_deps.py new file mode 100644 index 000000000..0f09e5cdd --- /dev/null +++ b/tests/test_deps.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Optional-dependency probe (lib/utils/deps.py, the --dependencies feature). +checkDependencies() attempts to import every supported DBMS driver and warns +on the ones missing; it must never raise regardless of what's installed. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +import lib.utils.deps as deps +from lib.utils.deps import checkDependencies + + +class _RecordingLogger(object): + """Captures every (level, message) emitted while installed as deps.logger.""" + + def __init__(self): + self.records = [] + + def warning(self, msg, *args): + self.records.append(("warning", msg % args if args else msg)) + + def info(self, msg, *args): + self.records.append(("info", msg % args if args else msg)) + + def debug(self, msg, *args): + self.records.append(("debug", msg % args if args else msg)) + + def error(self, msg, *args): + self.records.append(("error", msg % args if args else msg)) + + def messages(self, level=None): + return [m for (lvl, m) in self.records if level is None or lvl == level] + + +class TestCheckDependencies(unittest.TestCase): + def setUp(self): + self._real_logger = deps.logger + self.rec = _RecordingLogger() + deps.logger = self.rec + + def tearDown(self): + deps.logger = self._real_logger + + def test_missing_driver_warns_with_library_name(self): + # 'kinterbasdb' (Firebird driver) is essentially never installed, so the + # probe must hit the except branch and emit a warning naming the library. + try: + import kinterbasdb # noqa: F401 + self.skipTest("kinterbasdb is unexpectedly installed") + except ImportError: + pass + + checkDependencies() + + warnings = self.rec.messages("warning") + self.assertTrue(warnings, msg="no warnings captured for a missing driver") + # the Firebird entry must name its third-party library in a warning + self.assertTrue( + any("kinterbasdb" in w for w in warnings), + msg="missing Firebird driver did not produce a library-naming warning: %r" % warnings, + ) + + def test_all_present_emits_all_installed_info(self): + # force every __import__ to succeed so no library is ever recorded as + # missing; the empty-missing-set branch must emit the summary info line. + import builtins + + class _FakeModule(object): + __version__ = "999.0.0" + + real_import = builtins.__import__ + + def _always_succeed(name, *args, **kwargs): + try: + return real_import(name, *args, **kwargs) + except Exception: + return _FakeModule() + + builtins.__import__ = _always_succeed + try: + checkDependencies() + finally: + builtins.__import__ = real_import + + infos = self.rec.messages("info") + self.assertTrue( + any("all dependencies are installed" in m for m in infos), + msg="all-present path did not emit the summary info: %r" % infos, + ) + # and with nothing missing there must be no missing-library warnings + self.assertFalse( + any("third-party library" in w and "requires" in w for w in self.rec.messages("warning")), + msg="unexpected missing-library warning when all imports succeed", + ) + + def test_returns_none(self): + # contract: the probe is purely advisory and never returns a value + self.assertIsNone(checkDependencies()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dialectdbms.py b/tests/test_dialectdbms.py index 81de07ece..5dc28ac98 100644 --- a/tests/test_dialectdbms.py +++ b/tests/test_dialectdbms.py @@ -72,13 +72,16 @@ class TestDialectClassification(unittest.TestCase): self.assertEqual(_classify(base + (shift,)), expected, "engine %r misclassified (shift=%s)" % (engine, shift)) def test_no_false_positive_across_measured_set(self): + # non-collision property: every measured engine maps to EXACTLY its expected DBMS (or None), + # never to some other back-end. The shift flag is irrelevant for these (non-shift-sensitive) + # engines, so assert it both ways. for engine, (base, expected) in MEASURED.items(): for shift in (False, True): result = _classify(base + (shift,)) - if expected is None: - self.assertIsNone(result, "ambiguous engine %r leaked a DBMS prior" % engine) - else: - self.assertIn(result, (DBMS.MYSQL, DBMS.MSSQL, DBMS.PGSQL, DBMS.SQLITE, DBMS.MONETDB, DBMS.ORACLE)) + self.assertEqual(result, expected, "engine %r misclassified (shift=%s): got %r, expected %r" % (engine, shift, result, expected)) + # the only non-None DBMS priors the measured set can yield (sanity on the mapping itself) + produced = set(expected for _, expected in MEASURED.values() if expected is not None) + self.assertEqual(produced, {DBMS.MYSQL, DBMS.PGSQL, DBMS.SQLITE}) def test_all_error_signature_yields_no_prior(self): # an all-error signature (Oracle, ClickHouse, IRIS, or simply a WAF-blocked channel) is not diff --git a/tests/test_dns_engine.py b/tests/test_dns_engine.py index 5eaf2c0a7..767a5019c 100644 --- a/tests/test_dns_engine.py +++ b/tests/test_dns_engine.py @@ -29,6 +29,7 @@ character-based and a chunk could split a code point, need the real-DBMS run. import binascii import os +import re import socket import struct import sys @@ -251,16 +252,114 @@ class TestDnsExfilEngineMssql(TestDnsExfilEngine): DBMS_NAME = "Microsoft SQL Server" -class TestDnsLabelInvariant(unittest.TestCase): - """The exfil chunk is hex-encoded into ONE DNS label, so 2*chunk_length must never exceed the - 63-octet DNS label limit - otherwise the query carries an invalid (over-long) label and exfil - silently breaks. Guards the chunk_length arithmetic in dnsUse for every supported DBMS.""" - def test_hex_label_within_max_dns_label(self): - for dbms in (DBMS.MYSQL, DBMS.ORACLE, DBMS.PGSQL, DBMS.MSSQL): - chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2 - self.assertGreater(chunk_length, 0, "%s: non-positive chunk_length" % dbms) - self.assertLessEqual(2 * chunk_length, MAX_DNS_LABEL, - "%s: hex label (%d) exceeds MAX_DNS_LABEL (%d)" % (dbms, 2 * chunk_length, MAX_DNS_LABEL)) +class TestDnsLabelInvariant(_DnsCase): + """The exfil chunk is hex-encoded into ONE DNS label, so the label dnsUse emits must never + exceed the 63-octet DNS label limit - otherwise the query carries an invalid (over-long) label + and exfil silently breaks. + + Unlike a static formula check, this drives the REAL dnsUse() chunking through the REAL DNSServer + and asserts the invariant on the ACTUAL labels that reach the wire. The mock oracle does NOT + re-derive the chunk size: it slices each chunk to exactly the length dnsUse itself rendered into + its SUBSTRING call (captured live from agent.hexConvertField, whose input is the source's + substring expression). So if the chunk_length arithmetic in dnsUse regresses, the emitted hex + label grows past 63 octets and this test goes red - it observes the source's output, it does not + recompute it. + """ + + def _drive_and_collect_labels(self, secret): + """ + Runs dnsUse for L{secret} end-to-end against the real DNS server, slicing each chunk to the + length the SOURCE asked for (parsed from the live SUBSTRING expression dnsUse builds), and + returns (every label seen in every emitted query name, list of source chunk_lengths seen). + """ + secret_bytes = secret.encode("utf-8") + boundaries = [] + served = [0] + source_chunk_lengths = [] + # Snapshot the names the REAL DNSServer parsed off the wire, captured the moment they land + # in _requests - dnsUse's own .pop() consumes them, so we must grab them before that. + captured_names = [] + + real_randomStr = self._saved_randomStr + def spy_randomStr(length=4, alphabet=None, **kw): + if alphabet == DNS_BOUNDARIES_ALPHABET and length == 3: + out = real_randomStr(length=length, alphabet=alphabet, **kw) + boundaries.append(out) + return out + return real_randomStr(length=length, alphabet=alphabet, **kw) if alphabet is not None else real_randomStr(length=length, **kw) + dnsmod.randomStr = spy_randomStr + + # agent.hexConvertField receives the rendered SUBSTRING call, e.g. "MID((...),1,31)" / + # "SUBSTRING((...) FROM 1 FOR 13)"; the substring LENGTH argument (the source's real + # chunk_length) is the last integer literal in it. Capture it per iteration so the oracle + # emits a chunk of exactly that size - the source's arithmetic, not a copy of it. + saved_hexConvertField = agent.hexConvertField + def spy_hexConvertField(field): + source_chunk_lengths.append(int(re.findall(r"\d+", field)[-1])) + return saved_hexConvertField(field) + agent.hexConvertField = spy_hexConvertField + + def oracle(payload=None, *args, **kwargs): + prefix, suffix = boundaries[-2], boundaries[-1] + chunk_length = source_chunk_lengths[-1] + chunk = secret_bytes[served[0]:served[0] + chunk_length] + if chunk: + host = "%s.%s.%s.%s" % (prefix, binascii.hexlify(chunk).decode(), suffix, conf.dnsDomain) + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(3) + c.sendto(_build_query(host), ("127.0.0.1", self.server.port)) + try: + c.recvfrom(512) + finally: + c.close() + served[0] += len(chunk) + for _ in range(100): + with self.server._lock: + matched = [r for r in self.server._requests if host.encode() in r] + if matched: + captured_names.extend(r.decode() if isinstance(r, bytes) else r for r in matched) + break + time.sleep(0.01) + return None + + Connect.queryPage = staticmethod(oracle) + dnsmod.Request.queryPage = staticmethod(oracle) + + try: + result = dnsmod.dnsUse("%s AND %d=%d", "user()") + finally: + agent.hexConvertField = saved_hexConvertField + + # round-trip must still work (the source must actually reassemble what it chunked) + self.assertEqual(result, secret) + + labels = [] + for name in captured_names: + labels.extend(label for label in name.split(".") if label) + return labels, source_chunk_lengths + + def test_emitted_dns_labels_within_max_dns_label(self): + # long enough that every supported dialect's chunk_length forces several chunks (>1 label of + # hex payload), so the chunking loop - not just a single-shot path - is what we measure + secret = ("The quick brown fox jumps over the lazy dog " + "0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz") * 3 + for dbms_name in ("MySQL", "Oracle", "PostgreSQL", "Microsoft SQL Server"): + self.DBMS_NAME = dbms_name + set_dbms(dbms_name) + labels, source_chunk_lengths = self._drive_and_collect_labels(secret) + + # the source must have actually chunked (multiple SUBSTRING iterations), otherwise we + # would not be testing the chunking output at all + self.assertGreater(len(source_chunk_lengths), 1, + "%s: payload did not force multiple chunks (got %d)" % (dbms_name, len(source_chunk_lengths))) + self.assertTrue(all(cl > 0 for cl in source_chunk_lengths), + "%s: non-positive chunk_length from source: %r" % (dbms_name, source_chunk_lengths)) + + self.assertTrue(labels, "%s: no DNS query labels were captured" % dbms_name) + for label in labels: + self.assertLessEqual(len(label), MAX_DNS_LABEL, + "%s: emitted DNS label %r is %d octets, exceeds MAX_DNS_LABEL (%d)" + % (dbms_name, label, len(label), MAX_DNS_LABEL)) class TestDnsChannelDetection(_DnsCase): diff --git a/tests/test_dns_server.py b/tests/test_dns_server.py index 9e566e3d7..613518b7a 100644 --- a/tests/test_dns_server.py +++ b/tests/test_dns_server.py @@ -36,18 +36,69 @@ def build_query(name, tid=b"\x12\x34", qtype=1): class _HighPortDNSServer(DNSServer): - """Real DNSServer logic, bound on a high port (no root, no :53 probe)""" - def __init__(self, port, sock=None, maxlen=MAX_DNS_REQUESTS): + """Real DNSServer logic, bound on an ephemeral high port (no root, no :53 probe). + + Binds to port 0 and reads the kernel-chosen port back via getsockname() (same pattern + as tests/test_dns_engine.py) so concurrent/repeated runs never collide on a hardcoded + port. The actual port is exposed as L{self.port}. + """ + def __init__(self, sock=None, maxlen=MAX_DNS_REQUESTS): self._requests = collections.deque(maxlen=maxlen) self._lock = threading.Lock() if sock is None: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("127.0.0.1", port)) + sock.bind(("127.0.0.1", 0)) self._socket = sock + self.port = self._socket.getsockname()[1] self._running = False self._initialized = False + def close(self): + self._running = False + try: + self._socket.close() + except socket.error: + pass + + +# Maximum time (seconds) to wait for the daemon server thread to come up, or for a sent +# query to be recorded, before failing loudly instead of spinning/sleeping forever. +WAIT_TIMEOUT = 5.0 + + +def _wait_initialized(srv, timeout=WAIT_TIMEOUT): + """Bounded wait for the server thread to flip _initialized; fail fast if it never does.""" + deadline = time.time() + timeout + while not srv._initialized: + if time.time() > deadline: + raise RuntimeError("DNS server failed to initialize within %.1fs" % timeout) + time.sleep(0.01) + + +def _wait_recorded(srv, token, timeout=WAIT_TIMEOUT): + """Bounded wait until L{token} appears in a recorded request; False on timeout.""" + if hasattr(token, "encode"): + token = token.encode() + deadline = time.time() + timeout + while time.time() <= deadline: + with srv._lock: + if any(token in r for r in srv._requests): + return True + time.sleep(0.01) + return False + + +def _wait_popped(srv, prefix, suffix, timeout=WAIT_TIMEOUT): + """Bounded wait until pop(prefix, suffix) yields a value; returns it or None on timeout.""" + deadline = time.time() + timeout + while time.time() <= deadline: + popped = srv.pop(prefix, suffix) + if popped: + return popped + time.sleep(0.01) + return None + class _SendFailOnceSocket(object): """Wraps a real UDP socket; first sendto() raises (simulated transient failure)""" @@ -95,31 +146,30 @@ class TestDNSQuery(unittest.TestCase): class TestDNSServerRoundTrip(unittest.TestCase): - PORT = 5471 - @classmethod def setUpClass(cls): - cls.srv = _HighPortDNSServer(cls.PORT) + cls.srv = _HighPortDNSServer() cls.srv.run() - while not cls.srv._initialized: - time.sleep(0.02) + _wait_initialized(cls.srv) + + @classmethod + def tearDownClass(cls): + srv = getattr(cls, "srv", None) + if srv is not None: + srv.close() + cls.srv = None def _send(self, name): c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) c.settimeout(3) - c.sendto(build_query(name), ("127.0.0.1", self.PORT)) + c.sendto(build_query(name), ("127.0.0.1", self.srv.port)) try: c.recvfrom(512) except socket.timeout: pass finally: c.close() - for _ in range(100): - with self.srv._lock: - if any(name.encode() in r for r in self.srv._requests): - return True - time.sleep(0.01) - return False + return _wait_recorded(self.srv, name) def test_roundtrip_and_pop(self): self.assertTrue(self._send("aaa.cafe.bbb.exfil.test")) @@ -132,49 +182,40 @@ class TestDNSServerRoundTrip(unittest.TestCase): # labels regardless of qtype, and the server records before crafting the (A) response c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) c.settimeout(2) - c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.PORT)) + c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.srv.port)) try: c.recvfrom(512) except socket.timeout: pass finally: c.close() - for _ in range(200): - if self.srv.pop("ggg", "hhh"): - return - time.sleep(0.01) - self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)") + if not _wait_popped(self.srv, "ggg", "hhh"): + self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)") class TestDNSServerMemoryBound(unittest.TestCase): """The server records every received query (it listens on :53); only matching ones are popped. Unrelated/stray traffic and resolver retries must not grow memory without bound.""" - PORT = 5475 def test_requests_are_bounded_and_recent_kept(self): - srv = _HighPortDNSServer(self.PORT, maxlen=50) + srv = _HighPortDNSServer(maxlen=50) + self.addCleanup(srv.close) srv.run() - while not srv._initialized: - time.sleep(0.02) + _wait_initialized(srv) c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) for i in range(200): # flood well past the bound - c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", self.PORT)) + c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", srv.port)) c.close() # a legit exfil query right after the flood must still be capturable c2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); c2.settimeout(2) - c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", self.PORT)) + c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", srv.port)) try: c2.recvfrom(512) except socket.timeout: pass finally: c2.close() - popped = None - for _ in range(200): - popped = srv.pop("ppp", "qqq") - if popped: - break - time.sleep(0.01) + popped = _wait_popped(srv, "ppp", "qqq") with srv._lock: n = len(srv._requests) self.assertLessEqual(n, 50, "request buffer exceeded its bound (%d)" % n) @@ -182,11 +223,11 @@ class TestDNSServerMemoryBound(unittest.TestCase): class TestDNSServerResilience(unittest.TestCase): - def _make(self, port, sock=None): - srv = _HighPortDNSServer(port, sock=sock) + def _make(self, sock=None): + srv = _HighPortDNSServer(sock=sock) + self.addCleanup(srv.close) srv.run() - while not srv._initialized: - time.sleep(0.02) + _wait_initialized(srv) return srv def _query(self, port, name): @@ -200,34 +241,28 @@ class TestDNSServerResilience(unittest.TestCase): finally: c.close() - def _recorded(self, srv, token, tries=120): - for _ in range(tries): - with srv._lock: - if any(token.encode() in r for r in srv._requests): - return True - time.sleep(0.01) - return False + def _recorded(self, srv, token): + return _wait_recorded(srv, token) def test_survives_transient_send_error(self): - port = 5472 + # ephemeral bind, then wrap the bound socket so its first sendto() raises s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("127.0.0.1", port)) - srv = self._make(port, sock=_SendFailOnceSocket(s)) - self._query(port, "aaa.11.bbb.exfil.test") # first sendto raises - self._query(port, "ccc.22.ddd.exfil.test") # must still be served + s.bind(("127.0.0.1", 0)) + srv = self._make(sock=_SendFailOnceSocket(s)) + self._query(srv.port, "aaa.11.bbb.exfil.test") # first sendto raises + self._query(srv.port, "ccc.22.ddd.exfil.test") # must still be served self.assertTrue(self._recorded(srv, "ccc.22.ddd"), "DNS server died after one failing sendto (lost subsequent exfil)") self.assertTrue(srv._running) def test_survives_malformed_packets(self): - port = 5473 - srv = self._make(port) + srv = self._make() c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) for junk in (b"", b"\x00", b"\xff" * 7, b"\x12\x34\x01\x00\x00\x01" + b"\x20abc"): - c.sendto(junk, ("127.0.0.1", port)) + c.sendto(junk, ("127.0.0.1", srv.port)) c.close() - self._query(port, "ok.33.fine.exfil.test") + self._query(srv.port, "ok.33.fine.exfil.test") self.assertTrue(self._recorded(srv, "ok.33.fine"), "DNS server died on a malformed packet") @@ -235,14 +270,19 @@ class TestDNSServerResilience(unittest.TestCase): class TestDNSServerConcurrency(unittest.TestCase): """Under --threads, many workers fire DNS queries and call pop() while the server thread appends - all guarded by one lock. Each worker must get back exactly its own data.""" - PORT = 5474 @classmethod def setUpClass(cls): - cls.srv = _HighPortDNSServer(cls.PORT) + cls.srv = _HighPortDNSServer() cls.srv.run() - while not cls.srv._initialized: - time.sleep(0.02) + _wait_initialized(cls.srv) + + @classmethod + def tearDownClass(cls): + srv = getattr(cls, "srv", None) + if srv is not None: + srv.close() + cls.srv = None def test_concurrent_send_and_pop_no_crosstalk(self): import binascii, re @@ -258,19 +298,14 @@ class TestDNSServerConcurrency(unittest.TestCase): c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) c.settimeout(2) try: - c.sendto(build_query(host), ("127.0.0.1", self.PORT)) + c.sendto(build_query(host), ("127.0.0.1", self.srv.port)) try: c.recvfrom(512) except socket.timeout: pass finally: c.close() - got = None - for _ in range(200): - got = self.srv.pop(prefix, suffix) - if got: - break - time.sleep(0.01) + got = _wait_popped(self.srv, prefix, suffix) if not got: errors.append("worker %d: never popped its query" % i); return m = re.search(r"%s\.(?P.+?)\.%s" % (prefix, suffix), got, re.I) diff --git a/tests/test_dump_format.py b/tests/test_dump_format.py new file mode 100644 index 000000000..ce9076c6b --- /dev/null +++ b/tests/test_dump_format.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Output formatting of the result dumper (lib/core/dump.py) and the SQLite +replication backend (lib/core/replication.py). + +dump.Dump turns extracted DB structures (schemas, table/column listings, row +counts, single facts, user lists) into the human-readable ASCII tables printed +to the console, and serializes per-table row data to CSV / HTML / SQLite files. +None of that needs a live target, network or DBMS: the console renderers route +every line through Dump._write (overridden here to capture instead of print), +and the file renderers just write to a path we point at a temp dir. These tests +pin the rendered layout/escaping contracts so a formatting regression is caught +without an end-to-end scan. +""" + +import io +import os +import shutil +import sys +import tempfile +import unittest + +from collections import OrderedDict as _PlainOrderedDict + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.core.common import Backend +from lib.core.data import conf, kb +from lib.core.dump import Dump +from lib.core.enums import DUMP_FORMAT +from lib.core.replication import Replication + + +# --- console-rendering tests (no files): capture every Dump._write line -------------------------- + +class _CaptureCase(unittest.TestCase): + """Base for the console renderers: pins a neutral case-preserving DBMS, disables api/report + side channels, and replaces Dump._write with an in-memory capture so nothing hits stdout.""" + + _CONF_KEYS = ("api", "reportCollector", "dumpFormat", "col", "csvDel", "dumpPath", "dumpFile", + "limitStart", "limitStop", "forceDbms", "dbms") + _KB_KEYS = ("forcedDbms", "dbms") + + def setUp(self): + self._saved = dict((k, conf.get(k)) for k in self._CONF_KEYS) + self._savedKb = dict((k, kb.get(k)) for k in self._KB_KEYS) + conf.forceDbms = conf.dbms = None + kb.dbms = None + Backend.forceDbms("MySQL") + conf.api = False + conf.reportCollector = None + conf.col = None + conf.csvDel = "," + self.lines = [] + self.d = Dump() + self.d._write = self._capture + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + for k, v in self._savedKb.items(): + kb[k] = v + + def _capture(self, data, newline=True, console=True, content_type=None): + # mirror Dump._write's own line-vs-space join so multi-call lines reassemble faithfully + self.lines.append("%s%s" % (data, "\n" if newline else " ")) + + def text(self): + return "".join(self.lines) + + +class TestStringAndLister(_CaptureCase): + def test_string_scalar_quoted(self): + # a plain string fact is rendered as "header: 'value'" + self.d.string("current user", "root@localhost") + self.assertIn("current user: 'root@localhost'", self.text()) + + def test_string_multiline_block(self): + # a value containing a newline switches to the fenced ---\n...\n--- block form + self.d.string("banner", "line1\nline2") + out = self.text() + self.assertIn("banner:\n---\nline1\nline2\n---", out) + + def test_string_singleton_list_unwrapped(self): + # a one-element list is unwrapped to the scalar form (not the lister "[N]:" form) + self.d.string("current database", ["testdb"]) + out = self.text() + self.assertIn("current database: 'testdb'", out) + self.assertNotIn("[1]", out) + + def test_lister_sorts_and_counts(self): + # lister prints a "[count]:" header, one "[*] item" per element, sorted case-insensitively + self.d.lister("available databases", ["mysql", "Alpha", "zebra"]) + out = self.text() + self.assertIn("available databases [3]:", out) + body = out[out.index("[3]:"):] + # case-insensitive ascending: Alpha, mysql, zebra + self.assertLess(body.index("[*] Alpha"), body.index("[*] mysql")) + self.assertLess(body.index("[*] mysql"), body.index("[*] zebra")) + + def test_lister_dedupes(self): + # the sort path also de-duplicates (set()) before listing + self.d.lister("database management system users", ["root", "root", "guest"]) + out = self.text() + self.assertIn("database management system users [2]:", out) + self.assertEqual(out.count("[*] root"), 1) + + def test_lister_unsorted_preserves_order(self): + # sort=False (e.g. rFile) keeps insertion order + self.d.lister("files saved to", ["/z", "/a", "/m"], sort=False) + out = self.text() + self.assertLess(out.index("[*] /z"), out.index("[*] /a")) + self.assertLess(out.index("[*] /a"), out.index("[*] /m")) + + +class TestCurrentDb(_CaptureCase): + def test_label_default_dbms(self): + # MySQL is not in the schema/owner special-cased lists -> plain "current database" + self.d.currentDb("testdb") + self.assertIn("current database: 'testdb'", self.text()) + + def test_label_schema_dbms(self): + # Oracle is in the schema-equivalent list -> the label is annotated accordingly + Backend.forceDbms("Oracle") + self.d.currentDb("SYSTEM") + out = self.text() + self.assertIn("equivalent to schema on Oracle", out) + self.assertIn("SYSTEM", out) + + +class TestDbTables(_CaptureCase): + def test_table_listing_box(self): + self.d.dbTables({"testdb": ["users", "logs"]}) + out = self.text() + self.assertIn("Database: testdb", out) + self.assertIn("[2 tables]", out) + self.assertIn("| users", out) + self.assertIn("| logs", out) + # box borders present + self.assertIn("+", out) + + def test_single_table_singular(self): + self.d.dbTables({"testdb": ["only"]}) + self.assertIn("[1 table]", self.text()) + + def test_no_tables(self): + self.d.dbTables({}) + self.assertIn("No tables found", self.text()) + + def test_box_width_matches_longest_table(self): + # the border length tracks the longest table name (+2 padding) + self.d.dbTables({"testdb": ["a", "elephant"]}) + out = self.text() + # "elephant" is 8 chars -> a border line of 8+2 = 10 dashes exists + self.assertIn("+%s+" % ("-" * 10), out) + + +class TestDbTableColumns(_CaptureCase): + def test_typed_columns_two_column_box(self): + self.d.dbTableColumns({"testdb": {"users": {"id": "int", "name": "varchar(50)"}}}) + out = self.text() + self.assertIn("Database: testdb", out) + self.assertIn("Table: users", out) + self.assertIn("[2 columns]", out) + self.assertIn("| Column", out) + self.assertIn("| Type", out) + self.assertIn("int", out) + self.assertIn("varchar(50)", out) + + def test_typeless_columns_single_box(self): + # when no column carries a type, only the Column box is rendered (no Type header) + self.d.dbTableColumns({"testdb": {"users": {"id": None, "name": None}}}) + out = self.text() + self.assertIn("| Column", out) + self.assertNotIn("| Type", out) + + def test_mixed_types_still_show_type_header(self): + # even if the alphabetically-last column is type-less, a Type column must appear + self.d.dbTableColumns({"testdb": {"t": {"aaa": "int", "zzz": None}}}) + self.assertIn("| Type", self.text()) + + +class TestDbTablesCount(_CaptureCase): + def test_count_box_sorted_desc(self): + self.d.dbTablesCount({"testdb": {5: ["small"], 100: ["big"]}}) + out = self.text() + self.assertIn("Database: testdb", out) + self.assertIn("| Table", out) + self.assertIn("| Entries", out) + # higher count first (reverse sort) + self.assertLess(out.index("big"), out.index("small")) + self.assertIn("100", out) + + +class TestUserSettings(_CaptureCase): + def test_privileges_listed_with_admin_flag(self): + # userSettings accepts (settingsDict, adminsSet); admins get an "(administrator)" tag + settings = ({"root": ["ALL"], "guest": ["SELECT"]}, set(["root"])) + self.d.userSettings("database management system users privileges", settings, "privilege") + out = self.text() + self.assertIn("[*] root (administrator)", out) + self.assertIn("[*] guest", out) + self.assertNotIn("guest (administrator)", out) + self.assertIn("privilege: ALL", out) + self.assertIn("privilege: SELECT", out) + + +# --- file-rendering tests (CSV / HTML / SQLite): point output at a temp dir ---------------------- + +class _FileDumpCase(unittest.TestCase): + _CONF_KEYS = ("dumpFormat", "dumpPath", "dumpFile", "col", "api", "reportCollector", + "limitStart", "limitStop", "csvDel", "forceDbms", "dbms") + _KB_KEYS = ("forcedDbms", "dbms") + + def setUp(self): + self._saved = dict((k, conf.get(k)) for k in self._CONF_KEYS) + self._savedKb = dict((k, kb.get(k)) for k in self._KB_KEYS) + conf.forceDbms = conf.dbms = None + kb.dbms = None + Backend.forceDbms("MySQL") + self.tmp = tempfile.mkdtemp(prefix="sqlmap-dumpfmt-test") + conf.dumpPath = self.tmp + conf.dumpFile = None + conf.col = None + conf.api = False + conf.reportCollector = None + conf.limitStart = conf.limitStop = None + conf.csvDel = "," + self.d = Dump() + self.d._write = lambda *a, **k: None # silence the console table + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + for k, v in self._savedKb.items(): + kb[k] = v + shutil.rmtree(self.tmp, ignore_errors=True) + + def _path(self, table_values, ext): + db = table_values["__infos__"]["db"] or "All" + return os.path.join(self.tmp, db, "%s.%s" % (table_values["__infos__"]["table"], ext)) + + def _dump(self, table_values, fmt, ext): + conf.dumpFormat = fmt + self.d.dbTableValues(table_values) + with io.open(self._path(table_values, ext), encoding="utf-8") as f: + return f.read() + + +class TestCsvDump(_FileDumpCase): + def _sample(self): + return _PlainOrderedDict([ + ("__infos__", {"count": 2, "db": "testdb", "table": "users"}), + ("id", {"length": 2, "values": ["1", "2"]}), + ("name", {"length": 6, "values": ["luther", "fluffy"]}), + ]) + + def test_header_and_rows(self): + content = self._dump(self._sample(), DUMP_FORMAT.CSV, "csv") + lines = [l for l in content.splitlines() if l.strip()] + self.assertEqual(lines[0].split(","), ["id", "name"]) + self.assertEqual(lines[1].split(","), ["1", "luther"]) + self.assertEqual(lines[2].split(","), ["2", "fluffy"]) + + def test_delimiter_in_value_is_quoted(self): + # RFC-4180: a value containing the delimiter must be wrapped in quotes + tv = _PlainOrderedDict([ + ("__infos__", {"count": 1, "db": "testdb", "table": "t"}), + ("a", {"length": 8, "values": ["x,y"]}), + ("b", {"length": 1, "values": ["z"]}), + ]) + content = self._dump(tv, DUMP_FORMAT.CSV, "csv") + self.assertIn('"x,y"', content) + + def test_null_and_blank_markers(self): + # the display replacements apply to CSV too: DB NULL (" ") -> NULL, empty ("") -> + tv = _PlainOrderedDict([ + ("__infos__", {"count": 1, "db": "testdb", "table": "t"}), + ("a", {"length": 4, "values": [" "]}), + ("b", {"length": 7, "values": [""]}), + ("c", {"length": 1, "values": ["x"]}), + ]) + content = self._dump(tv, DUMP_FORMAT.CSV, "csv") + row = [l for l in content.splitlines() if l.strip()][1] + self.assertEqual(row.split(","), ["NULL", "", "x"]) + + def test_custom_delimiter(self): + conf.csvDel = ";" + content = self._dump(self._sample(), DUMP_FORMAT.CSV, "csv") + self.assertEqual(content.splitlines()[0].split(";"), ["id", "name"]) + + +class TestHtmlDump(_FileDumpCase): + def _sample(self): + return _PlainOrderedDict([ + ("__infos__", {"count": 1, "db": "testdb", "table": "users"}), + ("id", {"length": 2, "values": ["1"]}), + ("name", {"length": 6, "values": ["luther"]}), + ]) + + def test_html_scaffold_and_cells(self): + content = self._dump(self._sample(), DUMP_FORMAT.HTML, "html") + self.assertIn("", content) + self.assertIn("testdb.users", content) + self.assertIn("id", content) + self.assertIn(">name", content) + self.assertIn("1", content) + self.assertIn("luther", content) + self.assertIn("", content) + self.assertIn("", content) + + def test_html_escapes_markup(self): + # a value with HTML metacharacters must be escaped, not emitted raw + tv = _PlainOrderedDict([ + ("__infos__", {"count": 1, "db": "testdb", "table": "t"}), + ("payload", {"length": 16, "values": [""]}), + ]) + content = self._dump(tv, DUMP_FORMAT.HTML, "html") + self.assertNotIn("", content) + self.assertIn("<", content) + + +class TestSqliteDump(_FileDumpCase): + def test_rows_and_inferred_types(self): + tv = _PlainOrderedDict([ + ("__infos__", {"count": 2, "db": "testdb", "table": "people"}), + ("id", {"length": 2, "values": ["1", "2"]}), # all ints -> INTEGER + ("ratio", {"length": 4, "values": ["1.5", "2.0"]}), # floats -> REAL + ("name", {"length": 6, "values": ["alice", " "]}), # text with a NULL marker + ]) + conf.dumpFormat = DUMP_FORMAT.SQLITE + self.d.dbTableValues(tv) + + import sqlite3 + dbfile = os.path.join(self.tmp, "testdb.sqlite3") + self.assertTrue(os.path.exists(dbfile)) + conn = sqlite3.connect(dbfile) + try: + cur = conn.cursor() + cur.execute("SELECT id, ratio, name FROM people ORDER BY id") + rows = cur.fetchall() + self.assertEqual(rows[0], (1, 1.5, "alice")) + # the DB NULL marker (" ") was stored as a real NULL, not the "NULL" text + self.assertEqual(rows[1], (2, 2.0, None)) + # column affinities inferred from the values + cur.execute("PRAGMA table_info(people)") + types = {name: ctype for (_cid, name, ctype, _nn, _dv, _pk) in cur.fetchall()} + self.assertEqual(types["id"], "INTEGER") + self.assertEqual(types["ratio"], "REAL") + self.assertEqual(types["name"], "TEXT") + finally: + conn.close() + + +# --- replication backend tests (pure sqlite3, no network/DBMS) ----------------------------------- + +class TestReplication(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp(prefix="sqlmap-repl-test") + self.path = os.path.join(self.tmp, "out.sqlite3") + self.repl = Replication(self.path) + + def tearDown(self): + try: + self.repl.connection.close() + except Exception: + pass + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_create_insert_select_roundtrip(self): + t = self.repl.createTable("t", [("id", Replication.INTEGER), ("name", Replication.TEXT)]) + t.beginTransaction() + t.insert(["1", "alice"]) + t.insert(["2", "bob"]) + t.endTransaction() + rows = sorted(t.select()) + self.assertEqual(rows, [(1, "alice"), (2, "bob")]) + + def test_select_with_condition(self): + t = self.repl.createTable("t", [("id", Replication.INTEGER), ("name", Replication.TEXT)]) + t.insert(["1", "alice"]) + t.insert(["2", "bob"]) + self.assertEqual(t.select("name = 'bob'"), [(2, "bob")]) + + def test_insert_wrong_arity_raises(self): + from lib.core.exception import SqlmapValueException + t = self.repl.createTable("t", [("id", Replication.INTEGER), ("name", Replication.TEXT)]) + with self.assertRaises(SqlmapValueException): + t.insert(["only-one-value"]) + + def test_typeless_table(self): + t = self.repl.createTable("t", ["a", "b"], typeless=True) + t.insert(["x", "y"]) + self.assertEqual(t.select(), [("x", "y")]) + + def test_datatype_str(self): + self.assertEqual(str(Replication.TEXT), "TEXT") + self.assertEqual(str(Replication.INTEGER), "INTEGER") + self.assertIn("DataType", repr(Replication.REAL)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 000000000..bcf6da6a4 --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Unit coverage for the file-read/file-write/UDF-injection SQL & command builders: + + - plugins/generic/filesystem.py (encoding, INSERT/UPDATE query forging, + length probe, read/write dispatch) + - plugins/dbms/mssqlserver/filesystem.py + (debug.exe SCR script, BULK INSERT / + bin->hex extraction, PowerShell & + certutil base64 upload commands) + - lib/takeover/udf.py (sys_exec/sys_eval calls, CREATE FUNCTION + SQL for MySQL/PostgreSQL, remote-path + selection, UDF pruning) + +These methods are (near-)pure string builders given conf/kb plus the injection +layer. Each test drives the real method with inject.goStacked / inject.getValue +(and, for MSSQL, xpCmdshellWriteFile/execCmd) captured, and asserts the EXACT +SQL / command / encoded payload produced -- so a regression in the assembly +logic fails the test. No live target / network / DBMS involved. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.data import conf, kb +from lib.core.convert import encodeHex, encodeBase64, getText +from lib.core.enums import PAYLOAD + + +# --------------------------------------------------------------------------- # +# shared base: snapshot/restore every global + monkeypatch these tests touch # +# --------------------------------------------------------------------------- # +class _FsBase(unittest.TestCase): + # subclasses set `target_modules` = list of modules whose inject.* we patch + target_modules = () + + # conf fields read by the methods under test + _CONF_KEYS = ("batch", "direct", "fileRead", "fileWrite", "filePath", + "commonFiles", "osPwn", "osCmd", "osShell", "regRead", + "regAdd", "regDel", "tmpPath", "shLib", "encoding") + _KB_KEYS = ("bruteMode", "binaryField", "fileReadMode") + + def setUp(self): + self._conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._kb = {k: kb.get(k) for k in self._KB_KEYS} + self._patched = [] # (obj, attr, original) + + conf.batch = True + conf.direct = True + kb.bruteMode = False + + def tearDown(self): + for obj, attr, orig in reversed(self._patched): + setattr(obj, attr, orig) + for k, v in self._conf.items(): + conf[k] = v + for k, v in self._kb.items(): + kb[k] = v + + def patch(self, obj, attr, value): + self._patched.append((obj, attr, getattr(obj, attr))) + setattr(obj, attr, value) + return value + + +# --------------------------------------------------------------------------- # +# plugins/generic/filesystem.py # +# --------------------------------------------------------------------------- # +class TestGenericFilesystem(_FsBase): + import plugins.generic.filesystem as module + + def _fs(self): + return self.module.Filesystem() + + # -- fileContentEncode ------------------------------------------------- # + def test_fileContentEncode_hex_single(self): + # single=True -> one element, 0x-prefixed, exact lower-case hex of bytes + out = self._fs().fileContentEncode(b"ABC", "hex", True) + self.assertEqual(out, ["0x414243"]) + + def test_fileContentEncode_base64_single(self): + out = self._fs().fileContentEncode(b"ABC", "base64", True) + self.assertEqual(out, ["'QUJD'"]) + + def test_fileContentEncode_hex_chunked(self): + # 4 bytes -> 8 hex chars; chunkSize=4 -> two 0x-prefixed chunks of 4 chars + out = self._fs().fileContentEncode(b"ABCD", "hex", False, chunkSize=4) + self.assertEqual(out, ["0x4142", "0x4344"]) + + def test_fileContentEncode_base64_chunked(self): + # "ABCD" -> base64 "QUJDRA==" (8 chars); chunkSize=4 -> two quoted chunks + out = self._fs().fileContentEncode(b"ABCD", "base64", False, chunkSize=4) + self.assertEqual(out, ["'QUJD'", "'RA=='"]) + + def test_fileContentEncode_chunk_below_threshold_is_single(self): + # content shorter than chunkSize, single=False -> still one 0x chunk + out = self._fs().fileContentEncode(b"AB", "hex", False, chunkSize=256) + self.assertEqual(out, ["0x4142"]) + + def test_fileEncode_reads_then_encodes(self): + # fileEncode must read the file bytes and delegate to fileContentEncode + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_fe_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"hello") + try: + out = self._fs().fileEncode(path, "hex", True) + finally: + os.remove(path) + self.assertEqual(out, ["0x%s" % getText(encodeHex(b"hello"))]) + self.assertEqual(out, ["0x68656c6c6f"]) + + # -- fileToSqlQueries -------------------------------------------------- # + def test_fileToSqlQueries_insert_then_concat_update(self): + # first chunk -> INSERT; subsequent -> UPDATE using the DBMS concatenate + # template (MySQL: CONCAT(field, chunk)). + set_dbms("MySQL") + fs = self._fs() + queries = fs.fileToSqlQueries(["0x4142", "0x4344", "0x4546"]) + tbl, fld = fs.fileTblName, fs.tblField + self.assertEqual(queries[0], + "INSERT INTO %s(%s) VALUES (0x4142)" % (tbl, fld)) + self.assertEqual(queries[1], + "UPDATE %s SET %s=CONCAT(%s,0x4344)" % (tbl, fld, fld)) + self.assertEqual(queries[2], + "UPDATE %s SET %s=CONCAT(%s,0x4546)" % (tbl, fld, fld)) + + # -- _checkFileLength -------------------------------------------------- # + def test_checkFileLength_mysql_query_and_samefile(self): + # MySQL builds LENGTH(LOAD_FILE('')) and compares to local size. + set_dbms("MySQL") + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_cl_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"12345") # 5 bytes + captured = {} + + def getValue(query, *a, **k): + captured["query"] = query + return "5" + + self.patch(self.module.inject, "getValue", getValue) + try: + same = self._fs()._checkFileLength(path, "/etc/passwd") + finally: + os.remove(path) + self.assertEqual(captured["query"], + "LENGTH(LOAD_FILE('/etc/passwd'))") + self.assertIs(same, True) + + def test_checkFileLength_size_differs(self): + set_dbms("MySQL") + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_cl2_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"12345") # local 5 + self.patch(self.module.inject, "getValue", lambda q, *a, **k: "9") + try: + same = self._fs()._checkFileLength(path, "/etc/passwd") + finally: + os.remove(path) + # remote 9 != local 5 -> not the same file + self.assertIs(same, False) + + def test_checkFileLength_mssql_openrowset_stacked(self): + # MSSQL path issues an OPENROWSET BULK INSERT then DATALENGTH probe. + # createSupportTbl lives in the misc mixin; stub it on a subclass so the + # OPENROWSET-building branch runs in isolation. + set_dbms("Microsoft SQL Server") + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_cl3_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"ABCD") # 4 bytes + stacked = [] + + class FS(self.module.Filesystem): + def createSupportTbl(self, *a, **k): + pass + + self.patch(self.module.inject, "goStacked", + lambda q, *a, **k: stacked.append(q)) + self.patch(self.module.inject, "getValue", lambda q, *a, **k: "4") + fs = FS() + try: + same = fs._checkFileLength(path, "C:\\boot.ini") + finally: + os.remove(path) + tbl, fld = fs.fileTblName, fs.tblField + # createSupportTbl DROP+CREATE, then the OPENROWSET insert + insert = ("INSERT INTO %s(%s) SELECT %s FROM OPENROWSET(BULK " + "'C:\\boot.ini', SINGLE_BLOB) AS %s(%s)" + % (tbl, fld, fld, tbl, fld)) + self.assertIn(insert, stacked) + self.assertIs(same, True) + + def test_checkFileLength_not_written_warns_false(self): + # non-positive remote size -> treated as "not written" -> sameFile False + set_dbms("MySQL") + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_cl4_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"x") + self.patch(self.module.inject, "getValue", lambda q, *a, **k: None) + try: + same = self._fs()._checkFileLength(path, "/etc/passwd") + finally: + os.remove(path) + self.assertIs(same, False) + + # -- readFile ---------------------------------------------------------- # + def test_readFile_decodes_hex_and_writes(self): + # Drive the generic readFile orchestration with a stubbed stackedReadFile + # returning canned hex; assert the bytes handed to dataToOutFile are the + # decoded content (raw bytes), and the remote name is passed through. + set_dbms("MySQL") + written = {} + + class FS(self.module.Filesystem): + def checkDbmsOs(self): + pass + + def cleanup(self, *a, **k): + pass + + def stackedReadFile(self, remoteFile): + return encodeHex(b"secret-data", binary=False) + + def askCheckReadFile(self, localFile, remoteFile): + return None + + def grab(name, data): + written["d"] = (name, data) + return "/out/path" + + self.patch(self.module, "dataToOutFile", grab) + out = FS().readFile("/etc/shadow") + self.assertEqual(written["d"][0], "/etc/shadow") + self.assertEqual(written["d"][1], b"secret-data") + self.assertEqual(out, ["/out/path"]) + + def test_readFile_listlike_chunks_joined(self): + # list-of-chunks return value gets flattened before hex-decoding + set_dbms("MySQL") + written = {} + + class FS(self.module.Filesystem): + def checkDbmsOs(self): + pass + + def cleanup(self, *a, **k): + pass + + def stackedReadFile(self, remoteFile): + # two chunks (each a 1-element list, as inject.getValue returns) + return [[encodeHex(b"AB", binary=False)], + [encodeHex(b"CD", binary=False)]] + + def askCheckReadFile(self, localFile, remoteFile): + return True + + def grab(name, data): + written["d"] = data + return "/out" + + self.patch(self.module, "dataToOutFile", grab) + out = FS().readFile("/f") + self.assertEqual(written["d"], b"ABCD") + # askCheckReadFile True -> suffix annotation + self.assertEqual(out, ["/out (same file)"]) + + # -- writeFile dispatch ------------------------------------------------ # + def test_writeFile_dispatches_to_stacked(self): + # With stacking available (conf.direct True), writeFile must route to + # stackedWriteFile and return its result. + set_dbms("MySQL") + path = os.path.join( + os.environ.get("TMPDIR", "/tmp"), "sqlmap_wf_%d.bin" % os.getpid()) + with open(path, "wb") as f: + f.write(b"data") + calls = {} + + class FS(self.module.Filesystem): + def checkDbmsOs(self): + pass + + def cleanup(self, *a, **k): + calls["cleanup"] = True + + def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): + calls["args"] = (localFile, remoteFile, fileType, forceCheck) + return True + + try: + res = FS().writeFile(path, "/var/www/x", "text", forceCheck=True) + finally: + os.remove(path) + self.assertIs(res, True) + self.assertEqual(calls["args"], (path, "/var/www/x", "text", True)) + self.assertTrue(calls["cleanup"]) + + +# --------------------------------------------------------------------------- # +# plugins/dbms/mssqlserver/filesystem.py # +# --------------------------------------------------------------------------- # +class TestMSSQLFilesystem(_FsBase): + import plugins.dbms.mssqlserver.filesystem as module + + def _handler(self): + from plugins.dbms.mssqlserver import MSSQLServerMap + set_dbms("Microsoft SQL Server") + return MSSQLServerMap() + + # -- _dataToScr (debug.exe script) ------------------------------------- # + def test_dataToScr_header_and_hex_bytes(self): + fs = self._handler() + lines = fs._dataToScr(b"AB", "chunk1") + # header: name / rcx / size(hex) / fill + self.assertEqual(lines[0], "n chunk1") + self.assertEqual(lines[1], "rcx") + self.assertEqual(lines[2], "%x" % 2) # size = 2 bytes + self.assertEqual(lines[3], "f 0100 %x 00" % 2) + # the data 'e' line: base addr 0x100, hex of 'A'(41) and 'B'(42) + self.assertEqual(lines[4], "e 100 41 42") + self.assertEqual(lines[-2], "w") + self.assertEqual(lines[-1], "q") + + def test_dataToScr_wraps_lines_and_advances_address(self): + # lineLen=20, so 21 bytes -> two 'e' lines; second starts at 0x100+20=0x114 + fs = self._handler() + content = bytes(bytearray(range(21))) # 21 bytes 0x00..0x14 + lines = fs._dataToScr(content, "c") + eLines = [ln for ln in lines if ln.startswith("e ")] + self.assertEqual(len(eLines), 2) + self.assertTrue(eLines[0].startswith("e 100 00 01 02")) + # 20 bytes consumed -> next address 0x100+0x14 = 0x114 + self.assertTrue(eLines[1].startswith("e 114 14")) + + # -- stackedReadFile (BULK INSERT + bin->hex extraction) --------------- # + def test_stackedReadFile_builds_bulk_insert_and_decodes(self): + fs = self._handler() + stacked = [] + self.patch(self.module.inject, "goStacked", + lambda q, *a, **k: stacked.append(q)) + + # UNION available -> single getValue returns the hex content directly + def getValue(query, *a, **k): + return encodeHex(b"file-bytes", binary=False) + + self.patch(self.module.inject, "getValue", getValue) + self.patch(self.module, "isTechniqueAvailable", lambda *a, **k: True) + + result = fs.stackedReadFile("C:\\secret.txt") + + # the BULK INSERT statement loading the file into the support table + bulk = [q for q in stacked if q.startswith("BULK INSERT ")] + self.assertEqual(len(bulk), 1) + self.assertIn("FROM 'C:\\secret.txt'", bulk[0]) + self.assertIn("CODEPAGE='RAW'", bulk[0]) + # the bin->hex conversion routine must reference the 0..F charset + binhex = [q for q in stacked if "0123456789ABCDEF" in q] + self.assertEqual(len(binhex), 1) + self.assertIn("DATALENGTH", binhex[0]) + # result is the raw hex string returned by getValue + self.assertEqual(result, encodeHex(b"file-bytes", binary=False)) + + def test_stackedReadFile_chunked_when_no_union(self): + # No UNION technique -> COUNT(*) then per-row TOP-1 retrieval into a list + fs = self._handler() + self.patch(self.module.inject, "goStacked", lambda q, *a, **k: None) + self.patch(self.module, "isTechniqueAvailable", lambda *a, **k: False) + + chunks = ["41", "42"] + + def getValue(query, *a, **k): + if query.startswith("SELECT COUNT(*)"): + return "2" + # the per-index extraction query + if "NOT IN (SELECT TOP" in query: + return chunks.pop(0) + return None + + self.patch(self.module.inject, "getValue", getValue) + result = fs.stackedReadFile("C:\\x") + self.assertEqual(result, ["41", "42"]) + + # -- unionWriteFile is explicitly unsupported -------------------------- # + def test_unionWriteFile_unsupported(self): + from lib.core.exception import SqlmapUnsupportedFeatureException + fs = self._handler() + self.assertRaises(SqlmapUnsupportedFeatureException, + fs.unionWriteFile, "a", "b", "binary") + + # -- _stackedWriteFilePS (PowerShell base64) --------------------------- # + def test_stackedWriteFilePS_uploads_base64_and_builds_ps(self): + fs = self._handler() + writes = [] + cmds = [] + self.patch(fs, "xpCmdshellWriteFile", + lambda content, path, name: writes.append((content, name))) + self.patch(fs, "execCmd", lambda cmd: cmds.append(cmd)) + + fs._stackedWriteFilePS("C:\\Windows\\Temp", b"payload", + "C:\\out.exe", "binary") + + expected_b64 = encodeBase64(b"payload", binary=False) + # the base64 payload goes to the .txt file; the .ps1 holds the decoder. + uploaded = "".join(c for c, name in writes if name.endswith(".txt")) + self.assertEqual(uploaded, expected_b64) + # the powershell command line: ByPass + reference to the .ps1 script + self.assertEqual(len(cmds), 1) + self.assertIn("powershell -ExecutionPolicy ByPass -File", cmds[0]) + + def test_stackedWriteFilePS_script_decodes_to_remote(self): + # Assert the PS script body contains the FromBase64String + Set-Content + # targeting the exact remote file path. + fs = self._handler() + script = {} + + def grab(content, path, name): + if name.endswith(".ps1"): + script["body"] = content + + self.patch(fs, "xpCmdshellWriteFile", grab) + self.patch(fs, "execCmd", lambda cmd: None) + fs._stackedWriteFilePS("C:\\T", b"abc", "C:\\target.dll", "binary") + self.assertIn("[System.Convert]::FromBase64String($Base64)", script["body"]) + self.assertIn('Set-Content -Path "C:\\target.dll"', script["body"]) + + # -- _stackedWriteFileCertutilExe (certutil base64) -------------------- # + def test_stackedWriteFileCertutil_splits_b64_and_decodes(self): + fs = self._handler() + writes = [] + cmds = [] + self.patch(fs, "xpCmdshellWriteFile", + lambda content, path, name: writes.append(content)) + self.patch(fs, "execCmd", lambda cmd: cmds.append(cmd)) + + # >500 chars of base64 so the splitter actually wraps lines + content = b"Z" * 600 + fs._stackedWriteFileCertutilExe("C:\\T", "local", content, + "C:\\out.bin", "binary") + + b64 = encodeBase64(content, binary=False) + # uploaded text == base64 rejoined on newline at 500-char boundaries + uploaded = writes[0] + self.assertEqual(uploaded.replace("\n", ""), b64) + self.assertEqual(uploaded.split("\n")[0], b64[:500]) + # certutil -decode command targeting the remote file + self.assertEqual(len(cmds), 1) + self.assertIn("certutil -f -decode", cmds[0]) + self.assertIn("C:\\out.bin", cmds[0]) + + +# --------------------------------------------------------------------------- # +# lib/takeover/udf.py (+ MySQL/PostgreSQL CREATE FUNCTION overrides) # +# --------------------------------------------------------------------------- # +class TestUDF(_FsBase): + import lib.takeover.udf as module + + def _udf(self): + u = self.module.UDF() + u.cmdTblName = "cmdtbl" + u.tblField = "data" + return u + + # -- udfForgeCmd ------------------------------------------------------- # + def test_udfForgeCmd_wraps_quotes(self): + u = self._udf() + self.assertEqual(u.udfForgeCmd("whoami"), "'whoami'") + # already partially quoted -> not doubled + self.assertEqual(u.udfForgeCmd("'whoami"), "'whoami'") + self.assertEqual(u.udfForgeCmd("whoami'"), "'whoami'") + + def _escaped(self, u, cmd): + # mirror udfExecCmd's argument preparation: forge then escape via the + # active DBMS unescaper. (The escaper may hex-encode the literal; we want + # to assert the SELECT wrapping/udf-name wiring, not re-test escaping.) + return self.module.unescaper.escape(u.udfForgeCmd(cmd)) + + # -- udfExecCmd -------------------------------------------------------- # + def test_udfExecCmd_builds_select_call(self): + set_dbms("MySQL") + u = self._udf() + captured = {} + self.patch(self.module.inject, "goStacked", + lambda q, silent=False: captured.setdefault("q", q)) + u.udfExecCmd("id") + # default udfName is sys_exec; arg is the forged+escaped command + self.assertEqual(captured["q"], + "SELECT sys_exec(%s)" % self._escaped(u, "id")) + + def test_udfExecCmd_custom_udf_name(self): + set_dbms("MySQL") + u = self._udf() + captured = {} + self.patch(self.module.inject, "goStacked", + lambda q, silent=False: captured.setdefault("q", q)) + u.udfExecCmd("id", udfName="my_fn") + self.assertEqual(captured["q"], + "SELECT my_fn(%s)" % self._escaped(u, "id")) + + # -- udfEvalCmd -------------------------------------------------------- # + def test_udfEvalCmd_direct_joins_lines(self): + # conf.direct -> uses udfExecCmd output, converting \r to \n + set_dbms("MySQL") + conf.direct = True + u = self._udf() + self.patch(self.module.inject, "goStacked", + lambda q, silent=False: ["foo\rbar", "baz"]) + out = u.udfEvalCmd("id") + self.assertEqual(out, "foo\nbarbaz") + + def test_udfEvalCmd_stacked_insert_select_delete(self): + # non-direct -> INSERT via UDF, SELECT back, then DELETE + set_dbms("MySQL") + conf.direct = False + u = self._udf() + stacked = [] + self.patch(self.module.inject, "goStacked", + lambda q, *a, **k: stacked.append(q)) + self.patch(self.module.inject, "getValue", + lambda q, *a, **k: "RESULT") + out = u.udfEvalCmd("id", udfName="sys_eval") + self.assertEqual( + stacked[0], + "INSERT INTO cmdtbl(data) VALUES (sys_eval(%s))" + % self._escaped(u, "id")) + self.assertEqual(stacked[1], "DELETE FROM cmdtbl") + self.assertEqual(out, "RESULT") + + # -- udfCheckNeeded (pruning of the sys UDF set) ----------------------- # + def test_udfCheckNeeded_prunes_unrequested_udfs(self): + set_dbms("MySQL") + u = self._udf() + u.sysUdfs = { + "sys_fileread": {}, "sys_bineval": {}, + "sys_eval": {}, "sys_exec": {}, + } + # nothing requested -> everything irrelevant gets popped + conf.fileRead = conf.commonFiles = None + conf.osPwn = conf.osCmd = conf.osShell = conf.regRead = False + conf.regAdd = conf.regDel = False + u.udfCheckNeeded() + self.assertEqual(u.sysUdfs, {}) + + def test_udfCheckNeeded_keeps_exec_for_oscmd(self): + set_dbms("MySQL") + u = self._udf() + u.sysUdfs = { + "sys_fileread": {}, "sys_bineval": {}, + "sys_eval": {}, "sys_exec": {}, + } + conf.fileRead = conf.commonFiles = None + conf.osPwn = False + conf.osCmd = True # requests command exec + conf.osShell = conf.regRead = conf.regAdd = conf.regDel = False + u.udfCheckNeeded() + # sys_eval & sys_exec retained; fileread/bineval pruned + self.assertIn("sys_eval", u.sysUdfs) + self.assertIn("sys_exec", u.sysUdfs) + self.assertNotIn("sys_fileread", u.sysUdfs) + self.assertNotIn("sys_bineval", u.sysUdfs) + + def test_udfCheckNeeded_keeps_fileread_for_pgsql_fileread(self): + # sys_fileread is retained ONLY when a file read is requested AND the + # back-end is PostgreSQL (per the explicit DBMS.PGSQL guard). + set_dbms("PostgreSQL") + u = self._udf() + u.sysUdfs = {"sys_fileread": {}, "sys_bineval": {}, + "sys_eval": {}, "sys_exec": {}} + conf.fileRead = "/etc/passwd" + conf.commonFiles = None + conf.osPwn = conf.osCmd = conf.osShell = conf.regRead = False + conf.regAdd = conf.regDel = False + u.udfCheckNeeded() + self.assertIn("sys_fileread", u.sysUdfs) + + def test_udfCheckNeeded_drops_fileread_for_mysql_fileread(self): + # On MySQL the same file-read request still prunes sys_fileread (the + # guard keeps it only for PostgreSQL). + set_dbms("MySQL") + u = self._udf() + u.sysUdfs = {"sys_fileread": {}, "sys_bineval": {}, + "sys_eval": {}, "sys_exec": {}} + conf.fileRead = "/etc/passwd" + conf.commonFiles = None + conf.osPwn = conf.osCmd = conf.osShell = conf.regRead = False + conf.regAdd = conf.regDel = False + u.udfCheckNeeded() + self.assertNotIn("sys_fileread", u.sysUdfs) + + # -- udfCheckAndOverwrite --------------------------------------------- # + def test_udfCheckAndOverwrite_new_udf_scheduled(self): + # UDF does not exist -> no overwrite prompt -> scheduled for creation + set_dbms("MySQL") + u = self._udf() + self.patch(self.module.inject, "getValue", lambda q, *a, **k: False) + u.udfCheckAndOverwrite("sys_eval") + self.assertIn("sys_eval", u.udfToCreate) + + def test_udfCheckAndOverwrite_existing_no_overwrite(self): + # UDF exists and user declines overwrite -> NOT scheduled + set_dbms("MySQL") + u = self._udf() + self.patch(self.module.inject, "getValue", lambda q, *a, **k: True) + self.patch(u, "_askOverwriteUdf", lambda udf: False) + u.udfCheckAndOverwrite("sys_eval") + self.assertNotIn("sys_eval", u.udfToCreate) + + # -- udfInjectCore ----------------------------------------------------- # + def test_udfInjectCore_uploads_and_creates(self): + # Drive the full inject orchestration with the file write succeeding: + # every requested UDF must end up created and the support table built. + set_dbms("MySQL") + calls = {"created": [], "supportType": None} + + class U(self.module.UDF): + def __init__(self): + super(U, self).__init__() + self.cmdTblName = "cmdtbl" + self.tblField = "data" + self.udfLocalFile = __file__ # any existing file (checkFile passes) + self.udfRemoteFile = "/tmp/lib.so" + + def udfSetRemotePath(self): + pass + + def writeFile(self, localFile, remoteFile, fileType, forceCheck=False): + calls["write"] = (remoteFile, fileType, forceCheck) + return True + + def udfCreateFromSharedLib(self, udf, inpRet): + calls["created"].append(udf) + self.createdUdf.add(udf) + + def udfCreateSupportTbl(self, dataType): + calls["supportType"] = dataType + + u = U() + self.patch(self.module.inject, "getValue", lambda q, *a, **k: False) + result = u.udfInjectCore({"sys_eval": {"return": "string"}}) + self.assertIs(result, True) + # binary upload forced; remote path threaded through + self.assertEqual(calls["write"], ("/tmp/lib.so", "binary", True)) + self.assertEqual(calls["created"], ["sys_eval"]) + # MySQL support table uses longtext + self.assertEqual(calls["supportType"], "longtext") + + def test_udfInjectCore_noop_when_all_already_created(self): + # If every UDF is already created, nothing is uploaded and it returns True + set_dbms("MySQL") + + class U(self.module.UDF): + def writeFile(self, *a, **k): + raise AssertionError("writeFile must not be called") + + u = U() + u.createdUdf = {"sys_eval"} + result = u.udfInjectCore({"sys_eval": {"return": "string"}}) + self.assertIs(result, True) + self.assertEqual(u.udfToCreate, set()) + + # -- MySQL udfCreateFromSharedLib (CREATE FUNCTION ... SONAME) --------- # + def test_mysql_udfCreateFromSharedLib_sql(self): + import plugins.dbms.mysql.takeover as mod + set_dbms("MySQL") + t = mod.Takeover() + t.udfToCreate = {"sys_eval"} + t.createdUdf = set() + t.udfSharedLibName = "libsabc" + t.udfSharedLibExt = "so" + stacked = [] + self.patch(mod.inject, "goStacked", lambda q, *a, **k: stacked.append(q)) + t.udfCreateFromSharedLib("sys_eval", {"return": "string"}) + self.assertEqual(stacked[0], "DROP FUNCTION sys_eval") + self.assertEqual( + stacked[1], + "CREATE FUNCTION sys_eval RETURNS string SONAME 'libsabc.so'") + self.assertIn("sys_eval", t.createdUdf) + + # -- PostgreSQL udfCreateFromSharedLib (CREATE OR REPLACE FUNCTION) ---- # + def test_pgsql_udfCreateFromSharedLib_sql(self): + import plugins.dbms.postgresql.takeover as mod + set_dbms("PostgreSQL") + t = mod.Takeover() + t.udfToCreate = {"sys_eval"} + t.createdUdf = set() + t.udfRemoteFile = "/tmp/libsabc.so" + stacked = [] + self.patch(mod.inject, "goStacked", lambda q, *a, **k: stacked.append(q)) + t.udfCreateFromSharedLib( + "sys_eval", {"input": ["text"], "return": "text"}) + self.assertEqual(stacked[0], "DROP FUNCTION sys_eval(text)") + self.assertEqual( + stacked[1], + "CREATE OR REPLACE FUNCTION sys_eval(text) RETURNS text AS " + "'/tmp/libsabc.so', 'sys_eval' LANGUAGE C RETURNS NULL ON NULL " + "INPUT IMMUTABLE") + + # -- PostgreSQL udfSetRemotePath (OS-dependent path) ------------------- # + def test_pgsql_udfSetRemotePath_linux_and_windows(self): + # Linux -> /tmp/; Windows -> bare (saved into the data dir). + # Set kb.os directly to avoid Backend.setOs()'s interactive OS-mismatch + # prompt when flipping the OS mid-test. + import plugins.dbms.postgresql.takeover as mod + from lib.core.enums import OS + set_dbms("PostgreSQL") + t = mod.Takeover() + t.udfSharedLibName = "libsxyz" + t.udfSharedLibExt = "so" + + _os = kb.os + try: + kb.os = OS.LINUX + t.udfSetRemotePath() + self.assertEqual(t.udfRemoteFile, "/tmp/libsxyz.so") + + kb.os = OS.WINDOWS + t.udfSharedLibExt = "dll" + t.udfSetRemotePath() + self.assertEqual(t.udfRemoteFile, "libsxyz.dll") + finally: + kb.os = _os + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py new file mode 100644 index 000000000..879420feb --- /dev/null +++ b/tests/test_fingerprint.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +DBMS version/fork fingerprinting (plugins/dbms//fingerprint.py). Each +plugin's getFingerprint()/checkDbms() probes the backend with a cascade of +boolean expressions (inject.checkBooleanExpression) and version reads +(inject.getValue). Those are the network seam: stubbing them lets the dialect's +whole detection cascade run offline. We drive every targeted plugin with the +oracle pinned both True and False so opposite branches of the cascade execute. +""" + +import importlib +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.data import conf, kb +from lib.core.common import Backend + +# (display name, fingerprint module, handler package) +TARGETS = [ + ("MySQL", "plugins.dbms.mysql.fingerprint", "plugins.dbms.mysql"), + ("PostgreSQL", "plugins.dbms.postgresql.fingerprint", "plugins.dbms.postgresql"), + ("Microsoft SQL Server", "plugins.dbms.mssqlserver.fingerprint", "plugins.dbms.mssqlserver"), + ("Oracle", "plugins.dbms.oracle.fingerprint", "plugins.dbms.oracle"), + ("IBM DB2", "plugins.dbms.db2.fingerprint", "plugins.dbms.db2"), + ("Microsoft Access", "plugins.dbms.access.fingerprint", "plugins.dbms.access"), + ("Firebird", "plugins.dbms.firebird.fingerprint", "plugins.dbms.firebird"), + ("Sybase", "plugins.dbms.sybase.fingerprint", "plugins.dbms.sybase"), + ("SAP MaxDB", "plugins.dbms.maxdb.fingerprint", "plugins.dbms.maxdb"), + ("HSQLDB", "plugins.dbms.hsqldb.fingerprint", "plugins.dbms.hsqldb"), + ("H2", "plugins.dbms.h2.fingerprint", "plugins.dbms.h2"), + ("Presto", "plugins.dbms.presto.fingerprint", "plugins.dbms.presto"), + ("Vertica", "plugins.dbms.vertica.fingerprint", "plugins.dbms.vertica"), + ("Informix", "plugins.dbms.informix.fingerprint", "plugins.dbms.informix"), + ("InterSystems Cache", "plugins.dbms.cache.fingerprint", "plugins.dbms.cache"), + ("MonetDB", "plugins.dbms.monetdb.fingerprint", "plugins.dbms.monetdb"), + ("Altibase", "plugins.dbms.altibase.fingerprint", "plugins.dbms.altibase"), + ("ClickHouse", "plugins.dbms.clickhouse.fingerprint", "plugins.dbms.clickhouse"), + ("CrateDB", "plugins.dbms.cratedb.fingerprint", "plugins.dbms.cratedb"), + ("Cubrid", "plugins.dbms.cubrid.fingerprint", "plugins.dbms.cubrid"), + ("Mckoi", "plugins.dbms.mckoi.fingerprint", "plugins.dbms.mckoi"), + ("Virtuoso", "plugins.dbms.virtuoso.fingerprint", "plugins.dbms.virtuoso"), + ("Raima Database Manager", "plugins.dbms.raima.fingerprint", "plugins.dbms.raima"), + ("eXtremeDB", "plugins.dbms.extremedb.fingerprint", "plugins.dbms.extremedb"), + ("FrontBase", "plugins.dbms.frontbase.fingerprint", "plugins.dbms.frontbase"), + ("Apache Derby", "plugins.dbms.derby.fingerprint", "plugins.dbms.derby"), + ("MimerSQL", "plugins.dbms.mimersql.fingerprint", "plugins.dbms.mimersql"), +] + + +def _handler_cls(pkg): + main = importlib.import_module(pkg) + return [getattr(main, n) for n in dir(main) if n.endswith("Map")][0] + + +# Dialects whose non-extensive getFingerprint emits Format.getDbms() (i.e. +# " ") rather than a hard-coded DBMS.* constant, so the version +# that flowed through (Backend.setVersionList(["1.0"])) actually appears in the +# output. (In the test harness Backend.getDbms() is None because set_dbms uses +# forceDbms, so for these the dialect NAME is absent but "1.0" is load-bearing.) +ACTVER_DBMS = frozenset(( + "MySQL", "Microsoft SQL Server", "Firebird", "HSQLDB", +)) + +# Dialects whose getFingerprint has a fork concept: with the oracle pinned True +# the first fork-detection branch fires (MySQL->MariaDB, PostgreSQL->CockroachDB, +# Oracle->DM8, Cache->Iris, H2->Apache Ignite, Presto->Trino) and the output +# gains a " (... fork)" suffix. Pinned False, no fork is emitted. +FORK_DBMS = frozenset(( + "MySQL", "PostgreSQL", "Oracle", "InterSystems Cache", "H2", "Presto", +)) + +# Dialects whose getFingerprint genuinely needs more extraction state under +# conf.extensiveFp and raises a narrow KeyError before completing. +EXTENSIVE_RAISERS = frozenset(( + "SAP MaxDB", +)) + + +class TestFingerprint(unittest.TestCase): + def setUp(self): + self._saved = {k: conf.get(k) for k in ("batch", "extensiveFp", "api", "dbms", "forceDbms")} + self._kb = {k: kb.get(k) for k in ("dbmsVersion", "forcedDbms", "dbms", "stickyDBMS", + "resolutionDbms", "os", "osVersion", "osSP")} + conf.batch = True + conf.extensiveFp = False + conf.api = False + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + for k, v in self._kb.items(): + kb[k] = v + + def _drive(self, name, modpath, pkg, oracle): + set_dbms(name) + Backend.setVersionList(["1.0"]) + mod = importlib.import_module(modpath) + if hasattr(mod, "inject"): + mod.inject.checkBooleanExpression = lambda e, *a, **k: oracle + mod.inject.getValue = lambda q, *a, **k: "1.0" + handler = _handler_cls(pkg)() + fp = handler.getFingerprint() + self.assertIsInstance(fp, str) + + # Real content: the dialect's own identity must have flowed into the + # output, not merely the constant "back-end DBMS: " prefix. + if name in ACTVER_DBMS: + # Format.getDbms() embedded the version list -> "1.0" must appear. + self.assertIn("1.0", fp, + "%s fp lost the version that flowed through: %r" % (name, fp)) + else: + # the dialect name (DBMS.* constant) must appear verbatim. + self.assertIn(Backend.getForcedDbms(), fp, + "%s fp lost its dialect name: %r" % (name, fp)) + + # Fork detection: with the oracle pinned True the first fork branch + # fires for the fork-bearing dialects; pinned False none do. This is the + # only thing distinguishing the True/False runs for those dialects. + if name in FORK_DBMS: + if oracle: + self.assertIn("fork)", fp, + "%s did not emit a fork label with oracle=True: %r" % (name, fp)) + else: + self.assertNotIn("fork)", fp, + "%s emitted a fork label with oracle=False: %r" % (name, fp)) + else: + # dialects with no fork concept never emit a fork label + self.assertNotIn("fork)", fp) + + # checkDbms walks the dialect's detection cascade end-to-end; it must + # return a real boolean verdict (True/False), never None or a raise. + verdict = handler.checkDbms() + self.assertIn(verdict, (True, False), + "%s checkDbms() returned a non-bool: %r" % (name, verdict)) + return fp + + def test_fingerprint_oracle_true(self): + for name, modpath, pkg in TARGETS: + self._drive(name, modpath, pkg, True) + + def test_fingerprint_oracle_false(self): + for name, modpath, pkg in TARGETS: + self._drive(name, modpath, pkg, False) + + def test_fingerprint_extensive(self): + # conf.extensiveFp drives the deeper comment-/version-/dbms-check cascades + # (getFingerprint past the early return) — much more code per dialect. + # In this mode every dialect's output is built around an + # "active fingerprint: " line, so that header is the + # real content proof; the version "1.0" rides along for the ACTVER set. + conf.extensiveFp = True + try: + for name, modpath, pkg in TARGETS: + for oracle in (True, False): + set_dbms(name) + Backend.setVersionList(["1.0"]) + mod = importlib.import_module(modpath) + if hasattr(mod, "inject"): + mod.inject.checkBooleanExpression = lambda e, *a, **k: oracle + mod.inject.getValue = lambda q, *a, **k: "1.0" + handler = _handler_cls(pkg)() + if name in EXTENSIVE_RAISERS: + # this dialect genuinely needs extra extraction state under + # extensiveFp; assert it gets exactly that far and no further. + with self.assertRaises(KeyError): + handler.getFingerprint() + continue + fp = handler.getFingerprint() + self.assertIsInstance(fp, str) + self.assertIn("active fingerprint:", fp, + "%s extensiveFp produced no active-fingerprint line: %r" % (name, fp)) + if name in ACTVER_DBMS: + self.assertIn("1.0", fp, + "%s extensiveFp lost the version: %r" % (name, fp)) + finally: + conf.extensiveFp = False + + +def _make(name, modpath, pkg): + def _t(self): + # _drive already asserts real, dialect-specific content (version/name + + # fork label + a boolean checkDbms verdict) for both oracle states. + self._drive(name, modpath, pkg, True) + self._drive(name, modpath, pkg, False) + return _t + + +# one named test per DBMS for clearer reporting +for _name, _mod, _pkg in TARGETS: + setattr(TestFingerprint, "test_fp_%s" % _pkg.split(".")[-1], _make(_name, _mod, _pkg)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generic_enum_more.py b/tests/test_generic_enum_more.py new file mode 100644 index 000000000..683a459b7 --- /dev/null +++ b/tests/test_generic_enum_more.py @@ -0,0 +1,865 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional unit tests for the generic enumeration mixins, deliberately targeting +branches NOT already exercised by tests/test_databases_enum.py, +tests/test_users_enum.py, tests/test_search_enum.py and tests/test_generic_more.py +(which cover the conf.direct INBAND happy paths). + +This file drives the OTHER branches: + + * plugins/generic/databases.py - the INFERENCE paths (conf.direct=False + + isInferenceAvailable via kb.injection BOOLEAN state: count -> per-row getValue), + the MSSQL inband-paging fallback in getDbs(), getColumns onlyColNames / dumpMode, + the getColumns MySQL<5 / ACCESS bruteforce fallback, getCount over cachedTables, + and getStatements/getProcedures empty/none branches. + * plugins/generic/users.py - getPrivileges role/grant parsing per DBMS in BOTH the + inband path (PGSQL digit columns, MySQL<5 Y/N, Firebird letters, DB2 grant codes) + and the INFERENCE path (count then per-index privilege), getPasswordHashes + grouping/dedup in the inference path, getUsers inference, isDba MSSQL. + * plugins/generic/entries.py - dumpTable INFERENCE path (count -> column-pivot via + per-(index,column) getValue), the empty-table branch, the count-failure skip, + and the resolveKeysetCursor disabling via conf.noKeyset. + * plugins/generic/search.py - searchDb / searchTable / searchColumn INFERENCE + paths (count then per-index getValue), and the MySQL<5 bruteforce branch of + searchTable / searchColumn. + +Recipe (proven in tests/test_databases_enum.py): patch the module's inject.getValue +with canned rows in the EXACT shape the branch parses; for inference branches return +a positive int for EXPECTED.INT count calls then the per-row/per-index values; set the +needed kb.data flags; assert the exact resulting structure (sorted lists, +{db:{tbl:{col:type}}} dicts, privilege sets, dumpedTable values). + +CRITICAL STATE HYGIENE: every test snapshots and restores conf.*, the patched +inject.getValue (per module), kb.data.cached*, kb.hintValue, kb.injection.data, +Backend/forcedDbms in tearDown so nothing leaks into the rest of the suite. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms + +bootstrap() + +from lib.core.data import conf, kb +from lib.core.enums import EXPECTED, PAYLOAD + +import plugins.generic.databases as dbmod +import plugins.generic.users as umod +import plugins.generic.search as smod +import plugins.generic.entries as emod +from plugins.generic.databases import Databases +from plugins.generic.users import Users +from plugins.generic.search import Search +from plugins.generic.entries import Entries + +_NOOP = lambda self: None + + +def _inference_gv(count, sequence): + """Build an inject.getValue stub for blind inference branches. + + Returns `count` (as str) whenever the caller asks for EXPECTED.INT, otherwise + yields the next item from `sequence` wrapped as a single-cell row ([value]), + cycling if exhausted. This mirrors the count-then-per-row contract of every + isInferenceAvailable() branch. + """ + state = {"i": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(count) + val = sequence[state["i"] % len(sequence)] + state["i"] += 1 + return [val] + + return gv + + +# --------------------------------------------------------------------------- # +# databases.py +# --------------------------------------------------------------------------- # + +class _DbBase(unittest.TestCase): + _CONF_KEYS = ("direct", "technique", "db", "tbl", "col", "exclude", + "getComments", "excludeSysDbs", "search", "freshQueries") + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._saved_getValue = dbmod.inject.getValue + self._saved_checkBool = dbmod.inject.checkBooleanExpression + self._saved_injection_data = kb.injection.data + self._saved_has_is = kb.data.get("has_information_schema") + self._saved_hintValue = kb.get("hintValue") + self._saved_choices = dict(kb.choices) + self._saved_readInput = dbmod.readInput + self._saved_forceDbmsEnum = getattr(Databases, "forceDbmsEnum", None) + Databases.forceDbmsEnum = _NOOP + + conf.getComments = False + conf.excludeSysDbs = False + conf.exclude = None + conf.search = False + conf.freshQueries = False + conf.col = None + kb.data.has_information_schema = True + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + dbmod.inject.getValue = self._saved_getValue + dbmod.inject.checkBooleanExpression = self._saved_checkBool + dbmod.readInput = self._saved_readInput + kb.injection.data = self._saved_injection_data + kb.data.has_information_schema = self._saved_has_is + kb.hintValue = self._saved_hintValue + kb.choices.clear() + kb.choices.update(self._saved_choices) + if self._saved_forceDbmsEnum is not None: + Databases.forceDbmsEnum = self._saved_forceDbmsEnum + else: + try: + del Databases.forceDbmsEnum + except AttributeError: + pass + + def _fresh(self): + d = Databases() + kb.data.currentDb = "" + kb.data.cachedDbs = [] + kb.data.cachedTables = {} + kb.data.cachedColumns = {} + kb.data.cachedCounts = {} + kb.data.cachedStatements = [] + kb.data.cachedProcedures = [] + return d + + def _inference(self): + conf.direct = False + conf.technique = None + kb.injection.data = {PAYLOAD.TECHNIQUE.BOOLEAN: {"title": "AND boolean-based blind"}} + + +class TestDatabasesInference(_DbBase): + def test_get_columns_inference_pgsql_types(self): + # Blind column enumeration on PostgreSQL: a count, then for each index a + # column name followed by its type. Assert the {db:{tbl:{col:type}}} parse. + set_dbms("PostgreSQL") + self._inference() + d = self._fresh() + conf.db = "public" + conf.tbl = "users" + + names = ["id", "email"] + state = {"i": 0, "name": True} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return str(len(names)) + if state["name"]: + val = names[state["i"] % len(names)] + state["i"] += 1 + state["name"] = False + return [val] + state["name"] = True + return ["integer"] + + dbmod.inject.getValue = gv + result = d.getColumns() + cols = result["public"]["users"] + self.assertEqual(len(cols), 2) + self.assertEqual(cols.get("id"), "integer") + + def test_get_columns_inference_dump_mode_collist(self): + # dumpMode with an explicit conf.col list: in the inference branch the + # columns are taken straight from colList (no count/type queries at all) + # and stored with value None. Asserting no getValue ran proves the + # dump-mode shortcut, not a network round-trip. + set_dbms("MySQL") + self._inference() + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + conf.col = "id,name" + + def boom(*a, **k): + raise AssertionError("dumpMode+colList must not query in inference branch") + + dbmod.inject.getValue = boom + result = d.getColumns(dumpMode=True) + cols = result["testdb"]["users"] + # "name" is a reserved word -> safeSQLIdentificatorNaming backtick-quotes it; + # both columns must be present (count, since exact key varies by quoting). + self.assertEqual(len(cols), 2) + self.assertIn("id", cols) + self.assertIsNone(cols.get("id")) + + def test_get_count_over_cached_tables_inference(self): + # getCount with no conf.tbl: it calls getTables() then per-table _tableGetCount. + # Drive the inband table fetch + per-table count and assert the + # {db:{count:[tables]}} grouping (tables sharing a count are grouped). + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = None + kb.data.cachedTables = {"testdb": ["users", "posts"]} + + counts = {"users": "5", "posts": "5"} + + def gv(query, *a, **k): + for t, c in counts.items(): + if t in query: + return c + return "0" + + dbmod.inject.getValue = gv + result = d.getCount() + # both tables have count 5 -> grouped under the same key + self.assertEqual(sorted(result["testdb"][5]), ["posts", "users"]) + + def test_get_statements_count_zero_returns_empty(self): + # Inference path: a zero count short-circuits to the (empty) cache. + set_dbms("PostgreSQL") + self._inference() + d = self._fresh() + # getStatements compares the count with the int literal 0 (count == 0), so + # the count stub must return an int 0 (not "0") to take the empty branch. + dbmod.inject.getValue = lambda query, *a, **k: 0 if k.get("expected") == EXPECTED.INT else self.fail("must not fetch rows when count is 0") + result = d.getStatements() + self.assertEqual(result, []) + + def test_get_procedures_inference(self): + set_dbms("PostgreSQL") + self._inference() + d = self._fresh() + dbmod.inject.getValue = _inference_gv(2, ["sp_a", "sp_b"]) + result = d.getProcedures() + self.assertEqual(sorted(result), ["sp_a", "sp_b"]) + + def test_get_dbs_mssql_inband_paging(self): + # MSSQL with no rows from the primary query falls into the query2 paging + # loop (one indexed query per db until a blank value stops it). + set_dbms("Microsoft SQL Server") + conf.direct = True + d = self._fresh() + dbs = ["master", "model"] + + def gv(query, *a, **k): + # The primary inband query is 'SELECT name FROM master..sysdatabases' + # (no DB_NAME); make it return nothing so getDbs falls into the + # 'SELECT DB_NAME()' paging loop (query2). + if "DB_NAME" not in query: + return None + import re as _re + idx = int(_re.findall(r"DB_NAME\((\d+)\)", query)[0]) + return dbs[idx] if idx < len(dbs) else "" + + dbmod.inject.getValue = gv + result = d.getDbs() + self.assertEqual(sorted(result), ["master", "model"]) + + def test_get_tables_inference_grouped_per_db(self): + # Blind table enumeration: count for the db, then one table name per index. + set_dbms("MySQL") + self._inference() + d = self._fresh() + conf.db = "shop" + conf.tbl = None + dbmod.inject.getValue = _inference_gv(2, ["orders", "items"]) + result = d.getTables() + self.assertIn("shop", result) + self.assertEqual(sorted(result["shop"]), ["items", "orders"]) + + +class TestDatabasesBruteForce(_DbBase): + def test_get_columns_mysql_lt5_bruteforce_decline(self): + # MySQL < 5 (no information_schema) forces bruteForce in getColumns; with + # the common-column-existence prompt answered 'N' it returns None without + # issuing any column query. + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + kb.data.has_information_schema = False + kb.choices.columnExists = None + dbmod.readInput = lambda *a, **k: "N" + + def boom(*a, **k): + raise AssertionError("bruteForce decline must not query columns") + + dbmod.inject.getValue = boom + result = d.getColumns() + self.assertIsNone(result) + + def test_get_columns_bruteforce_dumpmode_collist_on_decline(self): + # bruteForce + decline + dumpMode + colList: the columns from colList are + # stored with None type (the dump-mode salvage branch), not dropped. + set_dbms("MySQL") + conf.direct = True + d = self._fresh() + conf.db = "testdb" + conf.tbl = "users" + conf.col = "a,b" + kb.data.has_information_schema = False + kb.choices.columnExists = None + dbmod.readInput = lambda *a, **k: "N" + dbmod.inject.getValue = lambda *a, **k: None + result = d.getColumns(dumpMode=True) + cols = result["testdb"]["users"] + self.assertEqual(sorted(cols.keys()), ["a", "b"]) + self.assertIsNone(cols.get("a")) + + +# --------------------------------------------------------------------------- # +# users.py +# --------------------------------------------------------------------------- # + +class _UsersBase(unittest.TestCase): + def setUp(self): + self._direct = conf.direct + self._technique = conf.technique + self._user = conf.user + self._gv = umod.inject.getValue + self._cbe = umod.inject.checkBooleanExpression + self._store = umod.storeHashesToFile + self._attack = umod.attackCachedUsersPasswords + self._readInput = umod.readInput + self._his = kb.data.get("has_information_schema") + self._injection_data = kb.injection.data + + set_dbms("MySQL") + conf.direct = True + conf.user = None + kb.data.has_information_schema = True + + umod.storeHashesToFile = lambda *a, **k: None + umod.attackCachedUsersPasswords = lambda *a, **k: None + umod.readInput = lambda *a, **k: "N" + + def tearDown(self): + conf.direct = self._direct + conf.technique = self._technique + conf.user = self._user + umod.inject.getValue = self._gv + umod.inject.checkBooleanExpression = self._cbe + umod.storeHashesToFile = self._store + umod.attackCachedUsersPasswords = self._attack + umod.readInput = self._readInput + kb.injection.data = self._injection_data + if self._his is None: + kb.data.pop("has_information_schema", None) + else: + kb.data.has_information_schema = self._his + + def _inference(self): + conf.direct = False + conf.technique = None + kb.injection.data = {PAYLOAD.TECHNIQUE.BOOLEAN: {"title": "AND boolean-based blind"}} + + +class TestUsersPrivilegesInband(_UsersBase): + def test_privileges_pgsql_multiple_digit_columns(self): + # PostgreSQL: privilege columns are digit flags; a column index maps to + # PGSQL_PRIVS only when its value is "1". Set createdb(1)=1 and super(2)=1, + # leave the rest 0; assert exactly those two privileges are parsed and that + # "super" makes the user an admin. + set_dbms("PostgreSQL") + from lib.core.dicts import PGSQL_PRIVS + ncols = max(PGSQL_PRIVS.keys()) + row = ["pguser"] + ["0"] * ncols + row[1] = "1" # createdb + row[2] = "1" # super + umod.inject.getValue = lambda query, *a, **k: [row] + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + self.assertEqual(set(privileges["pguser"]), {PGSQL_PRIVS[1], PGSQL_PRIVS[2]}) + self.assertIn("pguser", areAdmins) + + def test_privileges_mysql_lt5_yn_flags(self): + # MySQL < 5 (no information_schema): privilege columns are 'Y'/'N' flags + # mapped to MYSQL_PRIVS by column position. Y in col 1 -> select_priv. + set_dbms("MySQL") + from lib.core.dicts import MYSQL_PRIVS + kb.data.has_information_schema = False + ncols = max(MYSQL_PRIVS.keys()) + row = ["root"] + ["N"] * ncols + row[1] = "Y" # select_priv + row[3] = "Y" # update_priv + umod.inject.getValue = lambda query, *a, **k: [row] + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + self.assertIn(MYSQL_PRIVS[1], privileges["root"]) + self.assertIn(MYSQL_PRIVS[3], privileges["root"]) + self.assertNotIn(MYSQL_PRIVS[2], privileges["root"]) + + def test_privileges_firebird_letter_codes(self): + # Firebird: each privilege is a single letter mapped via FIREBIRD_PRIVS. + set_dbms("Firebird") + from lib.core.dicts import FIREBIRD_PRIVS + umod.inject.getValue = lambda query, *a, **k: [["fbuser", "S"], ["fbuser", "I"]] + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + self.assertEqual(set(privileges["fbuser"]), + {FIREBIRD_PRIVS["S"], FIREBIRD_PRIVS["I"]}) + + def test_privileges_db2_grant_codes(self): + # DB2: privilege string is ","; each 'Y'/'G' letter at + # position i appends the DB2_PRIVS[i] name to the privilege. + set_dbms("DB2") + from lib.core.dicts import DB2_PRIVS + conf.user = "db2admin" + # "DBADM" plus a grant string whose first letter (position 1) is 'Y' -> + # DB2_PRIVS[1] ("CONTROLAUTH") is appended. + umod.inject.getValue = lambda query, *a, **k: [["DB2ADMIN", "DBADM,Y"]] + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + joined = " ".join(privileges["DB2ADMIN"]) + self.assertIn("DBADM", joined) + self.assertIn(DB2_PRIVS[1], joined) + + +class TestUsersPrivilegesInference(_UsersBase): + def test_privileges_inference_mysql(self): + # Blind privilege enumeration for a named user: count, then one privilege + # string per index. MySQL >= 5 adds each verbatim. + set_dbms("MySQL") + self._inference() + conf.user = "root" + privs = ["SELECT", "SUPER"] + umod.inject.getValue = _inference_gv(2, privs) + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + # the user key is wildcard-wrapped for the MySQL information_schema LIKE + key = [k for k in privileges if "root" in k][0] + self.assertEqual(set(privileges[key]), {"SELECT", "SUPER"}) + self.assertTrue(areAdmins) # SUPER => admin + + def test_privileges_inference_oracle(self): + set_dbms("Oracle") + self._inference() + conf.user = "system" + umod.inject.getValue = _inference_gv(1, ["DBA"]) + users = Users() + kb.data.cachedUsersPrivileges = {} + privileges, areAdmins = users.getPrivileges() + self.assertIn("SYSTEM", privileges) + self.assertEqual(privileges["SYSTEM"], ["DBA"]) + self.assertIn("SYSTEM", areAdmins) + + +class TestUsersPasswordHashesInference(_UsersBase): + def test_password_hashes_inference_grouping(self): + # Blind password-hash enumeration for two users: per-user count, then one + # hash per index. Assert each user maps to its own hash list. + set_dbms("MySQL") + self._inference() + conf.user = "root,guest" + + # per-user single hash; count is 1 for every user + hashes = {"root": "*ROOTHASH", "guest": "*GUESTHASH"} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return "1" + for u, h in hashes.items(): + if u in query: + return [h] + return [None] + + umod.inject.getValue = gv + users = Users() + kb.data.cachedUsersPasswords = {} + res = users.getPasswordHashes() + self.assertEqual(res["root"], ["*ROOTHASH"]) + self.assertEqual(res["guest"], ["*GUESTHASH"]) + + def test_password_hashes_inference_dedup(self): + # The same hash returned twice for a user must be de-duplicated at the end + # (kb.data.cachedUsersPasswords[user] = list(set(...))). + set_dbms("MySQL") + self._inference() + conf.user = "root" + umod.inject.getValue = _inference_gv(2, ["*DUP", "*DUP"]) + users = Users() + kb.data.cachedUsersPasswords = {} + res = users.getPasswordHashes() + self.assertEqual(res["root"], ["*DUP"]) + + +class TestUsersGetUsersInference(_UsersBase): + def test_get_users_inference(self): + set_dbms("MySQL") + self._inference() + umod.inject.getValue = _inference_gv(2, ["root@localhost", "guest@%"]) + users = Users() + kb.data.cachedUsers = [] + res = users.getUsers() + self.assertEqual(sorted(res), ["guest@%", "root@localhost"]) + + def test_is_dba_mssql(self): + # MSSQL isDba goes through the generic checkBooleanExpression branch. + set_dbms("Microsoft SQL Server") + umod.inject.checkBooleanExpression = lambda query, *a, **k: True + users = Users() + kb.data.isDba = None + self.assertTrue(users.isDba()) + + +# --------------------------------------------------------------------------- # +# entries.py - inference (blind) dump path +# --------------------------------------------------------------------------- # + +class _RecordingDumper(object): + def __init__(self): + self.tableValues = [] + + def dbTableValues(self, tableValues): + self.tableValues.append(tableValues) + + +class _TestEntries(Entries): + def __init__(self): + Entries.__init__(self) + self.getColumnsResult = {} + self.getTablesResult = {} + + def forceDbmsEnum(self): + pass + + def getCurrentDb(self): + return "testdb" + + def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None, dumpMode=False): + kb.data.cachedColumns = dict(self.getColumnsResult) + + def getTables(self, bruteForce=None): + kb.data.cachedTables = dict(self.getTablesResult) + + +class _EntriesBase(unittest.TestCase): + _CONF_KEYS = ("db", "tbl", "col", "direct", "technique", "exclude", "search", + "disableHashing", "noKeyset", "keyset", "forcePivoting", "dumpWhere") + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._saved_dumper = conf.get("dumper") + self._gv = emod.inject.getValue + self._cbe = emod.inject.checkBooleanExpression + self._readInput = emod.readInput + self._saved_has_is = kb.data.get("has_information_schema") + self._saved_cachedColumns = kb.data.get("cachedColumns") + self._saved_cachedTables = kb.data.get("cachedTables") + self._saved_dumpedTable = kb.data.get("dumpedTable") + self._saved_dumpKbInt = kb.get("dumpKeyboardInterrupt") + self._saved_permissionFlag = kb.get("permissionFlag") + self._saved_injection_data = kb.injection.data + + set_dbms("MySQL") + conf.direct = False + conf.technique = None + conf.exclude = None + conf.search = False + conf.disableHashing = True + conf.noKeyset = True + conf.keyset = False + conf.forcePivoting = False + conf.dumpWhere = None + conf.dumper = _RecordingDumper() + + kb.data.has_information_schema = True + kb.data.cachedColumns = {} + kb.data.cachedTables = {} + kb.data.dumpedTable = {} + kb.dumpKeyboardInterrupt = False + kb.permissionFlag = False + kb.injection.data = {PAYLOAD.TECHNIQUE.BOOLEAN: {"title": "AND boolean-based blind"}} + + emod.readInput = lambda *a, **k: (k.get("default") if k.get("default") is not None else (a[1] if len(a) > 1 else None)) + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + conf.dumper = self._saved_dumper + emod.inject.getValue = self._gv + emod.inject.checkBooleanExpression = self._cbe + emod.readInput = self._readInput + kb.data.has_information_schema = self._saved_has_is + kb.data.cachedColumns = self._saved_cachedColumns + kb.data.cachedTables = self._saved_cachedTables + kb.data.dumpedTable = self._saved_dumpedTable + kb.dumpKeyboardInterrupt = self._saved_dumpKbInt + kb.permissionFlag = self._saved_permissionFlag + kb.injection.data = self._saved_injection_data + + +class TestEntriesInference(_EntriesBase): + def _entries(self, db="testdb", tbl="users", cols=("id", "name")): + e = _TestEntries() + e.getColumnsResult = {db: {tbl: {c: "varchar" for c in cols}}} + return e + + def test_dump_table_inference_column_pivot(self): + # Blind dump (conf.direct=False, BOOLEAN available): a row count, then one + # value per (index, column). Assert the per-column pivoted values match. + set_dbms("MySQL") + e = self._entries(cols=("id", "name")) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + + # data[index][column] -> value. 2 rows, columns id/name. + data = {0: {"id": "1", "name": "alice"}, 1: {"id": "2", "name": "bob"}} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return "2" # row count + # MySQL blind cell query: 'SELECT FROM testdb.users ORDER BY ... + # LIMIT ,1'. The row index is the LIMIT offset; the column is the + # SELECT projection. + import re as _re + idx = int(_re.search(r"LIMIT\s+(\d+)\s*,\s*1", query).group(1)) + proj = query.split(" FROM ", 1)[0] + col = "name" if "name" in proj else "id" + return data[idx][col] + + emod.inject.getValue = gv + e.dumpTable() + dumped = conf.dumper.tableValues[-1] + self.assertEqual(dumped["__infos__"]["count"], 2) + self.assertEqual(list(dumped["id"]["values"]), ["1", "2"]) + self.assertEqual(list(dumped["name"]["values"]), ["alice", "bob"]) + + def test_dump_table_inference_empty_table(self): + # A zero row count in the inference path yields empty per-column value + # lists and no dbTableValues emission (dumpedTable stays effectively empty). + set_dbms("MySQL") + e = self._entries(cols=("id",)) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + + emod.inject.getValue = lambda query, *a, **k: ("0" if k.get("expected") == EXPECTED.INT else self.fail("must not fetch cells for empty table")) + e.dumpTable() + # count 0 => empty entries => nothing dumped + self.assertEqual(conf.dumper.tableValues, []) + + def test_dump_table_inference_count_failure_skips(self): + # A non-numeric count in the inference path => the table is skipped with a + # warning, no values dumped. + set_dbms("MySQL") + e = self._entries(cols=("id",)) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + return None # count failed + self.fail("must not fetch cells when count failed") + + emod.inject.getValue = gv + e.dumpTable() + self.assertEqual(conf.dumper.tableValues, []) + + +# --------------------------------------------------------------------------- # +# search.py - inference (blind) paths +# --------------------------------------------------------------------------- # + +class _TestSearch(Search): + excludeDbsList = ["information_schema", "mysql"] + + def __init__(self): + Search.__init__(self) + self.like = ('2', "='%s'") # exact match (colConsider '2') + self.dumpFoundTablesCalls = [] + self.dumpFoundColumnCalls = [] + + def likeOrExact(self, what): + return self.like + + def forceDbmsEnum(self): + pass + + def getCurrentDb(self): + return "testdb" + + def dumpFoundTables(self, tables): + self.dumpFoundTablesCalls.append(tables) + + def dumpFoundColumn(self, dbs, foundCols, colConsider): + self.dumpFoundColumnCalls.append((dbs, foundCols, colConsider)) + + def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None, dumpMode=False): + db, tbl, col = conf.db, conf.tbl, conf.col + if db and tbl: + kb.data.cachedColumns.setdefault(db, {}).setdefault(tbl, {}) + kb.data.cachedColumns[db][tbl][col] = "varchar" + + +class _RecDumper(object): + def __init__(self): + self.listed = [] + self.dbTablesArg = None + self.dbColumnsArg = None + + def lister(self, header, elements, content_type=None, sort=True): + self.listed.append((header, list(elements) if elements else [])) + + def dbTables(self, dbTables): + self.dbTablesArg = dbTables + + def dbColumns(self, dbColumnsDict, colConsider, dbs): + self.dbColumnsArg = (dbColumnsDict, colConsider, dbs) + + +class _SearchBase(unittest.TestCase): + _CONF_KEYS = ("db", "tbl", "col", "direct", "technique", "excludeSysDbs", + "exclude", "search") + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._saved_dumper = conf.get("dumper") + self._gv = smod.inject.getValue + self._readInput = smod.readInput + self._saved_has_is = kb.data.get("has_information_schema") + self._saved_cachedColumns = kb.data.get("cachedColumns") + self._saved_hintValue = kb.get("hintValue") + self._saved_injection_data = kb.injection.data + + set_dbms("MySQL") + conf.direct = False + conf.technique = None + conf.excludeSysDbs = False + conf.exclude = None + conf.search = True + conf.dumper = _RecDumper() + + kb.data.has_information_schema = True + kb.data.cachedColumns = {} + kb.injection.data = {PAYLOAD.TECHNIQUE.BOOLEAN: {"title": "AND boolean-based blind"}} + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + conf.dumper = self._saved_dumper + smod.inject.getValue = self._gv + smod.readInput = self._readInput + kb.data.has_information_schema = self._saved_has_is + kb.data.cachedColumns = self._saved_cachedColumns + kb.hintValue = self._saved_hintValue + kb.injection.data = self._saved_injection_data + + +class TestSearchInference(_SearchBase): + def test_search_db_inference(self): + # Blind searchDb: count of matching dbs, then one db name per index. + s = _TestSearch() + conf.db = "testdb" + smod.inject.getValue = _inference_gv(2, ["testdb", "testdb2"]) + s.searchDb() + self.assertEqual(conf.dumper.listed[-1][0], "found databases") + self.assertEqual(sorted(conf.dumper.listed[-1][1]), ["testdb", "testdb2"]) + + def test_search_db_inference_no_match(self): + # Count fails (non-numeric) => no databases appended, empty listing. + s = _TestSearch() + conf.db = "ghost" + smod.inject.getValue = lambda query, *a, **k: (None if k.get("expected") == EXPECTED.INT else self.fail("must not page when count fails")) + s.searchDb() + self.assertEqual(conf.dumper.listed[-1][1], []) + + def test_search_table_inference_grouped(self): + # Blind searchTable, no conf.db: outer count of dbs holding the table, then + # per-db a name, then per-db a count of matching tables, then table names. + s = _TestSearch() + conf.tbl = "users" + conf.db = None + + # Sequencing by the EXPECTED.INT counts + the per-index string results. + # 1st count: number of databases with the table -> 1 + # 1st db name -> "testdb" + # 2nd count: number of tables in testdb -> 1 + # table name -> "users" + seq = {"counts": ["1", "1"], "ci": 0, "vals": ["testdb", "users"], "vi": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + v = seq["counts"][seq["ci"] % len(seq["counts"])] + seq["ci"] += 1 + return v + v = seq["vals"][seq["vi"] % len(seq["vals"])] + seq["vi"] += 1 + return [v] + + smod.inject.getValue = gv + s.searchTable() + self.assertEqual(conf.dumper.dbTablesArg, {"testdb": ["users"]}) + self.assertEqual(s.dumpFoundTablesCalls[-1], {"testdb": ["users"]}) + + def test_search_table_mysql_lt5_bruteforce_decline(self): + # MySQL < 5 forces the bruteforce path; declining the prompt returns None + # without any injection. + s = _TestSearch() + conf.tbl = "users" + conf.db = None + kb.data.has_information_schema = False + smod.readInput = lambda *a, **k: "N" + smod.inject.getValue = lambda *a, **k: self.fail("bruteforce decline must not query") + self.assertIsNone(s.searchTable()) + + def test_search_column_inference(self): + # Blind searchColumn, no db/tbl: count of dbs with the column, then db name; + # then per-db count of tables with the column, then table name -> getColumns + # folds the column into dbs. + s = _TestSearch() + conf.col = "password" + conf.db = None + conf.tbl = None + + seq = {"counts": ["1", "1"], "ci": 0, "vals": ["testdb", "users"], "vi": 0} + + def gv(query, *a, **k): + if k.get("expected") == EXPECTED.INT: + v = seq["counts"][seq["ci"] % len(seq["counts"])] + seq["ci"] += 1 + return v + v = seq["vals"][seq["vi"] % len(seq["vals"])] + seq["vi"] += 1 + return [v] + + smod.inject.getValue = gv + s.searchColumn() + dbs = conf.dumper.dbColumnsArg[2] + self.assertIn("testdb", dbs) + self.assertIn("users", dbs["testdb"]) + self.assertIn("password", dbs["testdb"]["users"]) + + def test_search_column_mysql_lt5_bruteforce_decline(self): + s = _TestSearch() + conf.col = "password" + conf.db = None + conf.tbl = None + kb.data.has_information_schema = False + smod.readInput = lambda *a, **k: "N" + smod.inject.getValue = lambda *a, **k: self.fail("bruteforce decline must not query") + # Declining returns None and never reaches dbColumns. + self.assertIsNone(s.searchColumn()) + self.assertIsNone(conf.dumper.dbColumnsArg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generic_more.py b/tests/test_generic_more.py new file mode 100644 index 000000000..00bcd0c8d --- /dev/null +++ b/tests/test_generic_more.py @@ -0,0 +1,873 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional unit tests for the generic plugin mixins, driving branches NOT already +covered by tests/test_search_enum.py / tests/test_databases_enum.py: + + * plugins/generic/entries.py - dumpTable column/table --exclude filtering, the + --where (conf.dumpWhere) query rewrite, disableHashing toggle, METADB suffix + db handling, the "no usable columns" / "missing columns" skip branches, and + dumpAll over multiple dbs/tables (dict and list shapes) plus dumpFoundTables / + dumpFoundColumn interactive flows. + * plugins/generic/custom.py - sqlQuery SELECT/non-query/stacked branches, the + MSSQL FROM rewrite, METADB suffix stripping, SqlmapNoneDataException handling, + and sqlFile. + * plugins/generic/misc.py - getRemoteTempPath (posix / windows-direct / MSSQL + ErrorLog), getVersionFromBanner, delRemoteFile, createSupportTbl, likeOrExact. + * plugins/generic/takeover.py - the PURE helpers only: Takeover.__init__ table + naming and the regRead/regAdd/regDel/osBof/osSmb control flow with the process/ + network collaborators stubbed out (no metasploit/icmpsh/UDF spawning). + +The injection layer (lib.request.inject.{getValue,goStacked}) is patched per +module, conf.direct=True selects the simple inband branches, conf.batch=True keeps +prompts non-interactive, and conf.dumper is a recording stub. Every test restores +all touched conf.* / kb.* / patched module attributes in tearDown so nothing leaks. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms + +bootstrap() + +from lib.core.common import Backend +from lib.core.data import conf, kb +from lib.core.enums import DBMS, OS +from lib.core.settings import NULL + +import plugins.generic.entries as emod +import plugins.generic.custom as cmod +import plugins.generic.misc as mmod +import plugins.generic.takeover as tmod +from plugins.generic.entries import Entries +from plugins.generic.custom import Custom +from plugins.generic.misc import Miscellaneous + + +class _RecordingDumper(object): + """Recording stand-in for conf.dumper (no printing / file writing).""" + + def __init__(self): + self.tableValues = [] + self.sqlQueries = [] + + def dbTableValues(self, tableValues): + self.tableValues.append(tableValues) + + def sqlQuery(self, query, queryRes): + self.sqlQueries.append((query, queryRes)) + + +# --------------------------------------------------------------------------- # +# entries.py +# --------------------------------------------------------------------------- # + +class _TestEntries(Entries): + """Entries with cross-mixin collaborators stubbed. + + forceDbmsEnum / getCurrentDb / getColumns / getTables are normally supplied by + sibling mixins; we emulate column/table discovery by populating kb.data.cached* + from canned attributes, exactly as the production plugins do. + """ + + def __init__(self): + Entries.__init__(self) + self.getColumnsResult = {} # assigned to kb.data.cachedColumns + self.getTablesResult = {} # assigned to kb.data.cachedTables + self.getColumnsCalls = [] + self.getTablesCalls = 0 + + def forceDbmsEnum(self): + pass + + def getCurrentDb(self): + return "testdb" + + def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None, dumpMode=False): + self.getColumnsCalls.append((conf.db, conf.tbl)) + kb.data.cachedColumns = dict(self.getColumnsResult) + + def getTables(self, bruteForce=None): + self.getTablesCalls += 1 + kb.data.cachedTables = dict(self.getTablesResult) + + +class _GenericBase(unittest.TestCase): + """Snapshot/restore for everything the generic mixins touch.""" + + _CONF_KEYS = ( + "db", "tbl", "col", "direct", "batch", "exclude", "search", + "disableHashing", "noKeyset", "keyset", "forcePivoting", "dumpWhere", + "tmpPath", "sqlQuery", "sqlFile", "regKey", "regVal", "regData", + "regType", "osPwn", "osShell", "cleanup", "privEsc", + ) + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in self._CONF_KEYS} + self._saved_dumper = conf.get("dumper") + + self._saved_getValue = { + emod: emod.inject.getValue, + cmod: cmod.inject.getValue, + mmod: mmod.inject.getValue, + } + self._saved_goStacked = { + cmod: cmod.inject.goStacked, + mmod: mmod.inject.goStacked, + } + self._saved_emod_readInput = emod.readInput + self._saved_mmod_readInput = mmod.readInput + + self._saved_kb = { + "cachedColumns": kb.data.get("cachedColumns"), + "cachedTables": kb.data.get("cachedTables"), + "dumpedTable": kb.data.get("dumpedTable"), + "has_information_schema": kb.data.get("has_information_schema"), + "dumpKeyboardInterrupt": kb.get("dumpKeyboardInterrupt"), + "permissionFlag": kb.get("permissionFlag"), + "hintValue": kb.get("hintValue"), + "injection_data": kb.injection.data, + "bannerFp": kb.get("bannerFp"), + "os": kb.get("os"), + } + self._saved_forceDbms = kb.get("forcedDbms") + + conf.direct = True + conf.batch = True + conf.exclude = None + conf.search = False + conf.disableHashing = True + conf.noKeyset = True + conf.keyset = False + conf.forcePivoting = False + conf.dumpWhere = None + conf.dumper = _RecordingDumper() + + kb.data.cachedColumns = {} + kb.data.cachedTables = {} + kb.data.dumpedTable = {} + kb.data.has_information_schema = True + kb.dumpKeyboardInterrupt = False + kb.permissionFlag = False + + def _readInput(message, default=None, checkBatch=True, boolean=False): + if boolean: + return default in (None, 'Y', 'y', True) + return default + + emod.readInput = _readInput + mmod.readInput = _readInput + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + conf.dumper = self._saved_dumper + + for mod, fn in self._saved_getValue.items(): + mod.inject.getValue = fn + for mod, fn in self._saved_goStacked.items(): + mod.inject.goStacked = fn + emod.readInput = self._saved_emod_readInput + mmod.readInput = self._saved_mmod_readInput + + kb.data.cachedColumns = self._saved_kb["cachedColumns"] + kb.data.cachedTables = self._saved_kb["cachedTables"] + kb.data.dumpedTable = self._saved_kb["dumpedTable"] + kb.data.has_information_schema = self._saved_kb["has_information_schema"] + kb.dumpKeyboardInterrupt = self._saved_kb["dumpKeyboardInterrupt"] + kb.permissionFlag = self._saved_kb["permissionFlag"] + kb.hintValue = self._saved_kb["hintValue"] + kb.injection.data = self._saved_kb["injection_data"] + kb.bannerFp = self._saved_kb["bannerFp"] + kb.os = self._saved_kb["os"] + kb.forcedDbms = self._saved_forceDbms + + @staticmethod + def _force_os(os_name): + # Backend.setOs only assigns when kb.os is currently None; reset first so + # tests can deterministically pin the back-end OS. + kb.os = None + Backend.setOs(os_name) + + +class TestEntriesDumpTable(_GenericBase): + def _entries(self, db="testdb", tbl="users", cols=("id", "name")): + e = _TestEntries() + e.getColumnsResult = {db: {tbl: {c: "varchar" for c in cols}}} + return e + + def test_exclude_filters_columns(self): + set_dbms("MySQL") + e = self._entries(cols=("id", "secret")) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + conf.exclude = "secret" + emod.inject.getValue = lambda *a, **k: [["1"]] + + e.dumpTable() + dumped = conf.dumper.tableValues[-1] + self.assertIn("id", dumped) + self.assertNotIn("secret", dumped) + + def test_exclude_all_columns_skips(self): + set_dbms("MySQL") + e = self._entries(cols=("secret",)) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + conf.exclude = "secret" + emod.inject.getValue = lambda *a, **k: self.fail("should not fetch entries") + + e.dumpTable() + # all columns excluded => "no usable column names" => nothing dumped + self.assertEqual(conf.dumper.tableValues, []) + + def test_dumpwhere_rewrites_query(self): + set_dbms("MySQL") + e = self._entries(cols=("id",)) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + conf.dumpWhere = "id>5" + captured = {} + + def gv(query, *a, **k): + captured["query"] = query + return [["9"]] + + emod.inject.getValue = gv + e.dumpTable() + # agent.whereQuery folds conf.dumpWhere into the dump query + self.assertIn("id>5", captured["query"]) + self.assertEqual(list(conf.dumper.tableValues[-1]["id"]["values"]), ["9"]) + + def test_disablehashing_false_path(self): + # conf.disableHashing False => attackDumpedTable() is invoked; with no + # hashes present it must complete without raising and still emit values. + set_dbms("MySQL") + e = self._entries(cols=("id", "name")) + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + conf.disableHashing = False + emod.inject.getValue = lambda *a, **k: [["1", "alice"]] + + # Spy on attackDumpedTable: with disableHashing False it MUST be invoked + # after the values are dumped. A recorder replaces it so we can assert the + # call happened (and no real dictionary attack runs). + saved_attack = emod.attackDumpedTable + calls = {"n": 0} + emod.attackDumpedTable = lambda *a, **k: calls.__setitem__("n", calls["n"] + 1) + try: + e.dumpTable() + finally: + emod.attackDumpedTable = saved_attack + + self.assertEqual(calls["n"], 1) + self.assertEqual(conf.dumper.tableValues[-1]["__infos__"]["count"], 1) + + def test_missing_columns_skips_table(self): + # getColumns yields nothing for the targeted table => skip without fetching. + set_dbms("MySQL") + e = _TestEntries() + e.getColumnsResult = {"testdb": {"other": {"id": "int"}}} + conf.db = "testdb" + conf.tbl = "users" + conf.col = None + emod.inject.getValue = lambda *a, **k: self.fail("should not fetch entries") + + e.dumpTable() + self.assertEqual(conf.dumper.tableValues, []) + + def test_multiple_tables_one_dumped(self): + set_dbms("MySQL") + e = _TestEntries() + e.getColumnsResult = {"testdb": {"users": {"id": "int"}, "posts": {"pid": "int"}}} + conf.db = "testdb" + conf.tbl = "users,posts" + conf.col = None + emod.inject.getValue = lambda *a, **k: [["1"]] + + e.dumpTable() + # both tables share the same cachedColumns dict => both dumped + tables = [tv["__infos__"]["table"] for tv in conf.dumper.tableValues] + self.assertIn("users", tables) + self.assertIn("posts", tables) + + def test_metadb_suffix_db(self): + # A db whose name carries the METADB_SUFFIX must not get a "db" prefix in + # kb.dumpTable, and dumping still succeeds. + from lib.core.settings import METADB_SUFFIX + set_dbms("MySQL") + metadb = "x%s" % METADB_SUFFIX + e = self._entries(db=metadb, tbl="t", cols=("c",)) + conf.db = metadb + conf.tbl = "t" + conf.col = None + emod.inject.getValue = lambda *a, **k: [["v"]] + + e.dumpTable() + self.assertEqual(list(conf.dumper.tableValues[-1]["c"]["values"]), ["v"]) + + +class TestEntriesDumpAll(_GenericBase): + def test_dumpall_multiple_dbs_tables(self): + set_dbms("MySQL") + e = _TestEntries() + conf.db = None + conf.tbl = None + conf.col = None + e.getTablesResult = {"db1": ["t1"], "db2": ["t2"]} + # dumpTable re-discovers columns per (db, tbl); supply both. + e.getColumnsResult = { + "db1": {"t1": {"a": "int"}}, + "db2": {"t2": {"b": "int"}}, + } + emod.inject.getValue = lambda *a, **k: [["x"]] + + e.dumpAll() + # Every table contributed a values batch. + self.assertEqual(len(conf.dumper.tableValues), 2) + + def test_dumpall_list_cached_tables(self): + # cachedTables as a bare list => wrapped under {None: [...]}. + set_dbms("MySQL") + e = _TestEntries() + conf.db = None + conf.tbl = None + conf.col = None + + # getTables sets cachedTables; emulate the list shape directly. + class _ListTables(_TestEntries): + def getTables(self_inner, bruteForce=None): + kb.data.cachedTables = ["users"] + + e = _ListTables() + # dumpAll wraps a bare list as {None: [...]}; dumpTable then resolves the + # None db via getCurrentDb() -> "testdb", so columns live under "testdb". + e.getColumnsResult = {"testdb": {"users": {"id": "int"}}} + emod.inject.getValue = lambda *a, **k: [["1"]] + + e.dumpAll() + self.assertTrue(conf.dumper.tableValues) + # The bare-list None db must be resolved via getCurrentDb() -> "testdb" + # before the dump; assert the dumped __infos__ carries the real db (not + # None) for the requested "users" table. + infos = conf.dumper.tableValues[-1]["__infos__"] + self.assertEqual(infos["db"], "testdb") + self.assertEqual(infos["table"], "users") + + def test_dumpall_exclude_skips_table(self): + set_dbms("MySQL") + e = _TestEntries() + conf.db = None + conf.tbl = None + conf.col = None + conf.exclude = "secret" + e.getTablesResult = {"db1": ["secret", "users"]} + e.getColumnsResult = {"db1": {"users": {"id": "int"}, "secret": {"id": "int"}}} + emod.inject.getValue = lambda *a, **k: [["1"]] + + e.dumpAll() + tables = [tv["__infos__"]["table"] for tv in conf.dumper.tableValues] + self.assertIn("users", tables) + self.assertNotIn("secret", tables) + + +class TestEntriesDumpFound(_GenericBase): + def _entries(self): + e = _TestEntries() + e.getColumnsResult = {"testdb": {"users": {"id": "int"}}} + return e + + def test_dump_found_tables_yes_all(self): + set_dbms("MySQL") + e = self._entries() + emod.inject.getValue = lambda *a, **k: [["1"]] + # batch readInput -> 'Y' (boolean True) and 'a'/'a' for db/table choices. + e.dumpFoundTables({"testdb": ["users"]}) + self.assertTrue(conf.dumper.tableValues) + # The interactive selection must dump the REQUESTED db/table, not just + # "something": assert the dumped __infos__ maps to testdb.users. + infos = conf.dumper.tableValues[-1]["__infos__"] + self.assertEqual(infos["db"], "testdb") + self.assertEqual(infos["table"], "users") + + def test_dump_found_tables_declined(self): + set_dbms("MySQL") + e = self._entries() + + def _no(message, default=None, checkBatch=True, boolean=False): + if boolean: + return False + return default + + emod.readInput = _no + emod.inject.getValue = lambda *a, **k: self.fail("must not dump when declined") + e.dumpFoundTables({"testdb": ["users"]}) + self.assertEqual(conf.dumper.tableValues, []) + + def test_dump_found_column_yes_all(self): + set_dbms("MySQL") + e = self._entries() + emod.inject.getValue = lambda *a, **k: [["1"]] + dbs = {"testdb": {"users": {"id": "int"}}} + e.dumpFoundColumn(dbs, foundCols=None, colConsider='1') + self.assertTrue(conf.dumper.tableValues) + # The selection must dump the REQUESTED db/table mapping, not just + # "something": assert the dumped __infos__ maps to testdb.users. + infos = conf.dumper.tableValues[-1]["__infos__"] + self.assertEqual(infos["db"], "testdb") + self.assertEqual(infos["table"], "users") + + +# --------------------------------------------------------------------------- # +# custom.py +# --------------------------------------------------------------------------- # + +class TestCustomSqlQuery(_GenericBase): + def test_select_joins_listlike_rows(self): + set_dbms("MySQL") + c = Custom() + cmod.inject.getValue = lambda query, **k: [["1", "alice"], ["2", "bob"]] + out = c.sqlQuery("SELECT id, name FROM users;") + # SELECT + list-like rows => each row joined into a single scalar string. + self.assertEqual(len(out), 2) + self.assertTrue(all(isinstance(_, str) for _ in out)) + + def test_select_scalar_passthrough(self): + set_dbms("MySQL") + c = Custom() + captured = {} + + def gv(query, **k): + captured["query"] = query + captured["fromUser"] = k.get("fromUser") + return "42" + + cmod.inject.getValue = gv + out = c.sqlQuery("SELECT COUNT(*) FROM users") + self.assertEqual(out, "42") + self.assertTrue(captured["fromUser"]) + + def test_metadb_suffix_stripped(self): + from lib.core.settings import METADB_SUFFIX + set_dbms("MySQL") + c = Custom() + captured = {} + + def gv(query, **k): + captured["query"] = query + return "x" + + cmod.inject.getValue = gv + c.sqlQuery("SELECT * FROM foo%s.bar" % METADB_SUFFIX) + # the METADB-suffixed schema qualifier is stripped before injection + self.assertNotIn(METADB_SUFFIX, captured["query"]) + + def test_mssql_from_dbo_rewrite(self): + set_dbms("Microsoft SQL Server") + c = Custom() + captured = {} + + def gv(query, **k): + captured["query"] = query + return "x" + + cmod.inject.getValue = gv + c.sqlQuery("SELECT * FROM mydb.users") + # single-dot FROM target gets the .dbo. schema spliced in for MSSQL + self.assertIn("mydb.dbo.users", captured["query"]) + + def test_nonquery_without_stacking_warns_none(self): + set_dbms("MySQL") + conf.direct = False + kb.injection.data = {} # no stacking technique available + c = Custom() + cmod.inject.getValue = lambda *a, **k: self.fail("must not run a query") + out = c.sqlQuery("DELETE FROM users") + self.assertIsNone(out) + + def test_nonquery_stacked_returns_null(self): + set_dbms("MySQL") + conf.direct = True # direct => stacked execution allowed + c = Custom() + calls = {} + + def go(query, *a, **k): + calls["query"] = query + + cmod.inject.goStacked = go + out = c.sqlQuery("DROP TABLE users") + self.assertEqual(out, NULL) + self.assertIn("DROP TABLE users", calls["query"]) + + def test_nonedata_exception_handled(self): + from lib.core.exception import SqlmapNoneDataException + set_dbms("MySQL") + c = Custom() + + def boom(*a, **k): + raise SqlmapNoneDataException("no data") + + cmod.inject.getValue = boom + # exception is swallowed and logged; output stays None + self.assertIsNone(c.sqlQuery("SELECT 1")) + + +class TestCustomSqlFile(_GenericBase): + def test_sqlfile_select_snippets(self): + set_dbms("MySQL") + c = Custom() + cmod.inject.getValue = lambda query, **k: "r" + + # getSQLSnippet reads from disk; patch it to return inline SQL. + saved = cmod.getSQLSnippet + try: + cmod.getSQLSnippet = lambda dbms, filename, **kw: "SELECT 1;SELECT 2" + conf.sqlFile = "dummy.sql" + c.sqlFile() + # two SELECT statements => two recorded dumper.sqlQuery calls + self.assertEqual(len(conf.dumper.sqlQueries), 2) + finally: + cmod.getSQLSnippet = saved + + def test_sqlfile_nonselect_snippet(self): + set_dbms("MySQL") + conf.direct = True + c = Custom() + cmod.inject.goStacked = lambda *a, **k: None + + saved = cmod.getSQLSnippet + try: + cmod.getSQLSnippet = lambda dbms, filename, **kw: "DROP TABLE x" + conf.sqlFile = "dummy.sql" + c.sqlFile() + # non-SELECT => single recorded call with the whole snippet + self.assertEqual(len(conf.dumper.sqlQueries), 1) + self.assertEqual(conf.dumper.sqlQueries[0][0], "DROP TABLE x") + finally: + cmod.getSQLSnippet = saved + + +# --------------------------------------------------------------------------- # +# misc.py +# --------------------------------------------------------------------------- # + +class _TestMisc(Miscellaneous): + """Miscellaneous with the OS/exec collaborators stubbed.""" + + cmdTblName = "sqlmapoutput" + + def __init__(self): + Miscellaneous.__init__(self) + self.checkDbmsOsCalls = 0 + self.execCmdCalls = [] + + def checkDbmsOs(self, detailed=False, vatch=False): + self.checkDbmsOsCalls += 1 + + def execCmd(self, cmd, silent=False): + self.execCmdCalls.append((cmd, silent)) + + +class TestMisc(_GenericBase): + def test_remote_temp_path_posix(self): + set_dbms("MySQL") + self._force_os(OS.LINUX) + conf.tmpPath = None + m = _TestMisc() + out = m.getRemoteTempPath() + self.assertEqual(out, "/tmp") + self.assertEqual(conf.tmpPath, "/tmp") + + def test_remote_temp_path_windows_direct(self): + set_dbms("MySQL") + self._force_os(OS.WINDOWS) + conf.tmpPath = None + conf.direct = True + m = _TestMisc() + out = m.getRemoteTempPath() + self.assertEqual(out, "%TEMP%") + + def test_remote_temp_path_explicit_windows_drive(self): + # An explicit Windows-style drive path flips Backend OS to Windows. + set_dbms("MySQL") + conf.tmpPath = "C:\\Temp" + m = _TestMisc() + out = m.getRemoteTempPath() + self.assertTrue(Backend.isOs(OS.WINDOWS)) + self.assertIn("Temp", out) + self.assertNotIn("\\", out) # ntToPosixSlashes normalized the path + + def test_remote_temp_path_mssql_errorlog(self): + set_dbms("Microsoft SQL Server") + conf.tmpPath = None + mmod.inject.getValue = lambda query, **k: "C:\\Logs\\ERRORLOG" + m = _TestMisc() + out = m.getRemoteTempPath() + # ntpath.dirname strips the ERRORLOG filename, then ntToPosixSlashes + # normalizes the slashes: the exact temp dir must be "C:/Logs". Asserting + # the full path (and that the filename is gone) proves dirname ran. + self.assertEqual(out, "C:/Logs") + self.assertNotIn("ERRORLOG", out) + + def test_get_version_from_banner(self): + set_dbms("MySQL") + conf.direct = True + kb.bannerFp = {} + mmod.inject.getValue = lambda query, **k: "5.7.31-log" + m = _TestMisc() + m.getVersionFromBanner() + # regex \d[\d.-]* extracts the leading numeric-ish run (trailing '-' kept) + self.assertEqual(kb.bannerFp["dbmsVersion"], "5.7.31-") + + def test_get_version_from_banner_cached(self): + set_dbms("MySQL") + kb.bannerFp = {"dbmsVersion": "8.0"} + mmod.inject.getValue = lambda *a, **k: self.fail("must not query when cached") + m = _TestMisc() + m.getVersionFromBanner() + self.assertEqual(kb.bannerFp["dbmsVersion"], "8.0") + + def test_del_remote_file_posix(self): + set_dbms("MySQL") + self._force_os(OS.LINUX) + m = _TestMisc() + m.delRemoteFile("/tmp/foo") + self.assertEqual(m.execCmdCalls[-1], ("rm -f /tmp/foo", True)) + + def test_del_remote_file_windows(self): + set_dbms("MySQL") + self._force_os(OS.WINDOWS) + m = _TestMisc() + m.delRemoteFile("C:/tmp/foo") + cmd, silent = m.execCmdCalls[-1] + self.assertTrue(cmd.startswith("del /F /Q")) + self.assertTrue(silent) + + def test_del_remote_file_empty_noop(self): + set_dbms("MySQL") + m = _TestMisc() + m.delRemoteFile(None) + self.assertEqual(m.execCmdCalls, []) + self.assertEqual(m.checkDbmsOsCalls, 0) + + def test_create_support_tbl(self): + set_dbms("MySQL") + m = _TestMisc() + stacked = [] + mmod.inject.goStacked = lambda query, **k: stacked.append(query) + m.createSupportTbl("mytbl", "data", "TEXT") + joined = " | ".join(stacked) + self.assertIn("DROP TABLE mytbl", joined) + self.assertIn("CREATE TABLE mytbl(data TEXT)", joined) + + def test_create_support_tbl_mssql_cmdtbl(self): + set_dbms("Microsoft SQL Server") + m = _TestMisc() + stacked = [] + mmod.inject.goStacked = lambda query, **k: stacked.append(query) + m.createSupportTbl(m.cmdTblName, "data", "NVARCHAR(4000)") + joined = " | ".join(stacked) + # MSSQL cmd output table gets an IDENTITY id column + self.assertIn("IDENTITY", joined) + + def test_like_or_exact_default(self): + m = _TestMisc() + mmod.readInput = lambda *a, **k: '1' + choice, cond = m.likeOrExact("table") + self.assertEqual(choice, '1') + self.assertIn("LIKE", cond) + + def test_like_or_exact_exact(self): + m = _TestMisc() + mmod.readInput = lambda *a, **k: '2' + choice, cond = m.likeOrExact("table") + self.assertEqual(choice, '2') + self.assertEqual(cond, "='%s'") + + def test_like_or_exact_invalid(self): + from lib.core.exception import SqlmapNoneDataException + m = _TestMisc() + mmod.readInput = lambda *a, **k: '9' + self.assertRaises(SqlmapNoneDataException, m.likeOrExact, "table") + + +# --------------------------------------------------------------------------- # +# takeover.py (pure helpers only) +# --------------------------------------------------------------------------- # + +class _TestTakeover(tmod.Takeover): + """Takeover with all process/network collaborators stubbed. + + Only the pure control-flow helpers (table naming, reg read/add/del dispatch, + osBof/osSmb guards) are exercised; metasploit/icmpsh/UDF spawning is replaced + with recorders so no external process or socket is ever created. + """ + + def __init__(self): + tmod.Takeover.__init__(self) + self.regCalls = [] + self.osVal = OS.WINDOWS + self.smbCalled = False + self.bofCalled = False + self._regInitCalled = 0 + + # neutralize environment setup / OS detection + def _regInit(self): + self._regInitCalled += 1 + + def checkDbmsOs(self, detailed=False, vatch=False): + pass + + def initEnv(self, *a, **k): + pass + + def getRemoteTempPath(self): + return "/tmp" + + def createMsfShellcode(self, *a, **k): + pass + + def readRegKey(self, regKey, regValue, parse=False): + self.regCalls.append(("read", regKey, regValue)) + return "value" + + def addRegKey(self, regKey, regValue, regType, regData): + self.regCalls.append(("add", regKey, regValue, regType, regData)) + + def delRegKey(self, regKey, regValue): + self.regCalls.append(("del", regKey, regValue)) + + def smb(self): + self.smbCalled = True + + def bof(self): + self.bofCalled = True + + +class TestTakeover(_GenericBase): + def _saved_takeover_readInput(self): + return tmod.readInput + + def setUp(self): + _GenericBase.setUp(self) + self._saved_t_readInput = tmod.readInput + + def tearDown(self): + tmod.readInput = self._saved_t_readInput + _GenericBase.tearDown(self) + + def test_init_cmd_table_name(self): + set_dbms("MySQL") + t = _TestTakeover() + self.assertEqual(t.cmdTblName, "%soutput" % conf.tablePrefix) + self.assertEqual(t.tblField, "data") + + def test_reg_read_from_conf(self): + set_dbms("Microsoft SQL Server") + conf.regKey = "HKLM\\Soft" + conf.regVal = "Name" + t = _TestTakeover() + out = t.regRead() + self.assertEqual(out, "value") + self.assertEqual(t.regCalls[-1], ("read", "HKLM\\Soft", "Name")) + self.assertEqual(t._regInitCalled, 1) + + def test_reg_read_defaults(self): + set_dbms("Microsoft SQL Server") + conf.regKey = None + conf.regVal = None + tmod.readInput = lambda message, default=None, **k: default + t = _TestTakeover() + t.regRead() + kind, regKey, regVal = t.regCalls[-1] + self.assertEqual(kind, "read") + self.assertIn("CurrentVersion", regKey) + self.assertEqual(regVal, "ProductName") + + def test_reg_add_from_conf(self): + set_dbms("Microsoft SQL Server") + conf.regKey = "HKLM\\Soft" + conf.regVal = "Name" + conf.regData = "data" + conf.regType = "REG_SZ" + t = _TestTakeover() + t.regAdd() + self.assertEqual(t.regCalls[-1], ("add", "HKLM\\Soft", "Name", "REG_SZ", "data")) + + def test_reg_add_missing_key_raises(self): + from lib.core.exception import SqlmapMissingMandatoryOptionException + set_dbms("Microsoft SQL Server") + conf.regKey = None + conf.regVal = None + conf.regData = None + conf.regType = None + tmod.readInput = lambda *a, **k: "" # empty -> missing mandatory option + t = _TestTakeover() + self.assertRaises(SqlmapMissingMandatoryOptionException, t.regAdd) + + def test_reg_del_confirmed(self): + set_dbms("Microsoft SQL Server") + conf.regKey = "HKLM\\Soft" + conf.regVal = "Name" + tmod.readInput = lambda message, default=None, boolean=False, **k: True if boolean else default + t = _TestTakeover() + t.regDel() + self.assertEqual(t.regCalls[-1], ("del", "HKLM\\Soft", "Name")) + + def test_reg_del_declined(self): + set_dbms("Microsoft SQL Server") + conf.regKey = "HKLM\\Soft" + conf.regVal = "Name" + tmod.readInput = lambda message, default=None, boolean=False, **k: False if boolean else default + t = _TestTakeover() + t.regDel() + # declined => no delRegKey call recorded + self.assertEqual([c for c in t.regCalls if c[0] == "del"], []) + + def test_osbof_wrong_dbms_raises(self): + from lib.core.exception import SqlmapUnsupportedDBMSException + set_dbms("MySQL") + conf.direct = True + t = _TestTakeover() + self.assertRaises(SqlmapUnsupportedDBMSException, t.osBof) + + def test_osbof_no_stacking_returns(self): + set_dbms("Microsoft SQL Server") + conf.direct = False + kb.injection.data = {} # no stacking, not direct => early return + t = _TestTakeover() + self.assertIsNone(t.osBof()) + self.assertFalse(t.bofCalled) + + def test_ossmb_non_windows_raises(self): + from lib.core.exception import SqlmapUnsupportedDBMSException + set_dbms("MySQL") + conf.direct = True + t = _TestTakeover() + + # checkDbmsOs is a no-op here, so force the non-Windows OS explicitly + self._force_os(OS.LINUX) + self.assertRaises(SqlmapUnsupportedDBMSException, t.osSmb) + self.assertFalse(t.smbCalled) + + def test_ossmb_windows_invokes_smb(self): + set_dbms("MySQL") + conf.direct = True + self._force_os(OS.WINDOWS) + t = _TestTakeover() + t.osSmb() + self.assertTrue(t.smbCalled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 64a76e930..753c5dba3 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -448,16 +448,67 @@ class TestGraphqlDialects(unittest.TestCase): self.assertIsNone(gi.DIALECTS["SQLite"].delay) +def _dbmsTruth(dbms): + """A truth() oracle that behaves like a real `dbms` back-end: it answers each + dialect's fingerprint predicate by the SQL *semantics* a genuine instance would + exhibit, keyed on the function tokens the predicate emits - never on the + fingerprint constant itself. A predicate referencing a function the back-end does + not implement raises an error on a real server and is therefore falsy here.""" + + # Which vendor-specific tokens each back-end actually understands. A predicate is + # true only if every vendor token it mentions belongs to this back-end (mirroring + # an unknown function being a hard error rather than a false comparison). + knows = { + "SQLite": ("SQLITE_VERSION()",), + "Microsoft SQL Server": ("@@VERSION",), + "PostgreSQL": ("version()",), + "MySQL": ("@@VERSION_COMMENT", "@@VERSION"), + } + # @@VERSION exists on both MSSQL and MySQL; the distinguishing factor is the + # '%Microsoft%' banner match, which only an actual Microsoft server satisfies. + vendorTokens = ("SQLITE_VERSION()", "@@VERSION_COMMENT", "@@VERSION", "version()") + owned = knows[dbms] + + def truth(cond): + # Any vendor token the predicate names must be implemented by this back-end, + # else the probe errors out (falsy). + for token in vendorTokens: + if token in cond and token not in owned: + # @@VERSION is shared; let the banner clause below decide instead. + if token == "@@VERSION" and "@@VERSION_COMMENT" not in cond: + continue + return False + if not any(token in cond for token in vendorTokens): + return False + # @@VERSION LIKE '%Microsoft%' is only true on a real Microsoft server. + if "@@VERSION" in cond and "Microsoft" in cond: + return dbms == "Microsoft SQL Server" + # version() LIKE 'PostgreSQL%' is only true on a real PostgreSQL server. + if "version()" in cond and "PostgreSQL" in cond: + return dbms == "PostgreSQL" + return True + + return truth + + class TestGraphqlFingerprint(unittest.TestCase): """DBMS fingerprinting drives off the universal truth() predicate""" def test_identifies_sqlite(self): - truth = lambda cond: cond == gi.DIALECTS["SQLite"].fingerprint - self.assertEqual(gi._fingerprint(truth), "SQLite") + # A SQLite-modelled oracle answers only SQLite's own probe; _fingerprint must + # discriminate to land on SQLite rather than echo the asserted constant. + self.assertEqual(gi._fingerprint(_dbmsTruth("SQLite")), "SQLite") def test_identifies_mysql(self): - truth = lambda cond: cond == gi.DIALECTS["MySQL"].fingerprint - self.assertEqual(gi._fingerprint(truth), "MySQL") + self.assertEqual(gi._fingerprint(_dbmsTruth("MySQL")), "MySQL") + + def test_identifies_mssql(self): + # @@VERSION is shared with MySQL; only the '%Microsoft%' banner match resolves it. + self.assertEqual(gi._fingerprint(_dbmsTruth("Microsoft SQL Server")), + "Microsoft SQL Server") + + def test_identifies_postgresql(self): + self.assertEqual(gi._fingerprint(_dbmsTruth("PostgreSQL")), "PostgreSQL") def test_unknown_backend(self): self.assertIsNone(gi._fingerprint(lambda cond: False)) diff --git a/tests/test_gui_helpers.py b/tests/test_gui_helpers.py new file mode 100644 index 000000000..bc8fc37b3 --- /dev/null +++ b/tests/test_gui_helpers.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Parser-introspection helpers in lib/utils/gui.py. The GUI itself needs a live +display (Tk), so it is excluded from the smoke test and never imported there; +these module-level helpers, however, are pure and work on argparse/optparse +parser+option objects. We exercise BOTH backends (argparse natively, optparse +via a lightweight stand-in) so the compatibility branches are walked. Importing +the module also covers its (otherwise-uncovered) top-level definitions. +""" + +import argparse +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.utils import gui + + +class _OptparseLikeOption(object): + """Minimal optparse.Option stand-in (drives the non-argparse branches).""" + def __init__(self, short, long_, dest, help_, type_=None, takes=True): + self._short_opts = [short] if short else [] + self._long_opts = [long_] if long_ else [] + self.dest = dest + self.help = help_ + self.type = type_ + self._takes = takes + + def takes_value(self): + return self._takes + + +class _OptparseLikeGroup(object): + def __init__(self, title, description, options): + self.title = title + self.description = description + self.option_list = options + + def get_description(self): + return self.description + + +def _build_argparse(): + p = argparse.ArgumentParser() + g = p.add_argument_group("Target", "options for the target") + g.add_argument("-u", "--url", dest="url", help="target url") + g.add_argument("--level", dest="level", type=int, help="level", choices=[1, 2, 3]) + g.add_argument("--flag", dest="flag", action="store_true", help="a boolean") + return p, g + + +class TestArgparseBackend(unittest.TestCase): + def setUp(self): + self.parser, self.group = _build_argparse() + + def test_parser_groups_found(self): + groups = gui._parserGroups(self.parser) + titles = [gui._groupTitle(g) for g in groups] + self.assertIn("Target", titles) + + def test_group_options_and_metadata(self): + opts = gui._groupOptions(self.group) + self.assertTrue(opts) + self.assertEqual(gui._groupDescription(self.group), "options for the target") + + def test_opt_accessors(self): + opts = gui._groupOptions(self.group) + by_dest = dict((gui._optDest(o), o) for o in opts) + url = by_dest["url"] + self.assertIn("--url", gui._optStrings(url)) + self.assertEqual(gui._optHelp(url), "target url") + self.assertTrue(gui._optTakesValue(url)) + self.assertEqual(gui._optValueType(url), "string") + self.assertIn("--url", gui._optionLabel(url)) + + def test_int_type_and_choices(self): + opts = gui._groupOptions(self.group) + by_dest = dict((gui._optDest(o), o) for o in opts) + level = by_dest["level"] + self.assertEqual(gui._optValueType(level), "int") + self.assertEqual(gui._optChoices(level), [1, 2, 3]) + + def test_store_true_takes_no_value(self): + opts = gui._groupOptions(self.group) + by_dest = dict((gui._optDest(o), o) for o in opts) + self.assertFalse(gui._optTakesValue(by_dest["flag"])) + + +class TestOptparseBackend(unittest.TestCase): + def setUp(self): + self.opt = _OptparseLikeOption("-u", "--url", "url", "target url", type_="string") + self.intopt = _OptparseLikeOption(None, "--level", "level", "level", type_="int") + self.boolopt = _OptparseLikeOption(None, "--flag", "flag", "flag", takes=False) + self.group = _OptparseLikeGroup("Target", "target opts", [self.opt, self.intopt, self.boolopt]) + + def test_opt_strings_from_short_long(self): + self.assertEqual(gui._optStrings(self.opt), ["-u", "--url"]) + + def test_value_type_and_takes(self): + self.assertEqual(gui._optValueType(self.intopt), "int") + self.assertTrue(gui._optTakesValue(self.opt)) + self.assertFalse(gui._optTakesValue(self.boolopt)) + + def test_group_description_via_method(self): + self.assertEqual(gui._groupDescription(self.group), "target opts") + self.assertEqual(gui._groupOptions(self.group), [self.opt, self.intopt, self.boolopt]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_har.py b/tests/test_har.py new file mode 100644 index 000000000..56e9b69b5 --- /dev/null +++ b/tests/test_har.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Tests for lib/utils/har.py -- HAR (HTTP Archive) collector and HTTP +request/response parsing used by sqlmap's --har-file feature. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.utils import har as H + + +class TestFakeSocket(unittest.TestCase): + def test_makefile_returns_bytesio(self): + sock = H.FakeSocket(b"hello\r\n") + f = sock.makefile() + self.assertEqual(f.read(), b"hello\r\n") + + +class TestRawPair(unittest.TestCase): + def test_stores_fields(self): + pair = H.RawPair(b"GET / HTTP/1.0\r\n\r\n", + b"HTTP/1.0 200 OK\r\n\r\n", + startTime=1000, endTime=2000) + self.assertEqual(pair.request, b"GET / HTTP/1.0\r\n\r\n") + self.assertEqual(pair.response, b"HTTP/1.0 200 OK\r\n\r\n") + self.assertEqual(pair.startTime, 1000) + self.assertEqual(pair.endTime, 2000) + + +class TestHTTPCollector(unittest.TestCase): + def test_collect_and_obtain(self): + c = H.HTTPCollector() + c.collectRequest(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n", + b"HTTP/1.0 200 OK\r\nContent-Type: text/html\r\n\r\nbody", + startTime=1000, endTime=2000) + result = c.obtain() + log = result["log"] + self.assertEqual(log["version"], "1.2") + self.assertEqual(log["creator"]["name"], "sqlmap") + entries = log["entries"] + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0]["request"]["method"], "GET") + self.assertEqual(entries[0]["response"]["status"], 200) + + +class TestHTTPCollectorFactory(unittest.TestCase): + def test_create_returns_collector(self): + f = H.HTTPCollectorFactory(harFile=True) + c = f.create() + self.assertIsInstance(c, H.HTTPCollector) + + +class TestEntry(unittest.TestCase): + def test_toDict(self): + req = H.Request("GET", "/path", "HTTP/1.1", + {"Host": "example.com"}) + resp = H.Response("HTTP/1.1", 200, "OK", + {"Content-Type": "text/html"}, b"body") + entry = H.Entry(req, resp, startTime=1000, endTime=2000, + extendedArguments={}) + d = entry.toDict() + self.assertEqual(d["request"]["method"], "GET") + self.assertEqual(d["response"]["status"], 200) + self.assertEqual(d["time"], 1000000) + self.assertIn("startedDateTime", d) + + +class TestRequest(unittest.TestCase): + def test_parse_simple_get(self): + raw = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n" + req = H.Request.parse(raw) + self.assertEqual(req.method, "GET") + self.assertEqual(req.path, "/path") + self.assertEqual(req.httpVersion, "HTTP/1.1") + self.assertEqual(req.headers.get("Host"), "example.com") + + def test_parse_with_comment(self): + raw = (b"HTTP request [#1]:\r\n" + b"POST /submit HTTP/1.0\r\n" + b"Host: example.com\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"body") + req = H.Request.parse(raw) + self.assertEqual(req.method, "POST") + self.assertEqual(req.path, "/submit") + self.assertEqual(req.comment, b"HTTP request [#1]:") + self.assertIn(b"body", req.postBody) + + def test_toDict(self): + req = H.Request("GET", "/", "HTTP/1.0", + {"Host": "test.com", "Accept": "*/*"}) + d = req.toDict() + self.assertEqual(d["method"], "GET") + self.assertEqual(d["url"], "http://test.com/") + self.assertEqual(len(d["headers"]), 2) + + def test_toDict_with_postbody(self): + req = H.Request("POST", "/", "HTTP/1.1", + {"Host": "test.com", "Content-Type": "application/json"}, + postBody=b'{"a":1}') + d = req.toDict() + self.assertEqual(d["postData"]["mimeType"], "application/json") + self.assertIn('{"a":1}', d["postData"]["text"]) + + def test_url_property(self): + req = H.Request("GET", "/path?q=1", "HTTP/1.0", + {"Host": "example.com"}) + self.assertEqual(req.url, "http://example.com/path?q=1") + + def test_url_no_host_header(self): + req = H.Request("GET", "/", "HTTP/1.0", {}) + self.assertIn("unknown", req.url) + + +class TestResponse(unittest.TestCase): + def test_parse_simple(self): + raw = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 4\r\n\r\nbody" + resp = H.Response.parse(raw) + self.assertEqual(resp.status, 200) + self.assertEqual(resp.statusText, "OK") + self.assertEqual(resp.headers.get("Content-Type"), "text/html") + self.assertEqual(resp.content, b"body") + + def test_parse_with_comment(self): + raw = (b"HTTP response [#1] (200 Fine):\r\n" + b"HTTP/1.0 200 Fine\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"response body") + resp = H.Response.parse(raw) + self.assertEqual(resp.status, 200) + self.assertEqual(resp.statusText, "Fine") + self.assertIn(b"HTTP response", resp.comment) + + def test_toDict(self): + resp = H.Response("HTTP/1.1", 404, "Not Found", + {"Content-Type": "text/html"}, b"not found") + d = resp.toDict() + self.assertEqual(d["status"], 404) + self.assertEqual(d["statusText"], "Not Found") + self.assertEqual(d["content"]["text"], "not found") + self.assertEqual(d["content"]["size"], 9) + + def test_toDict_binary_content_encoded(self): + resp = H.Response("HTTP/1.1", 200, "OK", + {"Content-Type": "application/octet-stream"}, + b"\x00\x01\xff") + d = resp.toDict() + self.assertEqual(d["content"]["encoding"], "base64") + + def test_toDict_non_text_content(self): + resp = H.Response("HTTP/1.1", 200, "OK", + {"Content-Type": "text/plain"}, b"plain text") + d = resp.toDict() + self.assertEqual(d["content"]["text"], "plain text") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_hash_crack.py b/tests/test_hash_crack.py new file mode 100644 index 000000000..f23838e0e --- /dev/null +++ b/tests/test_hash_crack.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Dictionary-attack machinery in lib/utils/hash.py (the cracking loop, hash-file +parsing, result storage and table/cache post-processing) - the part NOT covered +by tests/test_hash.py, which only exercises the pure hash-format functions. + +These run the single-process cracking path (conf.disableMulti=True) against a +TINY temp wordlist that contains the known plaintext, so a known hash is cracked +deterministically in milliseconds without interactive prompts, multiprocessing +pools, network, or the real default dictionary. conf.hashDB is forced to None so +hashDBRetrieve/hashDBWrite become no-ops (no session DB side effects). +""" + +import glob +import hashlib +import os +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.utils import hash as H +from lib.core.data import conf, kb +from lib.core.enums import MKSTEMP_PREFIX + +SCRATCH = "/tmp/claude-1000/-tmp-tmp-oUnlQJzlQN/fcd55d25-6313-49ed-817e-dcbe7fc2bf22/scratchpad" + +# known plaintext / hashes shared across tests +PW = "testpass" +MD5_HASH = hashlib.md5(PW.encode("utf-8")).hexdigest() + + +class _CrackBase(unittest.TestCase): + """Sets up a tiny wordlist and non-interactive, no-DB, single-process state.""" + + @classmethod + def setUpClass(cls): + cls._tmpfiles = [] + + # tiny wordlist containing the known plaintext (plus decoys) + cls.wordlist = os.path.join(SCRATCH, "test_hash_crack_wl.txt") + with open(cls.wordlist, "w") as f: + f.write("foo\nbar\n%s\nbaz\n" % PW) + cls._tmpfiles.append(cls.wordlist) + + @classmethod + def tearDownClass(cls): + for path in cls._tmpfiles: + try: + os.remove(path) + except OSError: + pass + + def setUp(self): + # snapshot global state we mutate + self._saved = { + "disableMulti": conf.disableMulti, + "hashDB": conf.hashDB, + "hashFile": conf.hashFile, + "wordlists": kb.wordlists, + "cachedUsersPasswords": kb.data.cachedUsersPasswords if "cachedUsersPasswords" in kb.data else None, + "storeHashes": kb.choices.storeHashes if "storeHashes" in kb.choices else None, + } + + # deterministic, fast, side-effect-free cracking + conf.disableMulti = True + conf.hashDB = None + kb.wordlists = [self.wordlist] + + def tearDown(self): + conf.disableMulti = self._saved["disableMulti"] + conf.hashDB = self._saved["hashDB"] + conf.hashFile = self._saved["hashFile"] + kb.wordlists = self._saved["wordlists"] + kb.data.cachedUsersPasswords = self._saved["cachedUsersPasswords"] + kb.choices.storeHashes = self._saved["storeHashes"] + + +class TestDictionaryAttack(_CrackBase): + def test_crack_md5_generic_variant_a(self): + # generic (no-salt) algorithms go through _bruteProcessVariantA + results = H.dictionaryAttack({"admin": [MD5_HASH]}) + self.assertEqual(results, [("admin", MD5_HASH, PW)]) + + def test_crack_postgres_variant_b(self): + # username-dependent algorithm goes through _bruteProcessVariantB + h = H.postgres_passwd(PW, "testuser", uppercase=False) + results = H.dictionaryAttack({"testuser": [h]}) + self.assertEqual(results, [("testuser", h, PW)]) + + def test_crack_django_md5_salted_variant_b(self): + # salted algorithm: salt is parsed out of the stored hash by dictionaryAttack + h = H.django_md5_passwd(PW, "salt") + results = H.dictionaryAttack({"u2": [h]}) + self.assertEqual(results, [("u2", h, PW)]) + + def test_no_password_found_returns_empty(self): + # plaintext not in wordlist -> nothing cracked + h = hashlib.md5(b"not-in-wordlist-xyz").hexdigest() + results = H.dictionaryAttack({"admin": [h]}) + self.assertEqual(results, []) + + def test_unknown_hash_format_ignored(self): + # a value that hashRecognition rejects produces no hash_regexes and no results + results = H.dictionaryAttack({"admin": ["not_a_hash"]}) + self.assertEqual(results, []) + + def test_empty_attack_dict(self): + self.assertEqual(H.dictionaryAttack({}), []) + + +class TestCrackHashFile(_CrackBase): + def setUp(self): + super(TestCrackHashFile, self).setUp() + # capture the parsed attack_dict that crackHashFile feeds to dictionaryAttack + self._captured = {} + self._real_attack = H.dictionaryAttack + + def _capture(attack_dict): + self._captured.clear() + self._captured.update(attack_dict) + return [] + + H.dictionaryAttack = _capture + + def tearDown(self): + H.dictionaryAttack = self._real_attack + super(TestCrackHashFile, self).tearDown() + + def test_user_colon_hash_file(self): + path = os.path.join(SCRATCH, "test_hash_crack_hashes.txt") + with open(path, "w") as f: + f.write("admin:%s\n" % MD5_HASH) + self._tmpfiles.append(path) + + conf.hashFile = path + self.assertIsNone(H.crackHashFile(path)) + + # the "user:hash" line is parsed into {username: [hash]} + self.assertEqual(self._captured, {"admin": [MD5_HASH]}) + + def test_bare_hash_file(self): + # no "user:hash" structure -> a dummy user is synthesised per line + path = os.path.join(SCRATCH, "test_hash_crack_bare.txt") + with open(path, "w") as f: + f.write("%s\n" % MD5_HASH) + self._tmpfiles.append(path) + + conf.hashFile = path + self.assertIsNone(H.crackHashFile(path)) + + from lib.core.settings import DUMMY_USER_PREFIX + self.assertEqual(len(self._captured), 1) + (key, value), = self._captured.items() + # the synthesised key uses the dummy-user prefix and maps to the bare hash + self.assertTrue(key.startswith(DUMMY_USER_PREFIX), + msg="bare line was not assigned a dummy user: %r" % key) + self.assertEqual(value, [MD5_HASH]) + + +class TestAttackCachedUsersPasswords(_CrackBase): + def test_annotates_cleartext(self): + kb.data.cachedUsersPasswords = {"admin": [MD5_HASH]} + H.attackCachedUsersPasswords() + # the original value is augmented in place with the recovered clear-text + self.assertIn("clear-text password: %s" % PW, kb.data.cachedUsersPasswords["admin"][0]) + + def test_no_cached_data_is_noop(self): + kb.data.cachedUsersPasswords = {} + # must simply return without touching anything + self.assertIsNone(H.attackCachedUsersPasswords()) + + +class TestStoreHashesToFile(_CrackBase): + def _hash_tempfiles(self): + pattern = os.path.join(tempfile.gettempdir(), MKSTEMP_PREFIX.HASHES + "*") + return set(glob.glob(pattern)) + + def test_store_disabled_writes_nothing(self): + kb.choices.storeHashes = False + before = self._hash_tempfiles() + H.storeHashesToFile({"admin": [MD5_HASH]}) + self.assertEqual(self._hash_tempfiles(), before) + + def test_store_enabled_writes_recognised_hash(self): + kb.choices.storeHashes = True + before = self._hash_tempfiles() + try: + H.storeHashesToFile({"admin": [MD5_HASH]}) + new = self._hash_tempfiles() - before + self.assertEqual(len(new), 1) + with open(next(iter(new))) as fh: + written = fh.read() + self.assertIn(MD5_HASH, written) + self.assertIn("admin", written) + finally: + for path in self._hash_tempfiles() - before: + try: + os.remove(path) + except OSError: + pass + + def test_empty_attack_dict_is_noop(self): + kb.choices.storeHashes = True + before = self._hash_tempfiles() + H.storeHashesToFile({}) + self.assertEqual(self._hash_tempfiles(), before) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_hashdb.py b/tests/test_hashdb.py index 597925c62..36bbd4dc9 100644 --- a/tests/test_hashdb.py +++ b/tests/test_hashdb.py @@ -109,10 +109,21 @@ class TestSerialized(_HashDBCase): def test_bytes_containing_value_survives(self): # REGRESSION (base64-pickle bytes fix): silently failed to restore on py3 before the fix. + # Must round-trip through SQLite, not the in-memory caches: write+flush here, then open a + # FRESH HashDB on the same file (empty read/write caches) so retrieve() hits the disk path. value = {"raw": b"\x00\x01\xff", "items": [b"ab", "s", 1]} self.db.write("bytesval", value, True) self.db.flush() - self.assertEqual(self.db.retrieve("bytesval", True), value) + + fresh = HashDB(self.path) + try: + # sanity: the value is genuinely not in the fresh in-memory caches + self.assertFalse(fresh._write_cache) + hash_ = HashDB.hashKey("bytesval") + self.assertIsNone(fresh._read_cache.get(hash_)) + self.assertEqual(fresh.retrieve("bytesval", True), value) + finally: + fresh.closeAll() class TestKeyHashing(_HashDBCase): diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..adac33d58 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Edge cases / control-flow branches of the blind-SQLi inference engine +(lib/techniques/blind/inference.py) plus the pure UNION configuration helper +(lib/techniques/union/use.py configUnion). + +Complements tests/test_inference_engine.py (which covers the happy-path char-by-char +extraction). Here we drive the REAL bisection() / queryOutputLength() against a mock +oracle (Request.queryPage replaced by a parser of our own parseable payload template) +to exercise the branches the engine test does not reach: + + * trivial returns: payload is None, length == 0 + * --first-char / --last-char range limiting (both via the function args and via + conf.firstChar / conf.lastChar) + * --hex output decoding of the assembled value + * kb.data.processChar post-processing hook + * session resume from HashDB: a fully cached value, and a PARTIAL_VALUE_MARKER + partial value that bisection continues from (against a REAL temp SQLite HashDB) + * queryOutputLength() forging + DIGITS-charset length retrieval + +No network, no live target, no real DBMS - exactly like the sibling engine test. +""" + +import os +import re +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.data import conf, kb +from lib.core.common import decodeDbmsHexValue +from lib.core.common import getCurrentThreadData +from lib.core.common import hashDBWrite +from lib.core.enums import CHARSET_TYPE +from lib.core.exception import SqlmapSyntaxException +from lib.core.settings import PARTIAL_VALUE_MARKER +from lib.request.connect import Connect +from lib.utils.hashdb import HashDB +import lib.techniques.blind.inference as inf +import lib.techniques.union.use as uu + +# bisection forges: safeStringFormat(payload, (expression, idx, posValue)); '>' is the +# greater-char marker (swapped to '=' on the final equality check). A parseable template +# lets the mock oracle recover (idx, operator, threshold) and answer against a known secret. +TEMPLATE = "EXPR=%s IDX=%d CMP>%d" +_PARSE = re.compile(r"IDX=(\d+) CMP(.)(\d+)") + +# conf/kb knobs bisection reads on the simple single-threaded, no-prediction path +_CONF = {"predictOutput": False, "threads": 1, "api": False, "verbose": 0, "hexConvert": False, + "charset": None, "firstChar": None, "lastChar": None, "timeSec": 5, "eta": False, + "repair": False, "flushSession": None, "freshQueries": None, "hashDB": None} +_KB = {"partRun": None, "safeCharEncode": False, "bruteMode": False, "fileReadMode": False, + "disableShiftTable": False, "originalTimeDelay": 5, "prependFlag": False, + "resumeValues": True, "inferenceMode": False} + + +class _InferenceCase(unittest.TestCase): + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in _CONF} + self._saved_kb = {k: kb.get(k) for k in _KB} + self._saved_qp = Connect.queryPage + self._saved_processChar = kb.data.get("processChar") + for k, v in _CONF.items(): + conf[k] = v + for k, v in _KB.items(): + kb[k] = v + kb.data.processChar = None + set_dbms("MySQL") + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + for k, v in self._saved_kb.items(): + kb[k] = v + kb.data.processChar = self._saved_processChar + Connect.queryPage = self._saved_qp + inf.Request.queryPage = self._saved_qp + + def _install_oracle(self, secret): + def oracle(payload=None, *args, **kwargs): + m = _PARSE.search(payload) + idx, op, threshold = int(m.group(1)), m.group(2), int(m.group(3)) + ch = ord(secret[idx - 1]) if 0 <= idx - 1 < len(secret) else 0 + return (ch > threshold) if op == ">" else (ch == threshold) + + Connect.queryPage = staticmethod(oracle) + inf.Request.queryPage = staticmethod(oracle) + + @staticmethod + def _reset_thread(): + td = getCurrentThreadData() + td.shared.value = "" + td.shared.index = [0] + td.shared.start = 0 + td.shared.count = 0 + + def _bisect(self, secret, expression="SELECT secret", length=None, **kwargs): + self._install_oracle(secret) + self._reset_thread() + if length is None: + length = len(secret) + return inf.bisection(TEMPLATE, expression, length=length, **kwargs) + + +class TestTrivialReturns(_InferenceCase): + def test_none_payload(self): + # payload is None -> (0, None) without ever touching the oracle + self.assertEqual(inf.bisection(None, "SELECT x"), (0, None)) + + def test_zero_length(self): + # length == 0 -> (0, "") short-circuit + self._install_oracle("ignored") + self._reset_thread() + self.assertEqual(inf.bisection(TEMPLATE, "SELECT x", length=0), (0, "")) + + +class TestRangeLimiting(_InferenceCase): + SECRET = "ABCDEFGH" + + def test_first_char_arg(self): + # firstChar=3 -> start from the 3rd character (1-based) -> drop "AB" + _, value = self._bisect(self.SECRET, firstChar=3) + self.assertEqual(value, "CDEFGH") + + def test_last_char_arg(self): + # lastChar=4 -> stop after the 4th character + _, value = self._bisect(self.SECRET, lastChar=4) + self.assertEqual(value, "ABCD") + + def test_conf_first_char(self): + conf.firstChar = 4 + _, value = self._bisect(self.SECRET) + self.assertEqual(value, "DEFGH") + + def test_conf_last_char(self): + conf.lastChar = 3 + _, value = self._bisect(self.SECRET) + self.assertEqual(value, "ABC") + + def test_first_and_last_window(self): + # combined window: chars 3..6 inclusive -> "CDEF" + _, value = self._bisect(self.SECRET, firstChar=3, lastChar=6) + self.assertEqual(value, "CDEF") + + +class TestHexConvert(_InferenceCase): + def test_hex_output_decoded(self): + # --hex: the retrieved value is a hex string the engine decodes on the way out + conf.hexConvert = True + hexed = "48656C6C6F" # "Hello" + _, value = self._bisect(hexed) + self.assertEqual(value, "Hello") + self.assertEqual(value, decodeDbmsHexValue(hexed)) + + +class TestProcessCharHook(_InferenceCase): + def test_process_char_applied_to_each_char(self): + # kb.data.processChar transforms every assembled character + kb.data.processChar = lambda c: c.upper() + _, value = self._bisect("abcde") + self.assertEqual(value, "ABCDE") + + +class TestResumeFromHashDB(_InferenceCase): + """bisection() consults the session store first (hashDBRetrieve(checkConf=True)). + Exercised against a REAL temporary SQLite HashDB (same approach as test_hashdb.py).""" + + def setUp(self): + _InferenceCase.setUp(self) + fd, self.path = tempfile.mkstemp(suffix=".sqlite") + os.close(fd) + os.remove(self.path) # HashDB creates it lazily + conf.hashDB = HashDB(self.path) + # hashDBRetrieve/Write key off these + self._saved_loc = (conf.get("hostname"), conf.get("path"), conf.get("port")) + conf.hostname = "test.invalid" + conf.path = "/" + conf.port = 80 + + def tearDown(self): + conf.hostname, conf.path, conf.port = self._saved_loc + try: + conf.hashDB.closeAll() + except Exception: + pass + if os.path.exists(self.path): + os.remove(self.path) + _InferenceCase.tearDown(self) + + def test_full_value_resumed(self): + # a complete cached value short-circuits the whole bisection (0 queries) + hashDBWrite("SELECT cached", "RESUMED") + conf.hashDB.flush() + count, value = self._bisect("ignored-secret", expression="SELECT cached", length=7) + self.assertEqual(value, "RESUMED") + self.assertEqual(count, 0) + + def test_partial_value_continued(self): + # a PARTIAL_VALUE_MARKER value is resumed-from: bisection keeps the prefix + # and extracts only the remaining characters + kb.inferenceMode = True # partial markers are honored only in inference mode + hashDBWrite("SELECT partial", "%sAB" % PARTIAL_VALUE_MARKER) + conf.hashDB.flush() + count, value = self._bisect("ABCDE", expression="SELECT partial", length=5) + self.assertEqual(value, "ABCDE") + self.assertGreater(count, 0) # it did real work for "CDE" + + +class TestQueryOutputLength(_InferenceCase): + def test_length_retrieved(self): + # queryOutputLength forges a LENGTH() expression and runs bisection with the + # DIGITS charset; the mock "secret" is the textual length itself + self._install_oracle("42") + self._reset_thread() + self.assertEqual(int(inf.queryOutputLength("SELECT data", TEMPLATE)), 42) + + def test_length_single_digit(self): + self._install_oracle("7") + self._reset_thread() + self.assertEqual(int(inf.queryOutputLength("SELECT data", TEMPLATE)), 7) + + def test_digits_charset_extracts_number(self): + # direct bisection with the DIGITS charset (queryOutputLength's inner call) + _, value = self._bisect("2026", charsetType=CHARSET_TYPE.DIGITS) + self.assertEqual(value, "2026") + + +class TestConfigUnion(unittest.TestCase): + """lib/techniques/union/use.py configUnion - pure parsing of --union-char / --union-cols.""" + + _CONF = {"uChar": None, "uCols": None, "uColsStart": 1, "uColsStop": 50} + + def setUp(self): + self._saved = {k: conf.get(k) for k in self._CONF} + self._saved_uchar = kb.get("uChar") + for k, v in self._CONF.items(): + conf[k] = v + + def tearDown(self): + for k, v in self._saved.items(): + conf[k] = v + kb.uChar = self._saved_uchar + + def test_char_and_range(self): + uu.configUnion(char="NULL", columns="2-6") + self.assertEqual(kb.uChar, "NULL") + self.assertEqual((conf.uColsStart, conf.uColsStop), (2, 6)) + + def test_single_column(self): + uu.configUnion(char="NULL", columns="4") + self.assertEqual((conf.uColsStart, conf.uColsStop), (4, 4)) + + def test_uchar_substitution_quoted(self): + # conf.uChar (non-digit) gets quoted and substituted into the [CHAR] template + conf.uChar = "test" + uu.configUnion(char="x[CHAR]x", columns="1") + self.assertEqual(kb.uChar, "x'test'x") + + def test_uchar_substitution_digit(self): + # a digit conf.uChar is substituted unquoted + conf.uChar = "88" + uu.configUnion(char="[CHAR]", columns="1") + self.assertEqual(kb.uChar, "88") + + def test_conf_ucols_overrides_columns_arg(self): + # conf.uCols takes precedence over the columns argument + conf.uCols = "3-9" + uu.configUnion(char="NULL", columns="1-2") + self.assertEqual((conf.uColsStart, conf.uColsStop), (3, 9)) + + def test_non_integer_range_raises(self): + self.assertRaises(SqlmapSyntaxException, uu.configUnion, char="NULL", columns="abc") + + def test_inverted_range_raises(self): + self.assertRaises(SqlmapSyntaxException, uu.configUnion, char="NULL", columns="9-2") + + def test_non_string_char_ignored(self): + # a non-string char leaves kb.uChar untouched (early return) + kb.uChar = "SENTINEL" + uu.configUnion(char=None, columns="1") + self.assertEqual(kb.uChar, "SENTINEL") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_ldap.py b/tests/test_ldap.py index b4bc24086..f590dcfb8 100644 --- a/tests/test_ldap.py +++ b/tests/test_ldap.py @@ -102,32 +102,56 @@ class TestHelpers(unittest.TestCase): class TestFingerprinting(unittest.TestCase): + # The mapping branches recognise a distinctive vendor substring *anywhere* inside + # a realistic error banner and normalise it to a canonical backend name. Feeding + # an embedded substring (not the bare canonical name) proves the source performs + # real substring discrimination rather than echoing its input. def test_fingerprintByError_ad(self): - self.assertEqual(ldap._fingerprintByError("Microsoft Active Directory"), - "Microsoft Active Directory") + self.assertEqual( + ldap._fingerprintByError("LDAP error from Microsoft Active Directory server"), + "Microsoft Active Directory") def test_fingerprintByError_openldap(self): - self.assertEqual(ldap._fingerprintByError("OpenLDAP"), "OpenLDAP") + self.assertEqual(ldap._fingerprintByError("OpenLDAP 2.4.57 SERVER_DOWN"), + "OpenLDAP") def test_fingerprintByError_apacheds(self): - self.assertEqual(ldap._fingerprintByError("ApacheDS"), "ApacheDS") + self.assertEqual(ldap._fingerprintByError("org.apache.directory.ApacheDS 2.0"), + "ApacheDS") def test_fingerprintByError_oracle(self): - self.assertEqual(ldap._fingerprintByError("Oracle Directory Server"), + self.assertEqual(ldap._fingerprintByError("Oracle Internet Directory / Oracle stack"), "Oracle Directory Server") def test_fingerprintByError_389(self): - self.assertEqual(ldap._fingerprintByError("389 Directory Server"), + self.assertEqual(ldap._fingerprintByError("Red Hat 389 ns-slapd"), "389 Directory Server") - def test_fingerprintByError_generic(self): - self.assertEqual(ldap._fingerprintByError("Generic LDAP"), "Generic LDAP") + def test_fingerprintByError_precedence_ad_over_oracle(self): + # A banner carrying two recognised substrings resolves to the earlier branch + # (Active Directory), proving the result is driven by branch order, not by an + # echo of whichever name happens to appear. + self.assertEqual( + ldap._fingerprintByError("Microsoft Active Directory bridged to Oracle"), + "Microsoft Active Directory") - def test_fingerprintByError_jndi(self): - self.assertEqual(ldap._fingerprintByError("Java JNDI"), "Java JNDI") + def test_fingerprintByError_none_and_empty(self): + # The only real branch reachable by non-mapping banners: the falsy guard. + self.assertIsNone(ldap._fingerprintByError(None)) + self.assertIsNone(ldap._fingerprintByError("")) - def test_fingerprintByError_pythonldap(self): - self.assertEqual(ldap._fingerprintByError("python-ldap"), "python-ldap") + def test_fingerprintByError_passthrough_when_unmatched(self): + # Banners that match no vendor branch (including the "python-ldap"/"Java JNDI" + # case, whose source branch is observationally identical to the catch-all) are + # returned verbatim. This single test documents that pass-through contract and, + # crucially, asserts such banners are NOT misclassified into a specific backend. + for banner in ("Generic LDAP", "python-ldap 3.4.0", "Caused by: Java JNDI", + "some unrecognised directory service"): + result = ldap._fingerprintByError(banner) + self.assertEqual(result, banner) + self.assertNotIn(result, ("Microsoft Active Directory", "OpenLDAP", + "ApacheDS", "Oracle Directory Server", + "389 Directory Server")) class TestGrid(unittest.TestCase): @@ -367,54 +391,41 @@ class TestCookiePlace(unittest.TestCase): class TestNestedFilterParsing(unittest.TestCase): + def setUp(self): + # Import the REAL vulnserver parser (same technique as + # tests/test_graphql.py :: TestVulnserverGraphqlParser). `extra` and + # `extra/vulnserver` are packages, so a plain import works. + from extra.vulnserver import vulnserver + self.vs = vulnserver + def test_nested_compound_parses_all_siblings(self): """Blockers 3: nested (&) inside (|) must parse all siblings.""" - # Inline copies of the vulnserver helpers so the test is self-contained - def _ldap_match(text, start): - depth = 0 - i = start - while i < len(text): - ch = text[i] - if ch == '(': - depth += 1 - elif ch == ')': - depth -= 1 - if depth == 0: - return i + 1 - elif ch == '\\': - i += 1 - i += 1 - return len(text) - - def _ldap_parse_value(text, start): - retVal = [] - i = start - while i < len(text) and text[i] not in (')',): - if text[i] == '\\' and i + 2 < len(text): - retVal.append(chr(int(text[i+1:i+3], 16))) - i += 3 - else: - retVal.append(text[i]) - i += 1 - return ''.join(retVal), i - - # Minimum reproduction of the fixed _ldap_filter_to_sql - # (the real function is in extra/vulnserver/vulnserver.py) - import sys, os - sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'extra', 'vulnserver')) - # Can't cleanly import vulnserver because of the __main__ guard. - # Instead we verify the fixed _ldap_match returns the correct end - # position for a nested compound filter, which was the root cause. f = '(|(&(uid=a)(cn=b))(mail=*))' - # The outer (| ... ) starts at 0 and should end at len(f) - outer_end = _ldap_match(f, 0) + + # The REAL _ldap_match must balance brackets across nested compounds. + # Outer (| ... ) starts at 0 and ends at len(f). + outer_end = self.vs._ldap_match(f, 0) self.assertEqual(outer_end, len(f)) - # The inner (& ... ) compound's opening '(' is at position 2 - # (f[2] == '('). _ldap_match must return the position after the - # matching ')' that closes the compound, i.e. right before (mail=*). - inner_end = _ldap_match(f, 2) + # Inner (& ... )'s opening '(' is at position 2; _ldap_match must + # return the position right before the (mail=*) sibling. + inner_end = self.vs._ldap_match(f, 2) self.assertEqual(f[inner_end:inner_end+8], '(mail=*)') + # The REAL filter->SQL conversion must surface EVERY sibling condition: + # both members of the nested (&) AND the (mail=*) sibling of the (|). + clause, params, end = self.vs._ldap_filter_to_sql(f) + self.assertEqual(end, len(f)) + self.assertIsNotNone(clause) + # nested-(&) siblings -> AND-joined, both columns present + self.assertIn(" AND ", clause) + self.assertIn("uid", clause) + self.assertIn("cn", clause) + # outer-(|) sibling must NOT be dropped + self.assertIn(" OR ", clause) + self.assertIn("mail", clause) + # the two equality values are parameterized in order + self.assertEqual(params, ["a", "b"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_option_more.py b/tests/test_option_more.py new file mode 100644 index 000000000..3e49b83e0 --- /dev/null +++ b/tests/test_option_more.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +Additional coverage for option setup / normalization helpers in +lib/core/option.py, targeting functions and branches NOT already exercised by +tests/test_option_setup.py: + + * _setTamperingFunctions (loads real tamper modules into kb.tamperFunctions) + * _setPreprocessFunctions (loads a preprocess(req) script into kb.preprocessFunctions) + * _setPostprocessFunctions (loads a postprocess(page, headers, code) script) + * _setSafeVisit (parses a safe request file into kb.safeReq) + * _cleanupOptions (additional normalization branches: delay cast, + csvDel/paramDel escape, col/binaryFields split, + torType upper, abortCode, getAll, dummy->batch) + * _basicOptionValidation (additional illegal option combinations / branches) + * _normalizeOptions (string + boolean option coercion) + * setVerbosity (eta clamp + high verbose) + +As in test_option_setup.py, option.py mutates the global conf/kb singletons +aggressively, so every test saves and restores the conf/kb fields it touches via +the _preserve() context manager so the shared state stays pristine for the rest +of the suite. +""" + +import contextlib +import logging +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap +bootstrap() + +from lib.core.data import conf, kb, logger, paths +from lib.core.exception import SqlmapSyntaxException +from lib.core.exception import SqlmapSystemException +from lib.core.exception import SqlmapGenericException +from lib.core.exception import SqlmapFilePathException +from lib.core.exception import SqlmapValueException +from lib.core.settings import MAX_CONNECT_RETRIES + +import lib.core.option as option + +_SENTINEL = object() + +# scratchpad for the preprocess/postprocess/safe-req fixture files +_SCRATCH = os.environ.get("CLAUDE_SCRATCH") or os.path.join(os.path.dirname(os.path.abspath(__file__)), "_option_more_tmp") + + +def tearDownModule(): + """Remove the scratch fixture directory so it never lingers on disk (and so a + stray __init__.py there can't shadow imports in a subsequent run).""" + import shutil + if os.path.isdir(_SCRATCH): + shutil.rmtree(_SCRATCH, ignore_errors=True) + + +@contextlib.contextmanager +def _preserve(target, *keys): + """Save the given keys of an AttribDict (conf/kb), then restore on exit. + + Missing keys are restored to absent so a test can't leak a brand-new field. + """ + saved = {} + for key in keys: + saved[key] = target[key] if key in target else _SENTINEL + try: + yield + finally: + for key in keys: + if saved[key] is _SENTINEL: + try: + del target[key] + except KeyError: + pass + else: + target[key] = saved[key] + + +class _ImportSandboxMixin(object): + """Loaders in option.py (tamper/preprocess/postprocess) permanently + `sys.path.insert(0,