Skip to content

Commit

Permalink
Fix the push method when the limit parameter is bigger than the chunk…
Browse files Browse the repository at this point in the history
…size
  • Loading branch information
josephnowak committed Jan 11, 2025
1 parent 5279bd1 commit 82274c8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 52 deletions.
49 changes: 13 additions & 36 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,6 @@ def _dtype_push(a, axis, dtype=None):
return _push(a, axis=axis)


def _reset_cumsum(a, axis, dtype=None):
import numpy as np

cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points


def _last_reset_cumsum(a, axis, keepdims=None):
import numpy as np

# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])


def _combine_reset_cumsum(a, b, axis):
import numpy as np

# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)


def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
Expand Down Expand Up @@ -145,16 +120,18 @@ def push(array, n, axis, method="blelloch"):
)

if n is not None and 0 < n < array.shape[axis] - 1:
valid_positions = da.reductions.cumreduction(
func=_reset_cumsum,
binop=partial(_combine_reset_cumsum, axis=axis),
ident=0,
x=da.isnan(array, dtype=int),
axis=axis,
dtype=int,
method=method,
preop=_last_reset_cumsum,
)
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)
# The idea is to calculate a cumulative sum of a bitmask
# created from the isnan method, but every time a False is found the sum
# must be restarted, and the final result indicates the amount of contiguous
# nan values found in the original array on every position
nan_bitmask = da.isnan(array, dtype=int)
cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method)
valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan)
valid_positions = push(valid_positions, None, axis, method=method)
# All the NaNs at the beginning are converted to 0
valid_positions = da.nan_to_num(valid_positions)
valid_positions = cumsum_nan - valid_positions
valid_positions = valid_positions <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)

return pushed_array
32 changes: 16 additions & 16 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,31 +1009,31 @@ def test_least_squares(use_dask, skipna):
@requires_dask
@requires_bottleneck
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
def test_push_dask(method):
@pytest.mark.parametrize(
"arr", [
[np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6],
[
np.nan, np.nan, np.nan, 2, np.nan, np.nan, np.nan, 9, np.nan,
np.nan, np.nan, np.nan
]
]
)
def test_push_dask(method, arr):
import bottleneck
import dask.array
import dask.array as da

array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
arr = np.array(arr)
chunks = list(range(1, 11)) + [(1, 2, 3, 2, 2, 1, 1)]

for n in [None, 1, 2, 3, 4, 5, 11]:
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
expected = bottleneck.push(arr, axis=0, n=n)
for c in chunks:
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
da.from_array(arr, chunks=c), axis=0, n=n, method=method
)
np.testing.assert_equal(actual, expected)

# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
axis=0,
n=n,
method=method,
)
np.testing.assert_equal(actual, expected)


def test_extension_array_equality(categorical1, int1):
int_duck_array = PandasExtensionArray(int1)
Expand Down

0 comments on commit 82274c8

Please sign in to comment.