Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lithops to parallelize open_mfdataset #9932

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements/all-but-numba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- hypothesis
- iris
- lxml # Optional dep of pydap
- lithops
- matplotlib-base
- nc-time-axis
- netcdf4
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment-3.13.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- hdf5
- hypothesis
- iris
- lithops
- lxml # Optional dep of pydap
- matplotlib-base
- nc-time-axis
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment-windows-3.13.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- hypothesis
- iris
- lxml # Optional dep of pydap
- lithops
- matplotlib-base
- nc-time-axis
- netcdf4
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- hypothesis
- iris
- lxml # Optional dep of pydap
- lithops
- matplotlib-base
- nc-time-axis
- netcdf4
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- hdf5
- hypothesis
- iris
- lithops
- lxml # Optional dep of pydap
- matplotlib-base
- nc-time-axis
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies:
- hypothesis
- iris=3.7
- lxml=4.9 # Optional dep of pydap
- lithops=3.5.1
- matplotlib-base=3.8
- nc-time-axis=1.4
# netcdf follows a 1.major.minor[.patch] convention
Expand Down
68 changes: 57 additions & 11 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,9 @@ def open_groups(
return groups


import warnings


def open_mfdataset(
paths: str
| os.PathLike
Expand Down Expand Up @@ -1480,9 +1483,10 @@ def open_mfdataset(
those corresponding to other dimensions.
* list of str: The listed coordinate variables will be concatenated,
in addition the "minimal" coordinates.
parallel : bool, default: False
If True, the open and preprocess steps of this function will be
performed in parallel using ``dask.delayed``. Default is False.
parallel : 'dask', 'lithops', or False
Specify whether the open and preprocess steps of this function will be
performed in parallel using ``dask.delayed``, in parallel using ``lithops.map``, or in serial.
Default is False. Passing True is now a deprecated alias for passing 'dask'.
join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer"
String indicating how to combine differing indexes
(excluding concat_dim) in objects
Expand Down Expand Up @@ -1596,27 +1600,67 @@ def open_mfdataset(

open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs)

if parallel:
if parallel is True:
warnings.warn(
"Passing ``parallel=True`` is deprecated, instead please pass ``parallel='dask'`` explicitly",
PendingDeprecationWarning,
stacklevel=2,
)
parallel = "dask"

if parallel == "dask":
import dask

# wrap the open_dataset, getattr, and preprocess with delayed
open_ = dask.delayed(open_dataset)
getattr_ = dask.delayed(getattr)
if preprocess is not None:
preprocess = dask.delayed(preprocess)
else:
elif parallel == "lithops":
import lithops

# TODO use RetryingFunctionExecutor instead?
fn_exec = lithops.FunctionExecutor()

# lithops doesn't have a delayed primitive
open_ = open_dataset
# TODO I don't know how best to chain this with the getattr
# getattr_ = getattr
elif parallel is False:
open_ = open_dataset
getattr_ = getattr
else:
raise ValueError(
f"{parallel} is an invalid option for the keyword argument ``parallel``"
)

datasets = [open_(p, **open_kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]
if parallel == "dask":
datasets = [open_(p, **open_kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

if parallel:
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, closers = dask.compute(datasets, closers)
elif parallel == "lithops":

def generate_lazy_ds(path):
# allows passing the open_dataset function to lithops without evaluating it
ds = open_(path, **kwargs)
return ds
Comment on lines +1648 to +1651
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks potentially like a functools.partial with **kwargs?


futures = fn_exec.map(generate_lazy_ds, paths1d)

# wait for all the serverless workers to finish, and send their resulting lazy datasets back to the client
# TODO do we need download_results?
completed_futures, _ = fn_exec.wait(futures, download_results=True)
datasets = completed_futures.get_result()
Comment on lines +1653 to +1658
Copy link
Collaborator

@keewis keewis Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can find an abstraction that works for both this (which is kinda like concurrent.futures' pool executors) and dask.

For example, maybe we can use functools.partial to mimic dask.delayed. The result would be a bunch of function objects without parameters, which would then be evaluated in fn_exec.map using operator.call.

(But I guess if we refactor the dask code as well we don't really need that idea)

elif parallel is False:
virtual_datasets = [open_(p, **kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in virtual_datasets]
if preprocess is not None:
virtual_datasets = [preprocess(ds) for ds in virtual_datasets]

# Combine all datasets, closing them in case of a ValueError
try:
Expand Down Expand Up @@ -1654,7 +1698,9 @@ def open_mfdataset(
ds.close()
raise

combined.set_close(partial(_multi_file_closer, closers))
# TODO remove if once closers added above
if parallel != "lithops":
combined.set_close(partial(_multi_file_closer, closers))

# read global attributes from the attrs_file or from the first dataset
if attrs_file is not None:
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _importorskip(
category=DeprecationWarning,
)
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
has_lithops, requires_lithops = _importorskip("lithops")
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
Expand Down
54 changes: 54 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
requires_h5netcdf_1_4_0_or_above,
requires_h5netcdf_ros3,
requires_iris,
requires_lithops,
requires_netcdf,
requires_netCDF4,
requires_netCDF4_1_6_2_or_above,
Expand Down Expand Up @@ -4410,6 +4411,59 @@
assert_identical(original, actual)


@requires_netCDF4
class TestParallel:
def test_validate_parallel_kwarg(self) -> None:
original = Dataset({"foo": ("x", np.random.randn(10))})
datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
with create_tmp_file() as tmp1:
with create_tmp_file() as tmp2:
save_mfdataset(datasets, [tmp1, tmp2])

with pytest.raises(ValueError, match="garbage is an invalid option"):
open_mfdataset(
[tmp1, tmp2],
concat_dim="x",
combine="nested",
parallel="garbage",
)

def test_deprecation_warning(self) -> None:
original = Dataset({"foo": ("x", np.random.randn(10))})
datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
with create_tmp_file() as tmp1:
with create_tmp_file() as tmp2:
save_mfdataset(datasets, [tmp1, tmp2])

with pytest.warns(
PendingDeprecationWarning,
match="please pass ``parallel='dask'`` explicitly",
):
open_mfdataset(

Check failure on line 4442 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 all-but-dask

TestParallel.test_deprecation_warning ModuleNotFoundError: No module named 'dask'
[tmp1, tmp2],
concat_dim="x",
combine="nested",
parallel=True,
)

@requires_lithops
def test_lithops_parallel(self) -> None:
# default configuration of lithops will use local executor

original = Dataset({"foo": ("x", np.random.randn(10))})
datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
with create_tmp_file() as tmp1:
with create_tmp_file() as tmp2:
save_mfdataset(datasets, [tmp1, tmp2])
with open_mfdataset(
[tmp1, tmp2],
concat_dim="x",
combine="nested",
parallel="lithops",
) as actual:
assert_identical(actual, original)


@requires_netCDF4
@requires_dask
def test_open_mfdataset_can_open_path_objects() -> None:
Expand Down
Loading