Skip to content

Instantly share code, notes, and snippets.

@a290
Last active January 15, 2024 07:08
Show Gist options
  • Select an option

  • Save a290/8e57ab0e1a6087f1206a0643e848e81a to your computer and use it in GitHub Desktop.

Select an option

Save a290/8e57ab0e1a6087f1206a0643e848e81a to your computer and use it in GitHub Desktop.
Linux BRAS stuff
diff -rupN ixgbe-3.22.3//src/ixgbe_main.c ixgbe-3.22.3-qinq//src/ixgbe_main.c
--- ixgbe-3.22.3//src/ixgbe_main.c 2014-08-15 04:17:42.000000000 +0500
+++ ixgbe-3.22.3-qinq//src/ixgbe_main.c 2014-11-18 19:07:49.000000000 +0500
@@ -3629,6 +3629,11 @@ static void ixgbe_configure_tx(struct ix
dmatxctl |= IXGBE_DMATXCTL_TE;
IXGBE_WRITE_REG(hw, IXGBE_DMATXCTL, dmatxctl);
}
+
+ /* Enable Global Double VLAN */
+ dmatxctl = IXGBE_READ_REG(hw, IXGBE_DMATXCTL);
+ dmatxctl |= IXGBE_DMATXCTL_GDV;
+ IXGBE_WRITE_REG(hw, IXGBE_DMATXCTL, dmatxctl);
/* Setup the HW Tx Head and Tail descriptor pointers */
for (i = 0; i < adapter->num_tx_queues; i++)
@@ -5719,6 +5724,12 @@ static void ixgbe_up_complete(struct ixg
mod_timer(&adapter->service_timer, jiffies);
ixgbe_clear_vf_stats_counters(adapter);
+
+ /* Set Extended VLAN bit */
+ ctrl_ext = IXGBE_READ_REG(hw, IXGBE_CTRL_EXT);
+ ctrl_ext |= IXGBE_CTRL_EXT_EXTENDED_VLAN;
+ IXGBE_WRITE_REG(hw, IXGBE_CTRL_EXT, ctrl_ext);
+
/* Set PF Reset Done bit so PF/VF Mail Ops can work */
ctrl_ext = IXGBE_READ_REG(hw, IXGBE_CTRL_EXT);
ctrl_ext |= IXGBE_CTRL_EXT_PFRSTD;
diff -rupN ixgbe-3.22.3//src/ixgbe_type.h ixgbe-3.22.3-qinq//src/ixgbe_type.h
--- ixgbe-3.22.3//src/ixgbe_type.h 2014-08-15 04:17:42.000000000 +0500
+++ ixgbe-3.22.3-qinq//src/ixgbe_type.h 2014-11-18 19:08:04.000000000 +0500
@@ -1119,6 +1119,7 @@ struct ixgbe_thermal_sensor_data {
#define IXGBE_CTRL_EXT_NS_DIS 0x00010000 /* No Snoop disable */
#define IXGBE_CTRL_EXT_RO_DIS 0x00020000 /* Relaxed Ordering disable */
#define IXGBE_CTRL_EXT_DRV_LOAD 0x10000000 /* Driver loaded bit for FW */
+#define IXGBE_CTRL_EXT_EXTENDED_VLAN 0x04000000 /* Extended VLAN bit */
/* Direct Cache Access (DCA) definitions */
#define IXGBE_DCA_CTRL_DCA_ENABLE 0x00000000 /* DCA Enable *
#!/usr/bin/python3
import Pyro4
from ipaddress import IPv4Address, IPv4Network
import yaml
import mysql.connector
import re
import time
import logging
config = None
servers = []
POL_QUANTUM = 8000 # ipt_ratelimit rounds CIR on [CONFIG_HZ * BITS_PER_BYTE] boundary
class NAS:
re_vlan = re.compile(r'^(b_)?sub([0-9]{1,4})\.([0-9]{1,4})$')
re_ratelimit = re.compile(r'^(?:b_sub|sub)(\d{1,4})\.(\d{1,4}) cir (\d+)')
def __init__(self, ipaddr, port, hmac_key, networks):
self.ipaddr = IPv4Address(ipaddr)
self.port = int(port)
self.hmac = hmac_key
self.nets = set()
for net in networks:
self.nets.add(IPv4Network(net))
self.syncts = int(time.time())
self.rpc = Pyro4.Proxy("PYRO:nasd@{}:{}".format(self.ipaddr, self.port))
self.rpc._pyroHmacKey = self.hmac
self.rpc._pyroTimeout = 1.5
self.subscribers = None
def check_ip(self, ipaddr: IPv4Address):
for net in self.nets:
if ipaddr in net:
return True
return False
def _sync_active_subscribers(self):
""" Get list of subscribers from NAS """
ret = dict()
raw_data = self.rpc.dump_addresses()
raw_rt_data = self.rpc.dump_ratelimit()
ratelimits = dict()
for data in raw_rt_data:
m = self.re_ratelimit.match(data)
if not m:
continue
spvlan, cvlan, rate = m.groups()
ratelimits[(int(spvlan), int(cvlan))] = int(rate)
for data in raw_data:
m = self.re_vlan.match(data[0])
if not m:
continue
blocked, spvlan, cvlan = m.groups()
spvlan = int(spvlan)
cvlan = int(cvlan)
ipaddr = IPv4Address(data[1])
if (spvlan, cvlan) not in ret:
ret[(spvlan, cvlan)] = dict(ipaddrs=set())
ret[(spvlan, cvlan)]['ipaddrs'].add(ipaddr)
if (spvlan, cvlan) in ratelimits:
ret[(spvlan, cvlan)]['ratelimit'] = ratelimits[(spvlan, cvlan)]
else:
ret[(spvlan, cvlan)]['ratelimit'] = None
ret[(spvlan, cvlan)]['blocked'] = 1 if blocked else 0
self.subscribers = ret
def sync(self, db_allsubs, discard_cache=False):
""" Sync NAS to billing database"""
# Filter out vlans with non-matching IPs
db_subs = dict()
for dbsub in db_allsubs:
matching_ipaddrs = set()
for ipaddr in db_allsubs[dbsub]['ipaddrs']:
if self.check_ip(ipaddr):
matching_ipaddrs.add(ipaddr)
if matching_ipaddrs:
db_subs[dbsub] = dict(ipaddrs=matching_ipaddrs, ratelimit=db_allsubs[dbsub]['ratelimit'],
blocked=db_allsubs[dbsub]['blocked'])
# pprint(db_subs)
if discard_cache or (self.rpc.get_syncts() != self.syncts):
logging.info("Syncing active subscribers")
self._sync_active_subscribers()
self.syncts = int(time.time())
self.rpc.update_syncts(self.syncts)
delsubs = self.subscribers.keys() - db_subs.keys() # Present only on NAS
newsubs = db_subs.keys() - self.subscribers.keys() # Present only in database
chksubs = self.subscribers.keys() & db_subs.keys() # Present on both
for subscriber in delsubs:
logging.info("Del subscriber {} {}".format(subscriber[0], subscriber[1]))
for ipaddr in self.subscribers[subscriber]['ipaddrs']:
logging.info('Remove IP {} vlan {} {}'.format(ipaddr, subscriber[0], subscriber[1]))
self.rpc.del_ipaddr(str(ipaddr), subscriber[0], subscriber[1])
for subscriber in newsubs:
logging.info("Add subscriber {} {}".format(subscriber[0], subscriber[1]))
for ipaddr in db_subs[subscriber]['ipaddrs']:
logging.info('Add IP {} vlan {} {}'.format(ipaddr, subscriber[0], subscriber[1]))
self.rpc.add_ipaddr(str(ipaddr), subscriber[0], subscriber[1])
logging.info("Adjust bandwidth for {} {} to {} kbit".format(subscriber[0], subscriber[1],
db_subs[subscriber]['ratelimit'] // 1024))
self.rpc.set_ratelimit(db_subs[subscriber]['ratelimit'], subscriber[0], subscriber[1])
logging.info("{} subscriber {} {}".format(
'Unblock' if db_subs[subscriber]['blocked'] else "Block",
subscriber[0], subscriber[1]
))
self.rpc.set_block(db_subs[subscriber]['blocked'], subscriber[0], subscriber[1])
for subscriber in chksubs:
diff_field = [x for x in db_subs[subscriber].keys() if
self.subscribers[subscriber][x] != db_subs[subscriber][x]]
if 'ipaddrs' in diff_field:
ipaddr_diff = db_subs[subscriber]['ipaddrs'] - self.subscribers[subscriber]['ipaddrs']
for ipaddr in ipaddr_diff:
logging.info('Add IP {} vlan {} {}'.format(ipaddr, subscriber[0], subscriber[1]))
self.rpc.add_ipaddr(str(ipaddr), subscriber[0], subscriber[1])
ipaddr_diff = self.subscribers[subscriber]['ipaddrs'] - db_subs[subscriber]['ipaddrs']
for ipaddr in ipaddr_diff:
logging.info('Remove IP {} vlan {} {}'.format(ipaddr, subscriber[0], subscriber[1]))
self.rpc.del_ipaddr(str(ipaddr), subscriber[0], subscriber[1])
if 'ratelimit' in diff_field:
logging.info("Adjust bandwidth for {} {} to {} kbit".format(subscriber[0], subscriber[1],
db_subs[subscriber]['ratelimit'] // 1024))
self.rpc.set_ratelimit(db_subs[subscriber]['ratelimit'], subscriber[0], subscriber[1])
if 'blocked' in diff_field:
logging.info("{} subscriber {} {}".format(
'Unblock' if db_subs[subscriber]['blocked'] else "Block",
subscriber[0], subscriber[1]
))
self.rpc.set_block(db_subs[subscriber]['blocked'], subscriber[0], subscriber[1])
self.subscribers = db_subs
def load_config(fname):
global config
global servers
# re_range = re.compile(r"^([0-9]+-[0-9]+)(?:,([0-9]+-[0-9]+))$")
config = yaml.safe_load(open(fname, 'r'))
servers = []
for srv in config['servers']:
servers.append(NAS(srv['rpc_ipaddr'], srv['rpc_port'], srv['hmac_key'], srv['networks']))
del (config['servers'])
def get_db_subscribers():
""" Get all subscribers from db"""
cnx = mysql.connector.connect(user=config['db_user'], password=config['db_password'],
host=config['db_host'],
database=config['db_schema'])
cursor = cnx.cursor(dictionary=True, buffered=True)
query = 'SELECT ip, spvlan, vlan as cvlan, (rx_speed * 1024 DIV {0} * {0}) as rate , (mode > 0) as blocked' \
' FROM {1} limit 50;'.format(POL_QUANTUM, config['db_table'])
cursor.execute(query)
subscribers = dict()
for record in cursor:
subscribers[(record['spvlan'], record['cvlan'])] = {
'ipaddrs': set([IPv4Address(x) for x in record['ip'].split(',')]),
'ratelimit': record['rate'],
'blocked': record['blocked']
}
cnx.close()
return subscribers
if __name__ == '__main__':
load_config('config.yaml')
logging.basicConfig(level=logging.DEBUG)
# pprint(servers[0].get_active_subscribers())
is_active = True
lastrun = 0
cache_expire_counter = 0
while is_active:
current_time = time.time()
if current_time - lastrun > config['sync_interval']:
lastrun = current_time
db_subscribers = get_db_subscribers()
for server in servers:
try:
logging.debug("Run {}".format(cache_expire_counter))
server.sync(db_subscribers,
discard_cache=True if cache_expire_counter == 0 else False)
cache_expire_counter = (cache_expire_counter + 1) % config['cache_expire']
except Exception as e:
print("Exception while syncing server {} : {}".format(server.ipaddr, e.args))
else:
time.sleep(1)
#!/usr/bin/python3
import struct
import yaml
import os
import Pyro4
from ipaddress import IPv4Address, IPv4Network
from pyroute2 import IPRoute
from pyroute2 import NetlinkError
from pyroute2 import IPBatch
from pyroute2 import IPRSocket
from pyroute2.netlink.rtnl import RTM_GETADDR
from pyroute2.netlink.rtnl import RTM_NEWADDR
from pyroute2.netlink import NLM_F_ROOT
from pyroute2.netlink import NLM_F_REQUEST
from pyroute2.netlink import NLMSG_DONE
from socket import AF_INET, ntohl
_config = None
ipr = IPRoute()
def load_config(fname):
global _config
# re_range = re.compile(r"^([0-9]+-[0-9]+)(?:,([0-9]+-[0-9]+))$")
_config = yaml.load(open(fname, 'r'))
for net in _config['networks']:
net['net'] = IPv4Network(net['net'])
def parse_ifa_data(data, result):
"""
Parse response to RTM_GETADDR and append result with (ifname, ipaddr) tuples
On systems with thousands of interfaces manually parsing nlmsgs is 10x faster than doing it with pyroute2
"""
NLMSG_HDRLEN = 16 # nlmsg header length
IFADDRMSG_HDRLEN = 8 # ifaddrmsg header length
NLA_ALIGN = 4 # All NLAs in netlink messages are aligned on 4-byte boundary
NLA_HDRLEN = 4 # 16 bit length + 16 bit NLA type
# NLA type
IFA_ADDRESS = 0x1 # Remote address
IFA_LOCAL = 0x2 # Local address
IFA_LABEL = 0x3 # interface name
offset = 0
allparts = False
while offset <= len(data) - 16:
msglen, msgtype = struct.unpack_from('IH', data, offset=offset)
msgborder = msglen + offset
if msgborder > len(data):
raise Exception("Buffer overflow")
# It is very unlikely to have non-multpart response to RTM_GETADDR
# So such case will not be checked
if msgtype == NLMSG_DONE: # All parts received. Nothing to do.
allparts = True
break
if msgtype == RTM_NEWADDR:
addr = None
local = None
label = None
offset += NLMSG_HDRLEN + IFADDRMSG_HDRLEN
while offset <= msgborder - NLA_HDRLEN: # Should always be able to get nlmsg_header
nlalen, nlatype = struct.unpack_from('HH', data, offset=offset)
if nlatype == IFA_ADDRESS:
addr, = struct.unpack_from('I', data, offset=offset + NLA_HDRLEN)
elif nlatype == IFA_LOCAL:
local, = struct.unpack_from('I', data, offset=offset + NLA_HDRLEN)
elif nlatype == IFA_LABEL: # last byte is NUL
label, = struct.unpack_from('{}s'.format(nlalen - NLA_HDRLEN - 1), data, offset=offset + NLA_HDRLEN)
offset += (nlalen + NLA_ALIGN - 1) & ~ (NLA_ALIGN - 1) # All NLAs are aligned on 4-byte boundary
if (addr is not None) and (local is not None) and (addr != local) and (label is not None):
result.append((label.decode(), ntohl(addr)))
return allparts
def dump_addr_table():
ipb = IPBatch()
ipb.addr((RTM_GETADDR, NLM_F_REQUEST | NLM_F_ROOT), family=AF_INET)
data = ipb.batch
s = IPRSocket()
s.sendto(data, (0, 0))
allrecv = False
addrs = []
while not allrecv:
data = s.recv(65336)
allrecv = parse_ifa_data(data, addrs)
return addrs
def dump_ratelimit_set(setname: str):
""" Dump specified ratelimit set as str list"""
if os.path.isfile("/proc/net/ipt_ratelimit/{}".format(setname)):
f = open("/proc/net/ipt_ratelimit/{}".format(setname), 'r')
return f.readlines()
else:
raise ValueError("Specified ratelimit set does not exist")
def get_interface_index(ifname: str):
""" Get ifindex by name. Return none if specified iface doesn`t exist """
try:
ifindex = ipr.link("get", ifname=ifname)[0]['index']
return ifindex
except NetlinkError:
return None
def create_iface(spvlan: int, cvlan: int):
"""
Create subscriber interface. Also creates supervlan interface if needed.
Returns vlan subif index
"""
spvlan = int(spvlan)
cvlan = int(cvlan)
if not ((1 < spvlan < 4092) and (1 < cvlan < 4092)):
raise ValueError("vlan id must be between 1 and 4092")
id_cvlan = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if id_cvlan is not None:
return id_cvlan # Interface already exists
id_cvlan = get_interface_index('b_sub{}.{}'.format(spvlan, cvlan))
if id_cvlan is not None:
user_unblock(spvlan, cvlan)
return id_cvlan # Interface already exists
# Check if spvlan iface exists and create it if necessary
id_spvlan = get_interface_index('vlan{}'.format(spvlan))
if not id_spvlan:
for iface in _config['interfaces']:
if iface['vlan-range'][0] <= spvlan < iface['vlan-range'][1]:
phy = get_interface_index(iface['ifname'])
break
if not phy:
raise ValueError('Wrong physical interface: {}'.format(_config['phy']))
ipr.link("add", ifname="vlan{}".format(spvlan), kind='vlan', vlan_id=spvlan, link=phy)
id_spvlan = get_interface_index('vlan{}'.format(spvlan))
if not id_spvlan:
raise Exception("Failed to create spvlan {}!".format(spvlan))
# Backup bras would have master interface down
# Bringing slave interface up will generate Error 100 "Network is down",
# which can be safely ignored in this case
try:
ipr.link("set", index=id_spvlan, state='up')
except NetlinkError as e:
if e.code != 100:
raise
# Create cvlan interface
ipr.link("add", ifname="sub{}.{}".format(spvlan, cvlan), kind='vlan',
vlan_id=cvlan, link=id_spvlan)
id_cvlan = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if not id_cvlan:
raise Exception("Failed to create subscriber interface {} {}".format(spvlan, cvlan))
try:
ipr.link("set", index=id_cvlan, state='up')
except NetlinkError as e: # See above for Error 100
if e.code != 100:
raise
# PPS Limits for common DDoS types
if 'pps_dns' in _config and os.path.isfile('/proc/net/ipt_ratelimit/pps_dns'):
with open('/proc/net/ipt_ratelimit/pps_dns', 'w') as f:
f.write("@+{} {}\n".format(id_cvlan, _config['pps_dns']))
if 'pps_ntp' in _config and os.path.isfile('/proc/net/ipt_ratelimit/pps_ntp'):
with open('/proc/net/ipt_ratelimit/pps_ntp', 'w') as f:
f.write("@+{} {}\n".format(id_cvlan, _config['pps_ntp']))
if 'pps_dhcp' in _config and os.path.isfile('/proc/net/ipt_ratelimit/pps_dhcp'):
with open('/proc/net/ipt_ratelimit/pps_dhcp', 'w') as f:
f.write("@+{} {}\n".format(id_cvlan, _config['pps_dhcp']))
return id_cvlan
def ipaddr_add(ipaddr: str, spvlan: int, cvlan: int):
""" Attach subscriber IP to interface"""
spvlan = int(spvlan)
cvlan = int(cvlan)
if not ((1 < spvlan < 4092) and (1 < cvlan < 4092)):
raise ValueError("vlan id must be between 1 and 4092")
ipaddr = IPv4Address(ipaddr)
ifidx = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
ifidx = get_interface_index('b_sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
raise ValueError("Specified interface does not exist")
gw = None
for net in _config['networks']:
if ipaddr in net['net']:
gw = net['gw']
if not gw:
raise ValueError("Specified IP address {} does not belong to any subscriber network".format(ipaddr))
try:
ipr.addr('add', address=str(ipaddr), mask=32, index=ifidx, local=gw)
except NetlinkError as e:
if e.code != 17: # IP alreasy exists on interface. Can be safely ignored.
raise
def ipaddr_delete(ipaddr: str, spvlan: int, cvlan: int):
""" Remove subscriber IP to interface"""
spvlan = int(spvlan)
cvlan = int(cvlan)
if not ((1 < spvlan < 4092) and (1 < cvlan < 4092)):
raise ValueError("vlan id must be between 1 and 4092")
ipaddr = IPv4Address(ipaddr)
ifidx = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
ifidx = get_interface_index('b_sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
raise ValueError("Specified interface does not exist")
gw = None
for net in _config['networks']:
if ipaddr in net['net']:
gw = net['gw']
if not gw:
raise ValueError("Specified IP address does not belong to any subscriber network")
ipr.addr('del', address=str(ipaddr), mask=32, index=ifidx, local=gw)
def set_ratelimit(rate, spvlan, cvlan):
""" Limit bandwidth on user vlan"""
spvlan = int(spvlan)
cvlan = int(cvlan)
if not ((1 < spvlan < 4092) and (1 < cvlan < 4092)):
raise ValueError("vlan id must be between 1 and 4092")
rate = int(rate)
if rate <= 0:
raise ValueError("Rate must me positive integer")
ifidx = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
ifidx = get_interface_index('b_sub{}.{}'.format(spvlan, cvlan))
if not ifidx:
raise ValueError("Specified interface does not exist")
if not ('rtset_down' in _config and os.path.isfile('/proc/net/ipt_ratelimit/{}'.format(_config['rtset_down']))):
raise FileNotFoundError("Downstream rtset does not exist")
if not ('rtset_up' in _config and os.path.isfile('/proc/net/ipt_ratelimit/{}'.format(_config['rtset_up']))):
raise FileNotFoundError("Upstream rtset does not exist")
with open('/proc/net/ipt_ratelimit/{}'.format(_config['rtset_down']), 'w') as f:
f.write("@+{} {}\n".format(ifidx, rate))
with open('/proc/net/ipt_ratelimit/{}'.format(_config['rtset_up']), 'w') as f:
f.write("@+{} {}\n".format(ifidx, rate))
def user_block(spvlan: int, cvlan: int):
""" Block user """
iface_idx = get_interface_index('sub{}.{}'.format(spvlan, cvlan))
if iface_idx:
ipr.link('set', index=iface_idx, state='down')
ipr.link('set', index=iface_idx, ifname='b_sub{}.{}'.format(spvlan, cvlan))
ipr.link('set', index=iface_idx, state='up')
def user_unblock(spvlan: int, cvlan: int):
""" Unblock user """
iface_idx = get_interface_index('b_sub{}.{}'.format(spvlan, cvlan))
if iface_idx:
ipr.link('set', index=iface_idx, state='down')
ipr.link('set', index=iface_idx, ifname='sub{}.{}'.format(spvlan, cvlan))
ipr.link('set', index=iface_idx, state='up')
@Pyro4.expose
class NasdRemote(object):
def __init__(self):
self.syncts = 0
@staticmethod
def test():
return "test"
def get_syncts(self):
return self.syncts
def update_syncts(self, ts):
self.syncts = int(ts)
@staticmethod
def add_ipaddr(ipaddr:str, spvlan: int, cvlan: int):
""" Add IP to subscribers's vlan"""
create_iface(spvlan, cvlan)
ipaddr_add(ipaddr, spvlan, cvlan)
@staticmethod
def del_ipaddr(ipaddr:str, spvlan: int, cvlan: int):
""" remove IP to subscribers's vlan"""
ipaddr_delete(ipaddr, spvlan, cvlan)
@staticmethod
def dump_addresses():
return dump_addr_table()
@staticmethod
def dump_ratelimit():
return dump_ratelimit_set(_config['rtset_down'])
@staticmethod
def set_ratelimit(rate, spvlan, cvlan):
set_ratelimit(rate, spvlan, cvlan)
@staticmethod
def set_block(blocked, spvlan, cvlan):
""" Block/unblock subscriber """
if blocked == 0:
user_unblock(spvlan, cvlan)
else:
user_block(spvlan, cvlan)
if __name__ == "__main__":
load_config('/etc/nasd.yaml')
daemon = Pyro4.Daemon(host=_config['rpc_ipaddr'], port=_config['rpc_port'])
daemon._pyroHmacKey = _config['hmac_key']
uri = daemon.register(NasdRemote, 'nasd')
print("Managed interfaces:")
for iface in _config['interfaces']:
print("Iface: {} => SVLANs {}-{}".format(iface['ifname'],iface['vlan-range'][0],iface['vlan-range'][1] ))
#if iface['vlan-range'][0] <= spvlan < iface['vlan-range'][1]:
# phy = get_interface_index(iface['ifname'])
# break
print("Ready. Object uri =", uri)
daemon.requestLoop()
#include <iostream>
#include <unordered_map>
#include <thread>
#include <tins/tins.h>
#include <thread>
#include <iterator>
#include <netinet/in.h>
#include <linux/if_packet.h>
#include <sys/stat.h>
#include <fstream>
#include <utility>
#include <regex>
#include <atomic>
#include "log4cpp/Category.hh"
#include "log4cpp/Appender.hh"
#include "log4cpp/OstreamAppender.hh"
#include "log4cpp/Layout.hh"
#include "log4cpp/BasicLayout.hh"
#include "log4cpp/Priority.hh"
#include "log4cpp/PatternLayout.hh"
#include <net/if.h>
#include <net/ethernet.h>
#include <stdio.h>
#include <sys/socket.h>
#include <unistd.h>
using namespace std;
log4cpp::Category& logroot = log4cpp::Category::getRoot();
struct iface_addr {
Tins::IPv4Address ipaddr;
Tins::IPv4Address netmask;
Tins::IPv4Address gateway;
bool multiple_addrs = false;
};
typedef std::unordered_map<int, iface_addr> iface_addrs;
struct _addr_table {
std::array<iface_addrs, 2> ipaddrs;
std::array<iface_addrs, 2> stbaddrs;
std::atomic<int> active_index;
} addr_table;
struct network {
Tins::IPv4Range adrange = Tins::IPv4Range(0,0); // Subscriber ipaddress
Tins::IPv4Address netmask; // Netmask
Tins::IPv4Address gateway;
bool stb = false;
};
struct stboption {
std::string vendor_class;
std::vector<uint8_t> opt43;
};
struct _config { // Global struct holding the config
// MAIN
std::vector<network> networks; // Supernets
std::vector<Tins::IPv4Address> nameservers;
uint32_t lease_time;
// STB DHCP Vendor Options
std::vector<stboption> stboptions;
} config;
struct request {
int ifindex = 0;
bool broadcast = false;
uint32_t xid = 0;
Tins::IPv4Address req_ip = 0;
Tins::IPv4Address ciaddr = 0;
Tins::HWAddress<ETHER_ADDR_LEN> chaddr;
uint8_t message_type = 0;
uint16_t secs;
std::string vendor_class;
};
void maintain_addr_map() {
int new_index = addr_table.active_index ^ 0x1;
FILE * routefile;
while(true) {
iface_addrs new_ipaddrs, new_stbaddrs;
logroot.debug("Refreshing address map");
char buf[256];
char ifname[IF_NAMESIZE];
uint32_t raw_ipaddr, prefix;
routefile = fopen("/proc/net/route","r");
fgets(buf, sizeof(buf), routefile);
while (fgets(buf, sizeof(buf), routefile) != NULL) {
if (sscanf(buf, "%s\t%x\t%*x\t%*x\t%*x\t%*x\t%*x\t%x", ifname, &raw_ipaddr, &prefix) == 3) {
if (prefix != 0xFFFFFFFF) continue; // Not a host route. Skip.
int ifindex = if_nametoindex(ifname);
if (!ifindex) {
logroot.error("Iface index not found: %s", ifname);
continue;
}
iface_addr new_addr;
Tins::IPv4Address ipaddr(raw_ipaddr);
for( network net : config.networks) {
if (net.adrange.contains(ipaddr)) {
new_addr.ipaddr = ipaddr;
new_addr.netmask = net.netmask;
new_addr.gateway = net.gateway;
iface_addrs &addrs = (!net.stb) ? new_ipaddrs : new_stbaddrs;
iface_addrs::const_iterator iter = addrs.find(ifindex);
if (iter == addrs.end()) {
addrs[ifindex] = new_addr;
} else {
addrs[ifindex].multiple_addrs = true;
}
break;
}
}
}
}
fclose(routefile);
addr_table.ipaddrs[new_index].swap(new_ipaddrs);
addr_table.stbaddrs[new_index].swap(new_stbaddrs);
addr_table.active_index = new_index;
logroot.debug("Addrs: %i",addr_table.ipaddrs[addr_table.active_index].size());
std::this_thread::sleep_for(std::chrono::seconds(60));
}
}
int parse_config() {
std::vector<std::regex> cfg_regexes = {
std::regex(R"(^(lease_time)=(\d+))"),
std::regex(R"(^(nameserver)=(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}))"),
std::regex(R"(^(network)=(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\/)"
R"((\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\/)"
R"((\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}))"),
std::regex(R"(^(stb_network)=(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\/)"
R"((\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\/)"
R"((\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}))"),
std::regex(R"(^(vendor)=\"([^\"]+)\":([0-9A-Fa-f]+))"),
};
std::smatch m;
struct stat statbuf;
std::string cfgfile = "/etc/nasdhcp.conf";
std::ifstream cfg;
if (stat(cfgfile.c_str(), &statbuf) != 0) {
cfgfile = "vpcdhcp.conf";
if (stat(cfgfile.c_str(), &statbuf) != 0) {
logroot.fatal("No config file found!");
return -1; /* No config file found */
}
}
cfg.open(cfgfile, std::ios::in);
std::string line;
while (std::getline(cfg, line)) {
for (std::regex re : cfg_regexes) {
std::regex_search(line, m, re);
if (m.size() > 1) {
if (m[1] == "lease_time") {
config.lease_time = (uint32_t) std::stoul(m[2]);
}
else if (m[1] == "nameserver") {
config.nameservers.push_back(Tins::IPv4Address(m[2]));
}
else if (m[1] == "network") {
network net;
net.adrange = Tins::IPv4Range::from_mask(Tins::IPv4Address(m[2]), Tins::IPv4Address(m[3]));
net.netmask = Tins::IPv4Address(m[3]);
net.gateway = Tins::IPv4Address(m[4]);
config.networks.push_back(net);
}
else if (m[1] == "stb_network") {
network net;
net.adrange = Tins::IPv4Range::from_mask(Tins::IPv4Address(m[2]), Tins::IPv4Address(m[3]));
net.netmask = Tins::IPv4Address(m[3]);
net.gateway = Tins::IPv4Address(m[4]);
net.stb = true;
config.networks.push_back(net);
}
else if (m[1] == "vendor") {
if (m[3].length() % 2) {
logroot.error("Vendor data: number of bytes must be even");
return -1;
}
stboption sopt;
sopt.vendor_class = m[2];
char hexbyte[2];
uint8_t temp;
std::stringstream ss(m[3]);
while (ss.read(hexbyte,2)) {
sscanf(hexbyte,"%02hhx", &temp);
sopt.opt43.push_back(temp);
}
config.stboptions.push_back(sopt);
}
break;
}
}
}
cfg.close();
return 0;
}
std::string msgtype(int i) {
switch (i) {
case Tins::DHCP::DISCOVER:
return "DISCOVER";
break;
case Tins::DHCP::REQUEST:
return "REQUEST";
break;
case Tins::DHCP::OFFER:
return "OFFER";
break;
case Tins::DHCP::ACK:
return "ACK";
break;
case Tins::DHCP::NAK:
return "NAK";
break;
default:
return "UNKNOWN";
break;
}
}
int recv_dhcp_packet(int sockrcv, request &req) {
uint8_t buf[65536];
uint8_t aux_buf[CMSG_SPACE(2048)];
struct iovec riov[1];
riov[0].iov_base = buf;
riov[0].iov_len = sizeof(buf);
struct msghdr msg;
memset((void *) &msg, 0, sizeof(msg));
msg.msg_iov = riov;
msg.msg_iovlen = 1;
msg.msg_control = aux_buf;
msg.msg_controllen = sizeof(aux_buf);
int bytesrcv = recvmsg(sockrcv, &msg, 0);
if (bytesrcv > 0) {
// Retrieve ifindex from ancillary data
struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if ((cmsg->cmsg_level == IPPROTO_IP) && (cmsg->cmsg_type == IP_PKTINFO)) {
struct in_pktinfo *pinfo = (struct in_pktinfo *)CMSG_DATA(cmsg);
req.ifindex = pinfo->ipi_ifindex;
// req.broadcast = (pinfo->ipi_addr.s_addr == INADDR_BROADCAST);
}
}
// Process DHCP Request
Tins::DHCP dhcp_packet;
try {
dhcp_packet = Tins::DHCP((uint8_t *)riov[0].iov_base, bytesrcv);
}
catch (Tins::malformed_packet) {
logroot.error("Malformed packet received");
return -1;
}
catch (...) {
return -1; // DHCP server should not crash because of some garbage data sent to it
}
if (dhcp_packet.opcode() != Tins::BootP::BOOTREQUEST) {
return -1; // Packet is not request
}
// Get relevant field from request
try {
req.message_type = dhcp_packet.type();
}
catch (Tins::option_not_found) {
return -1; // Non-DHCP packet
}
if ((req.message_type != Tins::DHCP::DISCOVER) && (req.message_type != Tins::DHCP::REQUEST)) {
return -1; // Server should only reply to DISCOVER and REQUEST packets
}
req.xid = dhcp_packet.xid();
req.ciaddr = dhcp_packet.ciaddr();
req.secs = dhcp_packet.secs();
req.broadcast = (dhcp_packet.padding() == 0x8000);
try {
req.req_ip = dhcp_packet.requested_ip();
}
catch (Tins::option_not_found) {
req.req_ip = 0;
}
req.chaddr = dhcp_packet.chaddr();
const Tins::DHCP::option *opt60 = dhcp_packet.search_option(Tins::DHCP::VENDOR_CLASS_IDENTIFIER);
if (opt60 && (opt60->data_size() > 0)) {
req.vendor_class = std::string((char *)(opt60->data_ptr()), opt60->data_size());
}
}
logroot.infoStream() << "Received " << msgtype(req.message_type) << " on iface " << req.ifindex
<< ": xid: " << std::hex << req.xid << " chaddr: " << req.chaddr.to_string()
<< " req-ip: " << req.req_ip.to_string()
<< " ciaddr: " << req.ciaddr.to_string()
<< " vendor-class: " << req.vendor_class;
return 0;
}
const Tins::DHCP::option opt0(0);
int dhcp_send_reply(int sockraw, request &req) {
std::vector<stboption>::const_iterator opt43it;
bool is_stb = false;
for (opt43it = config.stboptions.cbegin(); opt43it != config.stboptions.cend(); ++opt43it) {
if (opt43it->vendor_class == req.vendor_class) {
is_stb = true;
break;
}
}
const iface_addrs &addr_map = (!is_stb) ? addr_table.ipaddrs[addr_table.active_index] :
addr_table.stbaddrs[addr_table.active_index];
iface_addrs::const_iterator iter = addr_map.find(req.ifindex);
if (iter == addr_map.end()) {
logroot.debug("No address found for iface %i", req.ifindex);
return -1; // No ip addresses found for this iface. Ignore packet.
}
iface_addr iaddr = iter->second;
if(iaddr.multiple_addrs) { // Multiple ip addresses on iface. Ignore packet.
logroot.debug("Multiple address found for iface %i", req.ifindex);
return -1;
}
Tins::IPv4Address &yiaddr = iaddr.ipaddr;
Tins::IPv4Address &siaddr = iaddr.gateway;
auto pkt = Tins::IP(iaddr.ipaddr, siaddr) / Tins::UDP(68, 67) / Tins::DHCP();
if (req.broadcast)
pkt.dst_addr(Tins::IPv4Address::broadcast);
// DHCP Header
Tins::DHCP &dhcpdata = pkt.rfind_pdu<Tins::DHCP>();
dhcpdata.opcode(Tins::BootP::BOOTREPLY);
dhcpdata.xid(req.xid);
dhcpdata.secs(req.secs);
dhcpdata.chaddr(req.chaddr);
// DHCP Options
if(req.message_type == Tins::DHCP::DISCOVER) {
dhcpdata.type(Tins::DHCP::OFFER);
}
else if(req.message_type == Tins::DHCP::REQUEST) {
if ((req.req_ip && (req.req_ip != iaddr.ipaddr)) || (req.ciaddr && (req.ciaddr != iaddr.ipaddr))) {
dhcpdata.type(Tins::DHCP::NAK);
pkt.dst_addr(Tins::IPv4Address::broadcast);
}
else {
dhcpdata.type(Tins::DHCP::ACK);
}
}
if (dhcpdata.type() != Tins::DHCP::NAK) {
dhcpdata.yiaddr(yiaddr);
dhcpdata.siaddr(siaddr);
dhcpdata.subnet_mask(iaddr.netmask);
dhcpdata.routers({siaddr});
dhcpdata.domain_name_servers(config.nameservers);
dhcpdata.lease_time(config.lease_time);
if(is_stb) {
Tins::DHCP::option opt43(Tins::DHCP::VENDOR_ENCAPSULATED_OPTIONS, opt43it->opt43.size(), opt43it->opt43.data());
dhcpdata.add_option(opt43);
}
}
dhcpdata.server_identifier(Tins::IPv4Address(siaddr));
dhcpdata.end();
// PAD DHCP packet to at least 300 bytes
int pad = 300 - dhcpdata.size();
while (pad > 0) {
dhcpdata.add_option(opt0);
pad -= 2;
}
Tins::byte_array buf = pkt.serialize();
// Send reply packet
struct sockaddr_ll rawaddr;
memset((void *)&rawaddr,0,sizeof(rawaddr));
rawaddr.sll_ifindex = req.ifindex;
rawaddr.sll_family = AF_PACKET;
rawaddr.sll_protocol = htons(ETH_P_IP);
rawaddr.sll_halen = ETHER_ADDR_LEN;
if (pkt.dst_addr() == Tins::IPv4Address::broadcast) {
memset(rawaddr.sll_addr, 0xFF, ETHER_ADDR_LEN);
}
else {
req.chaddr.copy(std::begin(rawaddr.sll_addr));
}
int bytes_sent = sendto(sockraw, buf.data(), buf.size(), 0, (struct sockaddr *) &rawaddr, sizeof(rawaddr));
if (bytes_sent) {
logroot.infoStream() << "Sent " << msgtype(dhcpdata.type()) << " on iface " << req.ifindex
<< ": xid: " << std::hex << dhcpdata.xid()
<< " yiaddr: " << Tins::IPv4Address(iaddr.ipaddr).to_string()
<<" chaddr: " << req.chaddr.to_string();
} else {
logroot.error("Failed to send DHCP response on iface %i", req.ifindex);
}
}
void process_dhcp() {
// Init listening socket
int res;
int sockrcv = socket(AF_INET, SOCK_DGRAM, 0);
struct sockaddr_in localaddr;
memset((void *)&localaddr,0,sizeof(localaddr));
localaddr.sin_family = AF_INET;
localaddr.sin_port = htons(67);
localaddr.sin_addr.s_addr = htonl(INADDR_ANY);
int optval = 1;
if (setsockopt(sockrcv, SOL_SOCKET, SO_BROADCAST, &optval, sizeof(optval)) < 0) {
perror("Failed to set SO_BROADCAST on listening socket: ");
exit(EXIT_FAILURE);
}
optval = 1;
if (setsockopt(sockrcv, SOL_IP, IP_PKTINFO, &optval, sizeof(optval)) < 0) {
perror("Failed to set SO_BROADCAST on listening socket: ");
exit(EXIT_FAILURE);
}
if (bind(sockrcv, (struct sockaddr *) &localaddr, sizeof(localaddr)) < 0) {
perror("Failed to bind recv socket on 0.0.0.0:67");
exit(EXIT_FAILURE);
}
int sockraw = socket(AF_PACKET, SOCK_DGRAM, 0);
// Receive loop
while(true) {
request req;
res = recv_dhcp_packet(sockrcv, req);
if (res < 0) {
continue;
}
dhcp_send_reply(sockraw, req);
// break;
}
close(sockrcv);
}
int main(int argc, char *argv[])
{
// Setup logging
std::cout.setf(std::ios::unitbuf);
log4cpp::Appender *appender_console = new log4cpp::OstreamAppender("console", &std::cout);
log4cpp::PatternLayout *layout_console = new log4cpp::PatternLayout();
layout_console->setConversionPattern("[%p] %m%n");
appender_console->setLayout(layout_console);
logroot.addAppender(appender_console);
logroot.setPriority(log4cpp::Priority::INFO);
logroot.infoStream() << "nasdhcp started";
addr_table.active_index = 0;
if (parse_config() < 0) {
logroot.critStream() << "Failed to parse config!";
return 1;
}
std::thread t_am(maintain_addr_map);
process_dhcp();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment