Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nulinspiratie committed May 18, 2024
1 parent 11a9a61 commit 4eff283
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 16 deletions.
47 changes: 38 additions & 9 deletions qualang_tools/results/data_handler/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,39 @@ def iterate_nested_dict(
yield from iterate_nested_dict(v, parent_keys=keys)


def update_nested_dict(d, keys, value):
def update_nested_dict(d: dict, keys: List[Any], value: Any) -> None:
"""Update a nested dictionary with a new value
:param d: The dictionary to update
:param keys: The keys to the value to update
:param value: The new value to set
"""
subdict = d
for key in keys[:-1]:
subdict = subdict[key]

subdict[keys[-1]] = value


def copy_nested_dict(d: dict) -> dict:
"""Copy a nested dictionary, but don't make copies of the values
This function will copy a nested dictionary, but will not make copies of the values. This is useful if copying the
values may be an expensive operation (e.g. large arrays).
If you also want to make copies of the values, use `copy.deepcopy`
:param d: The dictionary to copy
:return: A new dictionary with the same structure as `d`, but with the same values
"""
new_dict = {}
for key, val in d.items():
if isinstance(val, dict):
new_dict[key] = copy_nested_dict(val)
else:
new_dict[key] = val
return new_dict


class DataProcessor(ABC):
def process(self, data):
return data
Expand Down Expand Up @@ -65,6 +90,8 @@ def file_suffix(self):
def process(self, data):
self.data_figures = {}

processed_data = copy_nested_dict(data)

for keys, val in iterate_nested_dict(data):
if not isinstance(val, plt.Figure):
continue
Expand All @@ -73,9 +100,9 @@ def process(self, data):
path = Path(self.nested_separator.join(keys[:-1] + [str(file_end)]))

self.data_figures[path] = val
update_nested_dict(data, keys, f"./{path}")
update_nested_dict(processed_data, keys, f"./{path}")

return data
return processed_data

def post_process(self, data_folder: Path):
for path, fig in self.data_figures.items():
Expand All @@ -100,6 +127,7 @@ def __init__(self, merge_arrays=None, merged_array_name=None):

def process(self, data):
self.data_arrays = {}
processed_data = copy_nested_dict(data)

for keys, val in iterate_nested_dict(data):
if not isinstance(val, np.ndarray):
Expand All @@ -108,10 +136,10 @@ def process(self, data):
path = Path(self.nested_separator.join(keys))
self.data_arrays[path] = val
if self.merge_arrays:
update_nested_dict(data, keys, f"./{self.merged_array_name}#{path}")
update_nested_dict(processed_data, keys, f"./{self.merged_array_name}#{path}")
else:
update_nested_dict(data, keys, f"./{path}.npy")
return data
update_nested_dict(processed_data, keys, f"./{path}.npy")
return processed_data

def post_process(self, data_folder: Path):
if self.merge_arrays:
Expand Down Expand Up @@ -151,6 +179,7 @@ def process(self, data):
import xarray as xr

self.data_arrays = {}
processed_data = copy_nested_dict(data)

for keys, val in iterate_nested_dict(data):
if not isinstance(val, xr.Dataset):
Expand All @@ -160,10 +189,10 @@ def process(self, data):
self.data_arrays[path] = val
if self.merge_arrays:
merged_array_name = Path(self.merged_array_name).with_suffix(self.file_suffix)
update_nested_dict(data, keys, f"./{merged_array_name}#{path}")
update_nested_dict(processed_data, keys, f"./{merged_array_name}#{path}")
else:
update_nested_dict(data, keys, f"./{path}{self.file_suffix}")
return data
update_nested_dict(processed_data, keys, f"./{path}{self.file_suffix}")
return processed_data

def save_merged_netcdf_arrays(self, path: Path, arrays: dict):
for array_path, array in self.data_arrays.items():
Expand Down
9 changes: 9 additions & 0 deletions tests/data_handler/test_matplotlib_plot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ def test_matplotlib_nested_save(tmp_path, fig):

file_data = json.loads((tmp_path / "data.json").read_text())
assert file_data == {"q0": {"fig": "./q0.fig.png", "value": 42}}


def test_matplotlib_save_does_not_affect_data(fig):
matplotlib_plot_saver = MatplotlibPlotSaver()
data = {"a": 1, "b": 2, "c": fig}
processed_data = matplotlib_plot_saver.process(data)

assert data == {"a": 1, "b": 2, "c": fig}
assert processed_data == {"a": 1, "b": 2, "c": "./c.png"}
53 changes: 50 additions & 3 deletions tests/data_handler/test_numpy_array_saver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import copy
import numpy as np

from qualang_tools.results.data_handler.data_processors import DEFAULT_DATA_PROCESSORS, NumpyArraySaver


def dicts_equal(d1, d2):
if d1.keys() != d2.keys():
return False
for key in d1:
if isinstance(d1[key], dict):
if not dicts_equal(d1[key], d2[key]):
return False
elif isinstance(d1[key], np.ndarray):
if not np.array_equal(d1[key], d2[key]):
return False
else:
if d1[key] != d2[key]:
return False
return True


def test_numpy_array_saver_process_merged():
data = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6]), "c": 3}

data_processor = NumpyArraySaver()

processed_data = data.copy()
processed_data = data_processor.process(processed_data)
processed_data = data_processor.process(data)

assert processed_data == {
"a": "./arrays.npz#a",
Expand Down Expand Up @@ -51,11 +67,42 @@ def test_numpy_array_saver_post_process_separate(tmp_path):
data = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6]), "c": 3}

data_processor = NumpyArraySaver(merge_arrays=False)
data_processor.process(data.copy())
data_processor.process(data)

