From 933d1b0c5fe6d181e064eabb86a285e94528dfd3 Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Wed, 18 Sep 2024 08:19:25 +0000 Subject: [PATCH] Fix rare server hang under postfn --- perf/server_latency.py | 2 +- perf/server_throughput.py | 2 +- perf/socket_latency.py | 2 +- perf/socket_proxy.py | 9 +++--- perf/socket_throughput.py | 2 +- portal/__init__.py | 2 +- portal/server.py | 66 ++++++++++++++++++++++++--------------- tests/test_server.py | 19 +++++++++++ 8 files changed, 68 insertions(+), 36 deletions(-) diff --git a/perf/server_latency.py b/perf/server_latency.py index 533bb5a..2b11d7e 100644 --- a/perf/server_latency.py +++ b/perf/server_latency.py @@ -18,7 +18,7 @@ def fn(x): def client(port): data = bytearray(size) - client = portal.Client('localhost', port) + client = portal.Client(port) futures = collections.deque() durations = collections.deque(maxlen=50) while True: diff --git a/perf/server_throughput.py b/perf/server_throughput.py index 2023ee7..95086f0 100644 --- a/perf/server_throughput.py +++ b/perf/server_throughput.py @@ -23,7 +23,7 @@ def fn(x): def client(port): data = bytearray(size) - client = portal.Client('localhost', port, maxinflight=prefetch + 1) + client = portal.Client(port, maxinflight=prefetch + 1) futures = collections.deque() for _ in range(prefetch): futures.append(client.call('foo', data)) diff --git a/perf/socket_latency.py b/perf/socket_latency.py index 8aa4647..b201d1e 100644 --- a/perf/socket_latency.py +++ b/perf/socket_latency.py @@ -18,7 +18,7 @@ def server(port): def client(port): data = [bytearray(size // parts) for _ in range(parts)] - client = portal.ClientSocket('localhost', port) + client = portal.ClientSocket(port) durations = collections.deque(maxlen=10) while True: start = time.perf_counter() diff --git a/perf/socket_proxy.py b/perf/socket_proxy.py index 6a54856..b7cd080 100644 --- a/perf/socket_proxy.py +++ b/perf/socket_proxy.py @@ -1,5 +1,4 @@ import collections -import queue import time import portal @@ -28,24 +27,24 @@ def server(port1): def proxy(port1, port2): server = portal.ServerSocket(port2) - client = portal.ClientSocket('localhost', port1) + client = portal.ClientSocket(port1) addrs = collections.deque() while True: try: addr, data = server.recv(timeout=0.0001) addrs.append(addr) client.send(data) - except queue.Empty: + except TimeoutError: pass try: data = client.recv(timeout=0.0001) server.send(addrs.popleft(), data) - except queue.Empty: + except TimeoutError: pass def client(port2): data = [bytearray(size // parts) for _ in range(parts)] - client = portal.ClientSocket('localhost', port2) + client = portal.ClientSocket(port2) for _ in range(prefetch): client.send(*data) while True: diff --git a/perf/socket_throughput.py b/perf/socket_throughput.py index b80a38e..ead5758 100644 --- a/perf/socket_throughput.py +++ b/perf/socket_throughput.py @@ -24,7 +24,7 @@ def server(port): def client(port): data = [bytearray(size // parts) for _ in range(parts)] - client = portal.ClientSocket('localhost', port) + client = portal.ClientSocket(port) for _ in range(prefetch): client.send(*data) durations = collections.deque(maxlen=50) diff --git a/portal/__init__.py b/portal/__init__.py index 3734bea..b820cba 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,4 +1,4 @@ -__version__ = '3.4.1' +__version__ = '3.4.2' import multiprocessing as mp try: diff --git a/portal/server.py b/portal/server.py index 45a977c..82400ca 100644 --- a/portal/server.py +++ b/portal/server.py @@ -1,7 +1,7 @@ import collections import concurrent.futures -import threading import time +import types from . import packlib from . import poollib @@ -34,11 +34,11 @@ def bind(self, name, workfn, postfn=None, workers=0): self.pools.append(pool) else: pool = self.pool - active = threading.Semaphore((workers or self.workers) + 1) - def workfn2(*args): - active.acquire() - return workfn(*args) - self.methods[name] = (workfn2, postfn, pool, active) + requests = collections.deque() + available = (workers or self.workers) + 1 + self.methods[name] = types.SimpleNamespace( + workfn=workfn, postfn=postfn, pool=pool, + requests=requests, available=available) def start(self, block=True): assert not self.running @@ -67,9 +67,10 @@ def stats(self): 'numrecv': mets['recv'], 'sendrate': mets['send'] / dur, 'recvrate': mets['recv'] / dur, + 'requests': sum(len(m.requests) for m in self.methods.values()), 'jobs': len(self.jobs), } - if any(postfn for _, postfn, _, _ in self.methods.values()): + if any(method.postfn for method in self.methods.values()): stats.update({ 'post_iqueue': len(self.postfn_inp), 'post_oqueue': len(self.postfn_out), @@ -84,7 +85,9 @@ def __exit__(self, *e): self.close() def _loop(self): - while self.running or self.jobs or self.postfn_inp or self.postfn_out: + methods = list(self.methods.values()) + pending = 0 + while self.running or pending: while True: # Loop syntax used to break on error. if not self.running: # Do not accept further requests. break @@ -107,24 +110,30 @@ def _loop(self): self._error(addr, reqnum, 3, f'Unknown method {name}') break self.metrics['recv'] += 1 - workfn, postfn, pool, active = self.methods[name] - job = pool.submit(workfn, *data) - job.active = active - job.postfn = postfn - job.addr = addr - job.reqnum = reqnum - self.jobs.add(job) - if postfn: - self.postfn_inp.append(job) + method = self.methods[name] + method.requests.append((addr, reqnum, data)) + pending += 1 break # We do not actually want to loop. + + for method in methods: + if method.requests and method.available: + method.available -= 1 + addr, reqnum, data = method.requests.popleft() + job = method.pool.submit(method.workfn, *data) + job.method = method + job.addr = addr + job.reqnum = reqnum + self.jobs.add(job) + if method.postfn: + self.postfn_inp.append(job) + completed, self.jobs = concurrent.futures.wait( self.jobs, 0.0001, concurrent.futures.FIRST_COMPLETED) for job in completed: try: data = job.result() - if job.postfn: - data, info = data - del info + if job.method.postfn: + data, _ = data data = packlib.pack(data) status = int(0).to_bytes(8, 'little', signed=False) self.socket.send(job.addr, job.reqnum, status, *data) @@ -132,19 +141,24 @@ def _loop(self): except Exception as e: self._error(job.addr, job.reqnum, 4, f'Error in server method: {e}') finally: - if not job.postfn: - job.active.release() + if not job.method.postfn: + job.method.available += 1 + pending -= 1 + if completed: + # Call postfns in the order the requests were received. while self.postfn_inp and self.postfn_inp[0].done(): job = self.postfn_inp.popleft() - data, info = job.result() - postjob = self.postfn_pool.submit(job.postfn, info) - postjob.active = job.active + _, info = job.result() + postjob = self.postfn_pool.submit(job.method.postfn, info) + postjob.method = job.method self.postfn_out.append(postjob) + while self.postfn_out and self.postfn_out[0].done(): postjob = self.postfn_out.popleft() - postjob.active.release() postjob.result() # Check if there was an error. + postjob.method.available += 1 + pending -= 1 def _error(self, addr, reqnum, status, message): status = status.to_bytes(8, 'little', signed=False) diff --git a/tests/test_server.py b/tests/test_server.py index 4b83b82..c83b385 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -165,6 +165,25 @@ def postfn(x): assert completed != list(range(10)) assert logged == list(range(10)) + @pytest.mark.parametrize('repeat', range(10)) + @pytest.mark.parametrize('Server', SERVERS) + def test_postfn_no_hang(self, repeat, Server): + def wrapper(): + port = portal.free_port() + def workfn(x): + return x, x + def postfn(x): + time.sleep(0.01) + server = Server(port, workers=4) + server.bind('fn', workfn, postfn) + server.start(block=False) + client = portal.Client(port) + futures = [client.fn(i) for i in range(20)] + [future.result() for future in futures] # Used to hang here. + client.close() + server.close() + assert portal.Thread(wrapper, start=True).join(timeout=10).exitcode == 0 + @pytest.mark.parametrize('repeat', range(5)) @pytest.mark.parametrize('Server', SERVERS) @pytest.mark.parametrize('workers', (1, 4))