Skip to content

Commit

Permalink
Refactor configuration implementation to use dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
blechschmidt committed Apr 13, 2024
1 parent 11e420b commit f598079
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 145 deletions.
125 changes: 117 additions & 8 deletions pallium/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,115 @@

import typing

from pallium import typeinfo
from pallium.exceptions import ConfigurationError
from pallium.sandbox import Sandbox
from . import typeinfo, hops
from .exceptions import ConfigurationError
from .sandbox import Sandbox

_T = typing.TypeVar('_T')
_primitive_types = [int, bool, str, float]


def json_serializable(cls: typing.Type[_T]) -> typing.Type[_T]:
"""
There are libs for this type of functionality, but we want to keep
dependencies low for security reasons.
"""

def json_value_to_instance(value, tp=None):
assert tp is not None, "Implementation requires type hinting."

if typing.get_origin(tp) == typing.Optional \
or typing.get_origin(tp) == typing.Union and len(typing.get_args(tp)) == 2 \
and type(None) in set(typing.get_args(tp)):
if value is None:
return None
# Unpack optional type
tp = typing.get_args(tp)[0]

if typing.get_origin(tp) == typing.Any:
return value

if tp in tuple(_primitive_types):
if not isinstance(value, tp):
raise ConfigurationError('Expected a %s' % tp)
return value

if typing.get_origin(tp) == list:
if not isinstance(value, list):
raise ConfigurationError('Expected a list')
return [json_value_to_instance(v, typing.get_args(tp)[0]) for v in value]

if not isinstance(value, dict):
raise ConfigurationError('Complex classes need to be deserialized from dict')

if hasattr(tp, 'from_json') and callable(tp.from_json):
return tp.from_json(value)

return from_json(tp, value)

def from_json(cls, json_data: typing.Dict[str, typing.Any]) -> _T:
constructor = {}
for key, value in json_data.items():
if key not in cls.__annotations__:
continue
attr_type = cls.__annotations__[key]
instance = json_value_to_instance(value, attr_type)
constructor[key] = instance
instance = cls(**constructor)

return instance

if not hasattr(cls, 'from_json'):
setattr(cls, 'from_json', classmethod(from_json))
return cls


class EthernetBridge:
def __init__(self, devices: typing.List[str], name: typing.Optional[str] = None):
self.name = name
self.devices = devices

@classmethod
def from_json(cls, obj):
return cls(**obj)


class Bridge:
def __init__(self, name: typing.Optional[str] = None,
routes: typing.List[typing.Union[ipaddress.ip_network, str]] = None,
dhcp: bool = False,
eth_bridge: typing.Optional[EthernetBridge] = None,
reserved_bypass: bool = True):
"""
A descriptor for building a bridge inside the main network namespace.
@param name: Bridge name. If unspecified, an automatically generated deterministic name is used.
@param routes: IP networks that should pass through the bridge.
@param dhcp: Whether a DHCP server should be started, providing clients with IP addresses.
@param eth_bridge: TODO.
@param reserved_bypass: Whether reserved addresses bypass the bridge.
"""
if routes is None:
routes = []
routes = list(map(ipaddress.ip_network, routes))
self.name = name
self.routes = routes
self.dhcp = dhcp
self.eth_bridge = eth_bridge
self.reserved_bypass = reserved_bypass

class FromJSON:
@classmethod
def from_json(cls, obj):
pass
if 'eth_bridge' in obj:
obj['eth_bridge'] = EthernetBridge.from_json(obj['eth_bridge'])
return cls(**obj)


@json_serializable
class LocalPortForwarding:
protocol: str # Either "tcp" or "udp"
host: (typeinfo.IPAddress, int)
guest: (typeinfo.IPAddress, int)
host: typing.Tuple[typeinfo.IPAddress, int]
guest: typing.Tuple[typeinfo.IPAddress, int]

