From 53a14f2c6ccee39beab965e81a0d8ed806340b02 Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Fri, 10 May 2024 22:29:31 +0000 Subject: [PATCH] Tree support and fever memory copies. --- requirements.txt | 3 ++- tests/test_server.py | 44 +++++++++++++++++++++++++++++++++++++++----- zerofun/__init__.py | 2 +- zerofun/client.py | 2 -- zerofun/server.py | 3 ++- zerofun/sockets.py | 28 +++++++++++++++++----------- zerofun/thread.py | 1 - 7 files changed, 61 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index e0c832e..3d3f8c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ cloudpickle -elements +elements>=3.3.0 msgpack numpy +psutil pyzmq diff --git a/tests/test_server.py b/tests/test_server.py index 008cc35..291c7aa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ sys.path.append(str(pathlib.Path(__file__).parent.parent)) +import elements import numpy as np import pytest import zerofun @@ -446,14 +447,47 @@ def test_proxy_batched(self, Server, inner_addr, outer_addr, workers): @pytest.mark.parametrize('Server', SERVERS) @pytest.mark.parametrize('addr', ADDRESSES) - def test_empty_dict(self, Server, addr): + @pytest.mark.parametrize('data', ( + {'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)}, + {'a': 12, 'b': [np.ones((1,), np.uint8), 13]}, + {'a': 12, 'b': ['c', [1, 2, 3]]}, + [], + {}, + 12, + [[{}, []]], + )) + def test_tree_data(self, Server, addr, data): + data = elements.tree.map(np.asarray, data) + print(data) + def tree_equal(tree1, tree2): + try: + comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2) + comps, _ = elements.tree.flatten(comps) + return all(comps) + except TypeError: + return False addr = addr.format(port=zerofun.get_free_port()) client = zerofun.Client(addr, pings=0, maxage=1) server = Server(addr) - def workfn(data): - assert data == {} - return {} + def workfn(indata): + assert tree_equal(indata, data) + return indata + server.bind('function', workfn) + with server: + client.connect(retry=False, timeout=1) + outdata = client.function(data).result() + assert tree_equal(outdata, data) + + @pytest.mark.parametrize('Server', SERVERS) + @pytest.mark.parametrize('addr', ADDRESSES) + def test_tree_none_result(self, Server, addr): + addr = addr.format(port=zerofun.get_free_port()) + client = zerofun.Client(addr, pings=0, maxage=1) + server = Server(addr) + def workfn(indata): + pass # No return value server.bind('function', workfn) with server: client.connect(retry=False, timeout=1) - assert client.function({}).result() == {} + result = client.function([]).result() + assert result == [] diff --git a/zerofun/__init__.py b/zerofun/__init__.py index 3092d20..185f47a 100644 --- a/zerofun/__init__.py +++ b/zerofun/__init__.py @@ -1,4 +1,4 @@ -__version__ = '1.2.0' +__version__ = '2.0.3' import multiprocessing as mp try: diff --git a/zerofun/client.py b/zerofun/client.py index 21d42f6..0511f9e 100644 --- a/zerofun/client.py +++ b/zerofun/client.py @@ -86,8 +86,6 @@ def call(self, method, data): self.queue.popleft().result() except IndexError: pass - assert isinstance(data, dict) - data = {k: np.asarray(v) for k, v in data.items()} data = sockets.pack(data) rid = self.socket.send_call(method, data) self.send_per_sec.step(1) diff --git a/zerofun/server.py b/zerofun/server.py index fb05d26..95f6eab 100644 --- a/zerofun/server.py +++ b/zerofun/server.py @@ -190,7 +190,8 @@ def _work(self, method, addr, rid, payload, recvd): result, logs = method.workfn(data) else: result = method.workfn(data) - result = result or {} + if result is None: + result = [] logs = None if method.batched: results = [ diff --git a/zerofun/sockets.py b/zerofun/sockets.py index b37f7b4..e373af3 100644 --- a/zerofun/sockets.py +++ b/zerofun/sockets.py @@ -4,9 +4,11 @@ import threading import time +import elements import numpy as np import zmq + DEBUG = False # DEBUG = True @@ -203,7 +205,8 @@ def send_ping(self, addr): def send_result(self, addr, rid, payload): with self.lock: - self.socket.send_multipart([addr, Type.RESULT.value, rid, *payload]) + self.socket.send_multipart( + [addr, Type.RESULT.value, rid, *payload], copy=False, track=True) def send_error(self, addr, rid, text): text = text.encode('utf-8') @@ -216,23 +219,26 @@ def close(self): def pack(data): - data = {k: np.asarray(v) for k, v in data.items()} + leaves, structure = elements.tree.flatten(data) dtypes, shapes, buffers = [], [], [] - items = sorted(data.items(), key=lambda x: x[0]) - keys, vals = zip(*items) if items else ((), ()) - dtypes = [v.dtype.str for v in vals] - shapes = [v.shape for v in vals] - buffers = [v.tobytes() for v in vals] - meta = (keys, dtypes, shapes) + for value in leaves: + value = np.asarray(value) + assert value.data.c_contiguous, ( + "Array is not contiguous in memory. Use np.asarray(arr, order='C') " + + "before passing the data into pack().") + dtypes.append(value.dtype.str) + shapes.append(value.shape) + buffers.append(value.data) + meta = (structure, dtypes, shapes) payload = [msgpack.packb(meta), *buffers] return payload def unpack(payload): meta, *buffers = payload - keys, dtypes, shapes = msgpack.unpackb(meta) - vals = [ + structure, dtypes, shapes = msgpack.unpackb(meta) + leaves = [ np.frombuffer(b, d).reshape(s) for i, (d, s, b) in enumerate(zip(dtypes, shapes, buffers))] - data = dict(zip(keys, vals)) + data = elements.tree.unflatten(leaves, structure) return data diff --git a/zerofun/thread.py b/zerofun/thread.py index 94007bb..b10f8da 100644 --- a/zerofun/thread.py +++ b/zerofun/thread.py @@ -11,7 +11,6 @@ def __init__(self, fn, *args, name=None, start=False): self._exitcode = None self.exception = None name = name or fn.__name__ - self.old_name = name[:] self.thread = threading.Thread( target=self._wrapper, args=args, name=name, daemon=True) self.started = False