#!/usr/bin/env python3 import sys from datetime import datetime import time from time import sleep from dnslib import DNSLabel, QTYPE, RD, RR, RCODE from dnslib import A, AAAA, CNAME, MX, NS, SOA, TXT from dnslib.server import DNSServer EPOCH = datetime(1970, 1, 1) TIMESTAMP = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') SERIAL = int((datetime.utcnow() - EPOCH).total_seconds()) TYPE_LOOKUP = { A: QTYPE.A, AAAA: QTYPE.AAAA, CNAME: QTYPE.CNAME, MX: QTYPE.MX, NS: QTYPE.NS, SOA: QTYPE.SOA, TXT: QTYPE.TXT, } SUSPICIOUS_RECORDS = { 'ftp': 'FTP ', 'cpanel': 'cPanel hosting control panel', 'admin': 'common record for admin control panels', 'ssh': 'SSH service', 'wordpress': 'common blogging platform', 'store': 'common record for stores', 'staging': 'common record for staging environments', 'mail': 'common record for email', } class Record: def __init__(self, rdata_type, *args, rtype=None, rname=None, ttl=None, **kwargs): if isinstance(rdata_type, RD): # actually an instance, not a type self._rtype = TYPE_LOOKUP[rdata_type.__class__] rdata = rdata_type else: self._rtype = TYPE_LOOKUP[rdata_type] if rdata_type == SOA and len(args) == 2: # add sensible times to SOA args += (( SERIAL, # serial number 60 * 60 * 1, # refresh 60 * 60 * 3, # retry 60 * 60 * 24, # expire 60 * 60 * 1, # minimum ),) rdata = rdata_type(*args) if rtype: self._rtype = rtype self._rname = rname self.kwargs = dict( rdata=rdata, ttl=self.sensible_ttl() if ttl is None else ttl, **kwargs ) def try_rr(self, q): if q.qtype == QTYPE.ANY or q.qtype == self._rtype: return self.as_rr(q.qname) def as_rr(self, alt_rname): return RR(rname=self._rname or alt_rname, rtype=self._rtype, **self.kwargs) def sensible_ttl(self): if self._rtype in (QTYPE.NS, QTYPE.SOA): return 60 * 60 * 24 else: return 300 @property def is_soa(self): return self._rtype == QTYPE.SOA def __str__(self): return '{} {}'.format(QTYPE[self._rtype], self.kwargs) ZONES = { 'example.com': [ Record(A, '1.2.3.4'), Record(CNAME, 'www.example.com.'), Record(MX, 'whatever.com.', 5), Record(MX, 'mx2.whatever.com.', 10), Record(MX, 'mx3.whatever.com.', 20), Record(NS, 'mx2.whatever.com.'), Record(NS, 'mx3.whatever.com.'), Record(TXT, 'hack the planet!!!'), Record(SOA, 'ns1.example.com', 'dns.example.com'), ], 'testing.com': [ Record(A, '185.92.221.142'), Record(TXT, 'hack the planet!!!'), Record(NS, 'ns1.lol.systems'), Record(NS, 'ns2.lol.systems'), Record(SOA, 'ns1.lol.systems', 'ns2.lol.systems'), ], 'www.testing.com': [ Record(CNAME, 'testing.com.'), ], 'derp.testing.com': [ Record(TXT, 'rekt lmao'), ] } class Resolver: def __init__(self): self.zones = {DNSLabel(k): v for k, v in ZONES.items()} def resolve(self, request, handler): reply = request.reply() zone = self.zones.get(request.q.qname) if zone is not None: for zone_records in zone: rr = zone_records.try_rr(request.q) rr and reply.add_answer(rr) else: # CATCHALL CODE STARTS HERE if QTYPE[reply.q.qtype] == 'CNAME': new_record = Record(CNAME, 'testing.com.') elif QTYPE[reply.q.qtype] == 'A': new_record = Record(A, '1.3.3.7') reply.add_answer(new_record.try_rr(request.q)) print('{} -- {} requested an invalid {} record: {}'.format(TIMESTAMP, handler.client_address[0], QTYPE[reply.q.qtype], reply.q.qname)) ''' # no direct zone so look for an SOA record for a higher level zone for zone_label, zone_records in self.zones.items(): if request.q.qname.matchSuffix(zone_label): try: soa_record = next(r for r in zone_records if r.is_soa) except StopIteration: continue else: reply.add_answer(soa_record.as_rr(zone_label)) break ''' return reply class DNSLogger: def __init__(self,log="",prefix=True): default = ["request","reply","truncated","error"] log = log.split(",") if log else [] enabled = set([ s for s in log if s[0] not in '+-'] or default) [ enabled.add(l[1:]) for l in log if l.startswith('+') ] [ enabled.discard(l[1:]) for l in log if l.startswith('-') ] for l in ['log_recv','log_send','log_request','log_reply', 'log_truncated','log_error','log_data']: if l[4:] not in enabled: setattr(self,l,self.log_pass) self.prefix = prefix def log_pass(self,*args): pass def log_prefix(self,handler): if self.prefix: return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"), handler.__class__.__name__, handler.server.resolver.__class__.__name__) else: return "" def log_recv(self,handler,data): print("%sReceived: [%s:%d] (%s) <%d> : %s" % ( self.log_prefix(handler), handler.client_address[0], handler.client_address[1], handler.protocol, len(data), binascii.hexlify(data))) def log_send(self,handler,data): print("%sSent: [%s:%d] (%s) <%d> : %s" % ( self.log_prefix(handler), handler.client_address[0], handler.client_address[1], handler.protocol, len(data), binascii.hexlify(data))) def log_request(self,handler,request): self.log_data(request) def log_reply(self,handler,reply): if reply.header.rcode == RCODE.NOERROR: test = reply.rr for i in test: if QTYPE[i.rtype] != 'SOA': print('{} requested a {} record: {}'.format(handler.client_address[0], QTYPE[reply.q.qtype], reply.q.qname)) self.log_data(reply) def log_truncated(self,handler,reply): print("%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % ( self.log_prefix(handler), handler.client_address[0], handler.client_address[1], handler.protocol, reply.q.qname, QTYPE[reply.q.qtype], ",".join([QTYPE[a.rtype] for a in reply.rr]))) self.log_data(reply) def log_error(self,handler,e): print("%sInvalid Request: [%s:%d] (%s) :: %s" % ( self.log_prefix(handler), handler.client_address[0], handler.client_address[1], handler.protocol, e)) def log_data(self,dnsobj): print("\n",dnsobj.toZone(" "),"\n",sep="") resolver = Resolver() logger = DNSLogger() servers = [ DNSServer(resolver, port=53, address='0.0.0.0', tcp=True, logger=logger), DNSServer(resolver, port=53, address='0.0.0.0', tcp=False, logger=logger), ] if __name__ == '__main__': for s in servers: s.start_thread() try: while 1: sleep(0.1) except KeyboardInterrupt: pass finally: for s in servers: s.stop()