Skip to content

Commit

Permalink
More flexible address specifications
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 17, 2024
1 parent 642dce6 commit f7a26d0
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 87 deletions.
2 changes: 1 addition & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.3.0'
__version__ = '3.4.0'

import multiprocessing as mp
try:
Expand Down
2 changes: 1 addition & 1 deletion portal/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def send_error(addr, reqnum, status, message):

try:
outer = server_socket.ServerSocket(outer_port, f'{name}Server', **kwargs)
inner = client.Client('localhost', inner_port, f'{name}Client', **kwargs)
inner = client.Client(inner_port, f'{name}Client', **kwargs)
batches = {} # {method: ([addr], [reqnum], structure, [array])}
jobs = []
shutdown = False
Expand Down
6 changes: 2 additions & 4 deletions portal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

class Client:

def __init__(
self, host, port=None, name='Client', maxinflight=16, **kwargs):
def __init__(self, addr, name='Client', maxinflight=16, **kwargs):
assert 1 <= maxinflight, maxinflight
self.maxinflight = maxinflight
self.reqnum = iter(itertools.count(0))
Expand All @@ -25,8 +24,7 @@ def __init__(
self.lock = threading.Lock()
# Socket is created after the above attributes because the callbacks access
# some of the attributes.
self.socket = client_socket.ClientSocket(
host, port, name, start=False, **kwargs)
self.socket = client_socket.ClientSocket(addr, name, start=False, **kwargs)
self.socket.callbacks_recv.append(self._recv)
self.socket.callbacks_disc.append(self._disc)
self.socket.callbacks_conn.append(self._conn)
Expand Down
46 changes: 25 additions & 21 deletions portal/client_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ class Options:
keepalive_fails: int = 10
logging: bool = True
logging_color: str = 'yellow'
connect_wait: float = 0.1


class ClientSocket:

def __init__(self, host, port=None, name='Client', start=True, **kwargs):
assert port or ':' in host, (host, port)
assert '://' not in host, (host, port)
if port is None:
host, port = host.rsplit(':', 1)
assert host and port, (host, port)
self.addr = (host, port)
def __init__(self, addr, name='Client', start=True, **kwargs):
addr = str(addr)
assert '://' not in addr, addr
host, port = addr.rsplit(':', 1) if ':' in addr else ('', addr)
self.name = name
self.options = Options(**{**contextlib.context.clientkw, **kwargs})
host = host or ('::1' if self.options.ipv6 else '127.0.0.1')
self.addr = (host, port)

self.callbacks_recv = []
self.callbacks_conn = []
Expand Down Expand Up @@ -190,6 +190,7 @@ def _connect(self):
port = int(port)
addr = (host, port, 0, 0) if self.options.ipv6 else (host, port)
sock = self._create()
start = time.time()
error = None
try:
sock.settimeout(10)
Expand All @@ -207,7 +208,7 @@ def _connect(self):
self._log(f'Still trying to connect... ({error})')
once = False
sock.close()
time.sleep(0.1)
time.sleep(self.options.connect_wait)
return None

def _create(self):
Expand All @@ -216,22 +217,25 @@ def _create(self):
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

after = self.options.keepalive_after
every = self.options.keepalive_every
fails = self.options.keepalive_fails
if sys.platform == 'linux':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, after)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, every)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, fails)
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT,
1000 * (after + every * fails))
if sys.platform == 'darwin':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, every)
if sys.platform == 'win32':
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, after * 1000, every * 1000))
if after and every and fails:
if sys.platform == 'linux':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, after)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, every)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, fails)
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT,
1000 * (after + every * fails))
if sys.platform == 'darwin':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, every)
if sys.platform == 'win32':
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, after * 1000, every * 1000))

return sock

def _log(self, *args):
Expand Down
10 changes: 5 additions & 5 deletions portal/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def start(self, block=True):
if block:
self.loop.join(timeout=None)

def close(self, timeout=None):
def close(self, timeout=None, internal=False):
assert self.running
self.socket.shutdown()
self.running = False
self.loop.join(timeout)
self.loop.kill()
if not internal:
self.loop.join(timeout)
self.loop.kill()
[x.close() for x in self.pools]
self.socket.close()

