Skip to content

Commit

Permalink
[EXPERIMENTAL] Add transfer nodes
Browse files Browse the repository at this point in the history
- add experimental decorator
  • Loading branch information
feluelle committed Apr 3, 2022
1 parent 6d6e2ed commit 1c9d005
Show file tree
Hide file tree
Showing 13 changed files with 266 additions and 18 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ jobs:
matrix:
python: ['3.9', '3.10']
experimental: [false]
experimental-features: [false, true]
include:
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: false
- python: '3.11.0-alpha - 3.11.0'
experimental: true
experimental-features: true
continue-on-error: ${{ matrix.experimental }}
steps:
- name: Check out repository
Expand Down Expand Up @@ -44,6 +49,8 @@ jobs:
- name: Install library
run: poetry install --no-interaction
- name: Run tests
env:
AIRFLOW_DIAGRAMS__EXPERIMENTAL: ${{ matrix.experimental-features }}
run: poetry run pytest --cov-report=xml --cov=airflow_diagrams tests/
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ Then just call it like this:

_Examples of generated diagrams can be found in the [examples](examples) directory._

## 🧪 Experimental Features

* **Transfer Nodes**: Convert Airflow transfer operators into two tasks i.e. source & destination grouped in a cluster

## 🤔 How it Works

1. ℹ️ It connects, by using the official [Apache Airflow Python Client](https://github.com/apache/airflow-client-python), to your Airflow installation to retrieve all DAGs (in case you don't specify any `dag_id`) and all Tasks for the DAG(s).
Expand Down
2 changes: 2 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ tasks:
# 4. Render diagram
- cd examples && python3 dbt_diagrams.py
fake-dag:
env:
AIRFLOW_DIAGRAMS__EXPERIMENTAL: true
cmds:
# 1. Create fake dag
- python3 dev/airflow/airflow_dags_creator.py
Expand Down
7 changes: 6 additions & 1 deletion airflow_diagrams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Top-level package for airflow-diagrams."""
from importlib.metadata import version
from os import getcwd
from os import getcwd, getenv
from os.path import dirname, join, realpath

__app_name__ = "airflow-diagrams"
__version__ = version(__name__)
__location__ = realpath(join(getcwd(), dirname(__file__)))
__experimental__ = getenv("AIRFLOW_DIAGRAMS__EXPERIMENTAL", "False").lower() in (
"true",
"1",
"t",
)
63 changes: 63 additions & 0 deletions airflow_diagrams/airflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import re
from dataclasses import dataclass
from typing import Generator, Optional

from airflow_client.client.api.dag_api import DAGApi
from airflow_client.client.api_client import ApiClient, Configuration

from airflow_diagrams.class_ref import ClassRef
from airflow_diagrams.utils import experimental


@dataclass
Expand All @@ -16,6 +18,21 @@ class AirflowTask:
downstream_task_ids: list[str]
group_name: Optional[str]

def __hash__(self) -> int:
"""
Build a hash based on all attributes.
:returns: a hash of all attributes.
"""
return (
hash(self.class_ref)
^ hash(self.task_id)
^ hash(
downstream_task_id for downstream_task_id in self.downstream_task_ids
)
^ hash(self.group_name)
)

def __str__(self) -> str:
"""
Define pretty string reprenstation.
Expand Down Expand Up @@ -64,6 +81,52 @@ def get_tasks(self) -> list[AirflowTask]:
]


@experimental
def transfer_nodes(tasks: list[AirflowTask]) -> None:
"""
Transfer Nodes replaces an Airflow transfer task by two tasks i.e. source & destination clustered.
:param tasks: The tasks to modify.
"""
transfer_tasks = [
(task, match.groups())
for task in tasks
if task.class_ref.module_path
and ".transfers." in task.class_ref.module_path
and (match := re.search(r"(\w+)To(\w+)", task.class_ref.class_name))
]

