Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/ref database #5

Merged
merged 4 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ WORKDIR /app
# Copy the dependencies file to the working directory
COPY /requirements ./requirements

RUN pip install uv

# Install any dependencies
RUN pip install --no-cache-dir -r requirements/prod.txt
ENV VIRTUAL_ENV=/usr/local
RUN pip install uv && uv pip install --no-cache -r requirements/prod.txt

# Copy the script to the container
COPY src/main.py .
Expand Down
86 changes: 53 additions & 33 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
import asyncio
import json
import logging
import os
import re
Expand All @@ -17,22 +18,6 @@
logger = logging.getLogger(__name__)


def str_to_bool(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
return value
if value.lower() == "true":
return True
if value.lower() == "false":
return False

raise ValueError(f"Invalid value: {value}")


def extract_pr_number(s):
match = re.search(r"refs/pull/(\d+)/merge", s)
return int(match.group(1)) if match else None


# dbt Cloud Env Vars
ACCOUNT_ID = os.getenv("INPUT_DBT_CLOUD_ACCOUNT_ID", None)
TOKEN = os.getenv("INPUT_DBT_CLOUD_SERVICE_TOKEN", None)
Expand All @@ -46,13 +31,9 @@ def extract_pr_number(s):

# Optional Env Vars
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", None)
INCLUDE_DOWNSTREAM = str_to_bool(os.getenv("INPUT_INCLUDE_DOWNSTREAM", True))
INCLUDE_DOWNSTREAM = os.getenv("INPUT_INCLUDE_DOWNSTREAM", True)
DBT_COMMAND = os.getenv("INPUT_DBT_COMMAND", "build")

# Derived variables
PULL_REQUEST_ID = extract_pr_number(GITHUB_REF)
SCHEMA_OVERRIDE = f"dbt_cloud_pr_{JOB_ID}_{PULL_REQUEST_ID}"

# Run Status Indicators
SUCCESS = ":white_check_mark:"
FAILURE = ":x:"
Expand Down Expand Up @@ -104,6 +85,24 @@ def extract_pr_number(s):
"""


def str_to_bool(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
return value

if value.lower() == "true":
return True

if value.lower() == "false":
return False

raise ValueError(f"Invalid value: {value}")


def extract_pr_number(s):
match = re.search(r"refs/pull/(\d+)/merge", s)
return int(match.group(1)) if match else None


def is_run_complete(run: Dict) -> bool:
return run["status"] in [10, 20, 30]

Expand All @@ -117,15 +116,22 @@ def get_run_status_emoji(status: int) -> str:
return status_dict[status]


def get_dbt_command(nodes: List[Dict]) -> List[str]:
def get_dbt_command(
nodes: List[Dict], database_override: str, schema_override: str
) -> List[str]:
command = f"dbt {DBT_COMMAND}"
string = "+ " if INCLUDE_DOWNSTREAM else " "
include_plus_operator = str_to_bool(INCLUDE_DOWNSTREAM)
string = "+ " if include_plus_operator else " "
command += f" -s {string.join([node['name'] for node in nodes])}"
if INCLUDE_DOWNSTREAM:
if include_plus_operator:
command += "+"

# TODO: Add database override to command
command += f" --vars '{{ref_schema_override: {SCHEMA_OVERRIDE}}}'"
variables = {
"ref_schema_override": schema_override,
"ref_database_override": database_override,
}
variables_str = json.dumps(variables)
command += f" --vars '{variables_str}'"
return [command]


Expand Down Expand Up @@ -267,14 +273,17 @@ async def get_downstream_nodes(project_dict: Dict):


async def main():
all_runs = []
pull_request_id = extract_pr_number(GITHUB_REF)
schema_override = f"dbt_cloud_pr_{JOB_ID}_{pull_request_id}"
payload = {
"cause": "Triggering CI Job from GH Action",
"git_branch": GIT_BRANCH,
"schema_override": SCHEMA_OVERRIDE,
"github_pull_request_id": PULL_REQUEST_ID,
"schema_override": schema_override,
"github_pull_request_id": pull_request_id,
}

all_jobs = [{"job_id": JOB_ID, "payload": payload}]
all_runs = []
while all_jobs:
# Trigger the CI jobs
job_tasks = [
Expand All @@ -295,12 +304,21 @@ async def main():
# Any public models updated in the run?
logger.info(f"Finding if any public models were updated in run {run['id']}")
public_models = await get_public_models_in_run(
run["job_id"], run["id"], SCHEMA_OVERRIDE
run["job_id"], run["id"], schema_override
)
if not public_models:
logger.info(f"No public models were updated in run {run['id']}.")
continue

databases = list(set([model["database"] for model in public_models]))
if len(databases) > 1:
logger.info(
f"Public models updated in run {run['id']} span multiple databases."
"This is not currently supported."
)
continue

database_override = databases[0]
logger.info(
f"Finding any downstream projects with public models updated in run {run['id']}"
)
Expand All @@ -320,16 +338,18 @@ async def main():
nodes = await get_downstream_nodes(project_dict)
if nodes:
logger.info(f"Found downstream nodes in project {project_id}")
steps_override = get_dbt_command(nodes)
job = await get_ci_job(project_id)
if job is not None:
logger.info(
f"CI job found in project {project_id} and will trigger shortly."
)
steps_override = get_dbt_command(
nodes, database_override, schema_override
)
job_payload = {
"cause": "Triggering downstream CI job",
"steps_override": steps_override,
"schema_override": SCHEMA_OVERRIDE,
"schema_override": schema_override,
}
all_jobs.append({"job_id": job["id"], "payload": job_payload})
else:
Expand Down Expand Up @@ -365,7 +385,7 @@ async def main():
headers={"Authorization": f"Bearer {GITHUB_TOKEN}"}
) as client:
url = (
f"https://api.github.com/repos/{REPO}/issues/{PULL_REQUEST_ID}/comments"
f"https://api.github.com/repos/{REPO}/issues/{pull_request_id}/comments"
)
response = client.post(url, json=payload)
if response.is_error:
Expand Down
Loading