Skip to content

Commit

Permalink
Adding support for default_args, as well as on_failure_callbacks defi…
Browse files Browse the repository at this point in the history
…ned using file and name
  • Loading branch information
jroach-astronomer authored and pankajastro committed Oct 21, 2024
1 parent e0876b3 commit e17371e
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 38 deletions.
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ consumer_dag:
bread_type: 'Sourdough'
```
![custom_operators.png](img/custom_operators.png)

### Callbacks
**dag-factory** also supports using "callbacks" at the DAG, Task, and TaskGroup level. These callbacks can be defined in
a few different ways. The first points directly to a Python function that has been defined in the `include/callbacks.py`
file.

```yaml
example_dag1:
on_failure_callback: include.callbacks.example_callback1
...
```

Here, the `on_success_callback` points to first a file, and then to a function name within that file. Notice that this
callback is defined using `default_args`, meaning this callback will be applied to all tasks.

```yaml
example_dag1:
...
default_args:
on_success_callback_file: include.callbacks
on_success_callback_name: example_callback1
```

**dag-factory** users can also leverage provider-built tools when configuring callbacks. In this example, the
`send_slack_notification` function from the Slack provider is used to dispatch a message when a DAG failure occurs. This
function is passed to the `callback` key under `on_failure_callback`. This pattern allows for callback definitions to
take parameters (such as `text`, `channel`, and `username`, as shown here).

**Note that this functionality is currently only supported for `on_failure_callback`'s defined at the DAG-level, or in
`default_args`. Support for other callback types and Task/TaskGroup-level definitions are coming soon.**

```yaml
example_dag1:
on_failure_callback:
callback: airflow.providers.slack.notifications.slack import send_slack_notification
text: |
:red_circle: Task Failed.
This task has failed and needs to be addressed.
Please remediate this issue ASAP.
channel: analytics-alerts
username: Airflow
...
```

## Notes

### HttpSensor (since 1.0.0)
Expand Down
60 changes: 35 additions & 25 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ def get_dag_params(self) -> Dict[str, Any]:
)

if utils.check_dict_key(dag_params["default_args"], "on_failure_callback"):
if isinstance(dag_params["default_args"]["on_failure_callback"], str):
dag_params["default_args"]["on_failure_callback"]: Callable = import_string(
dag_params["default_args"]["on_failure_callback"]
)
dag_params["default_args"]["on_failure_callback"]: Callable = self.set_callback(
parameters=dag_params["default_args"], callback_type="on_failure_callback"
)

if utils.check_dict_key(dag_params["default_args"], "on_retry_callback"):
if isinstance(dag_params["default_args"]["on_retry_callback"], str):
Expand All @@ -199,7 +198,9 @@ def get_dag_params(self) -> Dict[str, Any]:
dag_params["on_success_callback"]: Callable = import_string(dag_params["on_success_callback"])

if utils.check_dict_key(dag_params, "on_failure_callback"):
self.set_callback(parameters=dag_params, callback_type="on_failure_callback")
dag_params["on_failure_callback"]: Callable = self.set_callback(
parameters=dag_params, callback_type="on_failure_callback"
)

if utils.check_dict_key(dag_params, "on_success_callback_name") and utils.check_dict_key(
dag_params, "on_success_callback_file"
Expand All @@ -212,9 +213,8 @@ def get_dag_params(self) -> Dict[str, Any]:
if utils.check_dict_key(dag_params, "on_failure_callback_name") and utils.check_dict_key(
dag_params, "on_failure_callback_file"
):
dag_params["on_failure_callback"]: Callable = utils.get_python_callable(
dag_params["on_failure_callback_name"],
dag_params["on_failure_callback_file"],
dag_params["on_failure_callback"] = self.set_callback(
parameters=dag_params, callback_type="on_failure_callback", has_name_and_file=True
)

if utils.check_dict_key(dag_params["default_args"], "on_success_callback_name") and utils.check_dict_key(
Expand All @@ -229,10 +229,8 @@ def get_dag_params(self) -> Dict[str, Any]:
if utils.check_dict_key(dag_params["default_args"], "on_failure_callback_name") and utils.check_dict_key(
dag_params["default_args"], "on_failure_callback_file"
):

dag_params["default_args"]["on_failure_callback"]: Callable = utils.get_python_callable(
dag_params["default_args"]["on_failure_callback_name"],
dag_params["default_args"]["on_failure_callback_file"],
dag_params["default_args"]["on_failure_callback"] = self.set_callback(
parameters=dag_params["default_args"], callback_type="on_failure_callback", has_name_and_file=True
)

if utils.check_dict_key(dag_params, "template_searchpath"):
Expand Down Expand Up @@ -807,34 +805,46 @@ def build(self) -> Dict[str, Union[str, DAG]]:
return {"dag_id": dag_params["dag_id"], "dag": dag}

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str) -> None:
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
"""
Update the passed-in config with the callback.
:param parameters:
:param callback_type:
:returns: None
:param has_name_and_file:
:returns: Callable
"""
# There is scenario where a callback is passed in via a file and a name. For the most part, this will be a
# Python callable that is treated similarly to a Python callable that the PythonOperator may leverage. That
# being said, what if this is not a Python callable? What if this is another type?
if has_name_and_file:
return utils.get_python_callable(
python_callable_name=parameters[f"{callback_type}_name"],
python_callable_file=parameters[f"{callback_type}_file"],
)

# If the value stored at parameters[callback_type] is a string, it should be imported under the assumption that
# it is a function that is "ready to be called"
# it is a function that is "ready to be called". If not returning the function, something like this could be
# used to update the config parameters[callback_type] = import_string(parameters[callback_type])
if isinstance(parameters[callback_type], str):
parameters[callback_type]: Callable = import_string(parameters[callback_type])
return import_string(parameters[callback_type])

# Otherwise, if the parameter[callback_type] is a dictionary, it should be treated similar to the Python
# callable
elif isinstance(parameters[callback_type], dict):
# Pull the on_failure_callback dictionary from dag_params
on_state_callback_params: dict = parameters[callback_type]

# Check to see if there is a "callable" key in the on_failure_callback dictionary. If there is, parse
# Check to see if there is a "callback" key in the on_failure_callback dictionary. If there is, parse
# out that callable, and add the parameters
if utils.check_dict_key(on_state_callback_params, "callable"):
if isinstance(on_state_callback_params["callable"], str):
on_state_callback_callable: Callable = import_string(on_state_callback_params["callable"])
del on_state_callback_params["callable"]
if utils.check_dict_key(on_state_callback_params, "callback"):
if isinstance(on_state_callback_params["callback"], str):
on_state_callback_callable: Callable = import_string(on_state_callback_params["callback"])
del on_state_callback_params["callback"]

# Return the callable, this time, using the params provided in the YAML file, rather than a .py
# file with a callable configured
parameters[callback_type]: Callable = partial(
on_state_callback_callable, **on_state_callback_params
)
# file with a callable configured. If not returning the partial, something like this could be used
# to update the config ... parameters[callback_type]: Callable = partial(...)
return partial(on_state_callback_callable, **on_state_callback_params)

raise DagFactoryConfigException(f"Invalid type passed to {callback_type}")
54 changes: 41 additions & 13 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,17 @@
"doc_md": "##here is a doc md string",
"default_args": {
"owner": "custom_owner",
"on_failure_callback": {
"callback": f"{__name__}.empty_callback_with_params",
"param_1": "value_1",
"param_2": "value_2",
},
},
"description": "this is an example dag",
"schedule_interval": "0 3 * * *",
"tags": ["tag1", "tag2"],
"on_failure_callback": {
"callable": f"{__name__}.empty_callback_with_params",
"callback": f"{__name__}.empty_callback_with_params",
"param_1": "value_1",
"param_2": "value_2",
},
Expand Down Expand Up @@ -763,6 +768,7 @@ def test_make_task_with_callback():
assert callable(actual.on_retry_callback)


@pytest.mark.callbacks
def test_dag_with_callback_name_and_file():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE, DEFAULT_CONFIG)
dag = td.build().get("dag")
Expand All @@ -783,6 +789,7 @@ def test_dag_with_callback_name_and_file():
assert not callable(td_task.on_failure_callback)


@pytest.mark.callbacks
def test_dag_with_callback_name_and_file_default_args():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS, DEFAULT_CONFIG)
dag = td.build().get("dag")
Expand Down Expand Up @@ -822,26 +829,47 @@ def test_make_dag_with_callback():
td.build()


def test_on_failure_callback():
@pytest.mark.callbacks
@pytest.mark.parametrize(
"callback_type,in_default_args", [("on_failure_callback", False), ("on_failure_callback", True)]
)
def test_dag_with_on_callback_str(callback_type, in_default_args):
# Using a different config (DAG_CONFIG_CALLBACK) than below
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK, DEFAULT_CONFIG)
td.build()

config_obj = td.dag_config.get("default_args") if in_default_args else td.dag_config

# Validate the .set_callback() method works as expected when importing a string,
assert callback_type in config_obj
assert callable(config_obj.get(callback_type))
assert config_obj.get(callback_type).__name__ == "print_context_callback"


@pytest.mark.callbacks
@pytest.mark.parametrize(
"callback_type,in_default_args", [("on_failure_callback", False), ("on_failure_callback", True)]
)
def test_dag_with_on_callback_and_params(callback_type, in_default_args):
# Import the DAG using the callback config that was build above
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_WITH_PARAMETERS, DEFAULT_CONFIG)
td.build()

# Check to see if on_failure_callback is in the DAG config, and the type of value that is returned
assert "on_failure_callback" in td.dag_config
config_obj = td.dag_config.get("default_args") if in_default_args else td.dag_config

# Pull the callback
on_failure_callback: functools.partial = td.dag_config.get("on_failure_callback")
# Check to see if callback_type is in the DAG config, and the type of value that is returned, pull the callback
assert callback_type in config_obj
on_callback: functools.partial = config_obj.get(callback_type)

assert isinstance(on_failure_callback, functools.partial)
assert callable(on_failure_callback)
assert on_failure_callback.func.__name__ == "empty_callback_with_params"
assert isinstance(on_callback, functools.partial)
assert callable(on_callback)
assert on_callback.func.__name__ == "empty_callback_with_params"

# Parameters
assert "param_1" in on_failure_callback.keywords
assert on_failure_callback.keywords.get("param_1") == "value_1"
assert "param_2" in on_failure_callback.keywords
assert on_failure_callback.keywords.get("param_2") == "value_2"
assert "param_1" in on_callback.keywords
assert on_callback.keywords.get("param_1") == "value_1"
assert "param_2" in on_callback.keywords
assert on_callback.keywords.get("param_2") == "value_2"


def test_get_dag_params_with_template_searchpath():
Expand Down

0 comments on commit e17371e

Please sign in to comment.