for task, (source_class_name, destination_class_name) in transfer_tasks:
source_task_id = f"[SOURCE] {task.task_id}"
destination_task_id = f"[DESTINATION] {task.task_id}"
source = AirflowTask(
class_ref=ClassRef(
module_path=None, # We don't know if the original module_path belongs to source or destination
class_name=source_class_name,
),
task_id=source_task_id,
downstream_task_ids=[destination_task_id],
group_name=task.task_id,
)
destination = AirflowTask(
class_ref=ClassRef(
module_path=None, # We don't know if the original module_path belongs to source or destination
class_name=destination_class_name,
),
task_id=destination_task_id,
downstream_task_ids=task.downstream_task_ids,
group_name=task.task_id,
)
tasks.extend([source, destination])
tasks.remove(task)

transfer_task_ids = list(map(lambda task: task[0].task_id, transfer_tasks))
for t_idx, t in enumerate(tasks):
for dt_idx, dt_id in enumerate(t.downstream_task_ids):
if dt_id in transfer_task_ids:
tasks[t_idx].downstream_task_ids[dt_idx] = f"[SOURCE] {dt_id}"


class AirflowApiTree:
"""Retrieve Airflow Api information as a Tree."""

Expand Down
11 changes: 7 additions & 4 deletions airflow_diagrams/class_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class ClassRef:
"""A unique reference to a python class."""

module_path: str
module_path: Optional[str]
class_name: str

def __hash__(self) -> int:
Expand All @@ -29,7 +29,9 @@ def __str__(self) -> str:
:returns: the string representation of the class ref.
"""
return f"{self.module_path}.{self.class_name}"
if self.module_path:
return f"{self.module_path}.{self.class_name}"
return self.class_name

@staticmethod
def from_string(string: str) -> "ClassRef":
Expand All @@ -40,8 +42,9 @@ def from_string(string: str) -> "ClassRef":
:returns: the ClassRef object.
"""
module_path, class_name = string.rsplit(".", 1)
return ClassRef(module_path, class_name)
if "." in string:
return ClassRef(*string.rsplit(".", 1))
return ClassRef(module_path=None, class_name=string)


@dataclass
Expand Down
12 changes: 11 additions & 1 deletion airflow_diagrams/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typer import Argument, Exit, Option

from airflow_diagrams import __app_name__, __version__
from airflow_diagrams.airflow import retrieve_airflow_info
from airflow_diagrams.airflow import retrieve_airflow_info, transfer_nodes
from airflow_diagrams.class_ref import (
ClassRef,
ClassRefMatcher,
Expand Down Expand Up @@ -119,6 +119,11 @@ def generate( # dead: disable
exists=True,
dir_okay=False,
),
experimental: bool = Argument(
False,
envvar="AIRFLOW_DIAGRAMS__EXPERIMENTAL",
help="Enable experimental features by setting the variable to 'true'.",
),
) -> None:
if verbose:
rprint("💬Running with verbose output...")
Expand All @@ -130,6 +135,9 @@ def generate( # dead: disable
)
install(max_frames=0)

if experimental:
rprint("🧪Running with experimental features...")

mappings: dict = load_mappings(mapping_file) if mapping_file else {}

