diff --git a/CHANGELOG.md b/CHANGELOG.md index 2aad2bcd..90c535d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - unit - ``to_clock_cycles()`` now always returns an integer. ### Added - octave_tools - Added the possibility to pass the AutoCalibrationParams to ``get_correction_for_each_LO_and_IF()`` to customize the calibration parameters (IF_amplitude for instance). +- data_handler - Added support for nested figures and arrays ## [0.17.4] - 2024-05-07 ### Fixed diff --git a/qualang_tools/results/data_handler/data_processors.py b/qualang_tools/results/data_handler/data_processors.py index 5807c5a3..c9e51ce1 100644 --- a/qualang_tools/results/data_handler/data_processors.py +++ b/qualang_tools/results/data_handler/data_processors.py @@ -48,6 +48,7 @@ def post_process(self, data_folder: Path): class MatplotlibPlotSaver(DataProcessor): file_format: str = "png" + nested_separator: str = "." def __init__(self, file_format=None): if file_format is not None: @@ -65,11 +66,14 @@ def process(self, data): self.data_figures = {} for keys, val in iterate_nested_dict(data): - if isinstance(val, plt.Figure): - path = Path("/".join(keys)).with_suffix(self.file_suffix) + if not isinstance(val, plt.Figure): + continue + + file_end = Path(keys[-1]).with_suffix(self.file_suffix) + path = Path(self.nested_separator.join(keys[:-1] + [str(file_end)])) - self.data_figures[path] = val - update_nested_dict(data, keys, f"./{path}") + self.data_figures[path] = val + update_nested_dict(data, keys, f"./{path}") return data @@ -84,6 +88,7 @@ def post_process(self, data_folder: Path): class NumpyArraySaver(DataProcessor): merge_arrays: bool = True merged_array_name: str = "arrays.npz" + nested_separator: str = "." def __init__(self, merge_arrays=None, merged_array_name=None): if merge_arrays is not None: @@ -100,12 +105,12 @@ def process(self, data): if not isinstance(val, np.ndarray): continue - path = Path("/".join(keys)) + path = Path(self.nested_separator.join(keys)) self.data_arrays[path] = val if self.merge_arrays: update_nested_dict(data, keys, f"./{self.merged_array_name}#{path}") else: - update_nested_dict(data, keys, f"./{path.with_suffix('.npy')}") + update_nested_dict(data, keys, f"./{path}.npy") return data def post_process(self, data_folder: Path): @@ -115,7 +120,7 @@ def post_process(self, data_folder: Path): np.savez(data_folder / self.merged_array_name, **arrays) else: for path, arr in self.data_arrays.items(): - np.save(data_folder / path.with_suffix(".npy"), arr) + np.save(data_folder / f"{path}.npy", arr) DEFAULT_DATA_PROCESSORS.append(NumpyArraySaver) @@ -125,6 +130,7 @@ class XarraySaver(DataProcessor): merge_arrays: bool = False merged_array_name: str = "xarrays" file_format: str = "hdf5" + nested_separator: str = "." def __init__(self, merge_arrays=None, merged_array_name=None, file_format=None): if merge_arrays is not None: @@ -150,13 +156,13 @@ def process(self, data): if not isinstance(val, xr.Dataset): continue - path = Path("/".join(keys)) + path = Path(self.nested_separator.join(keys)) self.data_arrays[path] = val if self.merge_arrays: merged_array_name = Path(self.merged_array_name).with_suffix(self.file_suffix) update_nested_dict(data, keys, f"./{merged_array_name}#{path}") else: - update_nested_dict(data, keys, f"./{path.with_suffix(self.file_suffix)}") + update_nested_dict(data, keys, f"./{path}{self.file_suffix}") return data def save_merged_netcdf_arrays(self, path: Path, arrays: dict): @@ -183,7 +189,7 @@ def post_process(self, data_folder: Path): ) from e else: for path, array in self.data_arrays.items(): - array.to_netcdf(data_folder / path.with_suffix(self.file_suffix)) + array.to_netcdf(data_folder / f"{path}{self.file_suffix}") try: diff --git a/tests/data_handler/test_matplotlib_plot_saver.py b/tests/data_handler/test_matplotlib_plot_saver.py index d1df59df..e27b7a81 100644 --- a/tests/data_handler/test_matplotlib_plot_saver.py +++ b/tests/data_handler/test_matplotlib_plot_saver.py @@ -37,3 +37,14 @@ def test_save_plot_basic(tmp_path, fig): file_data = json.loads((tmp_path / "data.json").read_text()) assert file_data == {"a": 1, "b": 2, "c": "./c.png"} + + +def test_matplotlib_nested_save(tmp_path, fig): + data = {"q0": {"fig": fig, "value": 42}} + + save_data(data_folder=tmp_path, data=data, node_contents={}, data_processors=[MatplotlibPlotSaver()]) + + assert set(f.name for f in tmp_path.iterdir()) == set(["data.json", "node.json", "q0.fig.png"]) + + file_data = json.loads((tmp_path / "data.json").read_text()) + assert file_data == {"q0": {"fig": "./q0.fig.png", "value": 42}}