Skip to content

Commit

Permalink
Fix race condition between storing futures and server response
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 5, 2024
1 parent 72c1945 commit a833bcb
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 21 deletions.
2 changes: 1 addition & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.1.6'
__version__ = '3.1.8'

import multiprocessing as mp
try:
Expand Down
32 changes: 20 additions & 12 deletions portal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,18 @@ def call(self, method, *data):
strlen = len(name).to_bytes(8, 'little', signed=False)
sendargs = (reqnum, strlen, name, *packlib.pack(data))
# self.socket.send(reqnum, strlen, name, *packlib.pack(data))
self.socket.send(*sendargs)
future = Future()
rai = [False]
future = Future(rai)
future.sendargs = sendargs
self.futures[reqnum] = future
# Store future before sending request because the response may come fast
# and the response handler runs in the socket's background thread.
try:
self.socket.send(*sendargs)
except client_socket.Disconnected:
future = self.futures.pop(reqnum)
future.rai[0] = True
raise
return future

def close(self, timeout=None):
Expand Down Expand Up @@ -113,10 +121,10 @@ def _recv(self, data):

def _disc(self):
if self.socket.options.autoconn:
for future in self.futures.values():
for future in list(self.futures.values()):
future.resend = True
else:
for future in self.futures.values():
for future in list(self.futures.values()):
self._seterr(future, client_socket.Disconnected)
self.futures.clear()

Expand All @@ -127,17 +135,17 @@ def _conn(self):
self.socket.send(*future.sendargs)

def _seterr(self, future, e):
raised = [False]
future.raised = raised
future.set_error(e)
rai = future.rai
weakref.finalize(future, lambda: (
None if raised[0] else self.errors.append(e)))
None if rai[0] else self.errors.append(e)))


class Future:

def __init__(self):
self.raised = [False]
def __init__(self, rai):
assert rai == [False]
self.rai = rai
self.con = threading.Condition()
self.don = False
self.res = None
Expand All @@ -147,7 +155,7 @@ def __repr__(self):
if not self.done:
return 'Future(done=False)'
elif self.err:
return f"Future(done=True, error='{self.err}', raised={self.raised[0]})"
return f"Future(done=True, error='{self.err}', raised={self.rai[0]})"
else:
return 'Future(done=True)'

Expand All @@ -165,8 +173,8 @@ def result(self, timeout=None):
assert self.don
if self.err is None:
return self.res
if not self.raised[0]:
self.raised[0] = True
if not self.rai[0]:
self.rai[0] = True
raise self.err

def set_result(self, result):
Expand Down
15 changes: 12 additions & 3 deletions portal/client_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def require_connection(self, timeout):
def _loop(self):
recvbuf = buffers.RecvBuffer(maxsize=self.options.max_msg_size)
sock = None
poll = select.poll()
isconn = False # Local mirror of self.isconn without the lock.

while self.running or (self.sendq and isconn):
Expand All @@ -120,6 +121,7 @@ def _loop(self):
sock = self._connect()
if not sock:
break
poll.register(sock, select.POLLIN | select.POLLOUT)
self.isconn.set()
isconn = True
if not self.options.autoconn:
Expand All @@ -128,9 +130,15 @@ def _loop(self):

try:

readable, writable, _ = select.select([sock], [sock], [], 0.2)
# TODO: According to the py-spy profiler, the GIL is held during
# polling. Is there a way to avoid that?
pairs = poll.poll(0.2)
if not pairs:
continue
_, mask = pairs[0]


if readable:
if mask & select.POLLIN:
try:
recvbuf.recv(sock)
if recvbuf.done():
Expand All @@ -143,7 +151,7 @@ def _loop(self):
except BlockingIOError:
pass

if self.sendq and writable:
if self.sendq and mask & select.POLLOUT:
try:
self.sendq[0].send(sock)
if self.sendq[0].done():
Expand All @@ -157,6 +165,7 @@ def _loop(self):
self._log(f'Connection to server lost ({detail})')
self.isconn.clear()
isconn = False
poll.unregister(sock)
sock.close()
# Clear message queue on disconnect. There is no meaningful concept of
# sucessful delivery of a message at this level. For example, the
Expand Down
8 changes: 5 additions & 3 deletions portal/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import multiprocessing as mp
import os
import pathlib
import sys
import threading
import traceback

Expand Down Expand Up @@ -114,10 +115,10 @@ def error(self, e, name=None):
with self.printlock:
style = utils.style(color='red')
reset = utils.style(reset=True)
print(style + '\n---\n' + message + reset)
print(style + '\n---\n' + message + reset, file=sys.stderr)
if self.errfile:
self.errfile.write_text(message)
print(f'Wrote errorfile: {self.errfile}')
print(f'Wrote errorfile: {self.errfile}', file=sys.stderr)

def shutdown(self, exitcode):
# This kills the process tree forcefully to prevent hangs but results in
Expand Down Expand Up @@ -145,7 +146,8 @@ def get_children(self, ident=None):
def _watcher(self):
while True:
if self.errfile and self.errfile.exists():
print(f'Shutting down due to error file: {self.errfile}')
message = f'Shutting down due to error file: {self.errfile}',
print(message, file=sys.stderr)
self.shutdown(2)
break
if self.done.wait(self.interval):
Expand Down
2 changes: 2 additions & 0 deletions portal/server_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def _loop(self):
try:
while self.running or self._numsending():
writeable = []
# TODO: According to the py-spy profiler, the GIL is held during
# polling. Is there a way to avoid that?
for key, mask in self.sel.select(timeout=0.2):
if key.data is None and self.reading:
assert mask & selectors.EVENT_READ
Expand Down
18 changes: 16 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,22 @@ def test_manual_reconnect(self, repeat):
client.connect()
assert client.fn(1).result() == 1
server.close()
with pytest.raises(portal.Disconnected):
client.fn(2).result()

assert len(client.futures) == 0
assert len(client.errors) == 0
try:
future = client.fn(2)
try:
future.result()
assert False
except portal.Disconnected:
assert True
except portal.Disconnected:
time.sleep(1)
future = None
assert len(client.futures) == 0
assert len(client.errors) == 0

server = portal.Server(port)
server.bind('fn', lambda x: x)
server.start(block=False)
Expand Down

0 comments on commit a833bcb

Please sign in to comment.