Skip to content

Commit

Permalink
allow dataset alias to add more than one dataset events (apache#42189) (
Browse files Browse the repository at this point in the history
apache#42247)

(cherry picked from commit a5d0a63)

Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
utkarsharma2 and Lee-W authored Sep 16, 2024
1 parent 50445e3 commit 957856f
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 36 deletions.
1 change: 1 addition & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class DatasetAliasEvent(TypedDict):

source_alias_name: str
dest_dataset_uri: str
extra: dict[str, Any]


@attr.define()
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3030,11 +3030,11 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
session=session,
)
elif isinstance(obj, DatasetAlias):
if dataset_alias_event := events[obj].dataset_alias_event:
for dataset_alias_event in events[obj].dataset_alias_events:
dataset_alias_name = dataset_alias_event["source_alias_name"]
dataset_uri = dataset_alias_event["dest_dataset_uri"]
extra = events[obj].extra
extra = dataset_alias_event["extra"]
frozen_extra = frozenset(extra.items())
dataset_alias_name = dataset_alias_event["source_alias_name"]

dataset_tuple_to_alias_names_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)

Expand Down
17 changes: 13 additions & 4 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,24 @@ def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
raw_key = var.raw_key
return {
"extra": var.extra,
"dataset_alias_event": var.dataset_alias_event,
"dataset_alias_events": var.dataset_alias_events,
"raw_key": BaseSerialization.serialize(raw_key),
}


def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
raw_key = BaseSerialization.deserialize(var["raw_key"])
outlet_event_accessor = OutletEventAccessor(extra=var["extra"], raw_key=raw_key)
outlet_event_accessor.dataset_alias_event = var["dataset_alias_event"]
# This is added for compatibility. The attribute used to be dataset_alias_event and
# is now dataset_alias_events.
if dataset_alias_event := var.get("dataset_alias_event", None):
dataset_alias_events = [dataset_alias_event]
else:
dataset_alias_events = var.get("dataset_alias_events", [])

outlet_event_accessor = OutletEventAccessor(
extra=var["extra"],
raw_key=BaseSerialization.deserialize(var["raw_key"]),
dataset_alias_events=dataset_alias_events,
)
return outlet_event_accessor


Expand Down
10 changes: 4 additions & 6 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class OutletEventAccessor:

raw_key: str | Dataset | DatasetAlias
extra: dict[str, Any] = attrs.Factory(dict)
dataset_alias_event: DatasetAliasEvent | None = None
dataset_alias_events: list[DatasetAliasEvent] = attrs.field(factory=list)

def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None:
"""Add a DatasetEvent to an existing Dataset."""
Expand All @@ -190,12 +190,10 @@ def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> No
else:
return

if extra:
self.extra = extra

self.dataset_alias_event = DatasetAliasEvent(
source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri
event = DatasetAliasEvent(
source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri, extra=extra or {}
)
self.dataset_alias_events.append(event)


class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class OutletEventAccessor:
*,
extra: dict[str, Any],
raw_key: str | Dataset | DatasetAlias,
dataset_alias_event: DatasetAliasEvent | None = None,
dataset_alias_events: list[DatasetAliasEvent],
) -> None: ...
def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None: ...
extra: dict[str, Any]
raw_key: str | Dataset | DatasetAlias
dataset_alias_event: DatasetAliasEvent | None
dataset_alias_events: list[DatasetAliasEvent]

class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
def __iter__(self) -> Iterator[str]: ...
Expand Down
14 changes: 6 additions & 8 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool:


def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool:
return a.raw_key == b.raw_key and a.extra == b.extra and a.dataset_alias_event == b.dataset_alias_event
return a.raw_key == b.raw_key and a.extra == b.extra and a.dataset_alias_events == b.dataset_alias_events


class MockLazySelectSequence(LazySelectSequence):
Expand Down Expand Up @@ -240,25 +240,23 @@ def __len__(self) -> int:
lambda a, b: a.get_uri() == b.get_uri(),
),
(
OutletEventAccessor(
raw_key=Dataset(uri="test"), extra={"key": "value"}, dataset_alias_event=None
),
OutletEventAccessor(raw_key=Dataset(uri="test"), extra={"key": "value"}, dataset_alias_events=[]),
DAT.DATASET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
OutletEventAccessor(
raw_key=DatasetAlias(name="test_alias"),
extra={"key": "value"},
dataset_alias_event=DatasetAliasEvent(
source_alias_name="test_alias", dest_dataset_uri="test_uri"
),
dataset_alias_events=[
DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})
],
),
DAT.DATASET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
OutletEventAccessor(raw_key="test", extra={"key": "value"}),
OutletEventAccessor(raw_key="test", extra={"key": "value"}, dataset_alias_events=[]),
DAT.DATASET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
Expand Down
29 changes: 16 additions & 13 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,44 @@

class TestOutletEventAccessor:
@pytest.mark.parametrize(
"raw_key, dataset_alias_event",
"raw_key, dataset_alias_events",
(
(
DatasetAlias("test_alias"),
DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri"),
[DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})],
),
(Dataset("test_uri"), None),
(Dataset("test_uri"), []),
),
)
def test_add(self, raw_key, dataset_alias_event):
def test_add(self, raw_key, dataset_alias_events):
outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={})
outlet_event_accessor.add(Dataset("test_uri"))
assert outlet_event_accessor.dataset_alias_event == dataset_alias_event
assert outlet_event_accessor.dataset_alias_events == dataset_alias_events

@pytest.mark.db_test
@pytest.mark.parametrize(
"raw_key, dataset_alias_event",
"raw_key, dataset_alias_events",
(
(
DatasetAlias("test_alias"),
DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri"),
[DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})],
),
("test_alias", DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri")),
(Dataset("test_uri"), None),
(
"test_alias",
[DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})],
),
(Dataset("test_uri"), []),
),
)
def test_add_with_db(self, raw_key, dataset_alias_event, session):
def test_add_with_db(self, raw_key, dataset_alias_events, session):
dsm = DatasetModel(uri="test_uri")
dsam = DatasetAliasModel(name="test_alias")
session.add_all([dsm, dsam])
session.flush()

outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={})
outlet_event_accessor.add("test_uri")
assert outlet_event_accessor.dataset_alias_event == dataset_alias_event
outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={"not": ""})
outlet_event_accessor.add("test_uri", extra={})
assert outlet_event_accessor.dataset_alias_events == dataset_alias_events


class TestOutletEventAccessors:
Expand Down

0 comments on commit 957856f

Please sign in to comment.