From bdab7dc28d15aec230b6aa91f5b43d45a9a1dcea Mon Sep 17 00:00:00 2001 From: Avihais12344 <69143592+Avihais12344@users.noreply.github.com> Date: Sat, 21 Sep 2024 12:30:42 +0300 Subject: [PATCH] Use set instead of list for dags' tags (#41695) * Started working on dag tags, moved the tags to set, and added atest to check duplications. * Fixed more areas of the c'tor. * Fixed test of dag tags. * Added a check to see if the tags are mutable. * Added newsfragment. * Removed unecessary check. * Removed str specification at the type, for compatability with python 3.8. * Removed more type specification as part of compatability with python 3.8 * Fixed the newsfragment. * Added missing word. * Used `` for code segemnts at the rst file. * Reformatted the file. * Fixed wrong method for adding tag. * Added type hinting at the dag bag. * Deserialized the tags to set. * Adjusted the tests for the set type. * Added type hinting. * Sorting the tags by name. * Changed to typing. * Update newsfragments/41420.significant.rst Co-authored-by: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> * Update newsfragments/41420.significant.rst Co-authored-by: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> * Removed the generic specification at the dag args expected types, as it raises the error: Subscripted generics cannot be used with class and instance checks. * Added tags to the expected serialized DAG. * Added sorting the tags keys by the name key. * Fixed sorting tags by name to use `sorted` instead of `.sort` * Fixed tags comparesion, as it's now a set, and not a list. --------- Co-authored-by: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> --- airflow/models/dag.py | 9 +++-- airflow/models/dagbag.py | 4 +- airflow/serialization/serialized_objects.py | 2 + newsfragments/41420.significant.rst | 11 +++++ .../api_connexion/schemas/test_dag_schema.py | 18 ++++++++- tests/models/test_dag.py | 40 ++++++++++++++++++- tests/models/test_dagbag.py | 8 ++-- tests/models/test_serialized_dag.py | 4 +- tests/serialization/test_dag_serialization.py | 3 +- 9 files changed, 82 insertions(+), 17 deletions(-) create mode 100644 newsfragments/41420.significant.rst diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 388f322d2d55..00820585b68a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -42,6 +42,7 @@ Iterable, Iterator, List, + MutableSet, Pattern, Sequence, Union, @@ -351,7 +352,7 @@ def _create_orm_dagrun( "doc_md": str, "is_paused_upon_creation": bool, "render_template_as_native_obj": bool, - "tags": list, + "tags": Collection, "auto_register": bool, "fail_stop": bool, "dag_display_name": str, @@ -528,7 +529,7 @@ def __init__( is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, render_template_as_native_obj: bool = False, - tags: list[str] | None = None, + tags: Collection[str] | None = None, owner_links: dict[str, str] | None = None, auto_register: bool = True, fail_stop: bool = False, @@ -678,7 +679,7 @@ def __init__( self.doc_md = self.get_doc_md(doc_md) - self.tags = tags or [] + self.tags: MutableSet[str] = set(tags or []) self._task_group = TaskGroup.create_root(self) self.validate_schedule_and_params() wrong_links = dict(self.iter_invalid_owner_links()) @@ -3311,7 +3312,7 @@ def dag( is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, render_template_as_native_obj: bool = False, - tags: list[str] | None = None, + tags: Collection[str] | None = None, owner_links: dict[str, str] | None = None, auto_register: bool = True, fail_stop: bool = False, diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 8b155e7b526a..b2d45a133187 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -266,7 +266,7 @@ def _add_dag_from_db(self, dag_id: str, session: Session): """Add DAG to DagBag from DB.""" from airflow.models.serialized_dag import SerializedDagModel - row = SerializedDagModel.get(dag_id, session) + row: SerializedDagModel | None = SerializedDagModel.get(dag_id, session) if not row: return None @@ -457,7 +457,7 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk): found_dags.append(dag) return found_dags - def bag_dag(self, dag): + def bag_dag(self, dag: DAG): """ Add the DAG into the bag. diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 998b5ba3f422..12310685ec69 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1673,6 +1673,8 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: v = cls.deserialize(v) elif k == "params": v = cls._deserialize_params_dict(v) + elif k == "tags": + v = set(v) # else use v as it is setattr(dag, k, v) diff --git a/newsfragments/41420.significant.rst b/newsfragments/41420.significant.rst new file mode 100644 index 000000000000..361b8c7ea9c4 --- /dev/null +++ b/newsfragments/41420.significant.rst @@ -0,0 +1,11 @@ +**Breaking Change** + +Replaced Python's ``list`` with ``MutableSet`` for the property ``DAG.tags``. + +At the constractur you still can use list, +you actually can use any data structure that implements the +``Collection`` interface. + +The ``tags`` property of the ``DAG`` model would be of type +``MutableSet`` instead of ``list``, +as there are no actual duplicates at the tags. diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 858c62815f29..a4a86bc05cc9 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -185,7 +185,10 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): } }, "start_date": "2020-06-19T00:00:00+00:00", - "tags": [{"name": "example1"}, {"name": "example2"}], + "tags": sorted( + [{"name": "example1"}, {"name": "example2"}], + key=lambda val: val["name"], + ), "template_searchpath": None, "timetable_summary": "1 day, 0:00:00", "timezone": UTC_JSON_REPR, @@ -198,6 +201,10 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): } obj = schema.dump(dag) expected.update({"last_parsed": obj["last_parsed"]}) + obj["tags"] = sorted( + obj["tags"], + key=lambda val: val["name"], + ) assert obj == expected @@ -243,7 +250,10 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali } }, "start_date": "2020-06-19T00:00:00+00:00", - "tags": [{"name": "example1"}, {"name": "example2"}], + "tags": sorted( + [{"name": "example1"}, {"name": "example2"}], + key=lambda val: val["name"], + ), "template_searchpath": None, "timetable_summary": "Dataset", "timezone": UTC_JSON_REPR, @@ -256,4 +266,8 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali } obj = schema.dump(dag) expected.update({"last_parsed": obj["last_parsed"]}) + obj["tags"] = sorted( + obj["tags"], + key=lambda val: val["name"], + ) assert obj == expected diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index d5fd2ad7291c..90d956caeb7d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -805,7 +805,7 @@ def test_bulk_write_to_db(self): DAG.bulk_write_to_db(dags) # Adding tags for dag in dags: - dag.tags.append("test-dag2") + dag.tags.add("test-dag2") with assert_queries_count(9): DAG.bulk_write_to_db(dags) with create_session() as session: @@ -843,7 +843,7 @@ def test_bulk_write_to_db(self): # Removing all tags for dag in dags: - dag.tags = None + dag.tags = set() with assert_queries_count(9): DAG.bulk_write_to_db(dags) with create_session() as session: @@ -3383,6 +3383,42 @@ def test__tags_length(tags: list[str], should_pass: bool): DAG("test-dag", schedule=None, tags=tags) +@pytest.mark.parametrize( + "input_tags, expected_result", + [ + pytest.param([], set(), id="empty tags"), + pytest.param( + ["a normal tag"], + {"a normal tag"}, + id="one tag", + ), + pytest.param( + ["a normal tag", "another normal tag"], + {"a normal tag", "another normal tag"}, + id="two different tags", + ), + pytest.param( + ["a", "a"], + {"a"}, + id="two same tags", + ), + ], +) +def test__tags_duplicates(input_tags: list[str], expected_result: set[str]): + result = DAG("test-dag", tags=input_tags) + assert result.tags == expected_result + + +def test__tags_mutable(): + expected_tags = {"6", "7"} + test_dag = DAG("test-dag") + test_dag.tags.add("6") + test_dag.tags.add("7") + test_dag.tags.add("8") + test_dag.tags.remove("8") + assert test_dag.tags == expected_tags + + @pytest.mark.need_serialized_dag def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets): dataset1 = Dataset(uri="ds1") diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index a5ec740df285..0179fa865291 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -830,11 +830,11 @@ def test_get_dag_with_dag_serialization(self): # from DB with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 4)), tick=False): with assert_queries_count(0): - assert dag_bag.get_dag("example_bash_operator").tags == ["example", "example2"] + assert dag_bag.get_dag("example_bash_operator").tags == {"example", "example2"} # Make a change in the DAG and write Serialized DAG to the DB with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 6)), tick=False): - example_bash_op_dag.tags += ["new_tag"] + example_bash_op_dag.tags.add("new_tag") SerializedDagModel.write_dag(dag=example_bash_op_dag) # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' @@ -869,7 +869,7 @@ def test_get_dag_refresh_race_condition(self): ser_dag = dag_bag.get_dag("example_bash_operator") ser_dag_update_time = dag_bag.dags_last_fetched["example_bash_operator"] - assert ser_dag.tags == ["example", "example2"] + assert ser_dag.tags == {"example", "example2"} assert ser_dag_update_time == tz.datetime(2020, 1, 5, 1, 0, 10) with create_session() as session: @@ -883,7 +883,7 @@ def test_get_dag_refresh_race_condition(self): # Note the date *before* the deserialize step above, simulating a serialization happening # long before the transaction is committed with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 0)), tick=False): - example_bash_op_dag.tags += ["new_tag"] + example_bash_op_dag.tags.add("new_tag") SerializedDagModel.write_dag(dag=example_bash_op_dag) # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index f86aa1b90467..9f83280f8eb3 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -109,8 +109,8 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self): assert dag_updated is False # Update DAG - example_bash_op_dag.tags += ["new_tag"] - assert set(example_bash_op_dag.tags) == {"example", "example2", "new_tag"} + example_bash_op_dag.tags.add("new_tag") + assert example_bash_op_dag.tags == {"example", "example2", "new_tag"} dag_updated = SDM.write_dag(dag=example_bash_op_dag) s_dag_2 = session.get(SDM, example_bash_op_dag.dag_id) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 8311aa77e7e7..7dfe57054c60 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -260,6 +260,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "edge_info": {}, "dag_dependencies": [], "params": [], + "tags": [], }, } @@ -587,7 +588,7 @@ def test_dag_roundtrip_from_timetable(self, timetable): roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag)) self.validate_deserialized_dag(roundtripped, dag) - def validate_deserialized_dag(self, serialized_dag, dag): + def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG): """ Verify that all example DAGs work with DAG Serialization by checking fields between Serialized Dags & non-Serialized Dags