diagrams_class_refs: list[ClassRef] = retrieve_class_refs(
Expand Down Expand Up @@ -189,6 +197,8 @@ def generate( # dead: disable
rprint(f"[blue]🪄 Processing Airflow DAG {airflow_dag_id}...")
diagram_context = DiagramContext(airflow_dag_id)

transfer_nodes(airflow_tasks)

for airflow_task in airflow_tasks:
rprint(f"[blue dim] 🪄 Processing {airflow_task}...")
class_ref_matcher = ClassRefMatcher(
Expand Down
2 changes: 1 addition & 1 deletion airflow_diagrams/diagram.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ with Diagram("{{ name }}", show=False):
{% for node in nodes -%}
{% if node.cluster -%}
with {{ node.cluster.get_variable() }}:
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{{ node.get_variable() }} = {{ node.class_name }}()
{% else -%}
{{ node.get_variable() }} = {{ node.class_name }}("{{ node.get_label(label_wrap) }}")
{% endif -%}
Expand Down
14 changes: 13 additions & 1 deletion airflow_diagrams/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
from pathlib import Path

import yaml

from airflow_diagrams import __location__
from airflow_diagrams import __experimental__, __location__


def load_abbreviations() -> dict:
Expand Down Expand Up @@ -31,3 +32,14 @@ def load_mappings(file: Path) -> dict:
"r",
) as mapping_yaml:
return yaml.safe_load(mapping_yaml)


def experimental(func):
"""Decorate experimental features."""

def wrapper(*args, **kwargs):
if __experimental__:
logging.debug("Calling experimental feature: %s", func.__name__)
func(*args, **kwargs)

return wrapper
Binary file modified assets/images/usage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 85 additions & 8 deletions tests/test_airflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from airflow_diagrams.airflow import AirflowDag, AirflowTask
import pytest

from airflow_diagrams import __experimental__
from airflow_diagrams.airflow import AirflowDag, AirflowTask, transfer_nodes
from airflow_diagrams.class_ref import ClassRef


def test_airflow_dag_eq(airflow_api_tree):
"""Test Airflow DAG equality"""
kwargs = dict(dag_id="test_dag", dag_api=airflow_api_tree.dag_api)
airflow_dag = AirflowDag(**kwargs)
assert airflow_dag == AirflowDag(**kwargs)
assert airflow_dag != dict(**kwargs)


def test_airflow_dag_get_tasks(airflow_api_tree):
"""Test getting tasks from Airflow DAG"""
dag_id = "test_dag"
Expand Down Expand Up @@ -38,6 +49,79 @@ def test_airflow_dag_get_tasks(airflow_api_tree):
]


@pytest.mark.skipif(not __experimental__, reason="Transfer nodes are experimental.")
def test_transfer_nodes():
"""Test getting tasks from Airflow DAG"""
tasks = [
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_0",
downstream_task_ids=["test_task_1"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path="foo.transfers.bar",
class_name="FooToBar",
),
task_id="test_task_1",
downstream_task_ids=["test_task_2"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_2",
downstream_task_ids=[],
group_name=None,
),
]
transfer_nodes(tasks)
assert set(tasks) == {
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_0",
downstream_task_ids=["[SOURCE] test_task_1"],
group_name=None,
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Foo",
),
task_id="[SOURCE] test_task_1",
downstream_task_ids=["[DESTINATION] test_task_1"],
group_name="test_task_1",
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Bar",
),
task_id="[DESTINATION] test_task_1",
downstream_task_ids=["test_task_2"],
group_name="test_task_1",
),
AirflowTask(
class_ref=ClassRef(
module_path=None,
class_name="Fizz",
),
task_id="test_task_2",
downstream_task_ids=[],
group_name=None,
),
}


def test_airflow_api_tree_get_dags(airflow_api_tree):
"""Test getting dags from Airflow API Tree"""
airflow_api_tree.dag_api.get_dags.return_value = dict(
Expand Down Expand Up @@ -66,10 +150,3 @@ def test_airflow_api_tree_get_dags_with_dag_id(airflow_api_tree):
AirflowDag(dag_id=dag_id, dag_api=airflow_api_tree.dag_api),
]
airflow_api_tree.dag_api.assert_not_called()


def test_airflow_dag_eq(airflow_api_tree):
"""Test Airflow DAG equality"""
airflow_dag_kwargs = dict(dag_id="test_dag", dag_api=airflow_api_tree.dag_api)
assert AirflowDag(**airflow_dag_kwargs) == AirflowDag(**airflow_dag_kwargs)
assert AirflowDag(**airflow_dag_kwargs) != dict(**airflow_dag_kwargs)
18 changes: 18 additions & 0 deletions tests/test_class_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def class_ref():
)


@pytest.fixture()
def class_ref_without_module_path():
return ClassRef(
module_path=None,
class_name="ClassNameOperator",
)


@pytest.fixture()
def class_ref_matcher(class_ref):
return ClassRefMatcher(
Expand All @@ -39,6 +47,16 @@ def test_class_ref_str_and_from_string(class_ref):
assert ClassRef.from_string(str(class_ref)) == class_ref


def test_class_ref_str_and_from_string_without_module_path(
class_ref_without_module_path,
):
"""Test converting a ClassRef to str & creating a ClassRef from a string"""
assert (
ClassRef.from_string(str(class_ref_without_module_path))
== class_ref_without_module_path
)


def test_class_ref_matcher_match(class_ref_matcher):
"""Test matching"""
assert class_ref_matcher.match() == class_ref_matcher.choices[0]
Expand Down
Loading

0 comments on commit 1c9d005

Please sign in to comment.