diff --git a/brickflow/engine/task.py b/brickflow/engine/task.py index d5fe1b0d..b1dd9fef 100644 --- a/brickflow/engine/task.py +++ b/brickflow/engine/task.py @@ -4,6 +4,7 @@ import dataclasses import functools import inspect +import json import logging import textwrap from dataclasses import dataclass, field @@ -681,7 +682,26 @@ def _skip_because_not_selected(self) -> Tuple[bool, Optional[str]]: ) if selected_tasks is None or selected_tasks == "": return False, None - selected_task_list = selected_tasks.split(",") + + if selected_tasks.startswith("[") and selected_tasks.endswith("]"): + try: + selected_task_list = json.loads(selected_tasks) + except json.JSONDecodeError: + selected_task_list = [] + _ilog.info( + "Invalid JSON list in `brickflow_internal_only_run_tasks` parameter. Nothing will be skipped." + ) + except Exception as e: + selected_task_list = [] + _ilog.info( + "Error parsing `brickflow_internal_only_run_tasks` parameter as JSON, nothing to skip. Error: %s", + str(e), + ) + else: + selected_task_list = selected_tasks.split(",") + + selected_task_list = [task.strip() for task in selected_task_list] + if self.name not in selected_task_list: return ( True, diff --git a/tests/engine/test_task.py b/tests/engine/test_task.py index 9c3e54a6..821f5732 100644 --- a/tests/engine/test_task.py +++ b/tests/engine/test_task.py @@ -174,9 +174,16 @@ def test_should_skip_false(self, task_coms_mock: Mock): assert reason is None ctx._configure() + @pytest.mark.parametrize( + "tasks", + [ + "somethingelse", # other task + f"[{task_function_4.__name__}]", # invalid JSON list defaults to no skip + ], + ) @patch("brickflow.context.ctx.get_parameter") - def test_skip_not_selected_task(self, dbutils): - dbutils.value = "sometihngelse" + def test_skip_not_selected_task(self, dbutils, tasks): + dbutils.return_value = tasks skip, reason = wf.get_task( task_function_4.__name__ )._skip_because_not_selected() @@ -189,9 +196,18 @@ def test_skip_not_selected_task(self, dbutils): ) assert wf.get_task(task_function_4.__name__).execute() is None + @pytest.mark.parametrize( + "tasks", + [ + task_function_4.__name__, # clean string + f'["{task_function_4.__name__}"]', # clean JSON list + f'[" {task_function_4.__name__} "]', # spaced JSON list + f" {task_function_4.__name__} ", # spaced string + ], + ) @patch("brickflow.context.ctx.get_parameter") - def test_no_skip_selected_task(self, dbutils: Mock): - dbutils.return_value = task_function_4.__name__ + def test_no_skip_selected_task(self, dbutils, tasks): + dbutils.return_value = tasks skip, reason = wf.get_task( task_function_4.__name__ )._skip_because_not_selected()