Skip to content

Commit

Permalink
[EXPERIMENTAL] Add transfer nodes
Browse files Browse the repository at this point in the history
- improve test coverage
- add experimental decorator
  • Loading branch information
feluelle committed Mar 10, 2022
1 parent 555b8d4 commit 48e0c2f
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 20 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ jobs:
strategy:
fail-fast: false
matrix:
experimental-features: [false, true]
python: ['3.9', '3.10']
include:
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: false
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: true
steps:
- name: Check out repository
uses: actions/checkout@v2
Expand All @@ -42,6 +47,8 @@ jobs:
- name: Install library
run: poetry install --no-interaction
- name: Run tests
env:
AIRFLOW_DIAGRAMS__EXPERIMENTAL: ${{ matrix.experimental-features }}
run: poetry run pytest --cov-report=xml --cov=airflow_diagrams tests/
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ pip install airflow-diagrams
Then just call it like this:

```console
Usage: airflow-diagrams generate [OPTIONS]
Usage: airflow-diagrams generate [OPTIONS] [EXPERIMENTAL]

Generates <airflow-dag-id>_diagrams.py in <output-path> directory which
contains the definition to create a diagram. Run this file and you will get
a rendered diagram.

Arguments:
[EXPERIMENTAL] Enable experimental features by setting the variable to
'true'. [env var: AIRFLOW_DIAGRAMS__EXPERIMENTAL;default:
False]

Options:
-d, --airflow-dag-id TEXT The dag id from which to generate the diagram.
By default it generates for all.
Expand Down Expand Up @@ -66,6 +71,10 @@ Options:

_Examples of generated diagrams can be found in the [examples](examples) directory._

## 🧪 Experimental Features

* **Transfer Nodes**: Convert Airflow transfer operators into two tasks i.e. source & destination grouped in a cluster

## 🤔 How it Works

1. ℹ️ It connects, by using the official [Apache Airflow Python Client](https://github.com/apache/airflow-client-python), to your Airflow installation to retrieve all DAGs (in case you don't specify any `dag_id`) and all Tasks for the DAG(s).
Expand Down
7 changes: 6 additions & 1 deletion airflow_diagrams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Top-level package for airflow-diagrams."""
from importlib.metadata import version
from os import getcwd
from os import getcwd, getenv
from os.path import dirname, join, realpath

__app_name__ = "airflow-diagrams"
__version__ = version(__name__)
__location__ = realpath(join(getcwd(), dirname(__file__)))
__experimental__ = getenv("AIRFLOW_DIAGRAMS__EXPERIMENTAL", "False").lower() in (
"true",
"1",
"t",
)
62 changes: 57 additions & 5 deletions airflow_diagrams/airflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from re import findall
from typing import Generator, Optional

from airflow_client.client.api.dag_api import DAGApi
from airflow_client.client.api_client import ApiClient, Configuration

from airflow_diagrams.class_ref import ClassRef
from airflow_diagrams.utils import experimental


@dataclass
Expand Down Expand Up @@ -50,19 +52,69 @@ def get_tasks(self) -> list[AirflowTask]:
:returns: a list of Airflow Tasks
"""
return [
# TODO: Enable type checking when https://github.com/apache/airflow-client-python/issues/20 is fixed.
response = self.dag_api.get_tasks(self.dag_id, _check_return_type=False)

tasks = [
AirflowTask(
class_ref=ClassRef(**task["class_ref"]),
task_id=task["task_id"],
downstream_task_ids=task["downstream_task_ids"],
group_name=None,
)
# TODO: Enable type checking when https://github.com/apache/airflow-client-python/issues/20 is fixed.
for task in self.dag_api.get_tasks(self.dag_id, _check_return_type=False)[
"tasks"
]
for task in response["tasks"]
]

transfer_nodes(tasks)

return tasks


@experimental
def transfer_nodes(tasks: list[AirflowTask]) -> None:
"""
Transfer Nodes replaces an Airflow transfer task by two tasks i.e. source & destination clustered.
:param tasks: The tasks to modify.
"""
transfer_tasks = [
task for task in tasks if ".transfers." in task.class_ref.module_path
]

for task in transfer_tasks:
class_name_words = findall("[A-Z][^A-Z]*", task.class_ref.class_name)
to_index = class_name_words.index("To")
source_class_name = "".join(class_name_words[:to_index])
destination_class_name = "".join(class_name_words[to_index + 1 :])
source_task_id = f"[SOURCE] {task.task_id}"
destination_task_id = f"[DESTINATION] {task.task_id}"
source = AirflowTask(
class_ref=ClassRef(
module_path="<unknown>",
class_name=source_class_name,
),
task_id=source_task_id,
downstream_task_ids=[destination_task_id],
group_name=task.task_id,
)
destination = AirflowTask(
class_ref=ClassRef(
module_path="<unknown>",
class_name=destination_class_name,
),
task_id=destination_task_id,
downstream_task_ids=task.downstream_task_ids,
group_name=task.task_id,
)
tasks.extend([source, destination])
tasks.remove(task)

transfer_task_ids = list(map(lambda task: task.task_id, transfer_tasks))
for t_idx, t in enumerate(tasks):
for dt_idx, dt_id in enumerate(t.downstream_task_ids):
if dt_id in transfer_task_ids:
tasks[t_idx].downstream_task_ids[dt_idx] = f"[SOURCE] {dt_id}"


class AirflowApiTree:
"""Retrieve Airflow Api information as a Tree."""
Expand Down
10 changes: 9 additions & 1 deletion airflow_diagrams/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,24 @@ def generate( # dead: disable
exists=True,
dir_okay=False,
),
experimental: bool = Argument(
False,
envvar="AIRFLOW_DIAGRAMS__EXPERIMENTAL",
help="Enable experimental features by setting the variable to 'true'.",
),
) -> None:
if verbose:
rprint("💬 Running with verbose output...")
rprint("💬Running with verbose output...")
logging.basicConfig(
level=logging.DEBUG,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler()],
)

if experimental:
rprint("🧪Running with experimental features..")

mappings: dict = load_mappings(mapping_file) if mapping_file else {}

diagrams_class_refs: list[ClassRef] = retrieve_class_refs(
Expand Down
2 changes: 1 addition & 1 deletion airflow_diagrams/diagram.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ with Diagram("{{ name }}", show=False):
{% for node in nodes -%}
{% if node.cluster -%}
with {{ node.cluster.get_variable() }}:
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{{ node.get_variable() }} = {{ node.class_name }}()
{% else -%}
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{% endif -%}
Expand Down
14 changes: 13 additions & 1 deletion airflow_diagrams/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
from pathlib import Path

import yaml

from airflow_diagrams import __location__
from airflow_diagrams import __experimental__, __location__


def load_abbreviations() -> dict:
Expand Down Expand Up @@ -31,3 +32,14 @@ def load_mappings(file: Path) -> dict:
"r",
) as mapping_yaml:
return yaml.safe_load(mapping_yaml)


def experimental(func):
"""Decorate experimental features."""

def wrapper(*args, **kwargs):
if __experimental__:
logging.debug("Calling experimental feature: %s", func.__name__)
func(*args, **kwargs)

return wrapper
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ pytest-order = "^1.0.1"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.coverage.run]
omit = ["airflow_diagrams/__main__.py"]
79 changes: 72 additions & 7 deletions tests/test_airflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import pytest

from airflow_diagrams import __experimental__
from airflow_diagrams.airflow import AirflowDag, AirflowTask
from airflow_diagrams.class_ref import ClassRef


def test_airflow_dag_eq(airflow_api_tree):
"""Test Airflow DAG equality"""
kwargs = dict(dag_id="test_dag", dag_api=airflow_api_tree.dag_api)
airflow_dag = AirflowDag(**kwargs)
assert airflow_dag == AirflowDag(**kwargs)
assert airflow_dag != dict(**kwargs)


def test_airflow_dag_get_tasks(airflow_api_tree):
"""Test getting tasks from Airflow DAG"""
dag_id = "test_dag"
Expand Down Expand Up @@ -38,6 +49,67 @@ def test_airflow_dag_get_tasks(airflow_api_tree):
]


@pytest.mark.skipif(not __experimental__, reason="Transfer nodes are experimental.")
def test_airflow_dag_get_tasks_with_transfers(airflow_api_tree):
"""Test getting tasks from Airflow DAG"""
dag_id = "test_dag"
airflow_api_tree.dag_api.get_tasks.return_value = dict(
tasks=[
dict(
class_ref=dict(
module_path="fizz",
class_name="Fizz",
),
task_id="test_task_0",
downstream_task_ids=["test_task_1"],
),
dict(
class_ref=dict(
module_path="foo.transfers.bar",
class_name="FooToBar",
),
task_id="test_task_1",
downstream_task_ids=[],
),
],
)
assert airflow_api_tree.get_dags(dag_id=dag_id)[0].get_tasks() == [
AirflowTask(
class_ref=ClassRef(
**dict(
module_path="fizz",
class_name="Fizz",
)
),
task_id="test_task_0",
downstream_task_ids=["[SOURCE] test_task_1"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
**dict(
module_path="<unknown>",
class_name="Foo",
)
),
task_id="[SOURCE] test_task_1",
downstream_task_ids=["[DESTINATION] test_task_1"],
group_name="test_task_1",
),
AirflowTask(
class_ref=ClassRef(
**dict(
module_path="<unknown>",
class_name="Bar",
)
),
task_id="[DESTINATION] test_task_1",
downstream_task_ids=[],
group_name="test_task_1",
),
]


def test_airflow_api_tree_get_dags(airflow_api_tree):
"""Test getting dags from Airflow API Tree"""
airflow_api_tree.dag_api.get_dags.return_value = dict(
Expand Down Expand Up @@ -66,10 +138,3 @@ def test_airflow_api_tree_get_dags_with_dag_id(airflow_api_tree):
AirflowDag(dag_id=dag_id, dag_api=airflow_api_tree.dag_api),
]
airflow_api_tree.dag_api.assert_not_called()


def test_airflow_dag_eq(airflow_api_tree):
"""Test Airflow DAG equality"""
airflow_dag_kwargs = dict(dag_id="test_dag", dag_api=airflow_api_tree.dag_api)
assert AirflowDag(**airflow_dag_kwargs) == AirflowDag(**airflow_dag_kwargs)
assert AirflowDag(**airflow_dag_kwargs) != dict(**airflow_dag_kwargs)
Loading

0 comments on commit 48e0c2f

Please sign in to comment.