From 2f849ab7bd1dbc83797bf803e2e4373c1790e92c Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Wed, 18 Dec 2024 11:43:31 -0800 Subject: [PATCH] slice_dataset now casts --sel "start/stop" flags to string (to agree with API documentation) and "step" to int (because step must be an int in xarray). PiperOrigin-RevId: 707620263 --- scripts/slice_dataset.py | 36 ++++++++++++++++++++++------------- scripts/slice_dataset_test.py | 27 ++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/scripts/slice_dataset.py b/scripts/slice_dataset.py index 88d71f9..3f848d1 100644 --- a/scripts/slice_dataset.py +++ b/scripts/slice_dataset.py @@ -60,8 +60,9 @@ help=( 'Selection criteria, to pass to xarray.Dataset.sel. Passed as' ' key=value pairs, with key = VARNAME_{start,stop,step,list}. ' - 'If key ends with start, stop, or step, the value should be strings ' - '(defaulting to None). If key ends with "list", the value should be ' + 'If key ends with start, stop, or step, the values are used in a slice ' + 'as slice(str(start), str(stop), int(step)). start/stop/step default to' + ' None. If key ends with "list", the value should be ' 'a list of "+" delimited ints/floats/strings.' ), ) @@ -84,8 +85,9 @@ help=( 'Selection criteria, to pass to xarray.Dataset.drop_sel. Passed as' ' key=value pairs, with key = VARNAME_{start,stop,step,list}. ' - 'If key ends with start, stop, or step, the value should be strings ' - '(defaulting to None). If key ends with "list", the value should be ' + 'If key ends with start, stop, or step, the values are used in a slice ' + 'as slice(str(start), str(stop), int(step)). start/stop/step default to' + ' None. If key ends with "list", the value should be ' 'a list of "+" delimited ints/floats/strings.' ), ) @@ -138,9 +140,15 @@ def _get_selections( flag_values: dict[str, flag_utils.DimValueType], + is_sel_or_dropsel: bool, ) -> list[dict[str, t.Union[str, int, list[int], slice]]]: """Gets parts used to select based on flags.""" + def maybe_tostr(v): + if is_sel_or_dropsel: + return str(v) + return v + list_selectors = {} value_selectors = {} for k, v in flag_values.items(): @@ -157,18 +165,20 @@ def _get_selections( if '++' in v: raise ValueError(f'Found ambiguous "++" in {dim=} flag value {v}') list_selectors[dim] = [ - flag_utils.get_dim_value(v_i) for v_i in v.split('+') + maybe_tostr(flag_utils.get_dim_value(v_i)) for v_i in v.split('+') ] else: # Else handle non-list types v = flag_utils.get_dim_value(v) if dim not in value_selectors: value_selectors[dim] = [None, None, None] if placement == 'start': - value_selectors[dim][0] = v + value_selectors[dim][0] = maybe_tostr(v) elif placement == 'stop': - value_selectors[dim][1] = v - else: - value_selectors[dim][2] = v + value_selectors[dim][1] = maybe_tostr(v) + else: # Else 'step' + # In Xarray, step must be an int. + # https://github.com/pydata/xarray/issues/5228 + value_selectors[dim][2] = int(v) selections = [] for dim, selector in list_selectors.items(): @@ -191,13 +201,13 @@ def main(argv: abc.Sequence[str]) -> None: ds = ds[KEEP_VARIABLES.value] input_chunks = {k: v for k, v in input_chunks.items() if k in ds.dims} - for selection in _get_selections(ISEL.value): + for selection in _get_selections(ISEL.value, is_sel_or_dropsel=False): ds = ds.isel(selection) - for selection in _get_selections(SEL.value): + for selection in _get_selections(SEL.value, is_sel_or_dropsel=True): ds = ds.sel(selection) - for selection in _get_selections(DROP_ISEL.value): + for selection in _get_selections(DROP_ISEL.value, is_sel_or_dropsel=False): ds = ds.drop_isel(selection) - for selection in _get_selections(DROP_SEL.value): + for selection in _get_selections(DROP_SEL.value, is_sel_or_dropsel=True): ds = ds.drop_sel(selection) template = xbeam.make_template(ds) diff --git a/scripts/slice_dataset_test.py b/scripts/slice_dataset_test.py index f30adf2..b2e8fc7 100644 --- a/scripts/slice_dataset_test.py +++ b/scripts/slice_dataset_test.py @@ -30,20 +30,39 @@ def test_valid_selections(self): flag_values={ 'A_start': '1 day', 'A_stop': '10 days', - 'A_step': '2 days', + 'A_step': 2, 'B_stop': 2.2, 'C_step': 3, 'D_list': 'planes+trains+automobiles', }, + is_sel_or_dropsel=False, ) expected_sel = [ - {'A': slice('1 day', '10 days', '2 days')}, + {'A': slice('1 day', '10 days', 2)}, {'B': slice(None, 2.2, None)}, {'C': slice(None, None, 3)}, {'D': ['planes', 'trains', 'automobiles']}, ] self.assertCountEqual(expected_sel, sel) + def test_valid_selections_is_sel_or_dropsel(self): + sel = slice_dataset._get_selections( + flag_values={ + 'A_start': '1 day', + 'A_stop': '10 days', + 'A_step': 2, + 'B_stop': 2020, # As in the year 2020 for a date + 'D_list': 'planes+trains+automobiles', + }, + is_sel_or_dropsel=True, + ) + expected_sel = [ + {'A': slice('1 day', '10 days', 2)}, + {'B': slice(None, '2020', None)}, + {'D': ['planes', 'trains', 'automobiles']}, + ] + self.assertCountEqual(expected_sel, sel) + def test_valid_index_selections(self): isel = slice_dataset._get_selections( flag_values={ @@ -56,6 +75,7 @@ def test_valid_index_selections(self): 'Z_start': 1, 'W_step': 2, }, + is_sel_or_dropsel=False, ) expected_isel = [ {'A': [9, -1, 0]}, @@ -76,6 +96,7 @@ def test_invalid_placement_raises(self): 'X_stop': 10, 'X_bad': 2, }, + is_sel_or_dropsel=False, ) with self.subTest('Not ending in (start|stop|step|list) raises 2'): @@ -86,6 +107,7 @@ def test_invalid_placement_raises(self): 'X_stop': 10, 'X_step_and_more': 2, }, + is_sel_or_dropsel=False, ) with self.subTest('Not ending in (start|stop|step|list) raises 2'): @@ -96,6 +118,7 @@ def test_invalid_placement_raises(self): 'X_stop': 10, 'X_step_': 2, }, + is_sel_or_dropsel=False, )