diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index e812afdaf1..e047778e3d 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -167,4 +167,12 @@ def format_datetime_literal(v: pendulum.DateTime, precision: int = 6, no_tz: boo timespec = "milliseconds" elif precision < 3: timespec = "seconds" - return v.isoformat(sep=" ", timespec=timespec) + return "'" + v.isoformat(sep=" ", timespec=timespec) + "'" + + +def format_bigquery_datetime_literal( + v: pendulum.DateTime, precision: int = 6, no_tz: bool = False +) -> str: + """Returns BigQuery-adjusted datetime literal by prefixing required `TIMESTAMP` indicator.""" + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#timestamp_literals + return "TIMESTAMP " + format_datetime_literal(v, precision, no_tz) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 286a295e93..8432f8b544 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -9,6 +9,7 @@ DestinationLoadingWithoutStagingNotSupported, ) from dlt.common.utils import identity +from dlt.common.pendulum import pendulum from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION @@ -32,6 +33,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): supported_staging_file_formats: Sequence[TLoaderFileFormat] = None escape_identifier: Callable[[str], str] = None escape_literal: Callable[[Any], Any] = None + format_datetime_literal: Callable[..., str] = None decimal_precision: Tuple[int, int] = None wei_precision: Tuple[int, int] = None max_identifier_length: int = None @@ -61,6 +63,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): def generic_capabilities( preferred_loader_file_format: TLoaderFileFormat = None, ) -> "DestinationCapabilitiesContext": + from dlt.common.data_writers.escape import format_datetime_literal + caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = preferred_loader_file_format caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet", "csv"] @@ -68,6 +72,7 @@ def generic_capabilities( caps.supported_staging_file_formats = [] caps.escape_identifier = identity caps.escape_literal = serialize_value + caps.format_datetime_literal = format_datetime_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (EVM_DECIMAL_PRECISION, 0) caps.max_identifier_length = 65536 diff --git a/dlt/destinations/impl/bigquery/__init__.py b/dlt/destinations/impl/bigquery/__init__.py index 6d1491817a..d33466ed5e 100644 --- a/dlt/destinations/impl/bigquery/__init__.py +++ b/dlt/destinations/impl/bigquery/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_bigquery_identifier +from dlt.common.data_writers.escape import ( + escape_bigquery_identifier, + format_bigquery_datetime_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE @@ -11,6 +14,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supported_staging_file_formats = ["parquet", "jsonl"] caps.escape_identifier = escape_bigquery_identifier caps.escape_literal = None + caps.format_datetime_literal = format_bigquery_datetime_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (76, 38) caps.max_identifier_length = 1024 diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index eadedb742e..86eaa9236a 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional import yaml -from dlt.common.data_writers.escape import format_datetime_literal from dlt.common.logger import pretty_format_exception from dlt.common.pendulum import pendulum @@ -521,28 +520,30 @@ def gen_scd2_sql( staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) # get column names - escape_id = sql_client.capabilities.escape_identifier + caps = sql_client.capabilities + escape_id = caps.escape_identifier from_, to = list(map(escape_id, get_validity_column_names(root_table))) # validity columns hash_ = escape_id( get_first_column_name_with_prop(root_table, "x-row-version") ) # row hash column # define values for validity columns + format_datetime_literal = caps.format_datetime_literal + if format_datetime_literal is None: + format_datetime_literal = ( + DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal + ) boundary_ts = format_datetime_literal( current_load_package()["state"]["created_at"], - sql_client.capabilities.timestamp_precision, - ) - active_record_ts = format_datetime_literal( - HIGH_TS, sql_client.capabilities.timestamp_precision + caps.timestamp_precision, ) + active_record_ts = format_datetime_literal(HIGH_TS, caps.timestamp_precision) # retire updated and deleted records sql.append(f""" - UPDATE {root_table_name} SET {to} = '{boundary_ts}' - WHERE NOT EXISTS ( - SELECT s.{hash_} FROM {staging_root_table_name} AS s - WHERE {root_table_name}.{hash_} = s.{hash_} - ) AND {to} = '{active_record_ts}'; + UPDATE {root_table_name} SET {to} = {boundary_ts} + WHERE {to} = {active_record_ts} + AND {hash_} NOT IN (SELECT {hash_} FROM {staging_root_table_name}); """) # insert new active records in root table @@ -550,9 +551,9 @@ def gen_scd2_sql( col_str = ", ".join([c for c in columns if c not in (from_, to)]) sql.append(f""" INSERT INTO {root_table_name} ({col_str}, {from_}, {to}) - SELECT {col_str}, '{boundary_ts}' AS {from_}, '{active_record_ts}' AS {to} + SELECT {col_str}, {boundary_ts} AS {from_}, {active_record_ts} AS {to} FROM {staging_root_table_name} AS s - WHERE NOT EXISTS (SELECT s.{hash_} FROM {root_table_name} AS f WHERE f.{hash_} = s.{hash_}); + WHERE {hash_} NOT IN (SELECT {hash_} FROM {root_table_name}); """) # insert list elements for new active records in child tables diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index cf313eaa61..65a0742195 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -81,6 +81,7 @@ def assert_records_as_set(actual: List[Dict[str, Any]], expected: List[Dict[str, assert actual_set == expected_set +@pytest.mark.essential @pytest.mark.parametrize( "destination_config,simple,validity_column_names", [ # test basic case for alle SQL destinations supporting merge