Skip to content

Commit

Permalink
fixed documentation and refactored external task dependency sensor. (#21
Browse files Browse the repository at this point in the history
)

* fixed documentation and refactored external task dependency sensor.

* fixed oauth connection to make it lazy evaluated and not during construction

* updated docs and cleaned up arguments in airflow external task operator

* moved pari to CONTRIBUTORS.md
  • Loading branch information
stikkireddy authored Aug 25, 2023
1 parent e925c58 commit 3c60c2c
Show file tree
Hide file tree
Showing 29 changed files with 742 additions and 357 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Contributors
Thanks to the contributors who helped on this project apart from the authors
* [Danny Meijer](https://www.linkedin.com/in/dannydatascientist/)
* [Pariksheet Marotrao Barapatre](https://www.linkedin.com/in/pari-data-products/)

# Honorary Mentions
Thanks to the team below for invaluable insights and support throughout the initial release of this project
Expand All @@ -13,4 +14,3 @@ Thanks to the team below for invaluable insights and support throughout the init
* [Aditya Chaturvedi](https://www.linkedin.com/in/chaturvediaditya/)
* [Scott Haines](https://www.linkedin.com/in/scotthaines/)
* [Arijit Banerjee](https://www.linkedin.com/in/massborn/)
* [Pariksheet Marotrao Barapatre](https://www.linkedin.com/in/pari-data-products/)
22 changes: 8 additions & 14 deletions brickflow/cli/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import BaseModel, Field

from brickflow import (
get_brickflow_version,
BrickflowProjectConstants,
BrickflowDefaultEnvs,
BrickflowEnvVars,
Expand Down Expand Up @@ -322,13 +321,6 @@ def use_project(
type=str,
help="Path from project root to workflows dir",
)
@click.option(
"--deployment-mode",
prompt="Deployment mode",
type=click.Choice([BrickflowDeployMode.BUNDLE.value]),
default=BrickflowDeployMode.BUNDLE.value,
help="Deployment mode",
)
@click.option(
"-g",
"--git-https-url",
Expand All @@ -339,7 +331,7 @@ def use_project(
@click.option(
"-bfv",
"--brickflow-version",
default=get_brickflow_version(),
default=DEFAULT_BRICKFLOW_VERSION_MODE,
type=str,
prompt=INTERACTIVE_MODE,
)
Expand All @@ -362,7 +354,6 @@ def add(
name: str,
path_from_repo_root_to_project_root: str,
path_project_root_to_workflows_dir: str,
deployment_mode: str,
git_https_url: str,
brickflow_version: str,
spark_expectations_version: str,
Expand All @@ -373,7 +364,7 @@ def add(
name=name,
path_from_repo_root_to_project_root=path_from_repo_root_to_project_root,
path_project_root_to_workflows_dir=path_project_root_to_workflows_dir,
deployment_mode=deployment_mode,
deployment_mode=BrickflowDeployMode.BUNDLE.value,
)
multi_project_manager.add_project(project)

Expand Down Expand Up @@ -507,7 +498,7 @@ def handle_libraries(skip_libraries: Optional[bool] = None, **_: Any) -> None:
@projects.command(name="destroy")
@apply_bundles_deployment_options
def destroy_project(project: str, **kwargs: Any) -> None:
"""Destroy projects in the brickflow-multi-project.yml file"""
"""Destroys the deployed resources and workflows in databricks for the project"""
bf_project = multi_project_manager.get_project(project)
dir_to_change = multi_project_manager.get_project_ref(project).root_yaml_rel_path
handle_libraries(**kwargs)
Expand Down Expand Up @@ -557,7 +548,8 @@ def project_synth(**_: Any) -> None:
)
@apply_bundles_deployment_options
def sync_project(project: str, **kwargs: Any) -> None:
"""Sync project into databricks workspace from local. It is only unidirectional."""
"""Sync project file tree into databricks workspace from local.
It is only one way from local to databricks workspace."""
bf_project = multi_project_manager.get_project(project)
dir_to_change = multi_project_manager.get_project_ref(project).root_yaml_rel_path
handle_libraries(**kwargs)
Expand All @@ -584,12 +576,14 @@ def synth_bundles_for_project(project: str, **kwargs: Any) -> None:
project_synth(
workflows_dir=bf_project.path_project_root_to_workflows_dir, **kwargs
)
_ilog.info("SUCCESSFULLY SYNTHESIZED BUNDLE.YML FOR PROJECT %s", project)


@projects.command(name="deploy")
@apply_bundles_deployment_options
def deploy_project(project: str, **kwargs: Any) -> None:
"""Deploy projects in the brickflow-multi-project.yml file"""
"""Deploy the resources and workflows to databricks for the project
configured in the brickflow-project-root.yml file"""
bf_project = multi_project_manager.get_project(project)
dir_to_change = multi_project_manager.get_project_ref(project).root_yaml_rel_path
handle_libraries(**kwargs)
Expand Down
6 changes: 5 additions & 1 deletion brickflow_plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def setup_logger():

log = setup_logger()

from brickflow_plugins.airflow.operators.external_tasks import TaskDependencySensor
from brickflow_plugins.airflow.operators.external_tasks import (
TaskDependencySensor,
AirflowProxyOktaClusterAuth,
)
from brickflow_plugins.airflow.operators.native_operators import (
BashOperator,
BranchPythonOperator,
Expand Down Expand Up @@ -51,6 +54,7 @@ def ensure_installation():

__all__: List[str] = [
"TaskDependencySensor",
"AirflowProxyOktaClusterAuth",
"BashOperator",
"BranchPythonOperator",
"ShortCircuitOperator",
Expand Down
127 changes: 67 additions & 60 deletions brickflow_plugins/airflow/operators/external_tasks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import abc
import json
import logging
import os
from http import HTTPStatus
from typing import Callable

import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from airflow.models import Connection
from airflow.sensors.base import BaseSensorOperator
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

from brickflow_plugins import log


class MapDagSchedule:
class DagSchedule:
def get_schedule(self, wf_id: str, **args):
"""
Function that the sensors defined while deriving this class should
override.
"""
raise Exception("Override me.")

def get_task_run_status(
self, wf_id: str, task_id: str, run_date=None, cluster_id=None, **args
):
def get_task_run_status(self, wf_id: str, task_id: str, run_date=None, **args):
"""
Function that the sensors defined while deriving this class should
override.
Expand All @@ -32,30 +32,57 @@ def get_task_run_status(
# TODO: implement Delta Json


class MapDagScheduleHelper(MapDagSchedule):
def __init__(self, okta_conn_id: str):
self._okta_conn: Connection = Connection.get_connection_from_secrets(
okta_conn_id
)
class AirflowClusterAuth(abc.ABC):
@abc.abstractmethod
def get_access_token(self) -> str:
pass

@abc.abstractmethod
def get_airflow_api_url(self) -> str:
pass

@abc.abstractmethod
def get_version(self) -> str:
pass


class AirflowProxyOktaClusterAuth(AirflowClusterAuth):
def __init__(
self,
oauth2_conn_id: str,
airflow_cluster_url: str,
airflow_version: str = None,
get_airflow_version_callback: Callable[[str, str], str] = None,
):
self._airflow_version = airflow_version
self._get_airflow_version_callback = get_airflow_version_callback
self._oauth2_conn_id = oauth2_conn_id
self._airflow_url = airflow_cluster_url.rstrip("/")
if airflow_version is None and get_airflow_version_callback is None:
raise Exception(
"Either airflow_version or get_airflow_version_callback must be provided"
)

def get_okta_conn(self):
return Connection.get_connection_from_secrets(self._oauth2_conn_id)

def get_okta_url(self) -> str:
conn_type = self._okta_conn.conn_type
host = self._okta_conn.host
schema = self._okta_conn.schema
conn_type = self.get_okta_conn().conn_type
host = self.get_okta_conn().host
schema = self.get_okta_conn().schema
return f"{conn_type}://{host}/{schema}"

def get_okta_client_id(self) -> str:
return self._okta_conn.login
return self.get_okta_conn().login

def get_okta_client_secret(self) -> str:
return self._okta_conn.get_password()
return self.get_okta_conn().get_password()

def get_access_token(self) -> str:
okta_url = self.get_okta_url()
client_id = self.get_okta_client_id()
client_secret = self.get_okta_client_secret()

okta_url = os.getenv("OKTA_URL", okta_url)
payload = (
"client_id="
+ client_id
Expand All @@ -80,37 +107,27 @@ def get_access_token(self) -> str:
token_data = response.json()["access_token"]
return token_data

def get_airflow_api_url(self, cluster_id: str) -> str:
def get_airflow_api_url(self) -> str:
# TODO: templatize this to a env variable
base_api_url = f"https://proxy.us-east-1.map.nike.com/{cluster_id}"
return base_api_url
return self._airflow_url

def get_version(self, cluster_id: str) -> str:
session = requests.Session()
retries = Retry(
total=10, backoff_factor=1, status_forcelist=[502, 503, 504, 500]
)
session.mount("https://", HTTPAdapter(max_retries=retries))
version_check_url = (
self.get_airflow_api_url(cluster_id) + "/admin/rest_api/api?api=version"
)
logging.info(version_check_url)
otoken = self.get_access_token()
headers = {"Authorization": "Bearer " + otoken, "Accept": "application/json"}
out_version = "UKN"
response = session.get(version_check_url, headers=headers, verify=False)
if response.status_code == HTTPStatus.OK:
out_version = response.json()["output"]
log.info(response.text.encode("utf8"))
session.close()
return out_version

def get_task_run_status(
self, wf_id: str, task_id: str, run_date=None, cluster_id=None, **args
):
token_data = self.get_access_token()
api_url = self.get_airflow_api_url(cluster_id)
version_nr = self.get_version(cluster_id)
def get_version(self) -> str:
if self._airflow_version is not None:
return self._airflow_version
else:
return self._get_airflow_version_callback(
self._airflow_url, self.get_access_token()
)


class AirflowScheduleHelper(DagSchedule):
def __init__(self, airflow_auth: AirflowClusterAuth):
self._airflow_auth = airflow_auth

def get_task_run_status(self, wf_id: str, task_id: str, run_date=None, **kwargs):
token_data = self._airflow_auth.get_access_token()
api_url = self._airflow_auth.get_airflow_api_url()
version_nr = self._airflow_auth.get_version()
dag_id = wf_id
headers = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -162,28 +179,21 @@ def get_task_run_status(

return o_task_status

def get_schedule(self, wf_id: str, **kwargs):
"""
get work flow schedule cron syntax
"""
raise Exception("Do not have implementation")


class TaskDependencySensor(BaseSensorOperator):
def __init__(
self,
external_dag_id,
external_task_id,
okta_conn_id,
airflow_cluster_auth: AirflowClusterAuth,
allowed_states=None,
execution_delta=None,
execution_delta_json=None,
cluster_id=None,
*args,
**kwargs,
):
super(TaskDependencySensor, self).__init__(*args, **kwargs)
self.okta_conn_id = okta_conn_id
self.airflow_auth = airflow_cluster_auth
self.allowed_states = allowed_states or ["success"]

if execution_delta_json and execution_delta:
Expand All @@ -196,10 +206,8 @@ def __init__(
self.allowed_states = allowed_states
self.execution_delta = execution_delta
self.execution_delta_json = execution_delta_json
self.cluster_id = cluster_id

self._poke_count = 0
self.dbx_wf_id = kwargs.get("dbx_wf_id")

def poke(self, context):
log.info(f"executing poke.. {self._poke_count}")
Expand All @@ -208,11 +216,10 @@ def poke(self, context):

exec_time = context["execution_date"]

task_status = MapDagScheduleHelper(self.okta_conn_id).get_task_run_status(
task_status = AirflowScheduleHelper(self.airflow_auth).get_task_run_status(
wf_id=self.external_dag_id,
task_id=self.external_task_id,
run_date=exec_time,
cluster_id=self.cluster_id,
)
log.info(f"task_status= {task_status}")

Expand Down
4 changes: 4 additions & 0 deletions docs/api/airflow_external_task_dependency.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow_plugins.airflow.operators.external_tasks
handler: python
Expand Down
4 changes: 4 additions & 0 deletions docs/api/airflow_native_operators.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow_plugins.airflow.operators.native_operators
handler: python
Expand Down
4 changes: 4 additions & 0 deletions docs/api/cli.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow.cli
handler: python
Expand Down
4 changes: 4 additions & 0 deletions docs/api/compute.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow.engine.compute
handler: python
Expand Down
4 changes: 4 additions & 0 deletions docs/api/context.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow.context.context
handler: python
Expand Down
1 change: 0 additions & 1 deletion docs/api/misc.md

This file was deleted.

5 changes: 4 additions & 1 deletion docs/api/project.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

---
search:
exclude: true
---

::: brickflow.engine.project
handler: python
Expand Down
4 changes: 4 additions & 0 deletions docs/api/secrets.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
---
search:
exclude: true
---

::: brickflow_plugins.secrets
handler: python
Expand Down
Loading

0 comments on commit 3c60c2c

Please sign in to comment.