Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make merge write-disposition fall back to append if no primary or merge keys are specified #1225

Merged
merged 11 commits into from
Apr 24, 2024
109 changes: 59 additions & 50 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def generate_sql(


class SqlMergeJob(SqlBaseJob):
"""Generates a list of sql statements that merge the data from staging dataset into destination dataset."""
"""
Generates a list of sql statements that merge the data from staging dataset into destination dataset.
If no merge keys are discovered, falls back to append.
"""

failed_text: str = "Tried to generate a merge sql job for the following tables:"

Expand Down Expand Up @@ -383,68 +386,74 @@ def gen_merge_sql(
get_columns_names_with_prop(root_table, "merge_key"),
)
)
key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys)

unique_column: str = None
root_key_column: str = None
# if we do not have any merge keys to select from, we will fall back to a staged append, i.E.
# just skip the delete part
append_fallback = (len(primary_keys) + len(merge_keys)) == 0

if len(table_chain) == 1:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=True
)
# if no child tables, just delete data from top table
for clause in key_table_clauses:
sql.append(f"DELETE {clause};")
else:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=False
)
# use unique hint to create temp table with all identifiers to delete
unique_columns = get_columns_names_with_prop(root_table, "unique")
if not unique_columns:
raise MergeDispositionException(
sql_client.fully_qualified_dataset_name(),
staging_root_table_name,
[t["name"] for t in table_chain],
f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so"
" it is not possible to link child tables to it.",
)
# get first unique column
unique_column = escape_id(unique_columns[0])
# create temp table with unique identifier
create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql(
unique_column, key_table_clauses, sql_client
)
sql.extend(create_delete_temp_table_sql)
if not append_fallback:
key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys)

# delete from child tables first. This is important for databricks which does not support temporary tables,
# but uses temporary views instead
for table in table_chain[1:]:
table_name = sql_client.make_qualified_table_name(table["name"])
root_key_columns = get_columns_names_with_prop(table, "root_key")
if not root_key_columns:
unique_column: str = None
root_key_column: str = None

if len(table_chain) == 1:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=True
)
# if no child tables, just delete data from top table
for clause in key_table_clauses:
sql.append(f"DELETE {clause};")
else:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=False
)
# use unique hint to create temp table with all identifiers to delete
unique_columns = get_columns_names_with_prop(root_table, "unique")
if not unique_columns:
raise MergeDispositionException(
sql_client.fully_qualified_dataset_name(),
staging_root_table_name,
[t["name"] for t in table_chain],
"There is no root foreign key (ie _dlt_root_id) in child table"
f" {table['name']} so it is not possible to refer to top level table"
f" {root_table['name']} unique column {unique_column}",
"There is no unique column (ie _dlt_id) in top table"
f" {root_table['name']} so it is not possible to link child tables to it.",
)
root_key_column = escape_id(root_key_columns[0])
# get first unique column
unique_column = escape_id(unique_columns[0])
# create temp table with unique identifier
create_delete_temp_table_sql, delete_temp_table_name = (
cls.gen_delete_temp_table_sql(unique_column, key_table_clauses, sql_client)
)
sql.extend(create_delete_temp_table_sql)

# delete from child tables first. This is important for databricks which does not support temporary tables,
# but uses temporary views instead
for table in table_chain[1:]:
table_name = sql_client.make_qualified_table_name(table["name"])
root_key_columns = get_columns_names_with_prop(table, "root_key")
if not root_key_columns:
raise MergeDispositionException(
sql_client.fully_qualified_dataset_name(),
staging_root_table_name,
[t["name"] for t in table_chain],
"There is no root foreign key (ie _dlt_root_id) in child table"
f" {table['name']} so it is not possible to refer to top level table"
f" {root_table['name']} unique column {unique_column}",
)
root_key_column = escape_id(root_key_columns[0])
sql.append(
cls.gen_delete_from_sql(
table_name, root_key_column, delete_temp_table_name, unique_column
)
)

# delete from top table now that child tables have been prcessed
sql.append(
cls.gen_delete_from_sql(
table_name, root_key_column, delete_temp_table_name, unique_column
root_table_name, unique_column, delete_temp_table_name, unique_column
)
)

# delete from top table now that child tables have been prcessed
sql.append(
cls.gen_delete_from_sql(
root_table_name, unique_column, delete_temp_table_name, unique_column
)
)

