diff --git a/qualang_tools/results/data_handler/data_processors.py b/qualang_tools/results/data_handler/data_processors.py index c9e51ce1..5a991e0e 100644 --- a/qualang_tools/results/data_handler/data_processors.py +++ b/qualang_tools/results/data_handler/data_processors.py @@ -30,7 +30,13 @@ def iterate_nested_dict( yield from iterate_nested_dict(v, parent_keys=keys) -def update_nested_dict(d, keys, value): +def update_nested_dict(d: dict, keys: List[Any], value: Any) -> None: + """Update a nested dictionary with a new value + + :param d: The dictionary to update + :param keys: The keys to the value to update + :param value: The new value to set + """ subdict = d for key in keys[:-1]: subdict = subdict[key] @@ -38,6 +44,25 @@ def update_nested_dict(d, keys, value): subdict[keys[-1]] = value +def copy_nested_dict(d: dict) -> dict: + """Copy a nested dictionary, but don't make copies of the values + + This function will copy a nested dictionary, but will not make copies of the values. This is useful if copying the + values may be an expensive operation (e.g. large arrays). + If you also want to make copies of the values, use `copy.deepcopy` + + :param d: The dictionary to copy + :return: A new dictionary with the same structure as `d`, but with the same values + """ + new_dict = {} + for key, val in d.items(): + if isinstance(val, dict): + new_dict[key] = copy_nested_dict(val) + else: + new_dict[key] = val + return new_dict + + class DataProcessor(ABC): def process(self, data): return data @@ -65,6 +90,8 @@ def file_suffix(self): def process(self, data): self.data_figures = {} + processed_data = copy_nested_dict(data) + for keys, val in iterate_nested_dict(data): if not isinstance(val, plt.Figure): continue @@ -73,9 +100,9 @@ def process(self, data): path = Path(self.nested_separator.join(keys[:-1] + [str(file_end)])) self.data_figures[path] = val - update_nested_dict(data, keys, f"./{path}") + update_nested_dict(processed_data, keys, f"./{path}") - return data + return processed_data def post_process(self, data_folder: Path): for path, fig in self.data_figures.items(): @@ -100,6 +127,7 @@ def __init__(self, merge_arrays=None, merged_array_name=None): def process(self, data): self.data_arrays = {} + processed_data = copy_nested_dict(data) for keys, val in iterate_nested_dict(data): if not isinstance(val, np.ndarray): @@ -108,10 +136,10 @@ def process(self, data): path = Path(self.nested_separator.join(keys)) self.data_arrays[path] = val if self.merge_arrays: - update_nested_dict(data, keys, f"./{self.merged_array_name}#{path}") + update_nested_dict(processed_data, keys, f"./{self.merged_array_name}#{path}") else: - update_nested_dict(data, keys, f"./{path}.npy") - return data + update_nested_dict(processed_data, keys, f"./{path}.npy") + return processed_data def post_process(self, data_folder: Path): if self.merge_arrays: @@ -151,6 +179,7 @@ def process(self, data): import xarray as xr self.data_arrays = {} + processed_data = copy_nested_dict(data) for keys, val in iterate_nested_dict(data): if not isinstance(val, xr.Dataset): @@ -160,10 +189,10 @@ def process(self, data): self.data_arrays[path] = val if self.merge_arrays: merged_array_name = Path(self.merged_array_name).with_suffix(self.file_suffix) - update_nested_dict(data, keys, f"./{merged_array_name}#{path}") + update_nested_dict(processed_data, keys, f"./{merged_array_name}#{path}") else: - update_nested_dict(data, keys, f"./{path}{self.file_suffix}") - return data + update_nested_dict(processed_data, keys, f"./{path}{self.file_suffix}") + return processed_data def save_merged_netcdf_arrays(self, path: Path, arrays: dict): for array_path, array in self.data_arrays.items(): diff --git a/tests/data_handler/test_matplotlib_plot_saver.py b/tests/data_handler/test_matplotlib_plot_saver.py index e27b7a81..370ddd62 100644 --- a/tests/data_handler/test_matplotlib_plot_saver.py +++ b/tests/data_handler/test_matplotlib_plot_saver.py @@ -48,3 +48,12 @@ def test_matplotlib_nested_save(tmp_path, fig): file_data = json.loads((tmp_path / "data.json").read_text()) assert file_data == {"q0": {"fig": "./q0.fig.png", "value": 42}} + + +def test_matplotlib_save_does_not_affect_data(fig): + matplotlib_plot_saver = MatplotlibPlotSaver() + data = {"a": 1, "b": 2, "c": fig} + processed_data = matplotlib_plot_saver.process(data) + + assert data == {"a": 1, "b": 2, "c": fig} + assert processed_data == {"a": 1, "b": 2, "c": "./c.png"} diff --git a/tests/data_handler/test_numpy_array_saver.py b/tests/data_handler/test_numpy_array_saver.py index c6f38910..443fbcb0 100644 --- a/tests/data_handler/test_numpy_array_saver.py +++ b/tests/data_handler/test_numpy_array_saver.py @@ -1,15 +1,31 @@ +import copy import numpy as np from qualang_tools.results.data_handler.data_processors import DEFAULT_DATA_PROCESSORS, NumpyArraySaver +def dicts_equal(d1, d2): + if d1.keys() != d2.keys(): + return False + for key in d1: + if isinstance(d1[key], dict): + if not dicts_equal(d1[key], d2[key]): + return False + elif isinstance(d1[key], np.ndarray): + if not np.array_equal(d1[key], d2[key]): + return False + else: + if d1[key] != d2[key]: + return False + return True + + def test_numpy_array_saver_process_merged(): data = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6]), "c": 3} data_processor = NumpyArraySaver() - processed_data = data.copy() - processed_data = data_processor.process(processed_data) + processed_data = data_processor.process(data) assert processed_data == { "a": "./arrays.npz#a", @@ -51,7 +67,7 @@ def test_numpy_array_saver_post_process_separate(tmp_path): data = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6]), "c": 3} data_processor = NumpyArraySaver(merge_arrays=False) - data_processor.process(data.copy()) + data_processor.process(data) data_processor.post_process(data_folder=tmp_path) @@ -59,3 +75,34 @@ def test_numpy_array_saver_post_process_separate(tmp_path): assert (tmp_path / "b.npy").exists() assert np.array_equal(np.load(tmp_path / "a.npy"), data["a"]) assert np.array_equal(np.load(tmp_path / "b.npy"), data["b"]) + + +def test_numpy_array_saver_nested_no_merge(tmp_path): + data = {"q0": {"a": np.array([1, 2, 3]), "b": 3}, "c": np.array([4, 5, 6])} + + data_processor = NumpyArraySaver(merge_arrays=False) + processed_data = data_processor.process(data) + assert processed_data == { + "q0": {"a": "./q0.a.npy", "b": 3}, + "c": "./c.npy", + } + + data_processor.post_process(data_folder=tmp_path) + + assert (tmp_path / "q0.a.npy").exists() + assert not (tmp_path / "q0.b.npy").exists() + assert (tmp_path / "c.npy").exists() + + assert np.array_equal(np.load(tmp_path / "q0.a.npy"), data["q0"]["a"]) + assert np.array_equal(np.load(tmp_path / "c.npy"), data["c"]) + + +def test_numpy_array_saver_process_does_not_affect_data(): + data = {"q0": {"a": np.array([1, 2, 3]), "b": 3}, "c": np.array([4, 5, 6])} + deepcopied_data = copy.deepcopy(data) + + data_processor = NumpyArraySaver(merge_arrays=False) + processed_data = data_processor.process(data) + + assert dicts_equal(data, deepcopied_data) + assert not dicts_equal(processed_data, data) diff --git a/tests/data_handler/test_xarray_saver.py b/tests/data_handler/test_xarray_saver.py index 9d6ed875..b4380e8a 100644 --- a/tests/data_handler/test_xarray_saver.py +++ b/tests/data_handler/test_xarray_saver.py @@ -1,5 +1,7 @@ +import copy import pytest -import sys +import numpy as np +import xarray as xr from qualang_tools.results.data_handler.data_processors import XarraySaver @@ -11,6 +13,27 @@ def module_installed(module_name): return True +def dicts_equal(d1, d2): + if d1.keys() != d2.keys(): + return False + for key in d1: + if key not in d2: + return False + elif isinstance(d1[key], dict): + if not dicts_equal(d1[key], d2[key]): + return False + elif isinstance(d1[key], np.ndarray): + if not np.array_equal(d1[key], d2[key]): + return False + elif isinstance(d1[key], xr.Dataset): + if not d1[key].identical(d2[key]): + return False + else: + if not bool(d1[key] == d2[key]): + return False + return True + + @pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed") def test_xarray_saver_no_xarrays(): xarray_saver = XarraySaver() @@ -44,7 +67,7 @@ def test_xarray_saver_merge_netcdf(tmp_path): data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()} xarray_saver = XarraySaver(merge_arrays=True, file_format="nc") - processed_data = xarray_saver.process(data.copy()) + processed_data = xarray_saver.process(data) assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.nc#c", "d": "./xarrays.nc#d"} @@ -63,7 +86,7 @@ def test_xarray_saver_merge_hdf5(tmp_path): data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()} xarray_saver = XarraySaver(merge_arrays=True, file_format="h5") - processed_data = xarray_saver.process(data.copy()) + processed_data = xarray_saver.process(data) assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.h5#c", "d": "./xarrays.h5#d"} @@ -82,7 +105,7 @@ def test_xarray_saver_no_merge_netcdf(tmp_path): data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()} xarray_saver = XarraySaver(merge_arrays=False) - processed_data = xarray_saver.process(data.copy()) + processed_data = xarray_saver.process(data) assert processed_data == {"a": 1, "b": 2, "c": "./c.h5", "d": "./d.h5"} @@ -93,3 +116,55 @@ def test_xarray_saver_no_merge_netcdf(tmp_path): xr.load_dataset(tmp_path / "c.h5") xr.load_dataset(tmp_path / "d.h5") + + +@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed") +def test_xarray_saver_no_merge_hdf5_nested(tmp_path): + import xarray as xr + + data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}} + + xarray_saver = XarraySaver(merge_arrays=False, file_format="nc") + processed_data = xarray_saver.process(data) + + assert processed_data == {"a": 1, "b": 2, "c": "./c.nc", "d": {"d1": "./d.d1.nc"}} + + xarray_saver.post_process(data_folder=tmp_path) + + assert (tmp_path / "c.nc").exists() + assert (tmp_path / "d.d1.nc").exists() + + xr.load_dataset(tmp_path / "c.nc") + xr.load_dataset(tmp_path / "d.d1.nc") + + +@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed") +def test_xarray_saver_merge_hdf5_nested(tmp_path): + import xarray as xr + + data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}} + + xarray_saver = XarraySaver(merge_arrays=True, file_format="nc") + processed_data = xarray_saver.process(data) + + assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.nc#c", "d": {"d1": "./xarrays.nc#d.d1"}} + xarray_saver.post_process(data_folder=tmp_path) + + assert (tmp_path / "xarrays.nc").exists() + + xr.load_dataset(tmp_path / "xarrays.nc", group="c") + xr.load_dataset(tmp_path / "xarrays.nc", group="d.d1") + + +@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed") +def test_xarray_saver_does_not_affect_data(): + import xarray as xr + + data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}} + deepcopied_data = copy.deepcopy(data) + + xarray_saver = XarraySaver(merge_arrays=False, file_format="nc") + processed_data = xarray_saver.process(data) + + assert dicts_equal(data, deepcopied_data) + assert not dicts_equal(processed_data, data)