Skip to content

Commit

Permalink
Use set instead of list for dags' tags (apache#41695)
Browse files Browse the repository at this point in the history
* Started working on dag tags, moved the tags to set, and added atest to check duplications.

* Fixed more areas of the c'tor.

* Fixed test of dag tags.

* Added a check to see if the tags are mutable.

* Added newsfragment.

* Removed unecessary check.

* Removed str specification at the type, for compatability with python 3.8.

* Removed more type specification as part of compatability with python 3.8

* Fixed the newsfragment.

* Added missing word.

* Used `` for code segemnts at the rst file.

* Reformatted the file.

* Fixed wrong method for adding tag.

* Added type hinting at the dag bag.

* Deserialized the tags to set.

* Adjusted the tests for the set type.

* Added type hinting.

* Sorting the tags by name.

* Changed to typing.

* Update newsfragments/41420.significant.rst

Co-authored-by: Jens Scheffler <[email protected]>

* Update newsfragments/41420.significant.rst

Co-authored-by: Jens Scheffler <[email protected]>

* Removed the generic specification at the dag args expected types, as it raises the error: Subscripted generics cannot be used with class and instance checks.

* Added tags to the expected serialized DAG.

* Added sorting the tags keys by the name key.

* Fixed sorting tags by name to use `sorted` instead of `.sort`

* Fixed tags comparesion, as it's now a set, and not a list.

---------

Co-authored-by: Jens Scheffler <[email protected]>
  • Loading branch information
