diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a03652de2496..43d4a6c20e94 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -33,6 +33,7 @@ from typing import Callable from typing import Iterable from typing import List +from typing import Optional from typing import Tuple from typing import TypeVar from typing import Union @@ -78,6 +79,7 @@ from apache_beam.utils import windowed_value from apache_beam.utils.annotations import deprecated from apache_beam.utils.sharded_key import ShardedKey +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext @@ -953,6 +955,10 @@ def restore_timestamps(element): window.GlobalWindows.windowed_value((key, value), timestamp) for (value, timestamp) in values ] + + ungrouped = pcoll | Map(reify_timestamps).with_input_types( + Tuple[K, V]).with_output_types( + Tuple[K, Tuple[V, Optional[Timestamp]]]) else: # typing: All conditional function variants must have identical signatures @@ -966,7 +972,9 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) + # TODO(https://github.com/apache/beam/issues/33356): Support reshuffling + # unpicklable objects with a non-global window setting. + ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) # TODO(https://github.com/apache/beam/issues/19785) Using global window as # one of the standard window. This is to mitigate the Dataflow Java Runner @@ -1018,7 +1026,8 @@ def expand(self, pcoll): pcoll | 'AddRandomKeys' >> Map(lambda t: (random.randrange(0, self.num_buckets), t) ).with_input_types(T).with_output_types(Tuple[int, T]) - | ReshufflePerKey() + | ReshufflePerKey().with_input_types(Tuple[int, T]).with_output_types( + Tuple[int, T]) | 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types( Tuple[int, T]).with_output_types(T)) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index d86509c7dde3..7f166f78ef0a 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1010,6 +1010,45 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): equal_to(expected_data), label="formatted_after_reshuffle") + def test_reshuffle_unpicklable_in_global_window(self): + global _Unpicklable + + class _Unpicklable(object): + def __init__(self, value): + self.value = value + + def __getstate__(self): + raise NotImplementedError() + + def __setstate__(self, state): + raise NotImplementedError() + + class _UnpicklableCoder(beam.coders.Coder): + def encode(self, value): + return str(value.value).encode() + + def decode(self, encoded): + return _Unpicklable(int(encoded.decode())) + + def to_type_hint(self): + return _Unpicklable + + def is_deterministic(self): + return True + + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + + with TestPipeline() as pipeline: + data = [_Unpicklable(i) for i in range(5)] + expected_data = [0, 10, 20, 30, 40] + result = ( + pipeline + | beam.Create(data) + | beam.WindowInto(GlobalWindows()) + | beam.Reshuffle() + | beam.Map(lambda u: u.value * 10)) + assert_that(result, equal_to(expected_data)) + class WithKeysTest(unittest.TestCase): def setUp(self):