From 55652317e9c0a87778ecaaf0cda7bb0aa9c4c0ec Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Sun, 15 Sep 2024 00:03:05 +0000 Subject: [PATCH] Allow string ports for resolver --- portal/__init__.py | 2 +- portal/client_socket.py | 25 +++++++++++++------------ tests/test_client.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/portal/__init__.py b/portal/__init__.py index 0cce755..5820ef7 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,4 +1,4 @@ -__version__ = '3.2.2' +__version__ = '3.2.3' import multiprocessing as mp try: diff --git a/portal/client_socket.py b/portal/client_socket.py index 3f3ddf1..116b525 100644 --- a/portal/client_socket.py +++ b/portal/client_socket.py @@ -34,11 +34,11 @@ class Options: class ClientSocket: def __init__(self, host, port=None, name='Client', start=True, **kwargs): - assert (port or ':' in host) and '://' not in host, host + assert port or ':' in host, (host, port) + assert '://' not in host, (host, port) if port is None: host, port = host.rsplit(':', 1) - port = int(port) - assert host, host + assert host and port, (host, port) self.addr = (host, port) self.name = name self.options = Options(**{**contextlib.context.clientkw, **kwargs}) @@ -179,16 +179,19 @@ def _loop(self): sock.close() def _connect(self): - host, port = self.addr - self._log(f'Connecting to {host}:{port}') + self._log(f'Connecting to {self.addr[0]}:{self.addr[1]}') once = True while self.running: - sock, addr = self._create() + # We need to resolve the address regularly. + host, port = self.addr + if contextlib.context.resolver: + host, port = contextlib.context.resolver((host, port)) + assert isinstance(host, str), (host, port) + assert isinstance(port, int), (host, port) + addr = (host, port, 0, 0) if self.options.ipv6 else (host, port) + sock = self._create() error = None try: - # We need to resolve the address regularly. - if contextlib.context.resolver: - addr = contextlib.context.resolver(addr) sock.settimeout(10) sock.connect(addr) sock.settimeout(0) @@ -210,10 +213,8 @@ def _connect(self): def _create(self): if self.options.ipv6: sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - addr = (*self.addr, 0, 0) else: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - addr = self.addr # sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) after = self.options.keepalive_after every = self.options.keepalive_every @@ -231,7 +232,7 @@ def _create(self): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, every) if sys.platform == 'win32': sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, after * 1000, every * 1000)) - return sock, addr + return sock def _log(self, *args): if not self.options.logging: diff --git a/tests/test_client.py b/tests/test_client.py index e6cdf6f..764a6f7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -321,3 +321,20 @@ def client(): portal.Thread(server), portal.Thread(client), ]) + + def test_resolver(self): + portnum = portal.free_port() + + def client(portnum): + def resolver(host, portstr): + assert portstr == 'name' + return host, portnum + portal.setup(resolver=resolver) + client = portal.Client('localhost:name') + assert client.fn(42).result() == 42 + + server = portal.Server(portnum) + server.bind('fn', lambda x: x) + server.start(block=False) + portal.Process(client, portnum, start=True).join() + server.close()