# get name of column with hard_delete hint, if specified
not_deleted_cond: str = None
hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete")
Expand Down
81 changes: 20 additions & 61 deletions tests/load/pipeline/test_filesystem_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,19 @@
from tests.common.utils import load_json_case
from tests.utils import ALL_TEST_DATA_ITEM_FORMATS, TestDataItemFormat, skip_if_not_active
from dlt.destinations.path_utils import create_path

from tests.load.pipeline.utils import load_table_counts

skip_if_not_active("filesystem")


def assert_file_matches(
layout: str, job: LoadJobInfo, load_id: str, client: FilesystemClient
) -> None:
"""Verify file contents of load job are identical to the corresponding file in destination"""
local_path = Path(job.file_path)
filename = local_path.name
destination_fn = create_path(
layout,
filename,
client.schema.name,
load_id,
extra_placeholders=client.config.extra_placeholders,
)
destination_path = posixpath.join(client.dataset_path, destination_fn)

assert local_path.read_bytes() == client.fs_client.read_bytes(destination_path)


def test_pipeline_merge_write_disposition(default_buckets_env: str) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test shows that filesystem always falls back to append when "merge" is set. this is also what it says in the docs. I'm not sure why this test says it would replace when there is a primary key, this did not work before and it does not work that way now and this test was somehow very strange (not sure if I wrote that) but now it runs correctly. I we want to replace in certain cases, I have to add it and specify that in the docs.

"""Run pipeline twice with merge write disposition
Resource with primary key falls back to append. Resource without keys falls back to replace.
Regardless wether primary key is set or not, filesystem appends
"""
import pyarrow.parquet as pq # Module is evaluated by other tests

os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True"

pipeline = dlt.pipeline(
pipeline_name="test_" + uniq_id(),
destination="filesystem",
Expand All @@ -66,50 +50,25 @@ def other_data():
def some_source():
return [some_data(), other_data()]

info1 = pipeline.run(some_source(), write_disposition="merge")
info2 = pipeline.run(some_source(), write_disposition="merge")

client: FilesystemClient = pipeline.destination_client() # type: ignore[assignment]
layout = client.config.layout

append_glob = list(client._get_table_dirs(["some_data"]))[0]
replace_glob = list(client._get_table_dirs(["other_data"]))[0]

append_files = client.fs_client.ls(append_glob, detail=False, refresh=True)
replace_files = client.fs_client.ls(replace_glob, detail=False, refresh=True)

load_id1 = info1.loads_ids[0]
load_id2 = info2.loads_ids[0]

# resource with pk is loaded with append and has 1 copy for each load
assert len(append_files) == 2
assert any(load_id1 in fn for fn in append_files)
assert any(load_id2 in fn for fn in append_files)

# resource without pk is treated as append disposition
assert len(replace_files) == 2
assert any(load_id1 in fn for fn in replace_files)
assert any(load_id2 in fn for fn in replace_files)

# Verify file contents
assert info2.load_packages
for pkg in info2.load_packages:
assert pkg.jobs["completed_jobs"]
for job in pkg.jobs["completed_jobs"]:
assert_file_matches(layout, job, pkg.load_id, client)

complete_fn = f"{client.schema.name}.{LOADS_TABLE_NAME}.%s"
pipeline.run(some_source(), write_disposition="merge")
assert load_table_counts(pipeline, "some_data", "other_data") == {
"some_data": 3,
"other_data": 5,
}

# Test complete_load markers are saved
assert client.fs_client.isfile(posixpath.join(client.dataset_path, complete_fn % load_id1))
assert client.fs_client.isfile(posixpath.join(client.dataset_path, complete_fn % load_id2))
# second load shows that merge always appends on filesystem
pipeline.run(some_source(), write_disposition="merge")
assert load_table_counts(pipeline, "some_data", "other_data") == {
"some_data": 6,
"other_data": 10,
}

# Force replace
# Force replace, back to initial values
pipeline.run(some_source(), write_disposition="replace")
append_files = client.fs_client.ls(append_glob, detail=False, refresh=True)
replace_files = client.fs_client.ls(replace_glob, detail=False, refresh=True)
assert len(append_files) == 1
assert len(replace_files) == 1
assert load_table_counts(pipeline, "some_data", "other_data") == {
"some_data": 3,
"other_data": 5,
}


@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
Expand Down
13 changes: 10 additions & 3 deletions tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,17 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration)
assert github_2_counts["issues"] == 100 if destination_config.supports_merge else 115


