From 82274c8c41ced53fef8d2c1b0a9cc64452393b8a Mon Sep 17 00:00:00 2001 From: Joseph Gonzalez Date: Sat, 11 Jan 2025 14:30:25 -0400 Subject: [PATCH] Fix the push method when the limit parameter is bigger than the chunksize --- xarray/core/dask_array_ops.py | 49 ++++++++--------------------- xarray/tests/test_duck_array_ops.py | 32 +++++++++---------- 2 files changed, 29 insertions(+), 52 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 2dca38538e1..83ac787b6d1 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -92,31 +92,6 @@ def _dtype_push(a, axis, dtype=None): return _push(a, axis=axis) -def _reset_cumsum(a, axis, dtype=None): - import numpy as np - - cumsum = np.cumsum(a, axis=axis) - reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) - return cumsum - reset_points - - -def _last_reset_cumsum(a, axis, keepdims=None): - import numpy as np - - # Take the last cumulative sum taking into account the reset - # This is useful for blelloch method - return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) - - -def _combine_reset_cumsum(a, b, axis): - import numpy as np - - # It is going to sum the previous result until the first - # non nan value - bitmask = np.cumprod(b != 0, axis=axis) - return np.where(bitmask, b + a, b) - - def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push @@ -145,16 +120,18 @@ def push(array, n, axis, method="blelloch"): ) if n is not None and 0 < n < array.shape[axis] - 1: - valid_positions = da.reductions.cumreduction( - func=_reset_cumsum, - binop=partial(_combine_reset_cumsum, axis=axis), - ident=0, - x=da.isnan(array, dtype=int), - axis=axis, - dtype=int, - method=method, - preop=_last_reset_cumsum, - ) - pushed_array = da.where(valid_positions <= n, pushed_array, np.nan) + # The idea is to calculate a cumulative sum of a bitmask + # created from the isnan method, but every time a False is found the sum + # must be restarted, and the final result indicates the amount of contiguous + # nan values found in the original array on every position + nan_bitmask = da.isnan(array, dtype=int) + cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method) + valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan) + valid_positions = push(valid_positions, None, axis, method=method) + # All the NaNs at the beginning are converted to 0 + valid_positions = da.nan_to_num(valid_positions) + valid_positions = cumsum_nan - valid_positions + valid_positions = valid_positions <= n + pushed_array = da.where(valid_positions, pushed_array, np.nan) return pushed_array diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index e1306964757..a3c05cd3db0 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1009,31 +1009,31 @@ def test_least_squares(use_dask, skipna): @requires_dask @requires_bottleneck @pytest.mark.parametrize("method", ["sequential", "blelloch"]) -def test_push_dask(method): +@pytest.mark.parametrize( + "arr", [ + [np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6], + [ + np.nan, np.nan, np.nan, 2, np.nan, np.nan, np.nan, 9, np.nan, + np.nan, np.nan, np.nan + ] + ] +) +def test_push_dask(method, arr): import bottleneck - import dask.array + import dask.array as da - array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6]) + arr = np.array(arr) + chunks = list(range(1, 11)) + [(1, 2, 3, 2, 2, 1, 1)] for n in [None, 1, 2, 3, 4, 5, 11]: - expected = bottleneck.push(array, axis=0, n=n) - for c in range(1, 11): + expected = bottleneck.push(arr, axis=0, n=n) + for c in chunks: with raise_if_dask_computes(): actual = push( - dask.array.from_array(array, chunks=c), axis=0, n=n, method=method + da.from_array(arr, chunks=c), axis=0, n=n, method=method ) np.testing.assert_equal(actual, expected) - # some chunks of size-1 with NaN - with raise_if_dask_computes(): - actual = push( - dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), - axis=0, - n=n, - method=method, - ) - np.testing.assert_equal(actual, expected) - def test_extension_array_equality(categorical1, int1): int_duck_array = PandasExtensionArray(int1)