Skip to content

Commit

Permalink
Merge pull request #117 from jhaux/dset_append_label
Browse files Browse the repository at this point in the history
Add switch to turn on return of labels with examples
  • Loading branch information
pesser authored Aug 8, 2019
2 parents 9984559 + 48e9bcc commit 491e01e
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 4 deletions.
82 changes: 81 additions & 1 deletion edflow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,40 @@ def __init__(self):
A = ConcatenatedDataset(C, B) # Adding two Datasets
D = ConcatenatedDataset(A, A, A) # Multiplying two datasets
Labels in the example `dict`
----------------------------
Oftentimes it is good to store and load some values as lables as it can
increase performance and decrease storage size, e.g. when storing scalar
values. If you need these values to be returned by the :func:`get_example`
method, simply activate this behaviour by setting the attribute
:attr:`append_labels` to ``True``.
.. code-block:: python
SomeDerivedDataset(DatasetMixin):
def __init__(self):
self.labels = {'a': [1, 2, 3]}
self.append_labels = True
def get_example(self, idx):
return {'a' : idx**2, 'b': idx}
def __len__(self):
return 3
S = SomeDerivedDataset()
a = S[2]
print(a) # {'a': 3, 'b': 2}
S.append_labels = False
a = S[2]
print(a) # {'a': 4, 'b': 2}
Labels are appended to your example, after all code is executed from your
:attr:`get_example` method. Thus, if there are keys in your labels, which
can also be found in the examples, the label entries will override the
values in you example, as can be seen in the example above.
"""

def _d_msg(self, val):
Expand All @@ -136,10 +170,12 @@ def _d_msg(self, val):
"{}".format(type(val))
)

@traceable_method(ignores=[BrokenPipeError])
# @traceable_method(ignores=[BrokenPipeError])
def __getitem__(self, i):
ret_dict = super().__getitem__(i)

# print(self.append_labels)

if isinstance(i, slice):
start = i.start or 0
stop = i.stop
Expand All @@ -149,17 +185,28 @@ def __getitem__(self, i):
raise ValueError(self._d_msg(d))
d["index_"] = idx

if self.append_labels:
labels = {k: v[idx] for k, v in self.labels.items()}
d.update(labels)

elif isinstance(i, list) or isinstance(i, np.ndarray):
for idx, d in zip(i, ret_dict):
if not isinstance(d, dict):
raise ValueError(self._d_msg(d))
d["index_"] = idx

if self.append_labels:
labels = {k: v[idx] for k, v in self.labels.items()}
d.update(labels)

else:
if not isinstance(ret_dict, dict):
raise ValueError(self._d_msg(ret_dict))

ret_dict["index_"] = i
if self.append_labels:
labels = {k: v[i] for k, v in self.labels.items()}
ret_dict.update(labels)

return ret_dict

Expand Down Expand Up @@ -239,6 +286,39 @@ def __add__(self, dset):

return ConcatenatedDataset(self, dset)

@property
def labels(self):
"""Add default behaviour for datasets defining an attribute
:attr:`data`, which in turn is a dataset. This happens often when
stacking several datasets on top of each other.
The default behaviour now is to return ``self.data.labels``
if possible, and otherwise revert to the original behaviour.
"""
if hasattr(self, "data"):
return self.data.labels
elif hasattr(self, "_labels"):
return self._labels
else:
return super().labels

@labels.setter
def labels(self, labels):
if hasattr(self, "data"):
self.data.labels = labels
else:
self._labels = labels

@property
def append_labels(self):
if not hasattr(self, "_append_labels"):
self._append_labels = False
return self._append_labels

@append_labels.setter
def append_labels(self, value):
self._append_labels = value


def make_server_manager(port=63127, authkey=b"edcache"):
inqueue = queue.Queue()
Expand Down
4 changes: 1 addition & 3 deletions edflow/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def get_example(self, i):
@property
def labels(self):
if not hasattr(self, "_labels"):
self._labels = {
k: [i for i in range(self.size)] for k in ["index_", "other"]
}
self._labels = {k: [i for i in range(self.size)] for k in ["lbl1", "lbl2"]}
return self._labels

def __len__(self):
Expand Down
134 changes: 134 additions & 0 deletions tests/test_data/test_datasetmixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest

from edflow.data.dataset import DatasetMixin
from edflow.debug import DebugDataset


def test_dset_mxin():
class MyDset(DatasetMixin):
def get_example(self, idx):
return {"a": 1}

def __len__(self):
# Cannot be guessed from get_example!
return 10

D = MyDset()
ex = D[0]
assert "index_" in ex
assert "a" in ex

with pytest.raises(Exception):
# Must compare len with idx
ex[100]


def test_dset_mxin_app_labels():
class MyDset(DatasetMixin):
def __init__(self):
self.labels = {"l": [1, 2, 3]}
self.append_labels = True

def get_example(self, idx):
return {"a": 1}

def __len__(self):
return len(self.labels["l"])

D = MyDset()
ex = D[0]
assert "l" in ex
assert "a" in ex

with pytest.raises(Exception):
ex[100]

D.append_labels = False
ex = D[0]
assert "l" not in ex
assert "a" in ex

with pytest.raises(Exception):
ex[100]


def test_dset_mxin_data_attr():
class MyDset(DatasetMixin):
def __init__(self):
self.data = DebugDataset(size=10)

D = MyDset()
ex = D[0]
assert "val" in ex
assert "other" in ex
assert "index_" in ex

with pytest.raises(Exception):
ex[100]

lbs = D.labels
l = lbs["lbl1"][0]

with pytest.raises(Exception):
lbs["lbl1"][100]


def test_dset_mxin_data_attr_app_labels():
class MyDset(DatasetMixin):
def __init__(self):
self.data = DebugDataset(size=10)
self.append_labels = True

D = MyDset()
ex = D[0]
assert "val" in ex
assert "other" in ex
assert "index_" in ex
assert "lbl1" in ex
assert "lbl2" in ex

with pytest.raises(Exception):
ex[100]

lbs = D.labels
l = lbs["lbl1"][0]
l = lbs["lbl2"][0]

with pytest.raises(Exception):
lbs["lbl1"][100]
lbs["lbl2"][100]


def test_dset_mxin_ops():
"""Basically test the ConcatenatedDataset"""

class MyDset(DatasetMixin):
def __init__(self):
self.data = DebugDataset(size=10)

D1 = DebugDataset(size=10)
D2 = DebugDataset(size=10)

D3 = D1 + D2
assert len(D3) == 20
D3[13]
D3 = D2 + D1
assert len(D3) == 20
D3[13]

D4 = 3 * D1
assert len(D4) == 30
D4[13]
D4[23]
D4 = D1 * 3
assert len(D4) == 30
D4[13]
D4[23]


if __name__ == "__main__":
test_dset_mxin()
test_dset_mxin_ops()
test_dset_mxin_data_attr()
test_dset_mxin_app_labels()
test_dset_mxin_data_attr_app_labels()

0 comments on commit 491e01e

Please sign in to comment.