Skip to content

Commit

Permalink
Tree support and fever memory copies.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed May 11, 2024
1 parent c5b26f0 commit 7685ecc
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 43 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cloudpickle
elements
elements>=3.3.0
msgpack
numpy
psutil
pyzmq
80 changes: 74 additions & 6 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

sys.path.append(str(pathlib.Path(__file__).parent.parent))

import elements
import numpy as np
import pytest
import zerofun
Expand Down Expand Up @@ -368,7 +369,8 @@ def function(data):
with server:
client = zerofun.Client(addr, pings=0, maxage=1)
client.connect(retry=False, timeout=1)
futures = [client.function({'foo': [i]}) for i in range(batch)]
futures = [
client.function({'foo': np.asarray([i])}) for i in range(batch)]
results = [future.result()['foo'][0] for future in futures]
assert calls[0] == 1
assert results == list(range(batch))
Expand Down Expand Up @@ -446,14 +448,80 @@ def test_proxy_batched(self, Server, inner_addr, outer_addr, workers):

@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('addr', ADDRESSES)
def test_empty_dict(self, Server, addr):
@pytest.mark.parametrize('data', (
{'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)},
{'a': 12, 'b': [np.ones((1,), np.uint8), 13]},
{'a': 12, 'b': ['c', [1, 2, 3]]},
[],
{},
12,
[[{}, []]],
))
def test_tree_data(self, Server, addr, data):
data = elements.tree.map(np.asarray, data)
print(data)
def tree_equal(tree1, tree2):
try:
comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2)
comps, _ = elements.tree.flatten(comps)
return all(comps)
except TypeError:
return False
addr = addr.format(port=zerofun.get_free_port())
client = zerofun.Client(addr, pings=0, maxage=1)
server = Server(addr)
def workfn(data):
assert data == {}
return {}
def workfn(indata):
assert tree_equal(indata, data)
return indata
server.bind('function', workfn)
with server:
client.connect(retry=False, timeout=1)
outdata = client.function(data).result()
assert tree_equal(outdata, data)

@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('addr', ADDRESSES)
@pytest.mark.parametrize('data', (
{'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)},
{'a': 12, 'b': [np.ones((1,), np.uint8), 13]},
{'a': 12, 'b': ['c', [1, 2, 3]]},
[],
{},
12,
[[{}, []]],
))
def test_tree_data_batched(self, Server, addr, data):
data = elements.tree.map(np.asarray, data)
print(data)
def tree_equal(tree1, tree2):
try:
comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2)
comps, _ = elements.tree.flatten(comps)
return all(comps)
except TypeError:
return False
addr = addr.format(port=zerofun.get_free_port())
client = zerofun.Client(addr, pings=0, maxage=1)
server = Server(addr)
def workfn(indata):
return indata
server.bind('function', workfn, batch=4)
with server:
client.connect(retry=False, timeout=1)
futures = [client.function(data) for _ in range(4)]
for future in futures:
assert tree_equal(future.result(), data)

@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('addr', ADDRESSES)
def test_tree_none_result(self, Server, addr):
addr = addr.format(port=zerofun.get_free_port())
client = zerofun.Client(addr, pings=0, maxage=1)
server = Server(addr)
def workfn(indata):
pass # No return value
server.bind('function', workfn)
with server:
client.connect(retry=False, timeout=1)
assert client.function({}).result() == {}
result = client.function([]).result()
assert result == []
2 changes: 1 addition & 1 deletion zerofun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '1.2.0'
__version__ = '2.0.5'

import multiprocessing as mp
try:
Expand Down
2 changes: 0 additions & 2 deletions zerofun/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def call(self, method, data):
self.queue.popleft().result()
except IndexError:
pass
assert isinstance(data, dict)
data = {k: np.asarray(v) for k, v in data.items()}
data = sockets.pack(data)
rid = self.socket.send_call(method, data)
self.send_per_sec.step(1)
Expand Down
34 changes: 18 additions & 16 deletions zerofun/proc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ def __init__(
self.name = name
self.ipv6 = ipv6
self.server = server.Server(self.inner, name, ipv6, workers, errors)
self.batches = {}
self.batchsizes = {}
self.batcher = None

def bind(self, name, workfn, logfn=None, workers=0, batch=0):
self.batches[name] = batch
self.batchsizes[name] = batch
self.server.bind(name, workfn, logfn, workers, batch=0)

def start(self):
self.batcher = process.StoppableProcess(
self._batcher, self.address, self.inner,
self.batches, self.name, self.ipv6, name='batcher', start=True)
self.batchsizes, self.name, self.ipv6, name='batcher', start=True)
self.server.start()

def check(self):
Expand Down Expand Up @@ -59,13 +59,13 @@ def __exit__(self, type, value, traceback):
self.close()

@staticmethod
def _batcher(context, address, inner, batches, name, ipv6):
def _batcher(context, address, inner, batchsizes, name, ipv6):

socket = sockets.ServerSocket(address, ipv6)
inbound = sockets.ClientSocket(identity=0, pings=0, maxage=0)
inbound.connect(inner, timeout=120)
queues = collections.defaultdict(list)
buffers = collections.defaultdict(dict)
buffers = {}
pending = {}
elements.print(f'[{name}] Listening at {address}')

