From dda08bbbd17887a3bb61c035ad23d921f8eb0afa Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Sun, 8 Sep 2024 01:49:33 +0000 Subject: [PATCH] Catch rare connection timeout in server recv() --- portal/__init__.py | 2 +- portal/server_socket.py | 13 +++++++++---- tests/test_batching.py | 3 ++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/portal/__init__.py b/portal/__init__.py index 30c847a..54a08d9 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,4 +1,4 @@ -__version__ = '3.1.9' +__version__ = '3.1.10' import multiprocessing as mp try: diff --git a/portal/server_socket.py b/portal/server_socket.py index b2e5d72..ceb45dc 100644 --- a/portal/server_socket.py +++ b/portal/server_socket.py @@ -141,8 +141,11 @@ def _recv(self, conn): conn.recvbuf = buffers.RecvBuffer(maxsize=self.options.max_msg_size) try: conn.recvbuf.recv(conn.sock) - except ConnectionResetError: - self._disconnect(conn) + except OSError as e: + # For example: + # - ConnectionResetError + # - TimeoutError: [Errno 110] Connection timed out + self._disconnect(conn, e) return if conn.recvbuf.done(): if self.recvq.qsize() > self.options.max_recv_queue: @@ -150,8 +153,10 @@ def _recv(self, conn): self.recvq.put((conn.addr, conn.recvbuf.result())) conn.recvbuf = None - def _disconnect(self, conn): - self._log(f'Closed connection to {conn.addr[0]}:{conn.addr[1]}') + def _disconnect(self, conn, e): + detail = f'{type(e).__name__}' + detail = f'{detail}: {e}' if str(e) else detail + self._log(f'Closed connection to {conn.addr[0]}:{conn.addr[1]} ({detail})') conn = self.conns.pop(conn.addr) if conn.sendbufs: count = len(conn.sendbufs) diff --git a/tests/test_batching.py b/tests/test_batching.py index 3b27f43..0751c08 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -110,7 +110,8 @@ def test_tree(self, data): client.close() server.close() - def test_shape_mismatch(self): + @pytest.mark.parametrize('repeat', range(5)) + def test_shape_mismatch(self, repeat): port = portal.free_port() server = portal.BatchServer(port, errors=False) server.bind('fn', lambda x: x, batch=2)