diff --git a/docs/modules/ROOT/examples/merging/tutorial001.py b/docs/modules/ROOT/examples/merging/tutorial001.py index e7758bdc..8fc962d2 100644 --- a/docs/modules/ROOT/examples/merging/tutorial001.py +++ b/docs/modules/ROOT/examples/merging/tutorial001.py @@ -6,5 +6,5 @@ def app(event: StreamTimeEvent, api: Api, cache: Cache): # since we passed merge_events=True all 3 incoming events # and their records will be merged into a single event with 9 records - assert len(event.records) != 9 # this will not fail + assert len(event.records) == 9 # this will not fail return event diff --git a/src/corva/handlers.py b/src/corva/handlers.py index 04c2d4e2..37756e2b 100644 --- a/src/corva/handlers.py +++ b/src/corva/handlers.py @@ -19,7 +19,6 @@ import pydantic import redis -from typing_extensions import assert_never from corva.api import Api from corva.configuration import SETTINGS @@ -89,7 +88,13 @@ def wrapper(aws_event: Any, aws_context: Any) -> List[Any]: ) data_transformation_type = raw_custom_event_type or raw_event_type if merge_events: - aws_event = _merge_events(aws_event, data_transformation_type) + aws_event = _merge_events( + aws_event, + cast( + Union[Type[RawScheduledEvent], Type[RawStreamEvent]], + data_transformation_type, + ), + ) raw_events = data_transformation_type.from_raw_event(event=aws_event) if ( @@ -567,7 +572,10 @@ def _get_custom_event_type_by_raw_aws_event( return None, None -def _merge_events(aws_event: Any, data_transformation_type: Type[RawBaseEvent]) -> Any: +def _merge_events( + aws_event: Any, + data_transformation_type: Union[Type[RawScheduledEvent], Type[RawStreamEvent]], +) -> Any: """ Merges incoming aws_events into one. Merge happens differently, depending on app type. @@ -578,7 +586,10 @@ def _merge_events(aws_event: Any, data_transformation_type: Type[RawBaseEvent]) # scheduled event if not isinstance(aws_event[0], dict): aws_event = list(itertools.chain(*aws_event)) - is_depth = aws_event[0]["scheduler_type"] == SchedulerType.data_depth_milestone + scheduler_type = aws_event[0]["scheduler_type"] + if isinstance(scheduler_type, SchedulerType): + scheduler_type = scheduler_type.value + is_depth = scheduler_type == SchedulerType.data_depth_milestone.value event_start, event_end = ( ("top_depth", "bottom_depth") if is_depth @@ -595,13 +606,8 @@ def _merge_events(aws_event: Any, data_transformation_type: Type[RawBaseEvent]) aws_event = aws_event[0] return aws_event - elif data_transformation_type == RawStreamEvent: - # stream event - for event in aws_event[1:]: - aws_event[0]["records"].extend(event["records"]) - aws_event = [aws_event[0]] - return aws_event - - else: - # unexpected event type, raise an exception - assert_never(data_transformation_type) # type: ignore + # stream event + for event in aws_event[1:]: + aws_event[0]["records"].extend(event["records"]) + aws_event = [aws_event[0]] + return aws_event diff --git a/tests/unit/test_docs/test_merging.py b/tests/unit/test_docs/test_merging.py index 91166bd5..f5713f7e 100644 --- a/tests/unit/test_docs/test_merging.py +++ b/tests/unit/test_docs/test_merging.py @@ -46,6 +46,12 @@ def test_tutorial001(context): asset_id=1, company_id=1, ), + RawTimeRecord( + collection=str(), + timestamp=timestamp + 2, + asset_id=1, + company_id=1, + ), ], metadata=RawMetadata( app_stream_id=1, @@ -55,10 +61,10 @@ def test_tutorial001(context): ).dict() ] ) - timestamp += 2 + timestamp += 3 result_event: StreamTimeEvent = tutorial001.app(event, context)[0] - assert len(result_event.records) == 6, "records were not merged into a single event" + assert len(result_event.records) == 9, "records were not merged into a single event" @pytest.mark.parametrize( diff --git a/tests/unit/test_merge_events.py b/tests/unit/test_merge_events.py deleted file mode 100644 index a468ce48..00000000 --- a/tests/unit/test_merge_events.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from corva.handlers import _merge_events -from corva.models.task import RawTaskEvent - - -def test_events_not_merged_on_unexpected_event_type(): - """ - when unexpected event type(in our test - raw task event) is - passed - fail with RuntimeError - """ - aws_event = [{"sample": 1}, {"sample2": 2}] - with pytest.raises(AssertionError): - _merge_events(aws_event, RawTaskEvent)