# mark as essential for now
@pytest.mark.essential
@pytest.mark.parametrize(
"destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name
"destination_config",
destinations_configs(default_sql_configs=True, local_filesystem_configs=True),
ids=lambda x: x.name,
)
def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) -> None:
# NOTE: we can test filesystem destination merge behavior here too, will also fallback!
if destination_config.file_format == "insert_values":
pytest.skip("Insert values row count checking is buggy, skipping")
p = destination_config.setup_pipeline("github_3", full_refresh=True)
github_data = github()
# remove all keys
Expand All @@ -264,8 +271,8 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) -
info = p.run(github_data, loader_file_format=destination_config.file_format)
assert_load_info(info)
github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()])
# only ten rows remains. merge falls back to replace when no keys are specified
assert github_1_counts["issues"] == 10 if destination_config.supports_merge else 100 - 45
# we have 10 rows more, merge falls back to append if no keys present
assert github_1_counts["issues"] == 100 - 45 + 10


@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def destinations_configs(
DestinationTestConfiguration(destination="synapse", supports_dbt=False),
]

# sanity check that when selecting default destinations, one of each sql destination is actually
# provided
assert set(SQL_DESTINATIONS) == {d.destination for d in destination_configs}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has happened a few times already... not really related to this pr


if default_vector_configs:
# for now only weaviate
destination_configs += [DestinationTestConfiguration(destination="weaviate")]
Expand Down
57 changes: 25 additions & 32 deletions tests/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import random
from os import environ
import io
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i pulled this over from my filesystem state branch because i needed those changes here.


import dlt
from dlt.common import json, sleep
Expand Down Expand Up @@ -80,7 +81,7 @@ def assert_data_table_counts(p: dlt.Pipeline, expected_counts: DictStrAny) -> No
), f"Table counts do not match, expected {expected_counts}, got {table_counts}"


def load_file(path: str, file: str) -> Tuple[str, List[Dict[str, Any]]]:
def load_file(fs_client, path: str, file: str) -> Tuple[str, List[Dict[str, Any]]]:
"""
util function to load a filesystem destination file and return parsed content
values may not be cast to the right type, especially for insert_values, please
Expand All @@ -96,47 +97,43 @@ def load_file(path: str, file: str) -> Tuple[str, List[Dict[str, Any]]]:

# table name will be last element of path
table_name = path.split("/")[-1]

# skip loads table
if table_name == "_dlt_loads":
return table_name, []

full_path = posixpath.join(path, file)

# load jsonl
if ext == "jsonl":
with open(full_path, "rU", encoding="utf-8") as f:
for line in f:
file_text = fs_client.read_text(full_path)
for line in file_text.split("\n"):
if line:
result.append(json.loads(line))

# load insert_values (this is a bit volatile if the exact format of the source file changes)
elif ext == "insert_values":
with open(full_path, "rU", encoding="utf-8") as f:
lines = f.readlines()
# extract col names
cols = lines[0][15:-2].split(",")
for line in lines[2:]:
file_text = fs_client.read_text(full_path)
lines = file_text.split("\n")
cols = lines[0][15:-2].split(",")
for line in lines[2:]:
if line:
values = line[1:-3].split(",")
result.append(dict(zip(cols, values)))

# load parquet
elif ext == "parquet":
import pyarrow.parquet as pq

with open(full_path, "rb") as f:
table = pq.read_table(f)
cols = table.column_names
count = 0
for column in table:
column_name = cols[count]
item_count = 0
for item in column.to_pylist():
if len(result) <= item_count:
result.append({column_name: item})
else:
result[item_count][column_name] = item
item_count += 1
count += 1
file_bytes = fs_client.read_bytes(full_path)
table = pq.read_table(io.BytesIO(file_bytes))
cols = table.column_names
count = 0
for column in table:
column_name = cols[count]
item_count = 0
for item in column.to_pylist():
if len(result) <= item_count:
result.append({column_name: item})
else:
result[item_count][column_name] = item
item_count += 1
count += 1

return table_name, result

Expand All @@ -149,18 +146,14 @@ def load_files(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, A
client.dataset_path, detail=False, refresh=True
):
for file in files:
table_name, items = load_file(basedir, file)
table_name, items = load_file(client.fs_client, basedir, file)
if table_name not in table_names:
continue
if table_name in result:
result[table_name] = result[table_name] + items
else:
result[table_name] = items

# loads file is special case
if LOADS_TABLE_NAME in table_names and file.find(".{LOADS_TABLE_NAME}."):
result[LOADS_TABLE_NAME] = []

return result


Expand Down
Loading