Skip to content

Commit

Permalink
Add more tests (#258)
Browse files Browse the repository at this point in the history
closes: #199

- Add test for the outlets param and load YML dag with default suffix
cases
- Cleanup some import for Airflow < 2.0.0
  • Loading branch information
pankajastro authored Oct 18, 2024
1 parent 79bc648 commit 7f6bdda
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
36 changes: 6 additions & 30 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@
except ImportError:
from airflow.providers.common.sql.sensors.sql import SqlSensor

# python sensor was moved in Airflow 2.0.0
try:
from airflow.sensors.python import PythonSensor
except ImportError:
from airflow.contrib.sensors.python_sensor import PythonSensor

from airflow.sensors.python import PythonSensor

# k8s libraries are moved in v5.0.0
try:
Expand Down Expand Up @@ -69,29 +64,20 @@
)
from airflow.kubernetes.secret import Secret
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
except ImportError:
except ImportError: # pragma: no cover
from airflow.contrib.kubernetes.pod import Port
from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.contrib.kubernetes.secret import Secret
from airflow.contrib.kubernetes.volume import Volume
from airflow.contrib.kubernetes.volume_mount import VolumeMount
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator

from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod

from dagfactory import utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException

# pylint: disable=ungrouped-imports,invalid-name
# Disabling pylint's ungrouped-imports warning because this is a
# conditional import and cannot be done within the import group above
# TaskGroup is introduced in Airflow 2.0.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
from airflow.utils.task_group import TaskGroup
else:
TaskGroup = None
# pylint: disable=ungrouped-imports,invalid-name

# TimeTable is introduced in Airflow 2.2.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
from airflow.timetables.base import Timetable
Expand All @@ -104,12 +90,7 @@
else:
MappedOperator = None

# XComArg is introduced in Airflow 2.0.0
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
from airflow.models.xcom_arg import XComArg
else:
XComArg = None
# pylint: disable=ungrouped-imports,invalid-name
from airflow.models.xcom_arg import XComArg

if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
from airflow.datasets import Dataset
Expand Down Expand Up @@ -151,9 +132,6 @@ def get_dag_params(self) -> Dict[str, Any]:
raise DagFactoryConfigException("Failed to merge config with default config") from err
dag_params["dag_id"]: str = self.dag_name

if dag_params.get("task_groups") and version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
raise DagFactoryConfigException("`task_groups` key can only be used with Airflow 2.x.x")

if utils.check_dict_key(dag_params, "schedule_interval") and dag_params["schedule_interval"] == "None":
dag_params["schedule_interval"] = None

Expand Down Expand Up @@ -559,6 +537,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
else operator_obj.partial(**task_params).expand(**expand_kwargs)
)
except Exception as err:
raise err
raise DagFactoryException(f"Failed to create {operator_obj} task") from err
return task

Expand Down Expand Up @@ -693,10 +672,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:
if not dag_params.get("timetable") and not utils.check_dict_key(dag_params, "schedule"):
dag_kwargs["schedule_interval"] = dag_params.get("schedule_interval", timedelta(days=1))

if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.11"):
dag_kwargs["description"] = dag_params.get("description", None)
else:
dag_kwargs["description"] = dag_params.get("description", "")
dag_kwargs["description"] = dag_params.get("description", None)

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
dag_kwargs["max_active_tasks"] = dag_params.get(
Expand Down
32 changes: 31 additions & 1 deletion tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import datetime
import os
from pathlib import Path
from unittest.mock import patch
from unittest.mock import mock_open, patch

import pendulum
import pytest
from airflow import DAG
from packaging import version

from dagfactory.dagbuilder import Dataset

try:
from airflow.providers.http.sensors.http import HttpSensor
except ImportError:
Expand Down Expand Up @@ -879,3 +881,31 @@ def test_replace_expand_string_with_xcom():
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"])


@pytest.mark.skipif(
version.parse(AIRFLOW_VERSION) <= version.parse("2.4.0"), reason="Requires Airflow version greater than 2.4.0"
)
@pytest.mark.parametrize(
"outlets,output",
[
(
{"datasets": "s3://test/test.txt", "file": "file://path/to/my_file.txt"},
["s3://test/test.txt", "file://path/to/my_file.txt"],
),
(["s3://test/test.txt"], ["s3://test/test.txt"]),
],
)
@patch("dagfactory.dagbuilder.utils.get_datasets_uri_yaml_file", new_callable=mock_open)
def test_make_task_outlets(mock_read_file, outlets, output):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
task_params = {
"task_id": "process",
"python_callable_name": "expand_task",
"python_callable_file": os.path.realpath(__file__),
"outlets": outlets,
}
mock_read_file.return_value = output
operator = "airflow.operators.python_operator.PythonOperator"
actual = td.make_task(operator, task_params)
assert actual.outlets == [Dataset(uri) for uri in output]
10 changes: 10 additions & 0 deletions tests/test_dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,17 @@ def test_load_yaml_dags_succeed(mock_emit_usage_metrics_if_enabled):
dags_folder="tests/fixtures",
suffix=["dag_factory_variables_as_arguments.yml"],
)

# Confirm the representative telemetry for all the DAGs defined in the desired YAML is being sent
args = mock_emit_usage_metrics_if_enabled.call_args.args
assert args[0] == "load_yaml_dags"
assert args[1] == {"dags_count": 2, "tasks_count": 4, "taskgroups_count": 0}


def test_load_yaml_dags_default_suffix_succeed(caplog):
caplog.set_level(logging.INFO)
load_yaml_dags(
globals_dict=globals(),
dags_folder="tests/fixtures",
)
assert "Loading DAGs from tests/fixtures" in caplog.messages

0 comments on commit 7f6bdda

Please sign in to comment.