diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index a0c55da81800..724f268a8312 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -1402,6 +1402,14 @@ def __hash__(self): return hash( (self.wrapped_value_coder, self.timestamp_coder, self.window_coder)) + @classmethod + def from_type_hint(cls, typehint, registry): + # type: (Any, CoderRegistry) -> WindowedValueCoder + # Ideally this'd take two parameters so that one could hint at + # the window type as well instead of falling back to the + # pickle coders. + return cls(registry.get_coder(typehint.inner_type)) + Coder.register_structured_urn( common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder) diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 1667cb7a916a..892f508d0136 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder): self._register_coder_internal(str, coders.StrUtf8Coder) self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder) self._register_coder_internal(typehints.DictConstraint, coders.MapCoder) + self._register_coder_internal( + typehints.WindowedTypeConstraint, coders.WindowedValueCoder) # Default fallback coders applied in that order until the first matching # coder found. default_fallback_coders = [ diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 43d4a6c20e94..c9fd2c76b0db 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -74,6 +74,7 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature +from apache_beam.typehints.native_type_compatibility import TypedWindowedValue from apache_beam.typehints.sharded_key_type import ShardedKeyType from apache_beam.utils import shared from apache_beam.utils import windowed_value @@ -972,9 +973,8 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - # TODO(https://github.com/apache/beam/issues/33356): Support reshuffling - # unpicklable objects with a non-global window setting. - ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) + ungrouped = pcoll | Map(reify_timestamps).with_input_types( + Tuple[K, V]).with_output_types(Tuple[K, TypedWindowedValue[V]]) # TODO(https://github.com/apache/beam/issues/19785) Using global window as # one of the standard window. This is to mitigate the Dataflow Java Runner diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7f166f78ef0a..db73310dfe25 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1010,32 +1010,33 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): equal_to(expected_data), label="formatted_after_reshuffle") - def test_reshuffle_unpicklable_in_global_window(self): - global _Unpicklable + global _Unpicklable + global _UnpicklableCoder - class _Unpicklable(object): - def __init__(self, value): - self.value = value + class _Unpicklable(object): + def __init__(self, value): + self.value = value - def __getstate__(self): - raise NotImplementedError() + def __getstate__(self): + raise NotImplementedError() - def __setstate__(self, state): - raise NotImplementedError() + def __setstate__(self, state): + raise NotImplementedError() - class _UnpicklableCoder(beam.coders.Coder): - def encode(self, value): - return str(value.value).encode() + class _UnpicklableCoder(beam.coders.Coder): + def encode(self, value): + return str(value.value).encode() - def decode(self, encoded): - return _Unpicklable(int(encoded.decode())) + def decode(self, encoded): + return _Unpicklable(int(encoded.decode())) - def to_type_hint(self): - return _Unpicklable + def to_type_hint(self): + return _Unpicklable - def is_deterministic(self): - return True + def is_deterministic(self): + return True + def test_reshuffle_unpicklable_in_global_window(self): beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) with TestPipeline() as pipeline: @@ -1049,6 +1050,20 @@ def is_deterministic(self): | beam.Map(lambda u: u.value * 10)) assert_that(result, equal_to(expected_data)) + def test_reshuffle_unpicklable_in_non_global_window(self): + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + + with TestPipeline() as pipeline: + data = [_Unpicklable(i) for i in range(5)] + expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40] + result = ( + pipeline + | beam.Create(data) + | beam.WindowInto(window.SlidingWindows(size=3, period=1)) + | beam.Reshuffle() + | beam.Map(lambda u: u.value * 10)) + assert_that(result, equal_to(expected_data)) + class WithKeysTest(unittest.TestCase): def setUp(self): diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 6f704b37a969..381d4f7aae2b 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -24,9 +24,13 @@ import sys import types import typing +from typing import Generic +from typing import TypeVar from apache_beam.typehints import typehints +T = TypeVar('T') + _LOGGER = logging.getLogger(__name__) # Describes an entry in the type map in convert_to_beam_type. @@ -216,6 +220,18 @@ def convert_collections_to_typing(typ): return typ +# During type inference of WindowedValue, we need to pass in the inner value +# type. This cannot be achieved immediately with WindowedValue class because it +# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T]) +# could work in theory. However, the class is cythonized and it seems that +# cython does not handle generic classes well. +# The workaround here is to create a separate class solely for the type +# inference purpose. This class should never be used for creating instances. +class TypedWindowedValue(Generic[T]): + def __init__(self, *args, **kwargs): + raise NotImplementedError("This class is solely for type inference") + + def convert_to_beam_type(typ): """Convert a given typing type to a Beam type. @@ -267,6 +283,12 @@ def convert_to_beam_type(typ): # TODO(https://github.com/apache/beam/issues/20076): Currently unhandled. _LOGGER.info('Converting NewType type hint to Any: "%s"', typ) return typehints.Any + elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \ + getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue': + # Need to pass through WindowedValue class so that it can be converted + # to the correct type constraint in Beam + # This is needed to fix https://github.com/apache/beam/issues/33356 + pass elif (typ_module != 'typing') and (typ_module != 'collections.abc'): # Only translate types from the typing and collections.abc modules. return typ @@ -324,6 +346,10 @@ def convert_to_beam_type(typ): match=_match_is_exactly_collection, arity=1, beam_type=typehints.Collection), + _TypeMapEntry( + match=_match_issubclass(TypedWindowedValue), + arity=1, + beam_type=typehints.WindowedValue), ] # Find the first matching entry. diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 0e18e887c2a0..a65a0f753826 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1213,6 +1213,15 @@ def type_check(self, instance): repr(self.inner_type), instance.value.__class__.__name__)) + def bind_type_variables(self, bindings): + bound_inner_type = bind_type_variables(self.inner_type, bindings) + if bound_inner_type == self.inner_type: + return self + return WindowedValue[bound_inner_type] + + def __repr__(self): + return 'WindowedValue[%s]' % repr(self.inner_type) + class GeneratorHint(IteratorHint): """A Generator type hint.