From 957856fe77ccd3b359bff49b4fbcfa4c6fbcbbc4 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Mon, 16 Sep 2024 18:11:55 +0530 Subject: [PATCH] allow dataset alias to add more than one dataset events (#42189) (#42247) (cherry picked from commit a5d0a63d8784d7f4100a4770748c783261968e3c) Co-authored-by: Wei Lee --- airflow/datasets/__init__.py | 1 + airflow/models/taskinstance.py | 6 ++-- airflow/serialization/serialized_objects.py | 17 ++++++++--- airflow/utils/context.py | 10 +++---- airflow/utils/context.pyi | 4 +-- .../serialization/test_serialized_objects.py | 14 ++++----- tests/utils/test_context.py | 29 ++++++++++--------- 7 files changed, 45 insertions(+), 36 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 55d947544c1d..80ed083c4b9d 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -239,6 +239,7 @@ class DatasetAliasEvent(TypedDict): source_alias_name: str dest_dataset_uri: str + extra: dict[str, Any] @attr.define() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0c34d350247f..b34c71bc9fda 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3030,11 +3030,11 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se session=session, ) elif isinstance(obj, DatasetAlias): - if dataset_alias_event := events[obj].dataset_alias_event: + for dataset_alias_event in events[obj].dataset_alias_events: + dataset_alias_name = dataset_alias_event["source_alias_name"] dataset_uri = dataset_alias_event["dest_dataset_uri"] - extra = events[obj].extra + extra = dataset_alias_event["extra"] frozen_extra = frozenset(extra.items()) - dataset_alias_name = dataset_alias_event["source_alias_name"] dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 84ad5679182b..9eb3332a6b6a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -286,15 +286,24 @@ def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: raw_key = var.raw_key return { "extra": var.extra, - "dataset_alias_event": var.dataset_alias_event, + "dataset_alias_events": var.dataset_alias_events, "raw_key": BaseSerialization.serialize(raw_key), } def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: - raw_key = BaseSerialization.deserialize(var["raw_key"]) - outlet_event_accessor = OutletEventAccessor(extra=var["extra"], raw_key=raw_key) - outlet_event_accessor.dataset_alias_event = var["dataset_alias_event"] + # This is added for compatibility. The attribute used to be dataset_alias_event and + # is now dataset_alias_events. + if dataset_alias_event := var.get("dataset_alias_event", None): + dataset_alias_events = [dataset_alias_event] + else: + dataset_alias_events = var.get("dataset_alias_events", []) + + outlet_event_accessor = OutletEventAccessor( + extra=var["extra"], + raw_key=BaseSerialization.deserialize(var["raw_key"]), + dataset_alias_events=dataset_alias_events, + ) return outlet_event_accessor diff --git a/airflow/utils/context.py b/airflow/utils/context.py index c2a0ad7052ea..a72885401f7b 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -172,7 +172,7 @@ class OutletEventAccessor: raw_key: str | Dataset | DatasetAlias extra: dict[str, Any] = attrs.Factory(dict) - dataset_alias_event: DatasetAliasEvent | None = None + dataset_alias_events: list[DatasetAliasEvent] = attrs.field(factory=list) def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None: """Add a DatasetEvent to an existing Dataset.""" @@ -190,12 +190,10 @@ def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> No else: return - if extra: - self.extra = extra - - self.dataset_alias_event = DatasetAliasEvent( - source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri + event = DatasetAliasEvent( + source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri, extra=extra or {} ) + self.dataset_alias_events.append(event) class OutletEventAccessors(Mapping[str, OutletEventAccessor]): diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index d3546286cf7b..658aac5839ec 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -63,12 +63,12 @@ class OutletEventAccessor: *, extra: dict[str, Any], raw_key: str | Dataset | DatasetAlias, - dataset_alias_event: DatasetAliasEvent | None = None, + dataset_alias_events: list[DatasetAliasEvent], ) -> None: ... def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None: ... extra: dict[str, Any] raw_key: str | Dataset | DatasetAlias - dataset_alias_event: DatasetAliasEvent | None + dataset_alias_events: list[DatasetAliasEvent] class OutletEventAccessors(Mapping[str, OutletEventAccessor]): def __iter__(self) -> Iterator[str]: ... diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 661ecbf5dcb7..82d8c16f3fda 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -163,7 +163,7 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool: def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool: - return a.raw_key == b.raw_key and a.extra == b.extra and a.dataset_alias_event == b.dataset_alias_event + return a.raw_key == b.raw_key and a.extra == b.extra and a.dataset_alias_events == b.dataset_alias_events class MockLazySelectSequence(LazySelectSequence): @@ -240,9 +240,7 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor( - raw_key=Dataset(uri="test"), extra={"key": "value"}, dataset_alias_event=None - ), + OutletEventAccessor(raw_key=Dataset(uri="test"), extra={"key": "value"}, dataset_alias_events=[]), DAT.DATASET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), @@ -250,15 +248,15 @@ def __len__(self) -> int: OutletEventAccessor( raw_key=DatasetAlias(name="test_alias"), extra={"key": "value"}, - dataset_alias_event=DatasetAliasEvent( - source_alias_name="test_alias", dest_dataset_uri="test_uri" - ), + dataset_alias_events=[ + DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={}) + ], ), DAT.DATASET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( - OutletEventAccessor(raw_key="test", extra={"key": "value"}), + OutletEventAccessor(raw_key="test", extra={"key": "value"}, dataset_alias_events=[]), DAT.DATASET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 1237be2f8d34..0f4f80f36504 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -27,41 +27,44 @@ class TestOutletEventAccessor: @pytest.mark.parametrize( - "raw_key, dataset_alias_event", + "raw_key, dataset_alias_events", ( ( DatasetAlias("test_alias"), - DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri"), + [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], ), - (Dataset("test_uri"), None), + (Dataset("test_uri"), []), ), ) - def test_add(self, raw_key, dataset_alias_event): + def test_add(self, raw_key, dataset_alias_events): outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={}) outlet_event_accessor.add(Dataset("test_uri")) - assert outlet_event_accessor.dataset_alias_event == dataset_alias_event + assert outlet_event_accessor.dataset_alias_events == dataset_alias_events @pytest.mark.db_test @pytest.mark.parametrize( - "raw_key, dataset_alias_event", + "raw_key, dataset_alias_events", ( ( DatasetAlias("test_alias"), - DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri"), + [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], ), - ("test_alias", DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri")), - (Dataset("test_uri"), None), + ( + "test_alias", + [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], + ), + (Dataset("test_uri"), []), ), ) - def test_add_with_db(self, raw_key, dataset_alias_event, session): + def test_add_with_db(self, raw_key, dataset_alias_events, session): dsm = DatasetModel(uri="test_uri") dsam = DatasetAliasModel(name="test_alias") session.add_all([dsm, dsam]) session.flush() - outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={}) - outlet_event_accessor.add("test_uri") - assert outlet_event_accessor.dataset_alias_event == dataset_alias_event + outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={"not": ""}) + outlet_event_accessor.add("test_uri", extra={}) + assert outlet_event_accessor.dataset_alias_events == dataset_alias_events class TestOutletEventAccessors: