Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental BQ support to run dbt models with ExecutionMode.AIRFLOW_ASYNC #1230

Merged
merged 51 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
851564f
Draft: dbt compile task
pankajkoti Sep 25, 2024
9dc2c9c
Put compiled files under dag_id folder & refactor few snippets
pankajkoti Sep 29, 2024
0ce662e
Add tests & minor refactorings
pankajkoti Sep 29, 2024
1b6f57e
Apply suggestions from code review
pankajkoti Sep 29, 2024
cc48161
Install deps for the newly added example DAG
pankajkoti Sep 29, 2024
1068025
Add docs
pankajkoti Sep 30, 2024
faa706d
Add async run operator
pankajkoti Sep 25, 2024
0e155e4
Fix remote sql path and async args
pankajastro Sep 30, 2024
5f1ecaa
Fix query
pankajastro Sep 30, 2024
1278847
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
b3d6cf3
Use dbt node's filepath to construct remote path to fetch compiled SQ…
pankajkoti Sep 30, 2024
78bc069
Merge branch 'main' into execute-async-task
tatiana Sep 30, 2024
9ca5e85
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
99bf7c0
Fix unittests
tatiana Sep 30, 2024
3aaaf9e
Improve code
tatiana Sep 30, 2024
43158be
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
83b1010
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
bd6657a
Fix issue when using BQ deferrable operator - it requires location
tatiana Oct 1, 2024
1195955
Add limitation in docs
pankajastro Oct 1, 2024
2bdd9bb
Add full_refresh as templated field
pankajastro Oct 1, 2024
4a44603
Add more template fields
pankajastro Oct 1, 2024
c3c51cb
Construct & relay 'dbt dag-task group' identifier to upload & downloa…
pankajkoti Oct 1, 2024
72c6164
Fix model_name retrieval; get from dbt_node_config
pankajkoti Oct 1, 2024
e67098e
Fix unit tests
pankajkoti Oct 1, 2024
3e550bf
Fix subsequent failing unit tests
pankajkoti Oct 1, 2024
0730d0f
Fix type check failures
pankajkoti Oct 1, 2024
745768e
Add back the deleted sources.yml from jaffle_shop as it has dependenc…
pankajkoti Oct 1, 2024
43d62ea
Install dbt bigquery adapter for running simple_dag_async
pankajkoti Oct 1, 2024
9656248
Install dbt bigquery adapter in our CI setup scripts
pankajkoti Oct 1, 2024
a654f49
Update gcp conn in dev/dags/simple_dag_async.py
pankajkoti Oct 1, 2024
e60ace2
Refactor args in DbtRunAirflowAsyncOperator
tatiana Oct 1, 2024
7f055bc
Use GoogleCloudServiceAccountDictProfileMapping in profilemapping
pankajkoti Oct 1, 2024
ad057c8
set should_upload_compiled_sql to True
pankajkoti Oct 1, 2024
a70ca46
Remove async_op_args
tatiana Oct 1, 2024
7c6a1b2
remove install_deps from DAG
pankajkoti Oct 1, 2024
64a31d0
Merge branch 'main' into execute-async-task
tatiana Oct 1, 2024
c1aeff0
Fix test_build_airflow_graph_with_dbt_compile_task by passing needed …
pankajkoti Oct 1, 2024
02f7985
Specify required project id in the GoogleCloudServiceAccountDictProfi…
pankajkoti Oct 2, 2024
af454a9
Pass gcp_conn_id to super class init, otherwise it is lost & uses the…
pankajkoti Oct 2, 2024
9081e6a
Adapt manifest DAG to use & adapt to the newer GCP conn secret that i…
pankajkoti Oct 2, 2024
2dccf84
Release 1.7.0a1
tatiana Oct 2, 2024
7adeb99
Retrigger GH actions
tatiana Oct 2, 2024
7e6de30
temporarily move out simple_dag_async.py
tatiana Oct 2, 2024
16a87ea
Fix CI issue
tatiana Oct 2, 2024
05db6a0
Fix dbt-compile dependency by using Airflow tasks instead of dbt nodes
pankajkoti Oct 2, 2024
8fc4ae2
Apply suggestions from code review
pankajkoti Oct 2, 2024
ea5816b
Apply suggestions from code review
pankajkoti Oct 2, 2024
85f86a4
Add install instruction
pankajastro Oct 3, 2024
402f823
Add min airflow version in limitation
pankajastro Oct 3, 2024
621a4de
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
a0cb147
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,3 @@ webserver_config.py

