Skip to content

Commit

Permalink
Hybrid-ize tasks run_task executor entrypoint (apache#40762)
Browse files Browse the repository at this point in the history
There is an entrypoint to running tasks on executors other than the
backfill and scheduler jobs, that is the cli command tasks run_task. If
neither --local or --raw are provided, an executor instance is created
to run the task. Before this change, that was always the default
executor. This change updates that logic to check if the task instance
has been configured to run on a specific executor, if so, load that
executor to run the task instead of the default.
  • Loading branch information
o-nikolas authored Jul 13, 2024
1 parent fded2d8 commit 5ae4e37
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
5 changes: 4 additions & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
print("Could not pickle the DAG")
print(e)
raise e
executor = ExecutorLoader.get_default_executor()
if ti.executor:
executor = ExecutorLoader.load_executor(ti.executor)
else:
executor = ExecutorLoader.get_default_executor()
executor.job_id = None
executor.start()
print("Sending to executor.")
Expand Down
60 changes: 59 additions & 1 deletion tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import sys
from argparse import ArgumentParser
from contextlib import contextmanager, redirect_stdout
from importlib import reload
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -41,9 +42,11 @@
from airflow.cli.commands.task_command import LoggerMutationHelper
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound
from airflow.executors.local_executor import LocalExecutor
from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -179,7 +182,7 @@ def test_test_filters_secrets(self, capsys):

def test_cli_test_different_path(self, session, tmp_path):
"""
When thedag processor has a different dags folder
When the dag processor has a different dags folder
from the worker, ``airflow tasks run --local`` should still work.
"""
repo_root = Path(__file__).parents[3]
Expand Down Expand Up @@ -452,6 +455,61 @@ def test_cli_run_mutually_exclusive(self):
)
)

def test_cli_run_no_local_no_raw_runs_executor(self, dag_maker):
from airflow.cli.commands import task_command

with dag_maker(dag_id="test_executor", schedule="@daily") as dag:
with mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.load_executor"
) as loader_mock, mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
) as get_default_mock:
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2", executor="foo_executor_alias")

dag_maker.create_dagrun()

# Reload module to consume newly mocked executor loader
reload(task_command)

loader_mock.return_value = LocalExecutor()
get_default_mock.return_value = LocalExecutor()

# In the task1 case we will use the default executor
task_command.task_run(
self.parser.parse_args(
[
"tasks",
"run",
"test_executor",
"task1",
DEFAULT_DATE.isoformat(),
]
),
dag,
)
get_default_mock.assert_called_once()
loader_mock.assert_not_called()

# In the task2 case we will use the executor configured on the task
task_command.task_run(
self.parser.parse_args(
[
"tasks",
"run",
"test_executor",
"task2",
DEFAULT_DATE.isoformat(),
]
),
dag,
)
get_default_mock.assert_called_once() # Call from previous task
loader_mock.assert_called_once_with("foo_executor_alias")

# Reload module to remove mocked version of executor loader
reload(task_command)

def test_task_render(self):
"""
tasks render should render and displays templated fields for a given task
Expand Down

0 comments on commit 5ae4e37

Please sign in to comment.