Skip to content

Commit

Permalink
Merge pull request #5 from dpguthrie/feat/ref-database
Browse files Browse the repository at this point in the history
Feat/ref database
  • Loading branch information
dpguthrie authored Feb 17, 2024
2 parents 7c65a9e + 76ddd95 commit 803e4bf
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 34 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ WORKDIR /app
# Copy the dependencies file to the working directory
COPY /requirements ./requirements

RUN pip install uv

# Install any dependencies
RUN pip install --no-cache-dir -r requirements/prod.txt
ENV VIRTUAL_ENV=/usr/local
RUN pip install uv && uv pip install --no-cache -r requirements/prod.txt

# Copy the script to the container
COPY src/main.py .
Expand Down
86 changes: 53 additions & 33 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
import asyncio
import json
import logging
import os
import re
Expand All @@ -17,22 +18,6 @@
logger = logging.getLogger(__name__)


def str_to_bool(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
return value
if value.lower() == "true":
return True
if value.lower() == "false":
return False

raise ValueError(f"Invalid value: {value}")


def extract_pr_number(s):
match = re.search(r"refs/pull/(\d+)/merge", s)
return int(match.group(1)) if match else None


# dbt Cloud Env Vars
ACCOUNT_ID = os.getenv("INPUT_DBT_CLOUD_ACCOUNT_ID", None)
TOKEN = os.getenv("INPUT_DBT_CLOUD_SERVICE_TOKEN", None)
Expand All @@ -46,13 +31,9 @@ def extract_pr_number(s):

# Optional Env Vars
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", None)
INCLUDE_DOWNSTREAM = str_to_bool(os.getenv("INPUT_INCLUDE_DOWNSTREAM", True))
INCLUDE_DOWNSTREAM = os.getenv("INPUT_INCLUDE_DOWNSTREAM", True)
DBT_COMMAND = os.getenv("INPUT_DBT_COMMAND", "build")

# Derived variables
PULL_REQUEST_ID = extract_pr_number(GITHUB_REF)
SCHEMA_OVERRIDE = f"dbt_cloud_pr_{JOB_ID}_{PULL_REQUEST_ID}"

# Run Status Indicators
SUCCESS = ":white_check_mark:"
FAILURE = ":x:"
Expand Down Expand Up @@ -104,6 +85,24 @@ def extract_pr_number(s):
"""


def str_to_bool(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
return value

if value.lower() == "true":
return True

if value.lower() == "false":
return False

raise ValueError(f"Invalid value: {value}")


def extract_pr_number(s):
match = re.search(r"refs/pull/(\d+)/merge", s)
return int(match.group(1)) if match else None


def is_run_complete(run: Dict) -> bool:
return run["status"] in [10, 20, 30]

Expand All @@ -117,15 +116,22 @@ def get_run_status_emoji(status: int) -> str:
return status_dict[status]


def get_dbt_command(nodes: List[Dict]) -> List[str]:
def get_dbt_command(
nodes: List[Dict], database_override: str, schema_override: str
) -> List[str]:
command = f"dbt {DBT_COMMAND}"
string = "+ " if INCLUDE_DOWNSTREAM else " "
include_plus_operator = str_to_bool(INCLUDE_DOWNSTREAM)
string = "+ " if include_plus_operator else " "
command += f" -s {string.join([node['name'] for node in nodes])}"
if INCLUDE_DOWNSTREAM:
if include_plus_operator:
command += "+"

# TODO: Add database override to command
command += f" --vars '{{ref_schema_override: {SCHEMA_OVERRIDE}}}'"
variables = {
"ref_schema_override": schema_override,
"ref_database_override": database_override,
}
variables_str = json.dumps(variables)
command += f" --vars '{variables_str}'"
return [command]


Expand Down Expand Up @@ -267,14 +273,17 @@ async def get_downstream_nodes(project_dict: Dict):


async def main():
all_runs = []
pull_request_id = extract_pr_number(GITHUB_REF)
schema_override = f"dbt_cloud_pr_{JOB_ID}_{pull_request_id}"
payload = {
"cause": "Triggering CI Job from GH Action",
"git_branch": GIT_BRANCH,
"schema_override": SCHEMA_OVERRIDE,
"github_pull_request_id": PULL_REQUEST_ID,
"schema_override": schema_override,
"github_pull_request_id": pull_request_id,
}

all_jobs = [{"job_id": JOB_ID, "payload": payload}]
all_runs = []
while all_jobs:
# Trigger the CI jobs
job_tasks = [
Expand All @@ -295,12 +304,21 @@ async def main():
# Any public models updated in the run?
logger.info(f"Finding if any public models were updated in run {run['id']}")
public_models = await get_public_models_in_run(
run["job_id"], run["id"], SCHEMA_OVERRIDE
run["job_id"], run["id"], schema_override
)
if not public_models:
logger.info(f"No public models were updated in run {run['id']}.")
continue

databases = list(set([model["database"] for model in public_models]))
if len(databases) > 1:
logger.info(
f"Public models updated in run {run['id']} span multiple databases."
"This is not currently supported."
)
continue

database_override = databases[0]
logger.info(
f"Finding any downstream projects with public models updated in run {run['id']}"
)
Expand All @@ -320,16 +338,18 @@ async def main():
nodes = await get_downstream_nodes(project_dict)
if nodes:
logger.info(f"Found downstream nodes in project {project_id}")
steps_override = get_dbt_command(nodes)
job = await get_ci_job(project_id)
if job is not None:
logger.info(
f"CI job found in project {project_id} and will trigger shortly."
)
steps_override = get_dbt_command(
nodes, database_override, schema_override
)
job_payload = {
"cause": "Triggering downstream CI job",
"steps_override": steps_override,
"schema_override": SCHEMA_OVERRIDE,
"schema_override": schema_override,
}
all_jobs.append({"job_id": job["id"], "payload": job_payload})
else:
Expand Down Expand Up @@ -365,7 +385,7 @@ async def main():
headers={"Authorization": f"Bearer {GITHUB_TOKEN}"}
) as client:
url = (
f"https://api.github.com/repos/{REPO}/issues/{PULL_REQUEST_ID}/comments"
f"https://api.github.com/repos/{REPO}/issues/{pull_request_id}/comments"
)
response = client.post(url, json=payload)
if response.is_error:
Expand Down

0 comments on commit 803e4bf

Please sign in to comment.