From 7685ecc90bea53eebde2fefed953a4c35f8ac32e Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Fri, 10 May 2024 22:29:31 +0000 Subject: [PATCH] Tree support and fever memory copies. --- requirements.txt | 3 +- tests/test_server.py | 80 ++++++++++++++++++++++++++++++++++++++---- zerofun/__init__.py | 2 +- zerofun/client.py | 2 -- zerofun/proc_server.py | 34 +++++++++--------- zerofun/server.py | 10 +++--- zerofun/sockets.py | 28 +++++++++------ zerofun/thread.py | 1 - 8 files changed, 117 insertions(+), 43 deletions(-) diff --git a/requirements.txt b/requirements.txt index e0c832e..3d3f8c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ cloudpickle -elements +elements>=3.3.0 msgpack numpy +psutil pyzmq diff --git a/tests/test_server.py b/tests/test_server.py index 008cc35..e6f4964 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ sys.path.append(str(pathlib.Path(__file__).parent.parent)) +import elements import numpy as np import pytest import zerofun @@ -368,7 +369,8 @@ def function(data): with server: client = zerofun.Client(addr, pings=0, maxage=1) client.connect(retry=False, timeout=1) - futures = [client.function({'foo': [i]}) for i in range(batch)] + futures = [ + client.function({'foo': np.asarray([i])}) for i in range(batch)] results = [future.result()['foo'][0] for future in futures] assert calls[0] == 1 assert results == list(range(batch)) @@ -446,14 +448,80 @@ def test_proxy_batched(self, Server, inner_addr, outer_addr, workers): @pytest.mark.parametrize('Server', SERVERS) @pytest.mark.parametrize('addr', ADDRESSES) - def test_empty_dict(self, Server, addr): + @pytest.mark.parametrize('data', ( + {'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)}, + {'a': 12, 'b': [np.ones((1,), np.uint8), 13]}, + {'a': 12, 'b': ['c', [1, 2, 3]]}, + [], + {}, + 12, + [[{}, []]], + )) + def test_tree_data(self, Server, addr, data): + data = elements.tree.map(np.asarray, data) + print(data) + def tree_equal(tree1, tree2): + try: + comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2) + comps, _ = elements.tree.flatten(comps) + return all(comps) + except TypeError: + return False addr = addr.format(port=zerofun.get_free_port()) client = zerofun.Client(addr, pings=0, maxage=1) server = Server(addr) - def workfn(data): - assert data == {} - return {} + def workfn(indata): + assert tree_equal(indata, data) + return indata + server.bind('function', workfn) + with server: + client.connect(retry=False, timeout=1) + outdata = client.function(data).result() + assert tree_equal(outdata, data) + + @pytest.mark.parametrize('Server', SERVERS) + @pytest.mark.parametrize('addr', ADDRESSES) + @pytest.mark.parametrize('data', ( + {'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)}, + {'a': 12, 'b': [np.ones((1,), np.uint8), 13]}, + {'a': 12, 'b': ['c', [1, 2, 3]]}, + [], + {}, + 12, + [[{}, []]], + )) + def test_tree_data_batched(self, Server, addr, data): + data = elements.tree.map(np.asarray, data) + print(data) + def tree_equal(tree1, tree2): + try: + comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2) + comps, _ = elements.tree.flatten(comps) + return all(comps) + except TypeError: + return False + addr = addr.format(port=zerofun.get_free_port()) + client = zerofun.Client(addr, pings=0, maxage=1) + server = Server(addr) + def workfn(indata): + return indata + server.bind('function', workfn, batch=4) + with server: + client.connect(retry=False, timeout=1) + futures = [client.function(data) for _ in range(4)] + for future in futures: + assert tree_equal(future.result(), data) + + @pytest.mark.parametrize('Server', SERVERS) + @pytest.mark.parametrize('addr', ADDRESSES) + def test_tree_none_result(self, Server, addr): + addr = addr.format(port=zerofun.get_free_port()) + client = zerofun.Client(addr, pings=0, maxage=1) + server = Server(addr) + def workfn(indata): + pass # No return value server.bind('function', workfn) with server: client.connect(retry=False, timeout=1) - assert client.function({}).result() == {} + result = client.function([]).result() + assert result == [] diff --git a/zerofun/__init__.py b/zerofun/__init__.py index 3092d20..62c4e93 100644 --- a/zerofun/__init__.py +++ b/zerofun/__init__.py @@ -1,4 +1,4 @@ -__version__ = '1.2.0' +__version__ = '2.0.5' import multiprocessing as mp try: diff --git a/zerofun/client.py b/zerofun/client.py index 21d42f6..0511f9e 100644 --- a/zerofun/client.py +++ b/zerofun/client.py @@ -86,8 +86,6 @@ def call(self, method, data): self.queue.popleft().result() except IndexError: pass - assert isinstance(data, dict) - data = {k: np.asarray(v) for k, v in data.items()} data = sockets.pack(data) rid = self.socket.send_call(method, data) self.send_per_sec.step(1) diff --git a/zerofun/proc_server.py b/zerofun/proc_server.py index 30f5693..1200c9b 100644 --- a/zerofun/proc_server.py +++ b/zerofun/proc_server.py @@ -18,17 +18,17 @@ def __init__( self.name = name self.ipv6 = ipv6 self.server = server.Server(self.inner, name, ipv6, workers, errors) - self.batches = {} + self.batchsizes = {} self.batcher = None def bind(self, name, workfn, logfn=None, workers=0, batch=0): - self.batches[name] = batch + self.batchsizes[name] = batch self.server.bind(name, workfn, logfn, workers, batch=0) def start(self): self.batcher = process.StoppableProcess( self._batcher, self.address, self.inner, - self.batches, self.name, self.ipv6, name='batcher', start=True) + self.batchsizes, self.name, self.ipv6, name='batcher', start=True) self.server.start() def check(self): @@ -59,13 +59,13 @@ def __exit__(self, type, value, traceback): self.close() @staticmethod - def _batcher(context, address, inner, batches, name, ipv6): + def _batcher(context, address, inner, batchsizes, name, ipv6): socket = sockets.ServerSocket(address, ipv6) inbound = sockets.ClientSocket(identity=0, pings=0, maxage=0) inbound.connect(inner, timeout=120) queues = collections.defaultdict(list) - buffers = collections.defaultdict(dict) + buffers = {} pending = {} elements.print(f'[{name}] Listening at {address}') @@ -74,7 +74,7 @@ def _batcher(context, address, inner, batches, name, ipv6): result = socket.receive() if result: addr, rid, name, payload = result - batch = batches.get(name, None) + batch = batchsizes.get(name, None) if batch is not None: if batch: queue = queues[name] @@ -83,12 +83,14 @@ def _batcher(context, address, inner, batches, name, ipv6): addrs, rids, payloads = zip(*queue) queue.clear() datas = [sockets.unpack(x) for x in payloads] - idx = range(batch) - bufs = buffers[name] - for key, value in datas[0].items(): - bufs[key] = np.stack( - [datas[i][key] for i in idx], out=bufs.get(key, None)) - payload = sockets.pack(bufs) + if name not in buffers: + buffers[name] = buffer = elements.tree.map( + lambda *xs: np.stack(xs), *datas) + else: + buffers[name] = buffer = elements.tree.map( + lambda buf, *xs: np.stack(xs, out=buf), + buffers[name], *datas) + payload = sockets.pack(buffer) rid = inbound.send_call(name, payload) pending[rid] = (name, addrs, rids) else: @@ -102,12 +104,12 @@ def _batcher(context, address, inner, batches, name, ipv6): if result: inner_rid, payload = result name, addr, rid = pending.pop(inner_rid) - if batches[name]: + if batchsizes[name]: addrs, rids = addr, rid result = sockets.unpack(payload) results = [ - {k: v[i] for k, v in result.items()} - for i in range(batches[name])] + elements.tree.map(lambda x: x[i], result) + for i in range(batchsizes[name])] payloads = [sockets.pack(x) for x in results] for addr, rid, payload in zip(addrs, rids, payloads): socket.send_result(addr, rid, payload) @@ -116,7 +118,7 @@ def _batcher(context, address, inner, batches, name, ipv6): except sockets.RemoteError as e: inner_rid, msg = e.args[:2] name, addr, rid = pending.pop(inner_rid) - if batches[name]: + if batchsizes[name]: addrs, rids = addr, rid for addr, rid in zip(addrs, rids): socket.send_error(addr, rid, msg) diff --git a/zerofun/server.py b/zerofun/server.py index fb05d26..39f68ed 100644 --- a/zerofun/server.py +++ b/zerofun/server.py @@ -181,20 +181,20 @@ def _handle_dones(self): def _work(self, method, addr, rid, payload, recvd): if method.batched: data = [sockets.unpack(x) for x in payload] - data = { - k: np.stack([data[i][k] for i in range(method.insize)]) - for k, v in data[0].items()} + data = elements.tree.map(lambda *xs: np.stack(xs), *data) else: data = sockets.unpack(payload) if method.donefn: result, logs = method.workfn(data) else: result = method.workfn(data) - result = result or {} + if result is None: + result = [] logs = None if method.batched: results = [ - {k: v[i] for k, v in result.items()} for i in range(method.insize)] + elements.tree.map(lambda x: x[i], result) + for i in range(method.insize)] payload = [sockets.pack(x) for x in results] else: payload = sockets.pack(result) diff --git a/zerofun/sockets.py b/zerofun/sockets.py index b37f7b4..e373af3 100644 --- a/zerofun/sockets.py +++ b/zerofun/sockets.py @@ -4,9 +4,11 @@ import threading import time +import elements import numpy as np import zmq + DEBUG = False # DEBUG = True @@ -203,7 +205,8 @@ def send_ping(self, addr): def send_result(self, addr, rid, payload): with self.lock: - self.socket.send_multipart([addr, Type.RESULT.value, rid, *payload]) + self.socket.send_multipart( + [addr, Type.RESULT.value, rid, *payload], copy=False, track=True) def send_error(self, addr, rid, text): text = text.encode('utf-8') @@ -216,23 +219,26 @@ def close(self): def pack(data): - data = {k: np.asarray(v) for k, v in data.items()} + leaves, structure = elements.tree.flatten(data) dtypes, shapes, buffers = [], [], [] - items = sorted(data.items(), key=lambda x: x[0]) - keys, vals = zip(*items) if items else ((), ()) - dtypes = [v.dtype.str for v in vals] - shapes = [v.shape for v in vals] - buffers = [v.tobytes() for v in vals] - meta = (keys, dtypes, shapes) + for value in leaves: + value = np.asarray(value) + assert value.data.c_contiguous, ( + "Array is not contiguous in memory. Use np.asarray(arr, order='C') " + + "before passing the data into pack().") + dtypes.append(value.dtype.str) + shapes.append(value.shape) + buffers.append(value.data) + meta = (structure, dtypes, shapes) payload = [msgpack.packb(meta), *buffers] return payload def unpack(payload): meta, *buffers = payload - keys, dtypes, shapes = msgpack.unpackb(meta) - vals = [ + structure, dtypes, shapes = msgpack.unpackb(meta) + leaves = [ np.frombuffer(b, d).reshape(s) for i, (d, s, b) in enumerate(zip(dtypes, shapes, buffers))] - data = dict(zip(keys, vals)) + data = elements.tree.unflatten(leaves, structure) return data diff --git a/zerofun/thread.py b/zerofun/thread.py index 94007bb..b10f8da 100644 --- a/zerofun/thread.py +++ b/zerofun/thread.py @@ -11,7 +11,6 @@ def __init__(self, fn, *args, name=None, start=False): self._exitcode = None self.exception = None name = name or fn.__name__ - self.old_name = name[:] self.thread = threading.Thread( target=self._wrapper, args=args, name=name, daemon=True) self.started = False