Avihais12344 and jscheffl authored Sep 21, 2024
1 parent ba1c602 commit bdab7dc
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 17 deletions.
9 changes: 5 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Iterable,
Iterator,
List,
MutableSet,
Pattern,
Sequence,
Union,
Expand Down Expand Up @@ -351,7 +352,7 @@ def _create_orm_dagrun(
"doc_md": str,
"is_paused_upon_creation": bool,
"render_template_as_native_obj": bool,
"tags": list,
"tags": Collection,
"auto_register": bool,
"fail_stop": bool,
"dag_display_name": str,
Expand Down Expand Up @@ -528,7 +529,7 @@ def __init__(
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
render_template_as_native_obj: bool = False,
tags: list[str] | None = None,
tags: Collection[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
Expand Down Expand Up @@ -678,7 +679,7 @@ def __init__(

self.doc_md = self.get_doc_md(doc_md)

self.tags = tags or []
self.tags: MutableSet[str] = set(tags or [])
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()
wrong_links = dict(self.iter_invalid_owner_links())
Expand Down Expand Up @@ -3311,7 +3312,7 @@ def dag(
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
render_template_as_native_obj: bool = False,
tags: list[str] | None = None,
tags: Collection[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _add_dag_from_db(self, dag_id: str, session: Session):
"""Add DAG to DagBag from DB."""
from airflow.models.serialized_dag import SerializedDagModel

row = SerializedDagModel.get(dag_id, session)
row: SerializedDagModel | None = SerializedDagModel.get(dag_id, session)
if not row:
return None

Expand Down Expand Up @@ -457,7 +457,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
found_dags.append(dag)
return found_dags

def bag_dag(self, dag):
def bag_dag(self, dag: DAG):
"""
Add the DAG into the bag.
Expand Down
2 changes: 2 additions & 0 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,8 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k == "tags":
v = set(v)
# else use v as it is

setattr(dag, k, v)
Expand Down
11 changes: 11 additions & 0 deletions newsfragments/41420.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
**Breaking Change**

Replaced Python's ``list`` with ``MutableSet`` for the property ``DAG.tags``.

At the constractur you still can use list,
you actually can use any data structure that implements the
``Collection`` interface.

The ``tags`` property of the ``DAG`` model would be of type
``MutableSet`` instead of ``list``,
as there are no actual duplicates at the tags.
18 changes: 16 additions & 2 deletions tests/api_connexion/schemas/test_dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer):
}
},
"start_date": "2020-06-19T00:00:00+00:00",
"tags": [{"name": "example1"}, {"name": "example2"}],
"tags": sorted(
[{"name": "example1"}, {"name": "example2"}],
key=lambda val: val["name"],
),
"template_searchpath": None,
"timetable_summary": "1 day, 0:00:00",
"timezone": UTC_JSON_REPR,
Expand All @@ -198,6 +201,10 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer):
}
obj = schema.dump(dag)
expected.update({"last_parsed": obj["last_parsed"]})
obj["tags"] = sorted(
obj["tags"],
key=lambda val: val["name"],
)
assert obj == expected


Expand Down Expand Up @@ -243,7 +250,10 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali
}
},
"start_date": "2020-06-19T00:00:00+00:00",
"tags": [{"name": "example1"}, {"name": "example2"}],
"tags": sorted(
[{"name": "example1"}, {"name": "example2"}],
key=lambda val: val["name"],
),
"template_searchpath": None,
"timetable_summary": "Dataset",
"timezone": UTC_JSON_REPR,
Expand All @@ -256,4 +266,8 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali
}
obj = schema.dump(dag)
expected.update({"last_parsed": obj["last_parsed"]})
obj["tags"] = sorted(
obj["tags"],
key=lambda val: val["name"],
)
assert obj == expected
40 changes: 38 additions & 2 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def test_bulk_write_to_db(self):
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
dag.tags.append("test-dag2")
dag.tags.add("test-dag2")
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
Expand Down Expand Up @@ -843,7 +843,7 @@ def test_bulk_write_to_db(self):

# Removing all tags
for dag in dags:
dag.tags = None
dag.tags = set()
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
Expand Down Expand Up @@ -3383,6 +3383,42 @@ def test__tags_length(tags: list[str], should_pass: bool):
DAG("test-dag", schedule=None, tags=tags)


@pytest.mark.parametrize(
"input_tags, expected_result",
[
pytest.param([], set(), id="empty tags"),
pytest.param(
["a normal tag"],
{"a normal tag"},
id="one tag",
),
pytest.param(
["a normal tag", "another normal tag"],
{"a normal tag", "another normal tag"},
id="two different tags",
),
pytest.param(
["a", "a"],
{"a"},
id="two same tags",
),
],
)
def test__tags_duplicates(input_tags: list[str], expected_result: set[str]):
result = DAG("test-dag", tags=input_tags)
assert result.tags == expected_result


def test__tags_mutable():
expected_tags = {"6", "7"}
test_dag = DAG("test-dag")
test_dag.tags.add("6")
test_dag.tags.add("7")
test_dag.tags.add("8")
test_dag.tags.remove("8")
assert test_dag.tags == expected_tags


@pytest.mark.need_serialized_dag
def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets):
dataset1 = Dataset(uri="ds1")
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,11 +830,11 @@ def test_get_dag_with_dag_serialization(self):
# from DB
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 4)), tick=False):
with assert_queries_count(0):
assert dag_bag.get_dag("example_bash_operator").tags == ["example", "example2"]
assert dag_bag.get_dag("example_bash_operator").tags == {"example", "example2"}

# Make a change in the DAG and write Serialized DAG to the DB
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 6)), tick=False):
example_bash_op_dag.tags += ["new_tag"]
example_bash_op_dag.tags.add("new_tag")
SerializedDagModel.write_dag(dag=example_bash_op_dag)

# Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
Expand Down Expand Up @@ -869,7 +869,7 @@ def test_get_dag_refresh_race_condition(self):
ser_dag = dag_bag.get_dag("example_bash_operator")

ser_dag_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
assert ser_dag.tags == ["example", "example2"]
assert ser_dag.tags == {"example", "example2"}
assert ser_dag_update_time == tz.datetime(2020, 1, 5, 1, 0, 10)

with create_session() as session:
Expand All @@ -883,7 +883,7 @@ def test_get_dag_refresh_race_condition(self):
# Note the date *before* the deserialize step above, simulating a serialization happening
# long before the transaction is committed
with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 0)), tick=False):
example_bash_op_dag.tags += ["new_tag"]
example_bash_op_dag.tags.add("new_tag")
SerializedDagModel.write_dag(dag=example_bash_op_dag)

# Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self):
assert dag_updated is False

# Update DAG
example_bash_op_dag.tags += ["new_tag"]
assert set(example_bash_op_dag.tags) == {"example", "example2", "new_tag"}
example_bash_op_dag.tags.add("new_tag")
assert example_bash_op_dag.tags == {"example", "example2", "new_tag"}

dag_updated = SDM.write_dag(dag=example_bash_op_dag)
s_dag_2 = session.get(SDM, example_bash_op_dag.dag_id)
Expand Down
3 changes: 2 additions & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
"edge_info": {},
"dag_dependencies": [],
"params": [],
"tags": [],
},
}

Expand Down Expand Up @@ -587,7 +588,7 @@ def test_dag_roundtrip_from_timetable(self, timetable):
roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(roundtripped, dag)

def validate_deserialized_dag(self, serialized_dag, dag):
def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
Expand Down

0 comments on commit bdab7dc

Please sign in to comment.