# VI
*.sw[a-z]

# Ignore possibly created symlink to `dev/dags` for running `airflow dags test` command.
dags
3 changes: 2 additions & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

Contains dags, task groups, and operators.
"""
__version__ = "1.6.0"

__version__ = "1.7.0a1"


from cosmos.airflow.dag import DbtDag
Expand Down
33 changes: 26 additions & 7 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create_task_metadata(
node: DbtNode,
execution_mode: ExecutionMode,
args: dict[str, Any],
dbt_dag_task_group_identifier: str,
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
) -> TaskMetadata | None:
Expand All @@ -142,6 +143,7 @@ def create_task_metadata(
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param args: Arguments to be used to instantiate an Airflow Task
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
Expand All @@ -156,7 +158,10 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {"dbt_node_config": node.context_dict}
extra_context = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand Down Expand Up @@ -226,6 +231,7 @@ def generate_task_or_group(
node=node,
execution_mode=execution_mode,
args=task_args,
dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group),
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
)
Expand Down Expand Up @@ -268,14 +274,28 @@ def _add_dbt_compile_task(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)

for task_id, task in tasks_map.items():
if not task.upstream_list:
compile_airflow_task >> task

tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]

def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -> str:
dag_id = dag.dag_id
task_group_id = task_group.group_id if task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

return dag_task_group_identifier


def build_airflow_graph(
Expand Down Expand Up @@ -358,9 +378,8 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)
_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)


def create_airflow_task_dependencies(
Expand Down
16 changes: 16 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,21 @@
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self) -> str:
if isinstance(self.profile_mapping, BaseProfileMapping):
return str(self.profile_mapping.dbt_profile_type)

profile_path = self._get_profile_path()

Check warning on line 294 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L294

Added line #L294 was not covered by tests

with open(profile_path) as file:
profiles = yaml.safe_load(file)

Check warning on line 297 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L296-L297

Added lines #L296 - L297 were not covered by tests

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return str(target_type)

Check warning on line 301 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L299-L301

Added lines #L299 - L301 were not covered by tests

return "undefined"

Check warning on line 303 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L303

Added line #L303 was not covered by tests

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
3 changes: 2 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
from typing import Any

from airflow.models import BaseOperator
from airflow.models.dag import DAG
Expand All @@ -27,7 +28,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
module = importlib.import_module(module_name)
Operator = getattr(module, class_name)

task_kwargs = {}
task_kwargs: dict[str, Any] = {}
if task.owner != "":
task_kwargs["owner"] = task.owner

Expand Down
1 change: 0 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
179 changes: 151 additions & 28 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,190 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]

class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass
from abc import ABCMeta


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass
from airflow.models.baseoperator import BaseOperator


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass
class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
super().__init__(**kwargs)


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
class DbtBuildAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtBuildLocalOperator): # type: ignore
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator): # type: ignore
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalOperator): # type: ignore
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.extra_context = extra_context or {}
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

Check warning on line 88 in cosmos/operators/airflow_async.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/airflow_async.py#L88

Added line #L88 was not covered by tests

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

Check warning on line 92 in cosmos/operators/airflow_async.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/airflow_async.py#L92

Added line #L92 was not covered by tests

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.
# We need to pop them.
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys())
non_async_args -= {"task_id"}

for arg_key, arg_value in kwargs.items():
if arg_key not in non_async_args:
clean_kwargs[arg_key] = arg_value

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
**clean_kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")

Check warning on line 124 in cosmos/operators/airflow_async.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/airflow_async.py#L124

Added line #L124 was not covered by tests
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING:
assert self.project_dir is not None

Check warning on line 133 in cosmos/operators/airflow_async.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/airflow_async.py#L133

Added line #L133 was not covered by tests

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")

Check warning on line 161 in cosmos/operators/airflow_async.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/airflow_async.py#L161

Added line #L161 was not covered by tests
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
class DbtCompileAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCompileLocalOperator): # type: ignore
pass
Loading
Loading