Skip to content

Commit

Permalink
Add support for predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
xhochy committed Mar 11, 2020
1 parent 9c3765f commit 273271d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
22 changes: 20 additions & 2 deletions kartothek/io/dask/delayed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-


from collections import defaultdict
from copy import copy
from functools import partial
from typing import List

Expand Down Expand Up @@ -39,6 +39,7 @@
raise_if_dataset_exists,
store_dataset_from_partitions,
)
from kartothek.serialization import filter_df_from_predicates

from ._update import _update_dask_partitions_one_to_one
from ._utils import (
Expand Down Expand Up @@ -144,7 +145,13 @@ def _load_and_merge_mps(mp_list, store, label_merger, metadata_merger, merge_tas


def _load_and_merge_many_mps(
mp_list, store, label_merger, metadata_merger, merge_tasks, is_dispatched: bool
mp_list,
store,
label_merger,
metadata_merger,
merge_tasks,
is_dispatched: bool,
predicates=None,
):
if is_dispatched:
mp_list = [[mp.load_dataframes(store=store) for mp in mps] for mps in mp_list]
Expand All @@ -162,6 +169,14 @@ def _load_and_merge_many_mps(
for task in merge_tasks:
mp = mp.merge_many_dataframes(**task)

if predicates:
new_data = copy(mp.data)
new_data = {
key: filter_df_from_predicates(df, predicates=predicates)
for key, df in new_data.items()
}
mp = mp.copy(data=new_data)

return mp


Expand Down Expand Up @@ -262,6 +277,7 @@ def merge_many_datasets_as_delayed(
dispatch_by=None,
label_merger=None,
metadata_merger=None,
predicates=None,
):
"""
A dask.delayed graph to perform the merge of two full kartothek datasets.
Expand Down Expand Up @@ -319,6 +335,7 @@ def merge_many_datasets_as_delayed(
store=store,
match_how=match_how,
dispatch_by=dispatch_by,
predicates=predicates,
)
mps = map_delayed(
_load_and_merge_many_mps,
Expand All @@ -328,6 +345,7 @@ def merge_many_datasets_as_delayed(
metadata_merger=metadata_merger,
merge_tasks=merge_tasks,
is_dispatched=dispatch_by is not None,
predicates=predicates,
)

return list(mps)
Expand Down
9 changes: 7 additions & 2 deletions kartothek/io_components/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def align_datasets_many(
store,
match_how: str = "exact",
dispatch_by: Optional[List[str]] = None,
predicates=None,
):
"""
Determine dataset partition alignment
Expand Down Expand Up @@ -149,7 +150,11 @@ def align_datasets_many(
mps = [
# TODO: Add predicates
# We don't pass dispatch_by here as we will do the dispatching later
list(dispatch_metapartitions_from_factory(dataset_factory=dataset_factory))
list(
dispatch_metapartitions_from_factory(
dataset_factory=dataset_factory, predicates=predicates
)
)
for dataset_factory in dataset_factories
]

Expand Down Expand Up @@ -185,7 +190,7 @@ def align_datasets_many(
elif match_how == "dispatch_by":
index_dfs = []
for i, factory in enumerate(dataset_factories):
df = factory.get_indices_as_dataframe(dispatch_by)
df = factory.get_indices_as_dataframe(dispatch_by, predicates=predicates)
index_dfs.append(
df.reset_index().rename(
columns={"partition": f"partition_{i}"}, copy=False
Expand Down
58 changes: 58 additions & 0 deletions tests/io/dask/delayed/test_merge_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,61 @@ def test_merge_many_dataset_dispatch_by(
)
df_list = dask.compute(df_list)[0]
df_list = [mp.data for mp in df_list]
assert len(df_list) == 2

pdt.assert_frame_equal(
df_list[0]["merged_core_data"],
MERGE_EXP_CL1_FIRST,
check_like=True,
check_dtype=False,
check_categorical=False,
)
pdt.assert_frame_equal(
df_list[1]["merged_core_data"],
MERGE_EXP_CL4_FIRST,
check_like=True,
check_dtype=False,
check_categorical=False,
)

# predicate on primary index
df_list = merge_many_datasets_as_delayed(
dataset_uuids=[dataset_partition_keys.uuid, dataset_partition_keys.uuid],
store=store_session_factory,
merge_tasks=MERGE_TASKS_FIRST,
match_how="dispatch_by",
dispatch_by=["P"],
predicates=[[("P", "==", 1)]],
)
df_list = dask.compute(df_list)[0]
df_list = [mp.data for mp in df_list]
assert len(df_list) == 1

pdt.assert_frame_equal(
df_list[0]["merged_core_data"],
MERGE_EXP_CL1_FIRST,
check_like=True,
check_dtype=False,
check_categorical=False,
)

# predicate on non-index
df_list = merge_many_datasets_as_delayed(
dataset_uuids=[dataset_partition_keys.uuid, dataset_partition_keys.uuid],
store=store_session_factory,
merge_tasks=MERGE_TASKS_FIRST,
match_how="dispatch_by",
dispatch_by=["P"],
predicates=[[("TARGET", "==", 1)]],
)
df_list = dask.compute(df_list)[0]
df_list = [mp.data for mp in df_list]
assert len(df_list) == 1

pdt.assert_frame_equal(
df_list[0]["merged_core_data"],
MERGE_EXP_CL1_FIRST,
check_like=True,
check_dtype=False,
check_categorical=False,
)

0 comments on commit 273271d

Please sign in to comment.