Skip to content

Commit

Permalink
Flush error response before raise server error
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 12, 2024
1 parent 9bebe67 commit 242ebaa
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
2 changes: 1 addition & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.2.0'
__version__ = '3.2.1'

import multiprocessing as mp
try:
Expand Down
3 changes: 3 additions & 0 deletions portal/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,7 @@ def _error(self, addr, reqnum, status, message):
data = message.encode('utf-8')
self.socket.send(addr, reqnum, status, data)
if self.errors:
# Wait until the error is delivered to the client and then raise.
self.socket.shutdown()
self.socket.close()
raise RuntimeError(message)
1 change: 0 additions & 1 deletion portal/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class Thread:
"""

def __init__(self, fn, *args, name=None, start=False):
global TIDS
self.fn = fn
self.excode = None
name = name or getattr(fn, '__name__', 'thread')
Expand Down
25 changes: 23 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def test_unknown_method(self, Server):
def test_server_errors(self, Server):
port = portal.free_port()

server = Server(port, errors=False)
def fn(x):
if x == 2:
raise ValueError(x)
return x
server.bind('fn', fn)
server.start(block=False)

client = portal.Client('localhost', port)
assert client.fn(1).result() == 1
with pytest.raises(RuntimeError):
client.fn(2).result()
assert client.fn(3).result() == 3

client.close()
server.close()

@pytest.mark.parametrize('Server', SERVERS)
def test_server_errors_raise(self, Server):
port = portal.free_port()

def server(port):
server = Server(port, errors=True)
def fn(x):
Expand All @@ -108,8 +129,8 @@ def fn(x):
client = portal.Client('localhost', port)
assert client.fn(1).result() == 1
assert server.running
with pytest.raises(RuntimeError):
client.fn(2).result()
with pytest.raises((RuntimeError, TimeoutError)):
client.fn(2).result(timeout=3)

client.close()
server.join()
Expand Down

0 comments on commit 242ebaa

Please sign in to comment.