def __init__(self, spec):
scheme, rest = spec.split('://', 1)
Expand Down Expand Up @@ -52,6 +146,7 @@ def from_json(cls, obj):
return cls(obj)


@json_serializable
@dataclasses.dataclass
class PortForwarding:
local: typing.List[LocalPortForwarding] = dataclasses.field(default_factory=list)
Expand All @@ -68,12 +163,26 @@ def from_json(cls, obj):
return result


@json_serializable
@dataclasses.dataclass
class Run:
command: typing.Optional[typing.List[str]] = dataclasses.field(default=None)
quiet: bool = dataclasses.field(default=False) # Whether to suppress status information of pallium and its helpers


@json_serializable
@dataclasses.dataclass
class Networking:
port_forwarding: PortForwarding = dataclasses.field(default_factory=PortForwarding)
chain: typing.List[hops.Hop] = dataclasses.field(default_factory=list)
bridge: typing.Optional[Bridge] = dataclasses.field(default=None)
routes: typing.Optional[typing.List[str]] = dataclasses.field(default=None)
kill_switch: bool = dataclasses.field(default=True)


@json_serializable
@dataclasses.dataclass
class Configuration:
networking: Networking = dataclasses.field(default_factory=Networking)
sandbox: Sandbox = None
sandbox: typing.Optional[Sandbox] = dataclasses.field(default_factory=Sandbox)
run: Run = dataclasses.field(default_factory=Run)
17 changes: 17 additions & 0 deletions pallium/hops/hop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import signal
import subprocess
import typing
from typing import Optional, Union, List

from pyroute2.iproute import IPRoute
Expand Down Expand Up @@ -109,6 +110,22 @@ def __init__(self, quiet=None, dns=None, **kwargs):
if 'required_routes' not in dir(self):
self.required_routes = []

@classmethod
def from_json(cls, value: typing.Dict[str, typing.Any]) -> 'Hop':
value = dict(value)
type2class = dict()
for hop_class in util.get_subclasses(cls):
class_name = hop_class.__name__
if hop_class.__name__.endswith('Hop'):
class_name = class_name[:-len('Hop')]
type2class[class_name.lower()] = hop_class

hop_type = value.pop('type')
hop_class = type2class.get(hop_type.lower())
if hop_class is None:
raise ""
return hop_class(**value)

def popen(self, *args, **kwargs):
"""Popen wrapper that keeps track of the started processes and handles command output.
Expand Down
16 changes: 8 additions & 8 deletions pallium/hops/socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class Tun2socksHop(hop.Hop):
def __init__(self, protocol, address, username=None, password=None, **kwargs):
super().__init__(**kwargs)
self._pid = None
self._address, self._port = util.convert2addr(address, self.default_port)
self._username, self._password = username, password
self.required_routes = [ipaddress.ip_network(self._address)]
self.address, self.port = util.convert2addr(address, self.default_port)
self.username, self.password = username, password
self.required_routes = [ipaddress.ip_network(self.address)]
self.protocol = protocol

def free(self):
Expand Down Expand Up @@ -55,14 +55,14 @@ def connect(self):

# https://github.com/xjasonlyu/tun2socks
url_credentials = ''
if self._username is not None:
url_credentials += urllib.parse.quote(self._username)
if self._password is not None:
url_credentials += ':' + urllib.parse.quote(self._password)
if self.username is not None:
url_credentials += urllib.parse.quote(self.username)
if self.password is not None:
url_credentials += ':' + urllib.parse.quote(self.password)
url_credentials += '@'

cmd = [self.get_tool_path('tun2socks'), '-device', 'tun0',
'-proxy', self.protocol + '://%s%s:%d' % (url_credentials, str(self._address), self._port)]
'-proxy', self.protocol + '://%s%s:%d' % (url_credentials, str(self.address), self.port)]

proc = self.popen(cmd)
self._pid = proc.pid
Expand Down
Loading

0 comments on commit f598079

Please sign in to comment.