Skip to content

Commit

Permalink
calculate max nesting only once, and count nesting level backwards
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 6, 2024
1 parent 23582ce commit 88f7bec
Showing 1 changed file with 33 additions and 41 deletions.
74 changes: 33 additions & 41 deletions dlt/common/normalizers/json/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
TColumnName,
TSimpleRegex,
DLT_NAME_PREFIX,
TTableSchema,
)
from dlt.common.schema.utils import (
column_name_validator,
Expand Down Expand Up @@ -96,7 +95,7 @@ def _reset(self) -> None:
# self.primary_keys = Dict[str, ]

def _flatten(
self, table: str, dict_row: DictStrAny, parent_path: Tuple[str, ...], _r_lvl: int
self, table: str, dict_row: DictStrAny, _r_lvl: int
) -> Tuple[DictStrAny, Dict[Tuple[str, ...], Sequence[Any]]]:
out_rec_row: DictStrAny = {}
out_rec_list: Dict[Tuple[str, ...], Sequence[Any]] = {}
Expand All @@ -116,13 +115,11 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) -
)
# for lists and dicts we must check if type is possibly nested
if isinstance(v, (dict, list)):
if not self._is_nested_type(
self.schema, table, nested_name, self.max_nesting, parent_path, __r_lvl
):
if not self._is_nested_type(self.schema, table, nested_name, __r_lvl):
# TODO: if schema contains table {table}__{nested_name} then convert v into single element list
if isinstance(v, dict):
# flatten the dict more
norm_row_dicts(v, __r_lvl + 1, path + (norm_k,))
norm_row_dicts(v, __r_lvl - 1, path + (norm_k,))
else:
# pass the list to out_rec_list
out_rec_list[path + (schema_naming.normalize_table_identifier(k),)] = v
Expand Down Expand Up @@ -174,9 +171,9 @@ def _add_row_id(
flattened_row: DictStrAny,
parent_row_id: str,
pos: int,
_r_lvl: int,
is_root: bool = False,
) -> str:
if _r_lvl == 0: # root table
if is_root: # root table
row_id_type = self._get_root_row_id_type(self.schema, table)
if row_id_type in ("key_hash", "row_hash"):
subset = None
Expand All @@ -201,14 +198,14 @@ def _add_row_id(
flattened_row[self.c_dlt_id] = row_id
return row_id

def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> StrAny:
def _get_propagated_values(self, table: str, row: DictStrAny, is_root: bool) -> StrAny:
extend: DictStrAny = {}

config = self.propagation_config
if config:
# mapping(k:v): propagate property with name "k" as property with name "v" in nested table
mappings: Dict[TColumnName, TColumnName] = {}
if _r_lvl == 0:
if is_root:
mappings.update(config.get("root") or {})
if table in (config.get("tables") or {}):
mappings.update(config["tables"][table])
Expand Down Expand Up @@ -246,13 +243,13 @@ def _normalize_list(
parent_path,
parent_row_id,
idx,
_r_lvl + 1,
_r_lvl - 1,
)
else:
# found non-dict in seq, so wrap it
wrap_v = wrap_in_dict(self.c_value, v)
DataItemNormalizer._extend_row(extend, wrap_v)
self._add_row_id(table, wrap_v, wrap_v, parent_row_id, idx, _r_lvl)
self._add_row_id(table, wrap_v, wrap_v, parent_row_id, idx)
yield (table, self.schema.naming.shorten_fragments(*parent_path)), wrap_v

def _normalize_row(
Expand All @@ -264,6 +261,7 @@ def _normalize_row(
parent_row_id: Optional[str] = None,
pos: Optional[int] = None,
_r_lvl: int = 0,
is_root: bool = False,
) -> TNormalizedRowIterator:
schema = self.schema
table = schema.naming.shorten_fragments(*parent_path, *ident_path)
Expand All @@ -274,10 +272,10 @@ def _normalize_row(
# infer record hash or leave existing primary key if present
row_id = flattened_row.get(self.c_dlt_id, None)
if not row_id:
row_id = self._add_row_id(table, dict_row, flattened_row, parent_row_id, pos, _r_lvl)
row_id = self._add_row_id(table, dict_row, flattened_row, parent_row_id, pos, is_root)

# find fields to propagate to nested tables in config
extend.update(self._get_propagated_values(table, flattened_row, _r_lvl))
extend.update(self._get_propagated_values(table, flattened_row, is_root))

# yield parent table first
should_descend = yield (
Expand All @@ -295,7 +293,7 @@ def _normalize_row(
list_path,
parent_path + ident_path,
row_id,
_r_lvl + 1,
_r_lvl - 1,
)

def extend_schema(self) -> None:
Expand Down Expand Up @@ -361,10 +359,16 @@ def normalize_data_item(
row = cast(DictStrAny, item)
# identify load id if loaded data must be processed after loading incrementally
row[self.c_dlt_load_id] = load_id
# get table name and nesting level
root_table_name = self.schema.naming.normalize_table_identifier(table_name)
max_nesting = self._get_table_nesting_level(self.schema, root_table_name, self.max_nesting)

yield from self._normalize_row(
row,
{},
(self.schema.naming.normalize_table_identifier(table_name),),
(root_table_name,),
_r_lvl=max_nesting, # we count backwards
is_root=True,
)

@classmethod
Expand Down Expand Up @@ -423,26 +427,21 @@ def _normalize_prop(
)

@staticmethod
@lru_cache(maxsize=None)
def _get_table_nesting_level(
schema: Schema, table_name: str, parent_path: Tuple[str, ...]
schema: Schema, table_name: str, default_nesting: int = 1000
) -> Optional[int]:
"""gets table nesting level, will inherit from parent if not set"""

# try go get table directly

table = schema.tables.get(table_name)
max_nesting = None

if table and (max_nesting := cast(int, table.get("x-normalizer", {}).get("max_nesting"))):
if (
table
and (max_nesting := cast(int, table.get("x-normalizer", {}).get("max_nesting")))
is not None
):
return max_nesting

# if table is not found, try to get it from root path
if max_nesting is None and parent_path:
table = schema.tables.get(parent_path[0])

if table:
return cast(int, table.get("x-normalizer", {}).get("max_nesting"))

return None
return default_nesting

@staticmethod
@lru_cache(maxsize=None)
Expand All @@ -458,22 +457,15 @@ def _is_nested_type(
schema: Schema,
table_name: str,
field_name: str,
max_nesting: int,
parent_path: Tuple[str, ...],
_r_lvl: int,
) -> bool:
"""For those paths the nested objects should be left in place.
Cache perf: max_nesting < _r_lvl: ~2x faster, full check 10x faster
"""
# turn everything at the recursion level into nested type
max_table_nesting = DataItemNormalizer._get_table_nesting_level(
schema, table_name, parent_path
)
if max_table_nesting is not None:
max_nesting = max_table_nesting

assert _r_lvl <= max_nesting
if _r_lvl == max_nesting:
# nesting level is counted backwards
# is we have traversed to or beyond the calculated nesting level, we detect a nested type
if _r_lvl <= 0:
return True

column: TColumnSchema = None
Expand Down

0 comments on commit 88f7bec

Please sign in to comment.