diff --git a/kartothek/core/factory.py b/kartothek/core/factory.py index e3ffba93..a6ae2928 100644 --- a/kartothek/core/factory.py +++ b/kartothek/core/factory.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import copy -from typing import Optional +from typing import Optional, TypeVar from kartothek.core.dataset import DatasetMetadata, DatasetMetadataBase from kartothek.core.utils import _check_callable @@ -30,6 +30,9 @@ def _ensure_factory( ) +T = TypeVar("T", bound="DatasetFactory") + + class DatasetFactory(DatasetMetadataBase): _nullable_attributes = ["_cache_metadata", "_cache_store"] @@ -166,6 +169,6 @@ def load_all_indices(self, load_partition_indices=True, store=None): ) return self - def load_partition_indices(self): + def load_partition_indices(self: T) -> T: self._cache_metadata = self.dataset_metadata.load_partition_indices() return self diff --git a/kartothek/io/dask/delayed.py b/kartothek/io/dask/delayed.py index b74f3e21..2fcb623a 100644 --- a/kartothek/io/dask/delayed.py +++ b/kartothek/io/dask/delayed.py @@ -3,6 +3,7 @@ from collections import defaultdict from functools import partial +from typing import List import dask from dask import delayed @@ -19,7 +20,7 @@ delete_top_level_metadata, ) from kartothek.io_components.gc import delete_files, dispatch_files_to_gc -from kartothek.io_components.merge import align_datasets +from kartothek.io_components.merge import align_datasets, align_datasets_many from kartothek.io_components.metapartition import ( SINGLE_TABLE, MetaPartition, @@ -142,6 +143,28 @@ def _load_and_merge_mps(mp_list, store, label_merger, metadata_merger, merge_tas return mp +def _load_and_merge_many_mps( + mp_list, store, label_merger, metadata_merger, merge_tasks, is_dispatched: bool +): + if is_dispatched: + mp_list = [[mp.load_dataframes(store=store) for mp in mps] for mps in mp_list] + mp_list = [ + MetaPartition.merge_metapartitions(mps).concat_dataframes() + for mps in mp_list + ] + else: + mp_list = [mp.load_dataframes(store=store) for mp in mp_list] + + mp = MetaPartition.merge_metapartitions( + mp_list, label_merger=label_merger, metadata_merger=metadata_merger + ) + + for task in merge_tasks: + mp = mp.merge_many_dataframes(**task) + + return mp + + @default_docs def merge_datasets_as_delayed( left_dataset_uuid, @@ -230,6 +253,86 @@ def merge_datasets_as_delayed( return list(mps) +@default_docs +def merge_many_datasets_as_delayed( + dataset_uuids: List[str], + store, + merge_tasks, + match_how="exact", + dispatch_by=None, + label_merger=None, + metadata_merger=None, +): + """ + A dask.delayed graph to perform the merge of two full kartothek datasets. + + Parameters + ---------- + dataset_uuids : List[str] + match_how : Union[str, Callable] + Define the partition label matching scheme. + Available implementations are: + + * first : The partitions of the first dataset are considered to be the base + partitions and **all** partitions of the remaining datasets are + joined to the partitions of the first dataset. This should only be + used if all but the first dataset contain very few partitions. + * prefix_first : The labels of the partitions of the first dataset are + considered to be the prefixes to the other datasets. + * exact : All partition labels of each dataset need to be an exact match. + * callable : A callable with signature func(labels: List[str]) which + returns a boolean to determine if the partitions match. + + If True, an exact match of partition labels between the to-be-merged + datasets is required in order to merge. + If False (Default), the partition labels of the dataset with fewer + partitions are interpreted as prefixes. + merge_tasks : List[Dict] + A list of merge tasks. Each item in this list is a dictionary giving + explicit instructions for a specific merge. + Each dict should contain key/values: + + * 'output_label' : The table for the merged dataframe + * `merge_func`: A callable with signature + `merge_func(dfs, merge_kwargs)` to + handle the data preprocessing and merging. + * 'merge_kwargs' : The kwargs to be passed to the `merge_func` + + Example: + + .. code:: + + >>> merge_tasks = [ + ... { + ... "tables": ["first_table", "second_table"], + ... "merge_func": func, + ... "merge_kwargs": {"kwargs of merge_func": ''}, + ... "output_label": 'merged_core_data' + ... }, + ... ] + + """ + _check_callable(store) + + mps = align_datasets_many( + dataset_uuids=dataset_uuids, + store=store, + match_how=match_how, + dispatch_by=dispatch_by, + ) + mps = map_delayed( + _load_and_merge_many_mps, + mps, + store=store, + label_merger=label_merger, + metadata_merger=metadata_merger, + merge_tasks=merge_tasks, + is_dispatched=dispatch_by is not None, + ) + + return list(mps) + + def _load_and_concat_metapartitions_inner(mps, args, kwargs): return MetaPartition.concat_metapartitions( [mp.load_dataframes(*args, **kwargs) for mp in mps] diff --git a/kartothek/io_components/merge.py b/kartothek/io_components/merge.py index c1752fd2..3d21ff9f 100644 --- a/kartothek/io_components/merge.py +++ b/kartothek/io_components/merge.py @@ -1,8 +1,14 @@ import logging +from functools import partial, reduce +from typing import Dict, List, Optional + +import pandas as pd from kartothek.core.dataset import DatasetMetadata +from kartothek.core.factory import DatasetFactory from kartothek.io_components.metapartition import MetaPartition -from kartothek.io_components.utils import _instantiate_store +from kartothek.io_components.read import dispatch_metapartitions_from_factory +from kartothek.io_components.utils import _instantiate_store, _make_callable LOGGER = logging.getLogger(__name__) @@ -104,3 +110,104 @@ def align_datasets(left_dataset_uuid, right_dataset_uuid, store, match_how="exac "found".format(p_1, first_dataset) ) yield res + + +def align_datasets_many( + dataset_uuids: List[str], + store, + match_how: str = "exact", + dispatch_by: Optional[List[str]] = None, +): + """ + Determine dataset partition alignment + + Parameters + ---------- + left_dataset_uuid : basestring + right_dataset_uuid : basestring + store : KeyValuestore or callable + match_how : basestring or callable, {exact, prefix, all, callable} + + Yields + ------ + list + """ + if len(dataset_uuids) < 2: + raise ValueError("Need at least two datasets for merging.") + dataset_factories = [ + DatasetFactory( + dataset_uuid=dataset_uuid, + store_factory=_make_callable(store), + load_schema=True, + load_all_indices=False, + load_dataset_metadata=True, + ).load_partition_indices() + for dataset_uuid in dataset_uuids + ] + + store = _instantiate_store(store) + mps = [ + # TODO: Add predicates + # We don't pass dispatch_by here as we will do the dispatching later + list(dispatch_metapartitions_from_factory(dataset_factory=dataset_factory)) + for dataset_factory in dataset_factories + ] + + if match_how == "first": + if len(set(len(x) for x in mps)) != 1: + raise RuntimeError("All datasets must have the same number of partitions") + for mp_0 in mps[0]: + for other_mps in zip(*mps[1:]): + yield [mp_0] + list(other_mps) + elif match_how == "prefix_first": + # TODO: write a test which protects against the following scenario!! + # Sort the partition labels by length of the labels, starting with the + # labels which are the longest. This way we prevent label matching for + # similar partitions, e.g. cluster_100 and cluster_1. This, of course, + # works only as long as the internal loop removes elements which were + # matched already (here improperly called stack) + for mp_0 in mps[0]: + res = [mp_0] + label_0 = mp_0.label + for dataset_i in range(1, len(mps)): + for j, mp_i in enumerate(mps[dataset_i]): + if mp_i.label.startswith(label_0): + res.append(mp_i) + del mps[dataset_i][j] + break + else: + raise RuntimeError( + f"Did not find a matching partition in dataset {dataset_uuids[dataset_i]} for partition {label_0}" + ) + yield res + elif match_how == "exact": + raise NotImplementedError("exact") + elif match_how == "dispatch_by": + index_dfs = [] + for i, factory in enumerate(dataset_factories): + df = factory.get_indices_as_dataframe(dispatch_by) + index_dfs.append( + df.reset_index().rename( + columns={"partition": f"partition_{i}"}, copy=False + ) + ) + index_df = reduce(partial(pd.merge, on=dispatch_by), index_dfs) + + mps_by_label: List[Dict[str, MetaPartition]] = [] + for mpx in mps: + mps_by_label.append({}) + for mp in mpx: + mps_by_label[-1][mp.label] = mp + + for _, group in index_df.groupby(dispatch_by): + res_nested: List[List[MetaPartition]] = [] + for i in range(len(dataset_uuids)): + res_nested.append( + [ + mps_by_label[i][label] + for label in group[f"partition_{i}"].unique() + ] + ) + yield res_nested + else: + raise NotImplementedError(f"matching with '{match_how}' is not supported") diff --git a/kartothek/io_components/metapartition.py b/kartothek/io_components/metapartition.py index c1c2646c..98397a0a 100644 --- a/kartothek/io_components/metapartition.py +++ b/kartothek/io_components/metapartition.py @@ -10,7 +10,7 @@ from collections import Iterable, Iterator, defaultdict, namedtuple from copy import copy from functools import wraps -from typing import Any, Dict, Optional, cast +from typing import Any, Callable, Dict, List, Optional, cast import numpy as np import pandas as pd @@ -137,7 +137,7 @@ def _impl(self, *method_args, **method_kwargs): result = result.add_metapartition(mp, schema_validation=False) if not isinstance(result, MetaPartition): raise ValueError( - "Result for method {} is not a `MetaPartition` but".format( + "Result for method {} is not a `MetaPartition` but {}".format( method.__name__, type(method_return) ) ) @@ -926,6 +926,75 @@ def merge_dataframes( ) return self.copy(files={}, data=new_data, table_meta=new_table_meta) + @_apply_to_list + def merge_many_dataframes( + self, + tables: List[str], + merge_func: Callable, + merge_kwargs: Optional[Dict[str, Any]], + output_label: str, + ): + """ + Merge internal dataframes. + + The referenced dataframes are removed from the internal list and + the newly created dataframe is added. + + The merge itself can be completely customized by supplying a + callable `merge_func(dfs, **merge_kwargs)` which can + handle data pre-processing as well as the merge itself. + + Parameters + ---------- + dfs : List[str] + Category of the left dataframe. + output_label : str + Category for the newly created dataframe + merge_func : callable + The function to take care of the merge. + The function should have the signature + :func:`func(dfs, **kwargs)` + merge_kwargs : dict + Keyword arguments which should be supplied to the merge function + + Returns + ------- + MetaPartition + + """ + # Shallow copy + new_data = copy(self.data) + if merge_kwargs is None: + merge_kwargs = {} + + print(new_data.keys()) + dfs = [new_data.pop(table) for table in tables] + + LOGGER.debug("Merging internal dataframes of %s", self.label) + + try: + df_merged = merge_func(dfs, **merge_kwargs) + except TypeError: + LOGGER.error( + "Tried to merge using %s with kwargs:%s", + merge_func.__name__, + merge_kwargs, + ) + raise + + new_data[output_label] = df_merged + new_table_meta = copy(self.table_meta) + # The tables are no longer part of the MetaPartition, thus also drop + # their schema. + for table in tables: + del new_table_meta[table] + new_table_meta[output_label] = make_meta( + df_merged, + origin="{}/{}".format(output_label, self.label), + partition_keys=self.partition_keys, + ) + return self.copy(files={}, data=new_data, table_meta=new_table_meta) + @_apply_to_list def validate_schema_compatible(self, store, dataset_uuid): """ diff --git a/kartothek/io_components/read.py b/kartothek/io_components/read.py index 570c92ab..d6843a96 100644 --- a/kartothek/io_components/read.py +++ b/kartothek/io_components/read.py @@ -1,5 +1,5 @@ import warnings -from typing import Iterator, List, Set, Union, cast +from typing import Any, Callable, Iterator, List, Optional, Set, Union, cast, overload import pandas as pd @@ -10,6 +10,30 @@ from kartothek.serialization import check_predicates, columns_in_predicates +@overload +def dispatch_metapartitions_from_factory( + dataset_factory: Union[DatasetFactory, Callable], + label_filter: Optional[Callable] = None, + concat_partitions_on_primary_index: bool = False, + predicates: Optional[Any] = None, + store: Optional[Callable] = None, + dispatch_by: None = None, +) -> Iterator[MetaPartition]: + ... + + +@overload +def dispatch_metapartitions_from_factory( + dataset_factory: Union[DatasetFactory, Callable], + label_filter: Optional[Callable], + concat_partitions_on_primary_index, + predicates: Optional[Any], + store: Optional[Callable], + dispatch_by: List[str], +) -> Iterator[List[MetaPartition]]: + ... + + @normalize_args def dispatch_metapartitions_from_factory( dataset_factory, @@ -135,6 +159,6 @@ def dispatch_metapartitions( dataset_factory=dataset_factory, label_filter=label_filter, predicates=predicates, - dispatch_by=dispatch_by, concat_partitions_on_primary_index=concat_partitions_on_primary_index, + dispatch_by=dispatch_by, ) diff --git a/tests/io/dask/delayed/test_merge_many.py b/tests/io/dask/delayed/test_merge_many.py new file mode 100644 index 00000000..3c0ee69b --- /dev/null +++ b/tests/io/dask/delayed/test_merge_many.py @@ -0,0 +1,165 @@ +from datetime import date + +import dask +import pandas as pd +import pandas.testing as pdt + +from kartothek.io.dask.delayed import merge_many_datasets_as_delayed + + +def _merge_many(dfs, *args, **kwargs): + return pd.merge(dfs[0], dfs[1], *args, **kwargs) + + +MERGE_TASKS = [ + { + "tables": ["table", "PRED"], + "merge_func": _merge_many, + "merge_kwargs": {"how": "left", "sort": False, "copy": False}, + "output_label": "merged_core_data", + } +] + +MERGE_EXP_CL1 = pd.DataFrame( + { + "P": [1], + "L": [1], + "TARGET": [1], + "HORIZON": [1], + "PRED": [10], + "DATE": pd.to_datetime([date(2010, 1, 1)]), + } +) + +MERGE_EXP_CL2 = pd.DataFrame( + { + "P": [2], + "L": [2], + "TARGET": [2], + "HORIZON": [1], + "PRED": [10], + "DATE": pd.to_datetime([date(2009, 12, 31)]), + } +) + + +def test_merge_many_datasets_prefix_first( + dataset, evaluation_dataset, store_factory, store_session_factory, frozen_time +): + df_list = merge_many_datasets_as_delayed( + dataset_uuids=[dataset.uuid, evaluation_dataset.uuid], + store=store_session_factory, + merge_tasks=MERGE_TASKS, + match_how="prefix_first", + ) + df_list = dask.compute(df_list)[0] + df_list = [mp.data for mp in df_list] + + # Two partitions + assert len(df_list) == 2 + assert len(df_list[1]) == 2 + assert len(df_list[0]) == 2 + pdt.assert_frame_equal( + df_list[0]["merged_core_data"], + MERGE_EXP_CL1, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + pdt.assert_frame_equal( + df_list[1]["merged_core_data"], + MERGE_EXP_CL2, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + + +MERGE_TASKS_FIRST = [ + { + "tables": ["table_0", "table_1"], + "merge_func": _merge_many, + "merge_kwargs": {"how": "outer", "sort": False, "copy": False}, + "output_label": "merged_core_data", + } +] + +MERGE_EXP_CL1_FIRST = pd.DataFrame( + {"P": [1], "L": [1], "TARGET": [1], "DATE": pd.to_datetime([date(2010, 1, 1)])} +) + +MERGE_EXP_CL2_FIRST = pd.DataFrame( + { + "P": [1, 2], + "L": [1, 2], + "TARGET": [1, 2], + "DATE": pd.to_datetime([date(2010, 1, 1), date(2009, 12, 31)]), + } +) +MERGE_EXP_CL3_FIRST = pd.DataFrame( + { + "P": [2, 1], + "L": [2, 1], + "TARGET": [2, 1], + "DATE": pd.to_datetime([date(2009, 12, 31), date(2010, 1, 1)]), + } +) +MERGE_EXP_CL4_FIRST = pd.DataFrame( + {"P": [2], "L": [2], "TARGET": [2], "DATE": pd.to_datetime([date(2009, 12, 31)])} +) + + +def test_merge_many_dataset_first( + dataset_partition_keys, store_session_factory, frozen_time +): + df_list = merge_many_datasets_as_delayed( + dataset_uuids=[dataset_partition_keys.uuid, dataset_partition_keys.uuid], + store=store_session_factory, + merge_tasks=MERGE_TASKS_FIRST, + match_how="first", + ) + df_list = dask.compute(df_list)[0] + df_list = [mp.data for mp in df_list] + assert len(df_list) == 4 + pdt.assert_frame_equal( + df_list[0]["merged_core_data"], + MERGE_EXP_CL1_FIRST, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + pdt.assert_frame_equal( + df_list[1]["merged_core_data"], + MERGE_EXP_CL2_FIRST, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + pdt.assert_frame_equal( + df_list[2]["merged_core_data"], + MERGE_EXP_CL3_FIRST, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + pdt.assert_frame_equal( + df_list[3]["merged_core_data"], + MERGE_EXP_CL4_FIRST, + check_like=True, + check_dtype=False, + check_categorical=False, + ) + + +def test_merge_many_dataset_dispatch_by( + dataset_partition_keys, store_session_factory, frozen_time +): + df_list = merge_many_datasets_as_delayed( + dataset_uuids=[dataset_partition_keys.uuid, dataset_partition_keys.uuid], + store=store_session_factory, + merge_tasks=MERGE_TASKS_FIRST, + match_how="dispatch_by", + dispatch_by=["P"], + ) + df_list = dask.compute(df_list)[0] + df_list = [mp.data for mp in df_list]