Expand Down Expand Up @@ -149,6 +150,5 @@ def _error(self, addr, reqnum, status, message):
self.socket.send(addr, reqnum, status, data)
if self.errors:
# Wait until the error is delivered to the client and then raise.
self.socket.shutdown()
self.socket.close()
self.close(internal=True)
raise RuntimeError(message)
3 changes: 2 additions & 1 deletion portal/server_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ServerSocket:

def __init__(self, port, name='Server', **kwargs):
if isinstance(port, str):
port = int(port.rsplit(':', 1)[1])
assert '://' not in port, port
port = int(port.rsplit(':', 1)[-1])
self.name = name
self.options = Options(**{**contextlib.context.serverkw, **kwargs})
if self.options.ipv6:
Expand Down
20 changes: 10 additions & 10 deletions tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fn(x):
return 2 * x
server.bind('fn', fn, batch=4)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
futures = [client.fn(x) for x in range(8)]
results = [x.result() for x in futures]
assert (results == 2 * np.arange(8)).all()
Expand All @@ -38,7 +38,7 @@ def fn(x):
return 2 * x
server.bind('fn', fn, batch=4)
server.start(block=False)
clients = [portal.Client('localhost', port) for _ in range(8)]
clients = [portal.Client(port) for _ in range(8)]
futures = [x.fn(i) for i, x in enumerate(clients)]
results = [x.result() for x in futures]
assert (results == 2 * np.arange(8)).all()
Expand All @@ -53,7 +53,7 @@ def fn(x):
return 2 * x
server.bind('fn', fn, workers=4, batch=4)
server.start(block=False)
clients = [portal.Client('localhost', port) for _ in range(32)]
clients = [portal.Client(port) for _ in range(32)]
futures = [x.fn(i) for i, x in enumerate(clients)]
results = [x.result() for x in futures]
assert (results == 2 * np.arange(32)).all()
Expand All @@ -70,14 +70,14 @@ def test_proxy(self, workers):
server.start(block=False)

kwargs = dict(name='ProxyClient', maxinflight=4)
proxy_client = portal.Client('localhost', inner_port, **kwargs)
proxy_client = portal.Client(inner_port, **kwargs)
proxy_server = portal.BatchServer(
outer_port, 'ProxyServer', workers=workers)
proxy_server.bind(
'fn2', lambda x: proxy_client.fn(x).result(), batch=2)
proxy_server.start(block=False)

client = portal.Client('localhost', outer_port, 'OuterClient')
client = portal.Client(outer_port, 'OuterClient')
futures = [client.fn2(x) for x in range(16)]
results = [future.result() for future in futures]
assert (results == 2 * np.arange(16)).all()
Expand All @@ -102,7 +102,7 @@ def test_tree(self, data):
server = portal.BatchServer(port)
server.bind('fn', lambda x: x, batch=4)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
futures = [client.fn(data) for _ in range(4)]
results = [x.result() for x in futures]
for result in results:
Expand All @@ -116,7 +116,7 @@ def test_shape_mismatch(self, repeat):
server = portal.BatchServer(port, errors=False)
server.bind('fn', lambda x: x, batch=2)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
future1 = client.fn({'a': np.array(12)})
future2 = client.fn(42)
with pytest.raises(RuntimeError):
Expand All @@ -133,12 +133,12 @@ def test_client_drops(self, repeat):
server = portal.BatchServer(port)
server.bind('fn', lambda x: 2 * x, batch=4)
server.start(block=False)
client = portal.Client('localhost', port, name='Client1', autoconn=False)
client = portal.Client(port, 'Client1', autoconn=False)
client.connect()
future1 = client.fn(1)
future2 = client.fn(2)
client.close()
client = portal.Client('localhost', port, name='Client2')
client = portal.Client(port, 'Client2')
future3 = client.fn(3)
future4 = client.fn(4)
with pytest.raises(portal.Disconnected):
Expand All @@ -156,7 +156,7 @@ def test_server_drops(self, repeat):
server = portal.BatchServer(port)
server.bind('fn', lambda x: 2 * x, batch=2)
server.start(block=False)
client = portal.Client('localhost', port, autoconn=False)
client = portal.Client(port, autoconn=False)
client.connect()
future1 = client.fn(1)
server.close()
Expand Down
48 changes: 33 additions & 15 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def fn():
pass
server.bind('fn', fn)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
assert client.fn().result() is None
client.close()
server.close()

