Skip to content

Commit

Permalink
[FEATURE] Support remote workspace for RunJobTask (#140)
Browse files Browse the repository at this point in the history
* run job

* task type override

* patch

* codegen tests

* tests

* docs

* todo

* add tests

---------

Co-authored-by: Maxim Mityutko <[email protected]>
  • Loading branch information
maxim-mityutko and Maxim Mityutko authored Jul 22, 2024
1 parent d579492 commit 6141b5c
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 8 deletions.
24 changes: 24 additions & 0 deletions brickflow/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,30 @@ def _add_task(
else:
ensure_plugins = ensure_brickflow_plugins

# NOTE: REMOTE WORKSPACE RUN JOB OVERRIDE
# This is a temporary override for the RunJobTask because Databricks does not natively support
# triggering the job run in the remote workspace. By default, Databricks SDK derives the workspace URL
# from the runtime, and hence it is not required by the RunJobTask. The assumption is that if `host` parameter
# is set, user wants to trigger a remote job, in this case we set the task type to BRICKFLOW_TASK to
# enforce notebook type execution and replacing the original callable function with the RunJobInRemoteWorkspace
if task_type == TaskType.RUN_JOB_TASK:
func = f()
if func.host:
from brickflow_plugins.databricks.run_job import RunJobInRemoteWorkspace

task_type = TaskType.BRICKFLOW_TASK

def run_job_func() -> Callable:
# Using parameter values from the original RunJobTask
return RunJobInRemoteWorkspace(
job_name=func.job_name,
databricks_host=func.host,
databricks_token=func.token,
).execute()

f = run_job_func
# NOTE: END REMOTE WORKSPACE RUN JOB OVERRIDE

self.tasks[task_id] = Task(
task_id=task_id,
task_func=f,
Expand Down
55 changes: 55 additions & 0 deletions brickflow_plugins/databricks/run_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Union
from pydantic import SecretStr

from databricks.sdk import WorkspaceClient
from brickflow.context import ctx
from brickflow.engine.utils import get_job_id


class RunJobInRemoteWorkspace:
"""
Currently Databricks does not natively support running a job in a remote workspace via the RunJobTask.
This plugin adds this functionality. However, it aims to be a temporary solution until Databricks adds this
functionality natively.
The plugin does not support neither passing the parameters to the remote job, nor waiting for the job to finish.
Examples
--------
service_principle_pat = ctx.dbutils.secrets.get("scope", "service_principle_id")
WorkflowDependencySensor(
databricks_host=https://your_workspace_url.cloud.databricks.com,
databricks_token=service_principle_pat,
job_name="foo",
)
In the above snippet Databricks secrets are used as a secure service to store the databricks token.
If you get your token from another secret management service, like AWS Secrets Manager, GCP Secret Manager
or Azure Key Vault, just pass it in the databricks_token argument.
"""

def __init__(
self,
databricks_host: str,
databricks_token: Union[str, SecretStr],
job_name: str,
):
self.databricks_host = databricks_host
self.databricks_token = (
databricks_token
if isinstance(databricks_token, SecretStr)
else SecretStr(databricks_token)
)
self.job_name = job_name
self._workspace_obj = WorkspaceClient(
host=self.databricks_host, token=self.databricks_token.get_secret_value()
)

def execute(self):
job_id = get_job_id(
host=self.databricks_host,
token=self.databricks_token,
job_name=self.job_name,
)
# TODO: add support for passing parameters to the remote job
# TODO: wait for the job to finish
run = self._workspace_obj.jobs.run_now(job_id)
ctx.log.info("Job run status: %s", run.response)
6 changes: 6 additions & 0 deletions docs/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ RunJobTask class can accept the following as inputs:<br />
&emsp;<b>host [Optional]</b>: The URL of the Databricks workspace.<br />
&emsp;<b>token [Optional]</b>: The Databricks API token.

!!! important

Databricks does not natively support triggering the job run in the remote workspace. Only set `host` and `token`
parameters when remote trigger is required, it will envoke RunJobInRemoteWorkspace plugin which will transparently
substitute the native execution. No extra action will be required from the user.


#### JAR Task

Expand Down
30 changes: 30 additions & 0 deletions tests/codegen/expected_bundles/dev_bundle_monorepo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,36 @@ targets:
retry_on_timeout: null
task_key: run_job_task_a
timeout_seconds: null
- "depends_on":
- "outcome": null
"task_key": "notebook_task_a"
"email_notifications": {}
"existing_cluster_id": "existing_cluster_id"
"libraries": []
"max_retries": null
"min_retry_interval_millis": null
"notebook_task":
"base_parameters":
"all_tasks1": "test"
"all_tasks3": "123"
"brickflow_env": "dev"
"brickflow_internal_only_run_tasks": ""
"brickflow_internal_task_name": "{{task_key}}"
"brickflow_internal_workflow_name": "test"
"brickflow_internal_workflow_prefix": ""
"brickflow_internal_workflow_suffix": ""
"brickflow_job_id": "{{job_id}}"
"brickflow_parent_run_id": "{{parent_run_id}}"
"brickflow_run_id": "{{run_id}}"
"brickflow_start_date": "{{start_date}}"
"brickflow_start_time": "{{start_time}}"
"brickflow_task_key": "{{task_key}}"
"brickflow_task_retry_count": "{{task_retry_count}}"
"notebook_path": "some/path/to/root/test_databricks_bundle.py"
"source": "GIT"
"retry_on_timeout": null
"task_key": "run_job_task_b"
"timeout_seconds": null
- depends_on: []
email_notifications: {}
existing_cluster_id: existing_cluster_id
Expand Down
30 changes: 30 additions & 0 deletions tests/codegen/expected_bundles/dev_bundle_polyrepo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,36 @@ targets:
retry_on_timeout: null
task_key: run_job_task_a
timeout_seconds: null
- "depends_on":
- "outcome": null
"task_key": "notebook_task_a"
"email_notifications": {}
"existing_cluster_id": "existing_cluster_id"
"libraries": []
"max_retries": null
"min_retry_interval_millis": null
"notebook_task":
"base_parameters":
"all_tasks1": "test"
"all_tasks3": "123"
"brickflow_env": "dev"
"brickflow_internal_only_run_tasks": ""
"brickflow_internal_task_name": "{{task_key}}"
"brickflow_internal_workflow_name": "test"
"brickflow_internal_workflow_prefix": ""
"brickflow_internal_workflow_suffix": ""
"brickflow_job_id": "{{job_id}}"
"brickflow_parent_run_id": "{{parent_run_id}}"
"brickflow_run_id": "{{run_id}}"
"brickflow_start_date": "{{start_date}}"
"brickflow_start_time": "{{start_time}}"
"brickflow_task_key": "{{task_key}}"
"brickflow_task_retry_count": "{{task_retry_count}}"
"notebook_path": "test_databricks_bundle.py"
"source": "GIT"
"retry_on_timeout": null
"task_key": "run_job_task_b"
"timeout_seconds": null
- depends_on: []
email_notifications: {}
existing_cluster_id: existing_cluster_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,39 @@ targets:
retry_on_timeout: null
task_key: run_job_task_a
timeout_seconds: null
- "depends_on":
- "outcome": null
"task_key": "notebook_task_a"
"email_notifications": {}
"existing_cluster_id": "existing_cluster_id"
"libraries":
- "pypi":
"package": "brickflows==0.1.0"
"repo": null
"max_retries": null
"min_retry_interval_millis": null
"notebook_task":
"base_parameters":
"all_tasks1": "test"
"all_tasks3": "123"
"brickflow_env": "dev"
"brickflow_internal_only_run_tasks": ""
"brickflow_internal_task_name": "{{task_key}}"
"brickflow_internal_workflow_name": "test"
"brickflow_internal_workflow_prefix": ""
"brickflow_internal_workflow_suffix": ""
"brickflow_job_id": "{{job_id}}"
"brickflow_parent_run_id": "{{parent_run_id}}"
"brickflow_run_id": "{{run_id}}"
"brickflow_start_date": "{{start_date}}"
"brickflow_start_time": "{{start_time}}"
"brickflow_task_key": "{{task_key}}"
"brickflow_task_retry_count": "{{task_retry_count}}"
"notebook_path": "test_databricks_bundle.py"
"source": "GIT"
"retry_on_timeout": null
"task_key": "run_job_task_b"
"timeout_seconds": null
- depends_on: []
email_notifications: {}
existing_cluster_id: existing_cluster_id
Expand Down
30 changes: 30 additions & 0 deletions tests/codegen/expected_bundles/local_bundle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,36 @@ targets:
retry_on_timeout: null
task_key: run_job_task_a
timeout_seconds: null
- "depends_on":
- "outcome": null
"task_key": "notebook_task_a"
"email_notifications": {}
"existing_cluster_id": "existing_cluster_id"
"libraries": []
"max_retries": null
"min_retry_interval_millis": null
"notebook_task":
"base_parameters":
"all_tasks1": "test"
"all_tasks3": "123"
"brickflow_env": "local"
"brickflow_internal_only_run_tasks": ""
"brickflow_internal_task_name": "{{task_key}}"
"brickflow_internal_workflow_name": "test"
"brickflow_internal_workflow_prefix": ""
"brickflow_internal_workflow_suffix": ""
"brickflow_job_id": "{{job_id}}"
"brickflow_parent_run_id": "{{parent_run_id}}"
"brickflow_run_id": "{{run_id}}"
"brickflow_start_date": "{{start_date}}"
"brickflow_start_time": "{{start_time}}"
"brickflow_task_key": "{{task_key}}"
"brickflow_task_retry_count": "{{task_retry_count}}"
"notebook_path": "test_databricks_bundle.py"
"source": "WORKSPACE"
"retry_on_timeout": null
"task_key": "run_job_task_b"
"timeout_seconds": null
- depends_on: []
email_notifications: {}
existing_cluster_id: existing_cluster_id
Expand Down
30 changes: 30 additions & 0 deletions tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ targets:
retry_on_timeout: null
task_key: run_job_task_a
timeout_seconds: null
- "depends_on":
- "outcome": null
"task_key": "notebook_task_a"
"email_notifications": {}
"existing_cluster_id": "existing_cluster_id"
"libraries": []
"max_retries": null
"min_retry_interval_millis": null
"notebook_task":
"base_parameters":
"all_tasks1": "test"
"all_tasks3": "123"
"brickflow_env": "local"
"brickflow_internal_only_run_tasks": ""
"brickflow_internal_task_name": "{{task_key}}"
"brickflow_internal_workflow_name": "test"
"brickflow_internal_workflow_prefix": ""
"brickflow_internal_workflow_suffix": ""
"brickflow_job_id": "{{job_id}}"
"brickflow_parent_run_id": "{{parent_run_id}}"
"brickflow_run_id": "{{run_id}}"
"brickflow_start_date": "{{start_date}}"
"brickflow_start_time": "{{start_time}}"
"brickflow_task_key": "{{task_key}}"
"brickflow_task_retry_count": "{{task_retry_count}}"
"notebook_path": "test_databricks_bundle.py"
"source": "WORKSPACE"
"retry_on_timeout": null
"task_key": "run_job_task_b"
"timeout_seconds": null
- depends_on: []
email_notifications: {}
existing_cluster_id: existing_cluster_id
Expand Down
7 changes: 7 additions & 0 deletions tests/codegen/sample_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def run_job_task_a():
return RunJobTask(job_name="dev_object_raw_to_cleansed") # type: ignore


@wf.run_job_task(
depends_on=notebook_task_a,
)
def run_job_task_b():
return RunJobTask(job_name="dev_object_raw_to_cleansed", host="https://foo.cloud.databricks.com") # type: ignore


@wf.sql_task
def sample_sql_task_query() -> any:
return SqlTask(
Expand Down
5 changes: 4 additions & 1 deletion tests/codegen/test_databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
)
from brickflow.engine.project import Stage, Project
from brickflow.engine.task import NotebookTask
from tests.codegen.sample_workflow import wf

# `get_job_id` is being called during workflow init, hence the patch
with patch("brickflow.engine.task.get_job_id", return_value=12345678901234.0) as p:
from tests.codegen.sample_workflow import wf

# BUNDLE_FILE_NAME = str(Path(__file__).parent / f"bundle.yml")
BUNDLE_FILE_NAME = "bundle.yml"
Expand Down
38 changes: 38 additions & 0 deletions tests/databricks_plugins/test_run_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import re

import pytest
from requests_mock.mocker import Mocker as RequestsMocker

from brickflow.engine.utils import ctx
from brickflow_plugins.databricks.run_job import RunJobInRemoteWorkspace


class TestRunJob:
workspace_url = "https://42.cloud.databricks.com"
endpoint_url = f"{workspace_url}/api/.*/jobs/run-now"
response = {"run_id": 37, "number_in_job": 42}

ctx.log.propagate = True

@pytest.fixture(autouse=True)
def mock_get_job_id(self, mocker):
mocker.patch(
"brickflow_plugins.databricks.run_job.get_job_id",
return_value=1,
)

@pytest.fixture(autouse=True, name="api")
def mock_api(self):
rm = RequestsMocker()
rm.post(re.compile(self.endpoint_url), json=self.response, status_code=int(200))
yield rm

def test_run_job(self, api, caplog):
with api:
RunJobInRemoteWorkspace(
databricks_host=self.workspace_url,
databricks_token="token",
job_name="foo",
).execute()

assert "RunNowResponse(number_in_job=42, run_id=37)" in caplog.text
6 changes: 6 additions & 0 deletions tests/engine/sample_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TaskType,
TaskResponse,
DLTPipeline,
RunJobTask,
)
from brickflow.engine.workflow import Workflow, WorkflowPermissions, User

Expand Down Expand Up @@ -102,3 +103,8 @@ def task_function_4():
)
def custom_python_task_push():
pass


@wf.run_job_task()
def run_job_task():
return RunJobTask(job_name="foo", host="https://foo.cloud.databricks.com")
4 changes: 3 additions & 1 deletion tests/engine/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
ExecuteError,
)
from brickflow.engine.workflow import Workflow
from tests.engine.sample_workflow import wf, task_function

with patch("brickflow.engine.task.get_job_id", return_value=12345678901234):
from tests.engine.sample_workflow import wf, task_function


def side_effect(a, _): # noqa
Expand Down
Loading

0 comments on commit 6141b5c

Please sign in to comment.