Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slice_dataset now casts --sel "start/stop" flags to string (to agree with API documentation) and "step" to int (because step must be an int in xarray). #220

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions scripts/slice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@
help=(
'Selection criteria, to pass to xarray.Dataset.sel. Passed as'
' key=value pairs, with key = VARNAME_{start,stop,step,list}. '
'If key ends with start, stop, or step, the value should be strings '
'(defaulting to None). If key ends with "list", the value should be '
'If key ends with start, stop, or step, the values are used in a slice '
'as slice(str(start), str(stop), int(step)). start/stop/step default to'
' None. If key ends with "list", the value should be '
'a list of "+" delimited ints/floats/strings.'
),
)
Expand All @@ -84,8 +85,9 @@
help=(
'Selection criteria, to pass to xarray.Dataset.drop_sel. Passed as'
' key=value pairs, with key = VARNAME_{start,stop,step,list}. '
'If key ends with start, stop, or step, the value should be strings '
'(defaulting to None). If key ends with "list", the value should be '
'If key ends with start, stop, or step, the values are used in a slice '
'as slice(str(start), str(stop), int(step)). start/stop/step default to'
' None. If key ends with "list", the value should be '
'a list of "+" delimited ints/floats/strings.'
),
)
Expand Down Expand Up @@ -138,9 +140,15 @@

def _get_selections(
flag_values: dict[str, flag_utils.DimValueType],
is_sel_or_dropsel: bool,
) -> list[dict[str, t.Union[str, int, list[int], slice]]]:
"""Gets parts used to select based on flags."""

def maybe_tostr(v):
if is_sel_or_dropsel:
return str(v)
return v

list_selectors = {}
value_selectors = {}
for k, v in flag_values.items():
Expand All @@ -157,18 +165,20 @@ def _get_selections(
if '++' in v:
raise ValueError(f'Found ambiguous "++" in {dim=} flag value {v}')
list_selectors[dim] = [
flag_utils.get_dim_value(v_i) for v_i in v.split('+')
maybe_tostr(flag_utils.get_dim_value(v_i)) for v_i in v.split('+')
]
else: # Else handle non-list types
v = flag_utils.get_dim_value(v)
if dim not in value_selectors:
value_selectors[dim] = [None, None, None]
if placement == 'start':
value_selectors[dim][0] = v
value_selectors[dim][0] = maybe_tostr(v)
elif placement == 'stop':
value_selectors[dim][1] = v
else:
value_selectors[dim][2] = v
value_selectors[dim][1] = maybe_tostr(v)
else: # Else 'step'
# In Xarray, step must be an int.
# https://github.com/pydata/xarray/issues/5228
value_selectors[dim][2] = int(v)

selections = []
for dim, selector in list_selectors.items():
Expand All @@ -191,13 +201,13 @@ def main(argv: abc.Sequence[str]) -> None:
ds = ds[KEEP_VARIABLES.value]
input_chunks = {k: v for k, v in input_chunks.items() if k in ds.dims}

for selection in _get_selections(ISEL.value):
for selection in _get_selections(ISEL.value, is_sel_or_dropsel=False):
ds = ds.isel(selection)
for selection in _get_selections(SEL.value):
for selection in _get_selections(SEL.value, is_sel_or_dropsel=True):
ds = ds.sel(selection)
for selection in _get_selections(DROP_ISEL.value):
for selection in _get_selections(DROP_ISEL.value, is_sel_or_dropsel=False):
ds = ds.drop_isel(selection)
for selection in _get_selections(DROP_SEL.value):
for selection in _get_selections(DROP_SEL.value, is_sel_or_dropsel=True):
ds = ds.drop_sel(selection)

template = xbeam.make_template(ds)
Expand Down
27 changes: 25 additions & 2 deletions scripts/slice_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,39 @@ def test_valid_selections(self):
flag_values={
'A_start': '1 day',
'A_stop': '10 days',
'A_step': '2 days',
'A_step': 2,
'B_stop': 2.2,
'C_step': 3,
'D_list': 'planes+trains+automobiles',
},
is_sel_or_dropsel=False,
)
expected_sel = [
{'A': slice('1 day', '10 days', '2 days')},
{'A': slice('1 day', '10 days', 2)},
{'B': slice(None, 2.2, None)},
{'C': slice(None, None, 3)},
{'D': ['planes', 'trains', 'automobiles']},
]
self.assertCountEqual(expected_sel, sel)

def test_valid_selections_is_sel_or_dropsel(self):
sel = slice_dataset._get_selections(
flag_values={
'A_start': '1 day',
'A_stop': '10 days',
'A_step': 2,
'B_stop': 2020, # As in the year 2020 for a date
'D_list': 'planes+trains+automobiles',
},
is_sel_or_dropsel=True,
)
expected_sel = [
{'A': slice('1 day', '10 days', 2)},
{'B': slice(None, '2020', None)},
{'D': ['planes', 'trains', 'automobiles']},
]
self.assertCountEqual(expected_sel, sel)

def test_valid_index_selections(self):
isel = slice_dataset._get_selections(
flag_values={
Expand All @@ -56,6 +75,7 @@ def test_valid_index_selections(self):
'Z_start': 1,
'W_step': 2,
},
is_sel_or_dropsel=False,
)
expected_isel = [
{'A': [9, -1, 0]},
Expand All @@ -76,6 +96,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_bad': 2,
},
is_sel_or_dropsel=False,
)

with self.subTest('Not ending in (start|stop|step|list) raises 2'):
Expand All @@ -86,6 +107,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_step_and_more': 2,
},
is_sel_or_dropsel=False,
)

with self.subTest('Not ending in (start|stop|step|list) raises 2'):
Expand All @@ -96,6 +118,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_step_': 2,
},
is_sel_or_dropsel=False,
)


Expand Down
Loading