From 90283a3ab4842177861b229bc7dde03712276077 Mon Sep 17 00:00:00 2001 From: Johannes Haux Date: Tue, 6 Aug 2019 14:54:47 +0200 Subject: [PATCH] :white_check_mark: :tada: Add switch to turn on return of labels with examples --- edflow/data/dataset.py | 82 +++++++++++++++- edflow/debug.py | 4 +- tests/test_data/test_datasetmixin.py | 134 +++++++++++++++++++++++++++ 3 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 tests/test_data/test_datasetmixin.py diff --git a/edflow/data/dataset.py b/edflow/data/dataset.py index 5ab38f6..a3394c1 100644 --- a/edflow/data/dataset.py +++ b/edflow/data/dataset.py @@ -125,6 +125,40 @@ def __init__(self): A = ConcatenatedDataset(C, B) # Adding two Datasets D = ConcatenatedDataset(A, A, A) # Multiplying two datasets + Labels in the example `dict` + ---------------------------- + + Oftentimes it is good to store and load some values as lables as it can + increase performance and decrease storage size, e.g. when storing scalar + values. If you need these values to be returned by the :func:`get_example` + method, simply activate this behaviour by setting the attribute + :attr:`append_labels` to ``True``. + + .. code-block:: python + + SomeDerivedDataset(DatasetMixin): + def __init__(self): + self.labels = {'a': [1, 2, 3]} + self.append_labels = True + + def get_example(self, idx): + return {'a' : idx**2, 'b': idx} + + def __len__(self): + return 3 + + S = SomeDerivedDataset() + a = S[2] + print(a) # {'a': 3, 'b': 2} + + S.append_labels = False + a = S[2] + print(a) # {'a': 4, 'b': 2} + + Labels are appended to your example, after all code is executed from your + :attr:`get_example` method. Thus, if there are keys in your labels, which + can also be found in the examples, the label entries will override the + values in you example, as can be seen in the example above. """ def _d_msg(self, val): @@ -136,10 +170,12 @@ def _d_msg(self, val): "{}".format(type(val)) ) - @traceable_method(ignores=[BrokenPipeError]) + # @traceable_method(ignores=[BrokenPipeError]) def __getitem__(self, i): ret_dict = super().__getitem__(i) + # print(self.append_labels) + if isinstance(i, slice): start = i.start or 0 stop = i.stop @@ -149,17 +185,28 @@ def __getitem__(self, i): raise ValueError(self._d_msg(d)) d["index_"] = idx + if self.append_labels: + labels = {k: v[idx] for k, v in self.labels.items()} + d.update(labels) + elif isinstance(i, list) or isinstance(i, np.ndarray): for idx, d in zip(i, ret_dict): if not isinstance(d, dict): raise ValueError(self._d_msg(d)) d["index_"] = idx + if self.append_labels: + labels = {k: v[idx] for k, v in self.labels.items()} + d.update(labels) + else: if not isinstance(ret_dict, dict): raise ValueError(self._d_msg(ret_dict)) ret_dict["index_"] = i + if self.append_labels: + labels = {k: v[i] for k, v in self.labels.items()} + ret_dict.update(labels) return ret_dict @@ -232,6 +279,39 @@ def __add__(self, dset): return ConcatenatedDataset(self, dset) + @property + def labels(self): + """Add default behaviour for datasets defining an attribute + :attr:`data`, which in turn is a dataset. This happens often when + stacking several datasets on top of each other. + + The default behaviour now is to return ``self.data.labels`` + if possible, and otherwise revert to the original behaviour. + """ + if hasattr(self, "data"): + return self.data.labels + elif hasattr(self, "_labels"): + return self._labels + else: + return super().labels + + @labels.setter + def labels(self, labels): + if hasattr(self, "data"): + self.data.labels = labels + else: + self._labels = labels + + @property + def append_labels(self): + if not hasattr(self, "_append_labels"): + self._append_labels = False + return self._append_labels + + @append_labels.setter + def append_labels(self, value): + self._append_labels = value + def make_server_manager(port=63127, authkey=b"edcache"): inqueue = queue.Queue() diff --git a/edflow/debug.py b/edflow/debug.py index ad13373..eea9bc1 100644 --- a/edflow/debug.py +++ b/edflow/debug.py @@ -39,9 +39,7 @@ def get_example(self, i): @property def labels(self): if not hasattr(self, "_labels"): - self._labels = { - k: [i for i in range(self.size)] for k in ["index_", "other"] - } + self._labels = {k: [i for i in range(self.size)] for k in ["lbl1", "lbl2"]} return self._labels def __len__(self): diff --git a/tests/test_data/test_datasetmixin.py b/tests/test_data/test_datasetmixin.py new file mode 100644 index 0000000..efafce0 --- /dev/null +++ b/tests/test_data/test_datasetmixin.py @@ -0,0 +1,134 @@ +import pytest + +from edflow.data.dataset import DatasetMixin +from edflow.debug import DebugDataset + + +def test_dset_mxin(): + class MyDset(DatasetMixin): + def get_example(self, idx): + return {"a": 1} + + def __len__(self): + # Cannot be guessed from get_example! + return 10 + + D = MyDset() + ex = D[0] + assert "index_" in ex + assert "a" in ex + + with pytest.raises(Exception): + # Must compare len with idx + ex[100] + + +def test_dset_mxin_app_labels(): + class MyDset(DatasetMixin): + def __init__(self): + self.labels = {"l": [1, 2, 3]} + self.append_labels = True + + def get_example(self, idx): + return {"a": 1} + + def __len__(self): + return len(self.labels["l"]) + + D = MyDset() + ex = D[0] + assert "l" in ex + assert "a" in ex + + with pytest.raises(Exception): + ex[100] + + D.append_labels = False + ex = D[0] + assert "l" not in ex + assert "a" in ex + + with pytest.raises(Exception): + ex[100] + + +def test_dset_mxin_data_attr(): + class MyDset(DatasetMixin): + def __init__(self): + self.data = DebugDataset(size=10) + + D = MyDset() + ex = D[0] + assert "val" in ex + assert "other" in ex + assert "index_" in ex + + with pytest.raises(Exception): + ex[100] + + lbs = D.labels + l = lbs["lbl1"][0] + + with pytest.raises(Exception): + lbs["lbl1"][100] + + +def test_dset_mxin_data_attr_app_labels(): + class MyDset(DatasetMixin): + def __init__(self): + self.data = DebugDataset(size=10) + self.append_labels = True + + D = MyDset() + ex = D[0] + assert "val" in ex + assert "other" in ex + assert "index_" in ex + assert "lbl1" in ex + assert "lbl2" in ex + + with pytest.raises(Exception): + ex[100] + + lbs = D.labels + l = lbs["lbl1"][0] + l = lbs["lbl2"][0] + + with pytest.raises(Exception): + lbs["lbl1"][100] + lbs["lbl2"][100] + + +def test_dset_mxin_ops(): + """Basically test the ConcatenatedDataset""" + + class MyDset(DatasetMixin): + def __init__(self): + self.data = DebugDataset(size=10) + + D1 = DebugDataset(size=10) + D2 = DebugDataset(size=10) + + D3 = D1 + D2 + assert len(D3) == 20 + D3[13] + D3 = D2 + D1 + assert len(D3) == 20 + D3[13] + + D4 = 3 * D1 + assert len(D4) == 30 + D4[13] + D4[23] + D4 = D1 * 3 + assert len(D4) == 30 + D4[13] + D4[23] + + +if __name__ == "__main__": + test_dset_mxin() + test_dset_mxin_ops() + test_dset_mxin_data_attr() + test_dset_mxin_app_labels() + test_dset_mxin_data_attr_app_labels()