diff --git a/pallium/config.py b/pallium/config.py index 703f6bb..f907675 100644 --- a/pallium/config.py +++ b/pallium/config.py @@ -4,21 +4,115 @@ import typing -from pallium import typeinfo -from pallium.exceptions import ConfigurationError -from pallium.sandbox import Sandbox +from . import typeinfo, hops +from .exceptions import ConfigurationError +from .sandbox import Sandbox +_T = typing.TypeVar('_T') +_primitive_types = [int, bool, str, float] + + +def json_serializable(cls: typing.Type[_T]) -> typing.Type[_T]: + """ + There are libs for this type of functionality, but we want to keep + dependencies low for security reasons. + """ + + def json_value_to_instance(value, tp=None): + assert tp is not None, "Implementation requires type hinting." + + if typing.get_origin(tp) == typing.Optional \ + or typing.get_origin(tp) == typing.Union and len(typing.get_args(tp)) == 2 \ + and type(None) in set(typing.get_args(tp)): + if value is None: + return None + # Unpack optional type + tp = typing.get_args(tp)[0] + + if typing.get_origin(tp) == typing.Any: + return value + + if tp in tuple(_primitive_types): + if not isinstance(value, tp): + raise ConfigurationError('Expected a %s' % tp) + return value + + if typing.get_origin(tp) == list: + if not isinstance(value, list): + raise ConfigurationError('Expected a list') + return [json_value_to_instance(v, typing.get_args(tp)[0]) for v in value] + + if not isinstance(value, dict): + raise ConfigurationError('Complex classes need to be deserialized from dict') + + if hasattr(tp, 'from_json') and callable(tp.from_json): + return tp.from_json(value) + + return from_json(tp, value) + + def from_json(cls, json_data: typing.Dict[str, typing.Any]) -> _T: + constructor = {} + for key, value in json_data.items(): + if key not in cls.__annotations__: + continue + attr_type = cls.__annotations__[key] + instance = json_value_to_instance(value, attr_type) + constructor[key] = instance + instance = cls(**constructor) + + return instance + + if not hasattr(cls, 'from_json'): + setattr(cls, 'from_json', classmethod(from_json)) + return cls + + +class EthernetBridge: + def __init__(self, devices: typing.List[str], name: typing.Optional[str] = None): + self.name = name + self.devices = devices + + @classmethod + def from_json(cls, obj): + return cls(**obj) + + +class Bridge: + def __init__(self, name: typing.Optional[str] = None, + routes: typing.List[typing.Union[ipaddress.ip_network, str]] = None, + dhcp: bool = False, + eth_bridge: typing.Optional[EthernetBridge] = None, + reserved_bypass: bool = True): + """ + A descriptor for building a bridge inside the main network namespace. + + @param name: Bridge name. If unspecified, an automatically generated deterministic name is used. + @param routes: IP networks that should pass through the bridge. + @param dhcp: Whether a DHCP server should be started, providing clients with IP addresses. + @param eth_bridge: TODO. + @param reserved_bypass: Whether reserved addresses bypass the bridge. + """ + if routes is None: + routes = [] + routes = list(map(ipaddress.ip_network, routes)) + self.name = name + self.routes = routes + self.dhcp = dhcp + self.eth_bridge = eth_bridge + self.reserved_bypass = reserved_bypass -class FromJSON: @classmethod def from_json(cls, obj): - pass + if 'eth_bridge' in obj: + obj['eth_bridge'] = EthernetBridge.from_json(obj['eth_bridge']) + return cls(**obj) +@json_serializable class LocalPortForwarding: protocol: str # Either "tcp" or "udp" - host: (typeinfo.IPAddress, int) - guest: (typeinfo.IPAddress, int) + host: typing.Tuple[typeinfo.IPAddress, int] + guest: typing.Tuple[typeinfo.IPAddress, int] def __init__(self, spec): scheme, rest = spec.split('://', 1) @@ -52,6 +146,7 @@ def from_json(cls, obj): return cls(obj) +@json_serializable @dataclasses.dataclass class PortForwarding: local: typing.List[LocalPortForwarding] = dataclasses.field(default_factory=list) @@ -68,12 +163,26 @@ def from_json(cls, obj): return result +@json_serializable +@dataclasses.dataclass +class Run: + command: typing.Optional[typing.List[str]] = dataclasses.field(default=None) + quiet: bool = dataclasses.field(default=False) # Whether to suppress status information of pallium and its helpers + + +@json_serializable @dataclasses.dataclass class Networking: port_forwarding: PortForwarding = dataclasses.field(default_factory=PortForwarding) + chain: typing.List[hops.Hop] = dataclasses.field(default_factory=list) + bridge: typing.Optional[Bridge] = dataclasses.field(default=None) + routes: typing.Optional[typing.List[str]] = dataclasses.field(default=None) + kill_switch: bool = dataclasses.field(default=True) +@json_serializable @dataclasses.dataclass class Configuration: networking: Networking = dataclasses.field(default_factory=Networking) - sandbox: Sandbox = None + sandbox: typing.Optional[Sandbox] = dataclasses.field(default_factory=Sandbox) + run: Run = dataclasses.field(default_factory=Run) diff --git a/pallium/hops/hop.py b/pallium/hops/hop.py index 8329fa6..72f6f97 100644 --- a/pallium/hops/hop.py +++ b/pallium/hops/hop.py @@ -4,6 +4,7 @@ import shutil import signal import subprocess +import typing from typing import Optional, Union, List from pyroute2.iproute import IPRoute @@ -109,6 +110,22 @@ def __init__(self, quiet=None, dns=None, **kwargs): if 'required_routes' not in dir(self): self.required_routes = [] + @classmethod + def from_json(cls, value: typing.Dict[str, typing.Any]) -> 'Hop': + value = dict(value) + type2class = dict() + for hop_class in util.get_subclasses(cls): + class_name = hop_class.__name__ + if hop_class.__name__.endswith('Hop'): + class_name = class_name[:-len('Hop')] + type2class[class_name.lower()] = hop_class + + hop_type = value.pop('type') + hop_class = type2class.get(hop_type.lower()) + if hop_class is None: + raise "" + return hop_class(**value) + def popen(self, *args, **kwargs): """Popen wrapper that keeps track of the started processes and handles command output. diff --git a/pallium/hops/socks.py b/pallium/hops/socks.py index 4392781..5c39a46 100644 --- a/pallium/hops/socks.py +++ b/pallium/hops/socks.py @@ -18,9 +18,9 @@ class Tun2socksHop(hop.Hop): def __init__(self, protocol, address, username=None, password=None, **kwargs): super().__init__(**kwargs) self._pid = None - self._address, self._port = util.convert2addr(address, self.default_port) - self._username, self._password = username, password - self.required_routes = [ipaddress.ip_network(self._address)] + self.address, self.port = util.convert2addr(address, self.default_port) + self.username, self.password = username, password + self.required_routes = [ipaddress.ip_network(self.address)] self.protocol = protocol def free(self): @@ -55,14 +55,14 @@ def connect(self): # https://github.com/xjasonlyu/tun2socks url_credentials = '' - if self._username is not None: - url_credentials += urllib.parse.quote(self._username) - if self._password is not None: - url_credentials += ':' + urllib.parse.quote(self._password) + if self.username is not None: + url_credentials += urllib.parse.quote(self.username) + if self.password is not None: + url_credentials += ':' + urllib.parse.quote(self.password) url_credentials += '@' cmd = [self.get_tool_path('tun2socks'), '-device', 'tun0', - '-proxy', self.protocol + '://%s%s:%d' % (url_credentials, str(self._address), self._port)] + '-proxy', self.protocol + '://%s%s:%d' % (url_credentials, str(self.address), self.port)] proc = self.popen(cmd) self._pid = proc.pid diff --git a/pallium/profiles.py b/pallium/profiles.py index b8300a7..a065036 100644 --- a/pallium/profiles.py +++ b/pallium/profiles.py @@ -20,6 +20,7 @@ from pyroute2.netlink.rtnl.ifaddrmsg import IFA_F_NODAD, IFA_F_NOPREFIXROUTE from pyroute2.netlink.exceptions import NetlinkError +import pallium.config from .nftables import NFTables from . import audio, debugging, resolvconf, runtime @@ -59,47 +60,6 @@ def list(running_only=False): return result -class EthernetBridge: - def __init__(self, devices: List[str], name: Optional[str] = None): - self.name = name - self.devices = devices - - @classmethod - def from_json(cls, obj): - return cls(**obj) - - -class Bridge: - def __init__(self, name: Optional[str] = None, - routes: List[Union[ipaddress.ip_network, str]] = None, - dhcp: bool = False, - eth_bridge: Optional[EthernetBridge] = None, - reserved_bypass: bool = True): - """ - A descriptor for building a bridge inside the main network namespace. - - @param name: Bridge name. If unspecified, an automatically generated deterministic name is used. - @param routes: IP networks that should pass through the bridge. - @param dhcp: Whether a DHCP server should be started, providing clients with IP addresses. - @param eth_bridge: TODO. - @param reserved_bypass: Whether reserved addresses bypass the bridge. - """ - if routes is None: - routes = [] - routes = list(map(ipaddress.ip_network, routes)) - self.name = name - self.routes = routes - self.dhcp = dhcp - self.eth_bridge = eth_bridge - self.reserved_bypass = reserved_bypass - - @classmethod - def from_json(cls, obj): - if 'eth_bridge' in obj: - obj['eth_bridge'] = EthernetBridge.from_json(obj['eth_bridge']) - return cls(**obj) - - class NetPool: """ This class keeps track of used IP addresses such that duplicate assignments are avoided. @@ -118,20 +78,10 @@ class Profile: debug = False quiet = None - def __init__(self, chain: List[hops.Hop], *, # The API is not stable yet. Prevent positional args. - start_networks: List[IPNetworkLike] = None, - user: Union[str, int] = None, - quiet: Optional[bool] = None, - bridge: Optional[Bridge] = None, - routes: List[IPNetworkLike] = None, - preexec_fn: Optional[List[Callable]] = None, - postexec_fn: Optional[List[Callable]] = None, - enter: bool = False, - kill_switch: bool = True, - mounts: Optional[List[MountInstruction]] = None, - sandbox: Sandbox = None, - command: Union[List[str], str, None] = None, - configuration: config.Configuration = None): + def _prepare_for_execution(self, configuration: config.Configuration): + pass + + def __init__(self, conf: config.Configuration): """ Initialize a pallium profile. @@ -148,26 +98,29 @@ def __init__(self, chain: List[hops.Hop], *, # The API is not stable yet. Preve @param kill_switch: When enabled, traffic is not allowed to bypass hops. """ self._filepath = None + chain = conf.networking.chain if len(chain) == 0 or not isinstance(chain[-1], DummyHop): chain.append(DummyHop()) self.chain = chain - if start_networks is not None: - raise NotImplementedError('start_networks is currently unsupported.') - else: - start_networks = [] + + # TODO: Include this in config + start_networks = [] self.start_networks = list(map(ipaddress.ip_network, start_networks)) self.netinfo = None - self.bridge = bridge - self._preexec_fn = [] if preexec_fn is None else preexec_fn - self._postexec_fn = [] if postexec_fn is None else postexec_fn - self.user = user + self.bridge = conf.networking.bridge + self._preexec_fn = [] + self._postexec_fn = [] + # TODO: Support + self.user = None self.netpool = NetPool() - self.command = command + self.command = conf.run.command + routes = conf.networking.routes if routes is not None: chain[-1].required_routes = list(map(ipaddress.ip_network, routes)) # If Profile.quiet is a boolean class property and no quiet value has been supplied to the constructor, # we use Profile.quiet as a default value. + quiet = conf.run.quiet if quiet is None and isinstance(Profile.quiet, bool): quiet = Profile.quiet self._set_quiet(quiet) @@ -175,19 +128,20 @@ def __init__(self, chain: List[hops.Hop], *, # The API is not stable yet. Preve for hop in self.chain: if self.debug: hop.debug = True - self._enter = enter + # TODO: Make use of this? + self._enter = False self._context_sessions = [] - self.kill_switch = kill_switch - self._mounts = mounts if mounts is not None else [] + self.kill_switch = conf.networking.kill_switch + self._mounts = [] self.has_connected_functions = False - self.sandbox = sandbox - if sandbox is not None: + self.sandbox = conf.sandbox + if self.sandbox is not None: self._mounts.extend(self.sandbox.get_mounts()) - if configuration is None: - configuration = config.Configuration() - self.config = configuration + if conf is None: + conf = config.Configuration() + self.config = conf # noinspection PyRedeclaration @property @@ -218,6 +172,9 @@ def from_config(cls, settings: dict) -> 'Profile': @param settings: The settings as dictionary. @return: A profile which was constructed according to the settings. """ + + return Profile(pallium.config.Configuration.from_json(settings)) + if 'chain' not in settings: settings['chain'] = [] @@ -254,13 +211,6 @@ def from_config(cls, settings: dict) -> 'Profile': ) ) - defaults = {} - for k in settings: - if not k.startswith('default_'): - continue - argname = k[len('default_'):] - defaults[argname] = settings[k] - type2class = dict() for hop_class in util.get_subclasses(hops.Hop): class_name = hop_class.__name__ @@ -279,10 +229,6 @@ def from_config(cls, settings: dict) -> 'Profile': if tp not in type2class: raise ConfigurationError('Invalid hop type: "%s"' % tp) - for k in defaults: - if util.supports_named_arg(type2class[tp], k): - hop_option.set_default(k, defaults[k]) - remove = [] for k in hop_option: if k == 'dns': @@ -311,8 +257,9 @@ def from_config(cls, settings: dict) -> 'Profile': del hop_option[r] # noinspection PyArgumentList - hop = type2class[tp](**hop_option) - chain.append(hop) + # hop = type2class[tp](**hop_option) + # chain.append(hop) + chain = [hops.Hop.from_json(h) for h in settings['chain']] profile_args['routes'] = settings.get('routes', None) profile_args['mounts'] = [] diff --git a/pallium/sandbox.py b/pallium/sandbox.py index eb2cae7..bd106a5 100644 --- a/pallium/sandbox.py +++ b/pallium/sandbox.py @@ -438,6 +438,10 @@ def _setup_audio(self): self.paths.append(BindMountExternal('/run/user/%d/pulse/' % security.EUID)) def _setup_gui(self, method): + original_display = os.environ.get('DISPLAY', None) + if original_display is not None and original_display.startswith(':'): + original_display = int(original_display[1:]) + if method == 'xpra' or method is True: display_no = xpra.start_xpra() display = ':%d' % display_no @@ -447,7 +451,7 @@ def _setup_gui(self, method): display = os.environ.get('DISPLAY', None) if display is not None and display.startswith(':'): display = int(display[1:]) - self.paths.append(BindMountExternal('/tmp/.X11-unix/X1', '/tmp/.X11-unix/X%d' % display)) + self.paths.append(BindMountExternal('/tmp/.X11-unix/X%s' % original_display, '/tmp/.X11-unix/X%d' % display)) def _etc_virtuser_mounts(self, virtual_user): pwd_struct = pwd.getpwuid(security.RUID) diff --git a/tests/test_api.py b/tests/test_api.py index aa8ca15..1f1689b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,6 +23,8 @@ from pallium.nftables import NFTables, NFPROTO_INET from pallium.profiles import Profile from provision import DigitalOceanProvisioner, Machine +from pallium.config import Configuration +import pallium.config as config class PalliumTestCase(unittest.TestCase): @@ -69,6 +71,14 @@ def wrapper(*args, **kwargs): return wrapper +def config_from_chain(chain): + return config.Configuration( + networking=config.Networking( + chain=chain + ) + ) + + class TestPythonInterface(PalliumTestCase): machines = [] provisioner = None @@ -161,7 +171,7 @@ def run_in_ns(): with NFTables(nfgen_family=NFPROTO_INET) as nft: nft.table('del', name='pyroute_nftables_test') - with Profile([]) as session: + with Profile(config.Configuration()) as session: session.execute(run_in_ns) @fresh_resolvconf @@ -169,7 +179,7 @@ def test_no_chain_no_dns(self): def run_in_ns(): return requests.get('https://1.1.1.1', allow_redirects=False) - with Profile([]) as session: + with Profile(Configuration()) as session: result = session.execute(run_in_ns) assert result.ok @@ -184,7 +194,7 @@ def test_no_chain(self): def run_in_ns(): return self.get_ip() - with Profile([]) as session: + with Profile(Configuration()) as session: result = session.execute(run_in_ns) ipaddress.ip_address(result) assert True @@ -198,7 +208,11 @@ def test_tor_simple(self): def check_tor(): return requests.get('https://check.torproject.org/api/ip').json() - with Profile([TorHop()]) as session: + with Profile(config.Configuration( + networking=config.Networking( + chain=[TorHop()] + ) + )) as session: result = session.execute(check_tor) assert result['IsTor'] @@ -222,7 +236,11 @@ def check_tor(): continue return False - with Profile([TorHop()]) as session: + with Profile(config.Configuration( + networking=config.Networking( + chain=[TorHop()] + ) + )) as session: assert session.execute(check_tor) def _test_ssh_url(self, url, dns=None, allow_redirects=True, return_response=False): @@ -236,8 +254,16 @@ def check_connectivity(): # We are seeing this server for the first time ssh_args = ['-o', 'StrictHostKeyChecking=no'] - with Profile([pallium.hops.ssh.SshHop(machine.get_ssh_destination(), ssh_args=ssh_args, dns=dns)], - quiet=True) as session: + with Profile( + config.Configuration( + networking=config.Networking( + chain=[ + pallium.hops.ssh.SshHop(machine.get_ssh_destination(), ssh_args=ssh_args, dns=dns) + ] + ), + run=config.Run(quiet=True) + ) + ) as session: result = session.execute(check_connectivity) return result @@ -254,20 +280,20 @@ def test_ssh_dns(self): @fresh_resolvconf def test_openvpn(self): - with Profile([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) as session: + with Profile(config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)])) as session: result = session.execute(self.get_ip) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @fresh_resolvconf def test_openvpn_ipv6(self): - with Profile([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) as session: + with Profile(config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)])) as session: result = session.execute(self.get_ipv6) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @fresh_resolvconf def test_tor_openvpn_ipv4(self): - with Profile([TorHop(), pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) as session: + with Profile(config_from_chain([TorHop(), pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)])) as session: result = session.execute(self.get_ipv4) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @@ -276,7 +302,7 @@ def test_tor_openvpn_ipv4(self): def test_socks5_dns_udp(self): machine_ip = self.machines[0].get_public_ips()[0] socks5 = pallium.hops.socks.SocksHop((machine_ip, 1080), 'pmtest', self.password, dns=['1.1.1.1']) - with Profile([socks5]) as session: + with Profile(config_from_chain([socks5])) as session: result = session.execute(self.get_ip) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @@ -287,7 +313,7 @@ def check_connectivity(): machine_ip = self.machines[0].get_public_ips()[0] http = pallium.hops.socks.HttpHop((machine_ip, 3128), 'pmtest', self.password) - with Profile([http]) as session: + with Profile(config_from_chain([http])) as session: result = session.execute(check_connectivity) assert result.ok @@ -296,7 +322,7 @@ def test_exception(self): def run(): raise TestException - with Profile([]) as session: + with Profile(config_from_chain([])) as session: try: session.execute(run) assert False @@ -309,27 +335,25 @@ def inside_ns(): with open('/etc/resolv.conf') as f: return f.read() - with Profile([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file, dns=['1.2.3.4'])]) as session: + with Profile(config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file, dns=['1.2.3.4'])])) as session: result = session.execute(inside_ns) assert '1.2.3.4' in result @fresh_resolvconf def test_bridge_ipv4(self): self.require_net_admin() - chain = [pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)] - bridge = pallium.profiles.Bridge(routes=['0.0.0.0/0']) - bridge.dhcp = False - with Profile(chain, bridge=bridge) as session: + conf = config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) + conf.networking.bridge = config.Bridge(routes=['0.0.0.0/0'], dhcp=False) + with Profile(conf) as session: result = session.execute(lambda: self.get_ip(4)) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @fresh_resolvconf def test_bridge_ipv6(self): self.require_net_admin() - chain = [pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)] - bridge = pallium.profiles.Bridge(routes=['::/0']) - bridge.dhcp = False - with Profile(chain, bridge=bridge) as session: + conf = config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) + conf.networking.bridge = config.Bridge(routes=['::/0'], dhcp=False) + with Profile(conf) as session: result = session.execute(lambda: self.get_ip(6)) assert ipaddress.ip_address(result) in self.machines[0].get_public_ips() @@ -346,8 +370,9 @@ def inside_ns(): raised = True assert raised - chain = [pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)] - with Profile(chain, routes=['0.0.0.0/0']) as session: + conf = config_from_chain([pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file)]) + conf.networking.routes = ['0.0.0.0/0'] + with Profile(conf) as session: session.execute(inside_ns) @fresh_resolvconf @@ -369,14 +394,15 @@ def _test_bridge_dhcp(self, test_dns, test_eth_bridge): bridge_name = None else: bridge_name = 'pmtestbri' - bridge = pallium.profiles.Bridge(name=bridge_name) + conf = config_from_chain(chain) + conf.networking.bridge = config.Bridge(name=bridge_name) if test_eth_bridge: with IPRoute() as ip: ip.link('add', ifname='pmtestbri', kind='dummy') - bridge.eth_bridge = pallium.profiles.EthernetBridge(devices=['pmtestbri']) - bridge.dhcp = True + conf.networking.bridge.eth_bridge = config.EthernetBridge(devices=['pmtestbri']) + conf.networking.bridge.dhcp = True try: - with Profile(chain, bridge=bridge): + with Profile(conf): read_from_child, write_to_parent = os.pipe() read_from_parent, write_to_child = os.pipe() pid = os.fork() @@ -437,7 +463,9 @@ def test_openvpn_kill_switch(self): """ my_ip = ipaddress.ip_address(self.get_ipv4()) ovpn_hop = pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file, dns=['1.1.1.1']) - with Profile([ovpn_hop], kill_switch=False) as session: + conf = config_from_chain([ovpn_hop]) + conf.networking.kill_switch = False + with Profile(conf) as session: vpn_ip = ipaddress.ip_address(session.execute(self.get_ipv4)) assert vpn_ip in self.machines[0].get_public_ips() @@ -455,7 +483,9 @@ def test_openvpn_kill_switch(self): assert my_ip == bypass_vpn_ip ovpn_hop = pallium.hops.openvpn.OpenVpnHop(config=self.ovpn_config_file, dns=['1.1.1.1']) - with Profile([ovpn_hop]) as session: + conf2 = config_from_chain([ovpn_hop]) + conf2.networking.kill_switch = True + with Profile(conf2) as session: vpn_ip = ipaddress.ip_address(session.execute(self.get_ipv4)) assert vpn_ip in self.machines[0].get_public_ips() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..16ad42a --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,57 @@ +import dataclasses +import typing +import unittest + +import pallium.config +import pallium.hops.tor + + +class PalliumTestCase(unittest.TestCase): + @staticmethod + def test_json_serializable(): + @pallium.config.json_serializable + @dataclasses.dataclass + class TestClass: + a: int + b: typing.List[int] + + instance = TestClass.from_json({'a': 1, 'b': [1, 2, 3]}) + assert instance.a == 1 + assert instance.b == [1, 2, 3] + + @staticmethod + def test_sandbox_config(): + conf = pallium.config.Configuration.from_json({ + 'sandbox': { + 'gui': True, + 'virtuser': '$tmp' + } + }) + assert isinstance(conf.sandbox, pallium.config.Sandbox) + + @staticmethod + def test_config(): + json = { + 'networking': { + 'chain': [ + { + 'type': 'socks', + 'address': '127.0.0.1:1080', + 'username': 'johndoe', + 'password': 'secret' + } + ] + } + } + config = pallium.config.Configuration.from_json(json) + + assert config.run.quiet is False + assert len(config.networking.chain) == 1 + assert isinstance(config.networking.chain[0], pallium.hops.socks.SocksHop) + assert config.networking.chain[0].username == 'johndoe' + assert config.networking.chain[0].password == 'secret' + assert config.networking.bridge is None + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py index d181c42..c9cd2ef 100644 --- a/tests/test_sandbox.py +++ b/tests/test_sandbox.py @@ -46,8 +46,8 @@ def test_virtuser(self): "virtuser": "johndoe" } } - - assert get_output(profile, 'whoami') == 'johndoe' + whoami_output = get_output(profile, 'whoami') + assert whoami_output == 'johndoe' assert 'johndoe' in get_output(profile, ['sh', '-c', 'echo "$HOME"']) def test_minimal_filesystem(self): diff --git a/tests/test_unprivileged.py b/tests/test_unprivileged.py index 7c890a8..d2dda9b 100644 --- a/tests/test_unprivileged.py +++ b/tests/test_unprivileged.py @@ -10,6 +10,7 @@ import pallium.sysutil from pallium import sysutil, onexit from pallium.profiles import Profile +import pallium.config as config class PalliumTestCase(unittest.TestCase): @@ -23,13 +24,13 @@ def do_GET(self): self.send_response(204) self.end_headers() - profile = Profile.from_config({ + conf = config.Configuration.from_json({ 'sandbox': { 'gui': True, 'virtuser': '$tmp' - }, - 'chain': [] + } }) + profile = Profile(conf) session = profile.run() def run(): @@ -54,6 +55,7 @@ def run(): home = tempfile.mkdtemp(prefix='pallium_cli_test_gui_') os.chown(home, UID, UID)""" os.environ['HOME'] = '/tmp' + # p = subprocess.Popen(['id']) p = subprocess.Popen(['firefox', url]) onexit.register(lambda: os.kill(p.pid, signal.SIGTERM)) rlist, _, _ = select.select([read], [], [], 30)