Skip to content

Commit

Permalink
Fix rare server hang under postfn
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 18, 2024
1 parent 4af847b commit f23e1d2
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 35 deletions.
2 changes: 1 addition & 1 deletion perf/server_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def fn(x):

def client(port):
data = bytearray(size)
client = portal.Client('localhost', port)
client = portal.Client(port)
futures = collections.deque()
durations = collections.deque(maxlen=50)
while True:
Expand Down
2 changes: 1 addition & 1 deletion perf/server_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fn(x):

def client(port):
data = bytearray(size)
client = portal.Client('localhost', port, maxinflight=prefetch + 1)
client = portal.Client(port, maxinflight=prefetch + 1)
futures = collections.deque()
for _ in range(prefetch):
futures.append(client.call('foo', data))
Expand Down
2 changes: 1 addition & 1 deletion perf/socket_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def server(port):

def client(port):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port)
client = portal.ClientSocket(port)
durations = collections.deque(maxlen=10)
while True:
start = time.perf_counter()
Expand Down
9 changes: 4 additions & 5 deletions perf/socket_proxy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import queue
import time

import portal
Expand Down Expand Up @@ -28,24 +27,24 @@ def server(port1):

def proxy(port1, port2):
server = portal.ServerSocket(port2)
client = portal.ClientSocket('localhost', port1)
client = portal.ClientSocket(port1)
addrs = collections.deque()
while True:
try:
addr, data = server.recv(timeout=0.0001)
addrs.append(addr)
client.send(data)
except queue.Empty:
except TimeoutError:
pass
try:
data = client.recv(timeout=0.0001)
server.send(addrs.popleft(), data)
except queue.Empty:
except TimeoutError:
pass

def client(port2):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port2)
client = portal.ClientSocket(port2)
for _ in range(prefetch):
client.send(*data)
while True:
Expand Down
2 changes: 1 addition & 1 deletion perf/socket_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def server(port):

def client(port):
data = [bytearray(size // parts) for _ in range(parts)]
client = portal.ClientSocket('localhost', port)
client = portal.ClientSocket(port)
for _ in range(prefetch):
client.send(*data)
durations = collections.deque(maxlen=50)
Expand Down
2 changes: 1 addition & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.4.1'
__version__ = '3.4.2'

import multiprocessing as mp
try:
Expand Down
60 changes: 35 additions & 25 deletions portal/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import concurrent.futures
import threading
import time
import types

from . import packlib
from . import poollib
Expand Down Expand Up @@ -34,11 +34,11 @@ def bind(self, name, workfn, postfn=None, workers=0):
self.pools.append(pool)
else:
pool = self.pool
active = threading.Semaphore((workers or self.workers) + 1)
def workfn2(*args):
active.acquire()
return workfn(*args)
self.methods[name] = (workfn2, postfn, pool, active)
requests = collections.deque()
available = (workers or self.workers) + 1
self.methods[name] = types.SimpleNamespace(
workfn=workfn, postfn=postfn, pool=pool,
requests=requests, available=available)

def start(self, block=True):
assert not self.running
Expand Down Expand Up @@ -67,9 +67,10 @@ def stats(self):
'numrecv': mets['recv'],
'sendrate': mets['send'] / dur,
'recvrate': mets['recv'] / dur,
'requests': sum(len(m.requests) for m in self.methods.values()),
'jobs': len(self.jobs),
}
if any(postfn for _, postfn, _, _ in self.methods.values()):
if any(method.postfn for method in self.methods.values()):
stats.update({
'post_iqueue': len(self.postfn_inp),
'post_oqueue': len(self.postfn_out),
Expand All @@ -84,6 +85,7 @@ def __exit__(self, *e):
self.close()

def _loop(self):
methods = list(self.methods.values())
while self.running or self.jobs or self.postfn_inp or self.postfn_out:
while True: # Loop syntax used to break on error.
if not self.running: # Do not accept further requests.
Expand All @@ -107,44 +109,52 @@ def _loop(self):
self._error(addr, reqnum, 3, f'Unknown method {name}')
break
self.metrics['recv'] += 1
workfn, postfn, pool, active = self.methods[name]
job = pool.submit(workfn, *data)
job.active = active
job.postfn = postfn
job.addr = addr
job.reqnum = reqnum
self.jobs.add(job)
if postfn:
self.postfn_inp.append(job)
method = self.methods[name]
method.requests.append((addr, reqnum, data))
break # We do not actually want to loop.

for method in methods:
if method.requests and method.available:
method.available -= 1
addr, reqnum, data = method.requests.popleft()
job = method.pool.submit(method.workfn, *data)
job.method = method
job.addr = addr
job.reqnum = reqnum
self.jobs.add(job)
if method.postfn:
self.postfn_inp.append(job)

completed, self.jobs = concurrent.futures.wait(
self.jobs, 0.0001, concurrent.futures.FIRST_COMPLETED)
for job in completed:
try:
data = job.result()
if job.postfn:
data, info = data
del info
if job.method.postfn:
data, _ = data
data = packlib.pack(data)
status = int(0).to_bytes(8, 'little', signed=False)
self.socket.send(job.addr, job.reqnum, status, *data)
self.metrics['send'] += 1
except Exception as e:
self._error(job.addr, job.reqnum, 4, f'Error in server method: {e}')
finally:
if not job.postfn:
job.active.release()
if not job.method.postfn:
job.method.available += 1

if completed:
# Call postfns in the order the requests were received.
while self.postfn_inp and self.postfn_inp[0].done():
job = self.postfn_inp.popleft()
data, info = job.result()
postjob = self.postfn_pool.submit(job.postfn, info)
postjob.active = job.active
_, info = job.result()
postjob = self.postfn_pool.submit(job.method.postfn, info)
postjob.method = job.method
self.postfn_out.append(postjob)

while self.postfn_out and self.postfn_out[0].done():
postjob = self.postfn_out.popleft()
postjob.active.release()
postjob.result() # Check if there was an error.
postjob.method.available += 1

def _error(self, addr, reqnum, status, message):
status = status.to_bytes(8, 'little', signed=False)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,25 @@ def postfn(x):
assert completed != list(range(10))
assert logged == list(range(10))

@pytest.mark.parametrize('repeat', range(10))
@pytest.mark.parametrize('Server', SERVERS)
def test_postfn_no_hang(self, repeat, Server):
def wrapper():
port = portal.free_port()
def workfn(x):
return x, x
def postfn(x):
time.sleep(0.01)
server = Server(port, workers=4)
server.bind('fn', workfn, postfn)
server.start(block=False)
client = portal.Client(port)
futures = [client.fn(i) for i in range(20)]
[future.result() for future in futures] # Used to hang here.
client.close()
server.close()
assert portal.Thread(wrapper, start=True).join(timeout=10).exitcode == 0

@pytest.mark.parametrize('repeat', range(5))
@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('workers', (1, 4))
Expand Down

0 comments on commit f23e1d2

Please sign in to comment.