diff --git a/portal/__init__.py b/portal/__init__.py index 0453662..8755129 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,4 +1,4 @@ -__version__ = '3.1.6' +__version__ = '3.1.7' import multiprocessing as mp try: diff --git a/portal/client.py b/portal/client.py index 64408ff..646d55d 100644 --- a/portal/client.py +++ b/portal/client.py @@ -81,10 +81,18 @@ def call(self, method, *data): strlen = len(name).to_bytes(8, 'little', signed=False) sendargs = (reqnum, strlen, name, *packlib.pack(data)) # self.socket.send(reqnum, strlen, name, *packlib.pack(data)) - self.socket.send(*sendargs) - future = Future() + rai = [False] + future = Future(rai) future.sendargs = sendargs self.futures[reqnum] = future + # Store future before sending request because the response may come fast + # and the response handler runs in the socket's background thread. + try: + self.socket.send(*sendargs) + except client_socket.Disconnected: + future = self.futures.pop(reqnum) + future.rai[0] = True + raise return future def close(self, timeout=None): @@ -113,10 +121,10 @@ def _recv(self, data): def _disc(self): if self.socket.options.autoconn: - for future in self.futures.values(): + for future in list(self.futures.values()): future.resend = True else: - for future in self.futures.values(): + for future in list(self.futures.values()): self._seterr(future, client_socket.Disconnected) self.futures.clear() @@ -127,17 +135,17 @@ def _conn(self): self.socket.send(*future.sendargs) def _seterr(self, future, e): - raised = [False] - future.raised = raised future.set_error(e) + rai = future.rai weakref.finalize(future, lambda: ( - None if raised[0] else self.errors.append(e))) + None if rai[0] else self.errors.append(e))) class Future: - def __init__(self): - self.raised = [False] + def __init__(self, rai): + assert rai == [False] + self.rai = rai self.con = threading.Condition() self.don = False self.res = None @@ -147,7 +155,7 @@ def __repr__(self): if not self.done: return 'Future(done=False)' elif self.err: - return f"Future(done=True, error='{self.err}', raised={self.raised[0]})" + return f"Future(done=True, error='{self.err}', raised={self.rai[0]})" else: return 'Future(done=True)' @@ -165,8 +173,8 @@ def result(self, timeout=None): assert self.don if self.err is None: return self.res - if not self.raised[0]: - self.raised[0] = True + if not self.rai[0]: + self.rai[0] = True raise self.err def set_result(self, result): diff --git a/portal/client_socket.py b/portal/client_socket.py index 3c94827..047f05f 100644 --- a/portal/client_socket.py +++ b/portal/client_socket.py @@ -110,6 +110,7 @@ def require_connection(self, timeout): def _loop(self): recvbuf = buffers.RecvBuffer(maxsize=self.options.max_msg_size) sock = None + poll = select.poll() isconn = False # Local mirror of self.isconn without the lock. while self.running or (self.sendq and isconn): @@ -120,6 +121,7 @@ def _loop(self): sock = self._connect() if not sock: break + poll.register(sock, select.POLLIN | select.POLLOUT) self.isconn.set() isconn = True if not self.options.autoconn: @@ -128,9 +130,15 @@ def _loop(self): try: - readable, writable, _ = select.select([sock], [sock], [], 0.2) + # TODO: According to the py-spy profiler, the GIL is held during + # polling. Is there a way to avoid that? + pairs = poll.poll(0.2) + if not pairs: + continue + _, mask = pairs[0] + - if readable: + if mask & select.POLLIN: try: recvbuf.recv(sock) if recvbuf.done(): @@ -143,7 +151,7 @@ def _loop(self): except BlockingIOError: pass - if self.sendq and writable: + if self.sendq and mask & select.POLLOUT: try: self.sendq[0].send(sock) if self.sendq[0].done(): @@ -157,6 +165,7 @@ def _loop(self): self._log(f'Connection to server lost ({detail})') self.isconn.clear() isconn = False + poll.unregister(sock) sock.close() # Clear message queue on disconnect. There is no meaningful concept of # sucessful delivery of a message at this level. For example, the diff --git a/portal/contextlib.py b/portal/contextlib.py index 1aa7028..508a5da 100644 --- a/portal/contextlib.py +++ b/portal/contextlib.py @@ -2,6 +2,7 @@ import multiprocessing as mp import os import pathlib +import sys import threading import traceback @@ -114,10 +115,10 @@ def error(self, e, name=None): with self.printlock: style = utils.style(color='red') reset = utils.style(reset=True) - print(style + '\n---\n' + message + reset) + print(style + '\n---\n' + message + reset, file=sys.stderr) if self.errfile: self.errfile.write_text(message) - print(f'Wrote errorfile: {self.errfile}') + print(f'Wrote errorfile: {self.errfile}', file=sys.stderr) def shutdown(self, exitcode): # This kills the process tree forcefully to prevent hangs but results in @@ -145,7 +146,8 @@ def get_children(self, ident=None): def _watcher(self): while True: if self.errfile and self.errfile.exists(): - print(f'Shutting down due to error file: {self.errfile}') + message = f'Shutting down due to error file: {self.errfile}', + print(message, file=sys.stderr) self.shutdown(2) break if self.done.wait(self.interval): diff --git a/portal/server_socket.py b/portal/server_socket.py index d4f7734..b2e5d72 100644 --- a/portal/server_socket.py +++ b/portal/server_socket.py @@ -101,6 +101,8 @@ def _loop(self): try: while self.running or self._numsending(): writeable = [] + # TODO: According to the py-spy profiler, the GIL is held during + # polling. Is there a way to avoid that? for key, mask in self.sel.select(timeout=0.2): if key.data is None and self.reading: assert mask & selectors.EVENT_READ diff --git a/tests/test_client.py b/tests/test_client.py index 6eee12f..8a306c4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,8 +48,22 @@ def test_manual_reconnect(self, repeat): client.connect() assert client.fn(1).result() == 1 server.close() - with pytest.raises(portal.Disconnected): - client.fn(2).result() + + assert len(client.futures) == 0 + assert len(client.errors) == 0 + try: + future = client.fn(2) + try: + future.result() + assert False + except portal.Disconnected: + assert True + except portal.Disconnected: + time.sleep(1) + future = None + assert len(client.futures) == 0 + assert len(client.errors) == 0 + server = portal.Server(port) server.bind('fn', lambda x: x) server.start(block=False)