data_processor.post_process(data_folder=tmp_path)

assert (tmp_path / "a.npy").exists()
assert (tmp_path / "b.npy").exists()
assert np.array_equal(np.load(tmp_path / "a.npy"), data["a"])
assert np.array_equal(np.load(tmp_path / "b.npy"), data["b"])


def test_numpy_array_saver_nested_no_merge(tmp_path):
data = {"q0": {"a": np.array([1, 2, 3]), "b": 3}, "c": np.array([4, 5, 6])}

data_processor = NumpyArraySaver(merge_arrays=False)
processed_data = data_processor.process(data)
assert processed_data == {
"q0": {"a": "./q0.a.npy", "b": 3},
"c": "./c.npy",
}

data_processor.post_process(data_folder=tmp_path)

assert (tmp_path / "q0.a.npy").exists()
assert not (tmp_path / "q0.b.npy").exists()
assert (tmp_path / "c.npy").exists()

assert np.array_equal(np.load(tmp_path / "q0.a.npy"), data["q0"]["a"])
assert np.array_equal(np.load(tmp_path / "c.npy"), data["c"])


def test_numpy_array_saver_process_does_not_affect_data():
data = {"q0": {"a": np.array([1, 2, 3]), "b": 3}, "c": np.array([4, 5, 6])}
deepcopied_data = copy.deepcopy(data)

data_processor = NumpyArraySaver(merge_arrays=False)
processed_data = data_processor.process(data)

assert dicts_equal(data, deepcopied_data)
assert not dicts_equal(processed_data, data)
83 changes: 79 additions & 4 deletions tests/data_handler/test_xarray_saver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import pytest
import sys
import numpy as np
import xarray as xr
from qualang_tools.results.data_handler.data_processors import XarraySaver


Expand All @@ -11,6 +13,27 @@ def module_installed(module_name):
return True


def dicts_equal(d1, d2):
if d1.keys() != d2.keys():
return False
for key in d1:
if key not in d2:
return False
elif isinstance(d1[key], dict):
if not dicts_equal(d1[key], d2[key]):
return False
elif isinstance(d1[key], np.ndarray):
if not np.array_equal(d1[key], d2[key]):
return False
elif isinstance(d1[key], xr.Dataset):
if not d1[key].identical(d2[key]):
return False
else:
if not bool(d1[key] == d2[key]):
return False
return True


@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed")
def test_xarray_saver_no_xarrays():
xarray_saver = XarraySaver()
Expand Down Expand Up @@ -44,7 +67,7 @@ def test_xarray_saver_merge_netcdf(tmp_path):
data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()}

xarray_saver = XarraySaver(merge_arrays=True, file_format="nc")
processed_data = xarray_saver.process(data.copy())
processed_data = xarray_saver.process(data)

assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.nc#c", "d": "./xarrays.nc#d"}

Expand All @@ -63,7 +86,7 @@ def test_xarray_saver_merge_hdf5(tmp_path):
data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()}

xarray_saver = XarraySaver(merge_arrays=True, file_format="h5")
processed_data = xarray_saver.process(data.copy())
processed_data = xarray_saver.process(data)

assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.h5#c", "d": "./xarrays.h5#d"}

Expand All @@ -82,7 +105,7 @@ def test_xarray_saver_no_merge_netcdf(tmp_path):
data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": xr.Dataset()}

xarray_saver = XarraySaver(merge_arrays=False)
processed_data = xarray_saver.process(data.copy())
processed_data = xarray_saver.process(data)

assert processed_data == {"a": 1, "b": 2, "c": "./c.h5", "d": "./d.h5"}

Expand All @@ -93,3 +116,55 @@ def test_xarray_saver_no_merge_netcdf(tmp_path):

xr.load_dataset(tmp_path / "c.h5")
xr.load_dataset(tmp_path / "d.h5")


@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed")
def test_xarray_saver_no_merge_hdf5_nested(tmp_path):
import xarray as xr

data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}}

xarray_saver = XarraySaver(merge_arrays=False, file_format="nc")
processed_data = xarray_saver.process(data)

assert processed_data == {"a": 1, "b": 2, "c": "./c.nc", "d": {"d1": "./d.d1.nc"}}

xarray_saver.post_process(data_folder=tmp_path)

assert (tmp_path / "c.nc").exists()
assert (tmp_path / "d.d1.nc").exists()

xr.load_dataset(tmp_path / "c.nc")
xr.load_dataset(tmp_path / "d.d1.nc")


@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed")
def test_xarray_saver_merge_hdf5_nested(tmp_path):
import xarray as xr

data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}}

xarray_saver = XarraySaver(merge_arrays=True, file_format="nc")
processed_data = xarray_saver.process(data)

assert processed_data == {"a": 1, "b": 2, "c": "./xarrays.nc#c", "d": {"d1": "./xarrays.nc#d.d1"}}
xarray_saver.post_process(data_folder=tmp_path)

assert (tmp_path / "xarrays.nc").exists()

xr.load_dataset(tmp_path / "xarrays.nc", group="c")
xr.load_dataset(tmp_path / "xarrays.nc", group="d.d1")


@pytest.mark.skipif(not module_installed("xarray"), reason="xarray not installed")
def test_xarray_saver_does_not_affect_data():
import xarray as xr

data = {"a": 1, "b": 2, "c": xr.Dataset(), "d": {"d1": xr.Dataset()}}
deepcopied_data = copy.deepcopy(data)

xarray_saver = XarraySaver(merge_arrays=False, file_format="nc")
processed_data = xarray_saver.process(data)

assert dicts_equal(data, deepcopied_data)
assert not dicts_equal(processed_data, data)

0 comments on commit 4eff283

Please sign in to comment.