diff --git a/edflow/util.py b/edflow/util.py index babe9c8..6d03bd2 100644 --- a/edflow/util.py +++ b/edflow/util.py @@ -202,6 +202,10 @@ def retrieve( raise KeyNotFoundError(e) visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict except KeyNotFoundError as e: if default is None: print("Key not found: {}, seen: {}".format(keys, visited)) diff --git a/tests/test_util.py b/tests/test_util.py index 6ff736b..99e6b6e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -513,6 +513,28 @@ def test_retrieve_propagates_exception(): val = retrieve(dol, "b/c/d", default=0) +def test_retrieve_callable_leaves(): + dol = {"a": [1, 2], "b": callable_leave, "e": 2} + val = retrieve(dol, "b") + + # make sure expansion is returned + assert val == callable_leave() + + # make sure expansion was done in-place + assert dol["b"] == callable_leave() + + dol = {"a": [1, 2], "b": callable_leave, "e": 2} + val = retrieve(dol, "b/c") + # make sure expansion is returned + assert val == nested_leave() + # make sure expansion was done in-place + assert dol["b"]["c"] == nested_leave() + + dol = {"a": [1, 2], "b": callable_leave, "e": 2} + val = retrieve(dol, "b/c/d") + assert val == 1 + + # ====================== walk ====================