def test_manual_connect(self):
port = portal.free_port()
client = portal.Client('localhost', port, autoconn=False)
client = portal.Client(port, autoconn=False)
assert not client.connected
server = portal.Server(port)
server.bind('fn', lambda x: x)
Expand All @@ -44,7 +44,7 @@ def test_manual_reconnect(self, repeat):
server = portal.Server(port)
server.bind('fn', lambda x: x)
server.start(block=False)
client = portal.Client('localhost', port, autoconn=False)
client = portal.Client(port, autoconn=False)
client.connect()
assert client.fn(1).result() == 1
server.close()
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_connect_before_server(self):
results = []

def client():
client = portal.Client('localhost', port)
client = portal.Client(port)
results.append(client.fn(12).result())
client.close()

Expand All @@ -95,7 +95,7 @@ def test_future_order(self):
server = portal.Server(port)
server.bind('fn', lambda x: x)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
future1 = client.fn(1)
future2 = client.fn(2)
future3 = client.fn(3)
Expand All @@ -113,7 +113,7 @@ def fn(x):
return x
server.bind('fn', fn)
server.start(block=False)
client = portal.Client('localhost', port)
client = portal.Client(port)
future = client.fn(42)
with pytest.raises(TimeoutError):
future.result(timeout=0)
Expand Down Expand Up @@ -142,7 +142,7 @@ def fn(data):
server.bind('fn', fn, workers=4)
server.start(block=False)

client = portal.Client('localhost', port, maxinflight=2)
client = portal.Client(port, maxinflight=2)
futures = [client.fn(i) for i in range(16)]
results = [x.result() for x in futures]
assert results == list(range(16))
Expand All @@ -155,7 +155,7 @@ def test_future_cleanup(self, repeat):
server = portal.Server(port)
server.bind('fn', lambda x: x)
server.start(block=False)
client = portal.Client('localhost', port, maxinflight=1)
client = portal.Client(port, maxinflight=1)
client.fn(1)
client.fn(2)
future3 = client.fn(3)
Expand All @@ -175,7 +175,7 @@ def fn(x):
return x
server.bind('fn', fn)
server.start(block=False)
client = portal.Client('localhost', port, maxinflight=1)
client = portal.Client(port, maxinflight=1)
client.fn(1)
client.fn(2)
time.sleep(0.2)
Expand All @@ -191,7 +191,7 @@ def test_client_threadsafe(self, repeat, users=16):
server = portal.Server(port)
server.bind('fn', lambda x: x, workers=4)
server.start(block=False)
client = portal.Client('localhost', port, maxinflight=8)
client = portal.Client(port, maxinflight=8)
barrier = threading.Barrier(users)

def user():
Expand Down Expand Up @@ -228,7 +228,7 @@ def fn(x):
server.close()

def client():
client = portal.Client('localhost', port, maxinflight=2)
client = portal.Client(port, maxinflight=2)
futures = [client.fn(x) for x in range(5)]
results = [x.result() for x in futures]
assert results == list(range(5))
Expand Down Expand Up @@ -265,8 +265,7 @@ def server():
server.close()

def client():
client = portal.Client(
'localhost', port, maxinflight=1, autoconn=True)
client = portal.Client(port, maxinflight=1, autoconn=True)
assert client.fn(1).result() == 1
a.wait()
b.wait()
Expand Down Expand Up @@ -304,8 +303,7 @@ def server():
server.close()

def client():
client = portal.Client(
'localhost', port, maxinflight=1, autoconn=False)
client = portal.Client(port, maxinflight=1, autoconn=False)
client.connect()
assert client.fn(1).result() == 1
a.wait()
Expand All @@ -322,6 +320,26 @@ def client():
portal.Thread(client),
])

@pytest.mark.parametrize('ipv6', (False, True))
@pytest.mark.parametrize('fmt,typ', (
('{port}', int),
('{port}', str),
(':{port}', str),
('localhost:{port}', str),
('{localhost}:{port}', str),
))
def test_address_formats(self, fmt, typ, ipv6):
port = portal.free_port()
server = portal.Server(port, ipv6=ipv6)
server.bind('fn', lambda x: x)
server.start(block=False)
localhost = '::1' if ipv6 else '127.0.0.1'
addr = typ(fmt.format(port=port, localhost=localhost))
client = portal.Client(addr, ipv6=ipv6)
assert client.fn(42).result() == 42
client.close()
server.close()

def test_resolver(self):
portnum = portal.free_port()

Expand Down
Loading

0 comments on commit f7a26d0

Please sign in to comment.