Skip to content

Commit

Permalink
Feat: Allow
Browse files Browse the repository at this point in the history
  • Loading branch information
nulinspiratie committed May 17, 2024
1 parent c5517fd commit 11a9a61
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- unit - ``to_clock_cycles()`` now always returns an integer.
### Added
- octave_tools - Added the possibility to pass the AutoCalibrationParams to ``get_correction_for_each_LO_and_IF()`` to customize the calibration parameters (IF_amplitude for instance).
- data_handler - Added support for nested figures and arrays

## [0.17.4] - 2024-05-07
### Fixed
Expand Down
26 changes: 16 additions & 10 deletions qualang_tools/results/data_handler/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def post_process(self, data_folder: Path):

class MatplotlibPlotSaver(DataProcessor):
file_format: str = "png"
nested_separator: str = "."

def __init__(self, file_format=None):
if file_format is not None:
Expand All @@ -65,11 +66,14 @@ def process(self, data):
self.data_figures = {}

for keys, val in iterate_nested_dict(data):
if isinstance(val, plt.Figure):
path = Path("/".join(keys)).with_suffix(self.file_suffix)
if not isinstance(val, plt.Figure):
continue

file_end = Path(keys[-1]).with_suffix(self.file_suffix)
path = Path(self.nested_separator.join(keys[:-1] + [str(file_end)]))

self.data_figures[path] = val
update_nested_dict(data, keys, f"./{path}")
self.data_figures[path] = val
update_nested_dict(data, keys, f"./{path}")

return data

Expand All @@ -84,6 +88,7 @@ def post_process(self, data_folder: Path):
class NumpyArraySaver(DataProcessor):
merge_arrays: bool = True
merged_array_name: str = "arrays.npz"
nested_separator: str = "."

def __init__(self, merge_arrays=None, merged_array_name=None):
if merge_arrays is not None:
Expand All @@ -100,12 +105,12 @@ def process(self, data):
if not isinstance(val, np.ndarray):
continue

path = Path("/".join(keys))
path = Path(self.nested_separator.join(keys))
self.data_arrays[path] = val
if self.merge_arrays:
update_nested_dict(data, keys, f"./{self.merged_array_name}#{path}")
else:
update_nested_dict(data, keys, f"./{path.with_suffix('.npy')}")
update_nested_dict(data, keys, f"./{path}.npy")
return data

def post_process(self, data_folder: Path):
Expand All @@ -115,7 +120,7 @@ def post_process(self, data_folder: Path):
np.savez(data_folder / self.merged_array_name, **arrays)
else:
for path, arr in self.data_arrays.items():
np.save(data_folder / path.with_suffix(".npy"), arr)
np.save(data_folder / f"{path}.npy", arr)


DEFAULT_DATA_PROCESSORS.append(NumpyArraySaver)
Expand All @@ -125,6 +130,7 @@ class XarraySaver(DataProcessor):
merge_arrays: bool = False
merged_array_name: str = "xarrays"
file_format: str = "hdf5"
nested_separator: str = "."

def __init__(self, merge_arrays=None, merged_array_name=None, file_format=None):
if merge_arrays is not None:
Expand All @@ -150,13 +156,13 @@ def process(self, data):
if not isinstance(val, xr.Dataset):
continue

path = Path("/".join(keys))
path = Path(self.nested_separator.join(keys))
self.data_arrays[path] = val
if self.merge_arrays:
merged_array_name = Path(self.merged_array_name).with_suffix(self.file_suffix)
update_nested_dict(data, keys, f"./{merged_array_name}#{path}")
else:
update_nested_dict(data, keys, f"./{path.with_suffix(self.file_suffix)}")
update_nested_dict(data, keys, f"./{path}{self.file_suffix}")
return data

def save_merged_netcdf_arrays(self, path: Path, arrays: dict):
Expand All @@ -183,7 +189,7 @@ def post_process(self, data_folder: Path):
) from e
else:
for path, array in self.data_arrays.items():
array.to_netcdf(data_folder / path.with_suffix(self.file_suffix))
array.to_netcdf(data_folder / f"{path}{self.file_suffix}")


try:
Expand Down
11 changes: 11 additions & 0 deletions tests/data_handler/test_matplotlib_plot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ def test_save_plot_basic(tmp_path, fig):
file_data = json.loads((tmp_path / "data.json").read_text())

assert file_data == {"a": 1, "b": 2, "c": "./c.png"}


def test_matplotlib_nested_save(tmp_path, fig):
data = {"q0": {"fig": fig, "value": 42}}

save_data(data_folder=tmp_path, data=data, node_contents={}, data_processors=[MatplotlibPlotSaver()])

assert set(f.name for f in tmp_path.iterdir()) == set(["data.json", "node.json", "q0.fig.png"])

file_data = json.loads((tmp_path / "data.json").read_text())
assert file_data == {"q0": {"fig": "./q0.fig.png", "value": 42}}

0 comments on commit 11a9a61

Please sign in to comment.