Expand All @@ -74,7 +74,7 @@ def _batcher(context, address, inner, batches, name, ipv6):
result = socket.receive()
if result:
addr, rid, name, payload = result
batch = batches.get(name, None)
batch = batchsizes.get(name, None)
if batch is not None:
if batch:
queue = queues[name]
Expand All @@ -83,12 +83,14 @@ def _batcher(context, address, inner, batches, name, ipv6):
addrs, rids, payloads = zip(*queue)
queue.clear()
datas = [sockets.unpack(x) for x in payloads]
idx = range(batch)
bufs = buffers[name]
for key, value in datas[0].items():
bufs[key] = np.stack(
[datas[i][key] for i in idx], out=bufs.get(key, None))
payload = sockets.pack(bufs)
if name not in buffers:
buffers[name] = buffer = elements.tree.map(
lambda *xs: np.stack(xs), *datas)
else:
buffers[name] = buffer = elements.tree.map(
lambda buf, *xs: np.stack(xs, out=buf),
buffers[name], *datas)
payload = sockets.pack(buffer)
rid = inbound.send_call(name, payload)
pending[rid] = (name, addrs, rids)
else:
Expand All @@ -102,12 +104,12 @@ def _batcher(context, address, inner, batches, name, ipv6):
if result:
inner_rid, payload = result
name, addr, rid = pending.pop(inner_rid)
if batches[name]:
if batchsizes[name]:
addrs, rids = addr, rid
result = sockets.unpack(payload)
results = [
{k: v[i] for k, v in result.items()}
for i in range(batches[name])]
elements.tree.map(lambda x: x[i], result)
for i in range(batchsizes[name])]
payloads = [sockets.pack(x) for x in results]
for addr, rid, payload in zip(addrs, rids, payloads):
socket.send_result(addr, rid, payload)
Expand All @@ -116,7 +118,7 @@ def _batcher(context, address, inner, batches, name, ipv6):
except sockets.RemoteError as e:
inner_rid, msg = e.args[:2]
name, addr, rid = pending.pop(inner_rid)
if batches[name]:
if batchsizes[name]:
addrs, rids = addr, rid
for addr, rid in zip(addrs, rids):
socket.send_error(addr, rid, msg)
Expand Down
10 changes: 5 additions & 5 deletions zerofun/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,20 @@ def _handle_dones(self):
def _work(self, method, addr, rid, payload, recvd):
if method.batched:
data = [sockets.unpack(x) for x in payload]
data = {
k: np.stack([data[i][k] for i in range(method.insize)])
for k, v in data[0].items()}
data = elements.tree.map(lambda *xs: np.stack(xs), *data)
else:
data = sockets.unpack(payload)
if method.donefn:
result, logs = method.workfn(data)
else:
result = method.workfn(data)
result = result or {}
if result is None:
result = []
logs = None
if method.batched:
results = [
{k: v[i] for k, v in result.items()} for i in range(method.insize)]
elements.tree.map(lambda x: x[i], result)
for i in range(method.insize)]
payload = [sockets.pack(x) for x in results]
else:
payload = sockets.pack(result)
Expand Down
28 changes: 17 additions & 11 deletions zerofun/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import threading
import time

import elements
import numpy as np
import zmq


DEBUG = False
# DEBUG = True

Expand Down Expand Up @@ -203,7 +205,8 @@ def send_ping(self, addr):

def send_result(self, addr, rid, payload):
with self.lock:
self.socket.send_multipart([addr, Type.RESULT.value, rid, *payload])
self.socket.send_multipart(
[addr, Type.RESULT.value, rid, *payload], copy=False, track=True)

def send_error(self, addr, rid, text):
text = text.encode('utf-8')
Expand All @@ -216,23 +219,26 @@ def close(self):


def pack(data):
data = {k: np.asarray(v) for k, v in data.items()}
leaves, structure = elements.tree.flatten(data)
dtypes, shapes, buffers = [], [], []
items = sorted(data.items(), key=lambda x: x[0])
keys, vals = zip(*items) if items else ((), ())
dtypes = [v.dtype.str for v in vals]
shapes = [v.shape for v in vals]
buffers = [v.tobytes() for v in vals]
meta = (keys, dtypes, shapes)
for value in leaves:
value = np.asarray(value)
assert value.data.c_contiguous, (
"Array is not contiguous in memory. Use np.asarray(arr, order='C') " +
"before passing the data into pack().")
dtypes.append(value.dtype.str)
shapes.append(value.shape)
buffers.append(value.data)
meta = (structure, dtypes, shapes)
payload = [msgpack.packb(meta), *buffers]
return payload


def unpack(payload):
meta, *buffers = payload
keys, dtypes, shapes = msgpack.unpackb(meta)
vals = [
structure, dtypes, shapes = msgpack.unpackb(meta)
leaves = [
np.frombuffer(b, d).reshape(s)
for i, (d, s, b) in enumerate(zip(dtypes, shapes, buffers))]
data = dict(zip(keys, vals))
data = elements.tree.unflatten(leaves, structure)
return data
1 change: 0 additions & 1 deletion zerofun/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self, fn, *args, name=None, start=False):
self._exitcode = None
self.exception = None
name = name or fn.__name__
self.old_name = name[:]
self.thread = threading.Thread(
target=self._wrapper, args=args, name=name, daemon=True)
self.started = False
Expand Down

0 comments on commit 7685ecc

Please sign in to comment.