Skip to content

Commit

Permalink
Goldsky direct gcs load + dedupe (#1347)
Browse files Browse the repository at this point in the history
Goldsky direct gcs load working
  • Loading branch information
ravenac95 authored May 2, 2024
1 parent de40d65 commit 84f88dc
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 22 deletions.
5 changes: 3 additions & 2 deletions warehouse/oso_dagster/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ def random_cbt(cbt: CBTResource):
destination_dataset_name="deleteme_oso_sources_test",
partition_column_name="block_timestamp",
partition_column_transform=lambda c: f"TIMESTAMP_SECONDS(`{c}`)",
pointer_size=int(os.environ.get("GOLDSKY_CHECKPOINT_SIZE", "500")),
pointer_size=int(os.environ.get("GOLDSKY_CHECKPOINT_SIZE", "20000")),
bucket_key_id=os.environ.get("DUCKDB_GCS_KEY_ID"),
bucket_secret=os.environ.get("DUCKDB_GCS_SECRET"),
max_objects_to_load=2,
# uncomment the following value to test
# max_objects_to_load=2,
),
)

Expand Down
9 changes: 7 additions & 2 deletions warehouse/oso_dagster/cbt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def transform(
update_strategy: UpdateStrategy = UpdateStrategy.REPLACE,
time_partitioning: Optional[TimePartitioning] = None,
unique_column: Optional[str] = None,
timeout: float = 300,
**vars,
):
with self.bigquery.get_client() as client:
Expand All @@ -81,6 +82,7 @@ def transform(
destination_table,
time_partitioning=time_partitioning,
unique_column=unique_column,
timeout=timeout,
**vars,
)
return self._transform_existing(
Expand All @@ -89,6 +91,7 @@ def transform(
destination_table,
update_strategy,
unique_column=unique_column,
timeout=timeout,
**vars,
)

Expand All @@ -99,6 +102,7 @@ def _transform_existing(
destination_table: TableReference,
update_strategy: UpdateStrategy,
unique_column: Optional[str] = None,
timeout: float = 300,
**vars,
):
select_query = self.render_model(
Expand All @@ -122,7 +126,7 @@ def _transform_existing(
)

self.log.debug({"message": "updating", "query": update_query})
job = client.query(update_query)
job = client.query(update_query, timeout=timeout)
job.result()

def _transform_replace(
Expand All @@ -132,6 +136,7 @@ def _transform_replace(
destination_table: TableReference,
time_partitioning: Optional[TimePartitioning] = None,
unique_column: Optional[str] = None,
timeout: float = 300,
**vars,
):
select_query = self.render_model(
Expand All @@ -146,7 +151,7 @@ def _transform_replace(
unique_column=unique_column,
select_query=select_query,
)
job = client.query(create_or_replace_query)
job = client.query(create_or_replace_query, timeout=timeout)
self.log.debug(
{"message": "replacing with query", "query": create_or_replace_query}
)
Expand Down
169 changes: 153 additions & 16 deletions warehouse/oso_dagster/goldsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from dagster import asset, AssetExecutionContext
from dagster_gcp import BigQueryResource, GCSResource
from google.api_core.exceptions import NotFound
from google.cloud.bigquery import TableReference
from google.cloud.bigquery import (
TableReference,
LoadJobConfig,
SourceFormat,
Client as BQClient,
)
from .goldsky_dask import setup_kube_cluster_client, DuckDBGCSPlugin, RetryTaskManager
from .cbt import CBTResource, UpdateStrategy, TimePartitioning
from .factories import AssetFactoryResponse
Expand All @@ -39,6 +44,11 @@ class GoldskyConfig:
dask_worker_memory: str = "4096Mi"
dask_scheduler_memory: str = "2560Mi"
dask_image: str = "ghcr.io/opensource-observer/dagster-dask:distributed-test-10"
dask_is_enabled: bool = False

# Allow 15 minute load table jobs
load_table_timeout_seconds: float = 900
transform_timeout_seconds: float = 900

working_destination_dataset_name: str = "oso_raw_sources"
working_destination_preload_path: str = "_temp"
Expand Down Expand Up @@ -202,25 +212,23 @@ def process_goldsky_file(item: GoldskyProcessItem):
class GoldskyWorker:
def __init__(
self,
worker: str,
name: str,
job_id: str,
pointer_table: str,
latest_checkpoint: GoldskyCheckpoint | None,
gcs: GCSResource,
bigquery: BigQueryResource,
config: GoldskyConfig,
queue: GoldskyQueue,
task_manager: RetryTaskManager,
):
self.name = worker
self.config = config
self.gcs = gcs
self.queue = queue
self.bigquery = bigquery
self.task_manager = task_manager
self.name = name
self.job_id = job_id
self.latest_checkpoint = latest_checkpoint
self.pointer_table = pointer_table
self.latest_checkpoint = latest_checkpoint
self.gcs = gcs
self.bigquery = bigquery
self.config = config
self.queue = queue

def worker_destination_uri(self, filename: str):
return f"gs://{self.config.source_bucket_name}/{self.worker_destination_path(filename)}"
Expand Down Expand Up @@ -248,6 +256,102 @@ def deduped_table(self) -> TableReference:
def worker_wildcard_uri(self):
return self.worker_destination_uri("table_*.parquet")

async def process(self, context: AssetExecutionContext):
raise NotImplementedError("process not implemented on the base class")


class DirectGoldskyWorker(GoldskyWorker):
async def process(self, context: AssetExecutionContext):
await asyncio.to_thread(
self.run_load_bigquery_load,
context,
)
return self

def run_load_bigquery_load(self, context: AssetExecutionContext):
to_load: List[str] = []
with self.bigquery.get_client() as client:
item = self.queue.dequeue()
latest_checkpoint = item.checkpoint
while item is not None:
# For our own convenience we have the option to do a piecemeal
# loading. However, for direct loading this shouldn't be
# necessary
source = f"gs://{self.config.source_bucket_name}/{item.blob_name}"
to_load.append(source)
if len(to_load) >= self.config.pointer_size:
job_config = LoadJobConfig(
source_format=SourceFormat.PARQUET,
)
load_job = client.load_table_from_uri(
to_load,
self.raw_table,
job_config=job_config,
timeout=self.config.load_table_timeout_seconds,
)
self.update_pointer_table(client, context, item.checkpoint)
load_job.result()
to_load = []
latest_checkpoint = item.checkpoint
item = self.queue.dequeue()

if len(to_load) > 0:
job_config = LoadJobConfig(
source_format=SourceFormat.PARQUET,
)
load_job = client.load_table_from_uri(
to_load,
self.raw_table,
job_config=job_config,
timeout=self.config.load_table_timeout_seconds,
)
self.update_pointer_table(client, context, latest_checkpoint)
load_job.result()
to_load = []

def update_pointer_table(
self,
client: BQClient,
context: AssetExecutionContext,
new_checkpoint: GoldskyCheckpoint,
):
pointer_table = self.pointer_table
tx_query = f"""
BEGIN
BEGIN TRANSACTION;
DELETE FROM `{pointer_table}` WHERE worker = '{self.name}';
INSERT INTO `{pointer_table}` (worker, job_id, timestamp, checkpoint)
VALUES ('{self.name}', '{new_checkpoint.job_id}', {new_checkpoint.timestamp}, {new_checkpoint.worker_checkpoint});
COMMIT TRANSACTION;
EXCEPTION WHEN ERROR THEN
-- Roll back the transaction inside the exception handler.
SELECT @@error.message;
ROLLBACK TRANSACTION;
END;
"""
context.log.debug(f"query: {tx_query}")
return client.query_and_wait(tx_query)


class DaskGoldskyWorker(GoldskyWorker):
def __init__(
self,
name: str,
job_id: str,
pointer_table: str,
latest_checkpoint: GoldskyCheckpoint | None,
gcs: GCSResource,
bigquery: BigQueryResource,
config: GoldskyConfig,
queue: GoldskyQueue,
task_manager: RetryTaskManager,
):
super().__init__(
name, job_id, pointer_table, latest_checkpoint, gcs, bigquery, config, queue
)
self.task_manager = task_manager

async def process(self, context: AssetExecutionContext):
try:
await self.process_all_files(context)
Expand All @@ -258,7 +362,7 @@ async def process(self, context: AssetExecutionContext):
async def process_all_files(self, context: AssetExecutionContext):
count = 0
item = self.queue.dequeue()
current_checkpoint = item.checkpoint
latest_checkpoint = item.checkpoint
in_flight = []
while item is not None:
source = f"gs://{self.config.source_bucket_name}/{item.blob_name}"
Expand Down Expand Up @@ -301,24 +405,24 @@ async def process_all_files(self, context: AssetExecutionContext):
context.log.debug(f"Worker[{self.name}] done waiting for blobs")

# Update the pointer table to the latest item's checkpoint
await self.update_pointer_table(context, current_checkpoint)
await self.update_pointer_table(context, item.checkpoint)

in_flight = []
count = 0

current_checkpoint = item.checkpoint
latest_checkpoint = item.checkpoint
item = self.queue.dequeue()

if len(in_flight) > 0:
context.log.debug(
f"Finalizing worker {self.name} waiting for {len(in_flight)} blobs to process. Last checkpoint {current_checkpoint.worker_checkpoint}",
f"Finalizing worker {self.name} waiting for {len(in_flight)} blobs to process. Last checkpoint {latest_checkpoint.worker_checkpoint}",
)
progress = 0
for coro in asyncio.as_completed(in_flight):
await coro
progress += 1
context.log.debug(f"Worker[{self.name}] progress: {progress}/{count}")
await self.update_pointer_table(context, current_checkpoint)
await self.update_pointer_table(context, latest_checkpoint)

return self.name

Expand Down Expand Up @@ -506,6 +610,37 @@ def ensure_dataset(self, context: AssetExecutionContext, dataset_id: str):

async def load_worker_tables(
self, loop: asyncio.AbstractEventLoop, context: AssetExecutionContext
):
if self.config.dask_is_enabled:
return await self.dask_load_worker_tables(loop, context)
return await self.direct_load_worker_tables(context)

async def direct_load_worker_tables(
self, context: AssetExecutionContext
) -> GoldskyWorker:
worker_coroutines = []
workers: List[GoldskyWorker] = []
worker_status, queues = self.load_queues(context)
for worker_name, queue in queues.worker_queues():
worker = DirectGoldskyWorker(
worker_name,
self._job_id,
self.pointer_table,
worker_status.get(worker_name, None),
self.gcs,
self.bigquery,
self.config,
queue,
)
worker_coroutines.append(worker.process(context))
workers.append(worker)
for coro in asyncio.as_completed(worker_coroutines):
worker: GoldskyWorker = await coro
context.log.info(f"Worker[{worker.name}] completed latest data load")
return workers

async def dask_load_worker_tables(
self, loop: asyncio.AbstractEventLoop, context: AssetExecutionContext
) -> List[GoldskyWorker]:
context.log.info("loading worker tables for goldsky asset")
last_restart = time.time()
Expand Down Expand Up @@ -550,7 +685,7 @@ async def parallel_load_worker_tables(
worker_coroutines = []
workers: List[GoldskyWorker] = []
for worker_name, queue in queues.worker_queues():
worker = GoldskyWorker(
worker = DaskGoldskyWorker(
worker_name,
job_id,
self.pointer_table,
Expand Down Expand Up @@ -591,6 +726,7 @@ async def dedupe_worker_tables(
partition_column_name=self.config.partition_column_name,
partition_column_transform=self.config.partition_column_transform,
raw_table=worker.raw_table,
timeout=self.config.transform_timeout_seconds,
)
)
completed = 0
Expand Down Expand Up @@ -623,6 +759,7 @@ async def merge_worker_tables(
unique_column=self.config.dedupe_unique_column,
order_column=self.config.dedupe_order_column,
workers=workers,
timeout=self.config.transform_timeout_seconds,
)

async def clean_working_destintation(
Expand Down
4 changes: 2 additions & 2 deletions warehouse/oso_dagster/models/goldsky_dedupe.sql
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
SELECT
{% if partition_column_name %}
{{ source(raw_table).select_columns(exclude=["_checkpoint", partition_column_name]) }},
{{ source(raw_table).select_columns(exclude=[partition_column_name]) }},
{{ partition_column_transform(partition_column_name) }} AS `{{ partition_column_name }}`
{% else %}
{{ source(raw_table).select_columns(exclude=["_checkpoint"]) }}
{{ source(raw_table).select_columns() }}
{% endif %}
FROM {{ source(raw_table).fqdn }} AS worker
QUALIFY ROW_NUMBER() OVER (PARTITION BY `{{ unique_column }}` ORDER BY `{{ order_column }}` DESC) = 1

0 comments on commit 84f88dc

Please sign in to comment.