diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index cdb9005..e50b3ab 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -36,12 +36,7 @@ except ImportError: from airflow.providers.common.sql.sensors.sql import SqlSensor -# python sensor was moved in Airflow 2.0.0 -try: - from airflow.sensors.python import PythonSensor -except ImportError: - from airflow.contrib.sensors.python_sensor import PythonSensor - +from airflow.sensors.python import PythonSensor # k8s libraries are moved in v5.0.0 try: @@ -69,7 +64,7 @@ ) from airflow.kubernetes.secret import Secret from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator -except ImportError: +except ImportError: # pragma: no cover from airflow.contrib.kubernetes.pod import Port from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv from airflow.contrib.kubernetes.secret import Secret @@ -77,21 +72,12 @@ from airflow.contrib.kubernetes.volume_mount import VolumeMount from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator +from airflow.utils.task_group import TaskGroup from kubernetes.client.models import V1Container, V1Pod from dagfactory import utils from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException -# pylint: disable=ungrouped-imports,invalid-name -# Disabling pylint's ungrouped-imports warning because this is a -# conditional import and cannot be done within the import group above -# TaskGroup is introduced in Airflow 2.0.0 -if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - from airflow.utils.task_group import TaskGroup -else: - TaskGroup = None -# pylint: disable=ungrouped-imports,invalid-name - # TimeTable is introduced in Airflow 2.2.0 if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"): from airflow.timetables.base import Timetable @@ -104,12 +90,7 @@ else: MappedOperator = None -# XComArg is introduced in Airflow 2.0.0 -if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - from airflow.models.xcom_arg import XComArg -else: - XComArg = None -# pylint: disable=ungrouped-imports,invalid-name +from airflow.models.xcom_arg import XComArg if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"): from airflow.datasets import Dataset @@ -151,9 +132,6 @@ def get_dag_params(self) -> Dict[str, Any]: raise DagFactoryConfigException("Failed to merge config with default config") from err dag_params["dag_id"]: str = self.dag_name - if dag_params.get("task_groups") and version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): - raise DagFactoryConfigException("`task_groups` key can only be used with Airflow 2.x.x") - if utils.check_dict_key(dag_params, "schedule_interval") and dag_params["schedule_interval"] == "None": dag_params["schedule_interval"] = None @@ -559,6 +537,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: else operator_obj.partial(**task_params).expand(**expand_kwargs) ) except Exception as err: + raise err raise DagFactoryException(f"Failed to create {operator_obj} task") from err return task @@ -693,10 +672,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: if not dag_params.get("timetable") and not utils.check_dict_key(dag_params, "schedule"): dag_kwargs["schedule_interval"] = dag_params.get("schedule_interval", timedelta(days=1)) - if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.11"): - dag_kwargs["description"] = dag_params.get("description", None) - else: - dag_kwargs["description"] = dag_params.get("description", "") + dag_kwargs["description"] = dag_params.get("description", None) if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"): dag_kwargs["max_active_tasks"] = dag_params.get( diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 4ef6aa2..f777aa6 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -1,13 +1,15 @@ import datetime import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import mock_open, patch import pendulum import pytest from airflow import DAG from packaging import version +from dagfactory.dagbuilder import Dataset + try: from airflow.providers.http.sensors.http import HttpSensor except ImportError: @@ -879,3 +881,31 @@ def test_replace_expand_string_with_xcom(): updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict) assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"]) assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"]) + + +@pytest.mark.skipif( + version.parse(AIRFLOW_VERSION) <= version.parse("2.4.0"), reason="Requires Airflow version greater than 2.4.0" +) +@pytest.mark.parametrize( + "outlets,output", + [ + ( + {"datasets": "s3://test/test.txt", "file": "file://path/to/my_file.txt"}, + ["s3://test/test.txt", "file://path/to/my_file.txt"], + ), + (["s3://test/test.txt"], ["s3://test/test.txt"]), + ], +) +@patch("dagfactory.dagbuilder.utils.get_datasets_uri_yaml_file", new_callable=mock_open) +def test_make_task_outlets(mock_read_file, outlets, output): + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) + task_params = { + "task_id": "process", + "python_callable_name": "expand_task", + "python_callable_file": os.path.realpath(__file__), + "outlets": outlets, + } + mock_read_file.return_value = output + operator = "airflow.operators.python_operator.PythonOperator" + actual = td.make_task(operator, task_params) + assert actual.outlets == [Dataset(uri) for uri in output] diff --git a/tests/test_dagfactory.py b/tests/test_dagfactory.py index 6c44f6c..53dc78a 100644 --- a/tests/test_dagfactory.py +++ b/tests/test_dagfactory.py @@ -429,7 +429,17 @@ def test_load_yaml_dags_succeed(mock_emit_usage_metrics_if_enabled): dags_folder="tests/fixtures", suffix=["dag_factory_variables_as_arguments.yml"], ) + # Confirm the representative telemetry for all the DAGs defined in the desired YAML is being sent args = mock_emit_usage_metrics_if_enabled.call_args.args assert args[0] == "load_yaml_dags" assert args[1] == {"dags_count": 2, "tasks_count": 4, "taskgroups_count": 0} + + +def test_load_yaml_dags_default_suffix_succeed(caplog): + caplog.set_level(logging.INFO) + load_yaml_dags( + globals_dict=globals(), + dags_folder="tests/fixtures", + ) + assert "Loading DAGs from tests/fixtures" in caplog.messages