Skip to content

Commit

Permalink
Convert information into dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed May 8, 2023
1 parent cb41ab0 commit 17ad88f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 62 deletions.
162 changes: 103 additions & 59 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
DROP_RELATION_MACRO_NAME = "drop_relation"
FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties"

KEY_TABLE_OWNER = "Owner"
KEY_TABLE_STATISTICS = "Statistics"
KEY_TABLE_OWNER = "owner"
KEY_TABLE_STATISTICS = "statistics"

TABLE_OR_VIEW_NOT_FOUND_MESSAGES = (
"[TABLE_OR_VIEW_NOT_FOUND]",
Expand All @@ -58,6 +58,14 @@ class SparkConfig(AdapterConfig):
merge_update_columns: Optional[str] = None


@dataclass(frozen=True)
class RelationInfo:
table_schema: str
table_name: str
columns: List[Tuple[str, str]]
properties: Dict[str, str]


class SparkAdapter(SQLAdapter):
COLUMN_NAMES = (
"table_database",
Expand All @@ -79,9 +87,7 @@ class SparkAdapter(SQLAdapter):
"stats:rows:description",
"stats:rows:include",
)
INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE)
INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE)
INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE)
INFORMATION_COLUMN_REGEX = re.compile(r" \|-- (.*): (.*) \(nullable = (.*)\)")
HUDI_METADATA_COLUMNS = [
"_hoodie_commit_time",
"_hoodie_commit_seqno",
Expand All @@ -91,7 +97,6 @@ class SparkAdapter(SQLAdapter):
]

Relation: TypeAlias = SparkRelation
RelationInfo = Tuple[str, str, str]
Column: TypeAlias = SparkColumn
ConnectionManager: TypeAlias = SparkConnectionManager
AdapterSpecificConfigs: TypeAlias = SparkConfig
Expand Down Expand Up @@ -139,13 +144,42 @@ def add_schema_to_cache(self, schema) -> str:
def _get_relation_information(self, row: agate.Row) -> RelationInfo:
"""relation info was fetched with SHOW TABLES EXTENDED"""
try:
_schema, name, _, information = row
# Example lines:
# Database: dbt_schema
# Table: names
# Owner: fokkodriesprong
# Created Time: Mon May 08 18:06:47 CEST 2023
# Last Access: UNKNOWN
# Created By: Spark 3.3.2
# Type: MANAGED
# Provider: hive
# Table Properties: [transient_lastDdlTime=1683562007]
# Statistics: 16 bytes
# Schema: root
# |-- idx: integer (nullable = false)
# |-- name: string (nullable = false)
table_properties = {}
columns = []
_schema, name, _, information_blob = row
for line in information_blob.split("\n"):
if line:
if line.startswith(" |--"):
# A column
m = self.INFORMATION_COLUMN_REGEX.match(line)
columns.append(
(m[1], m[2])
)
else:
# A property
parts = line.split(": ", maxsplit=2)
table_properties[parts[0].lower()] = parts[1]

except ValueError:
raise dbt.exceptions.DbtRuntimeError(
f'Invalid value from "show tables extended ...", got {len(row)} values, expected 4'
)

return _schema, name, information
return RelationInfo(_schema, name, columns, table_properties)

def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo:
"""Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement"""
Expand All @@ -165,13 +199,49 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn
logger.debug(f"Error while retrieving information about {table_name}: {e.msg}")
table_results = AttrDict()

information = ""
for info_row in table_results:
info_type, info_value, _ = info_row
if not info_type.startswith("#"):
information += f"{info_type}: {info_value}\n"
# idx int
# name string
#
# # Partitioning
# Not partitioned
#
# # Metadata Columns
# _spec_id int
# _partition struct<>
# _file string
# _pos bigint
# _deleted boolean
#
# # Detailed Table Information
# Name sandbox.dbt_tabular3.names
# Location s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb
# Provider iceberg

# Wrap it in an iter, so we continue reading the properties from where we stopped reading columns
table_results_itr = iter(table_results)

# First the columns
columns = []
for info_row in table_results_itr:
if info_row[0] == '':
break
columns.append(
(info_row[0], info_row[1])
)

return _schema, name, information
# Next all the properties
table_properties = {}
for info_row in table_results_itr:
info_type, info_value, _ = info_row
if not info_type.startswith("#") and info_type != '':
table_properties[info_type.lower()] = info_value

return RelationInfo(
_schema,
name,
columns,
table_properties
)

def _build_spark_relation_list(
self,
Expand All @@ -181,23 +251,24 @@ def _build_spark_relation_list(
"""Aggregate relations with format metadata included."""
relations = []
for row in row_list:
_schema, name, information = relation_info_func(row)
relation = relation_info_func(row)

rel_type: RelationType = (
RelationType.View if "Type: VIEW" in information else RelationType.Table
RelationType.View if relation.properties.get("type") == "VIEW" else RelationType.Table
)
is_delta: bool = "Provider: delta" in information
is_hudi: bool = "Provider: hudi" in information
is_iceberg: bool = "Provider: iceberg" in information
is_delta: bool = relation.properties.get("provider") == "delta"
is_hudi: bool = relation.properties.get("provider") == "hudi"
is_iceberg: bool = relation.properties.get("provider") == "iceberg"

relation: BaseRelation = self.Relation.create( # type: ignore
schema=_schema,
identifier=name,
schema=relation.table_schema,
identifier=relation.table_name,
type=rel_type,
information=information,
is_delta=is_delta,
is_iceberg=is_iceberg,
is_hudi=is_hudi,
columns=relation.columns,
properties=relation.properties,
)
relations.append(relation)

Expand Down Expand Up @@ -250,44 +321,26 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[
return super().get_relation(database, schema, identifier)

def parse_describe_extended(
self, relation: BaseRelation, raw_rows: AttrDict
self, relation: SparkRelation, raw_rows: AttrDict
) -> List[SparkColumn]:
# Convert the Row to a dict
dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows]
# Find the separator between the rows and the metadata provided
# by the DESCRIBE TABLE EXTENDED statement
pos = self.find_table_information_separator(dict_rows)

# Remove rows that start with a hash, they are comments
rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")]
metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]}

raw_table_stats = metadata.get(KEY_TABLE_STATISTICS)
raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS)
table_stats = SparkColumn.convert_table_stats(raw_table_stats)
return [
SparkColumn(
table_database=None,
table_schema=relation.schema,
table_name=relation.name,
table_type=relation.type,
table_owner=str(metadata.get(KEY_TABLE_OWNER)),
table_owner=relation.properties.get(KEY_TABLE_OWNER, ""),
table_stats=table_stats,
column=column["col_name"],
column=column_name,
column_index=idx,
dtype=column["data_type"],
dtype=column_type,
)
for idx, column in enumerate(rows)
for idx, (column_name, column_type) in enumerate(relation.columns)
]

@staticmethod
def find_table_information_separator(rows: List[dict]) -> int:
pos = 0
for row in rows:
if not row["col_name"] or row["col_name"].startswith("#"):
break
pos += 1
return pos

def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]:
columns = []
try:
Expand All @@ -309,20 +362,11 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]:
columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS]
return columns

def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkColumn]:
if hasattr(relation, "information"):
information = relation.information or ""
else:
information = ""
owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information)
owner = owner_match[0] if owner_match else None
matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information)
def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]:
owner = relation.properties.get(KEY_TABLE_OWNER, "")
columns = []
stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, information)
raw_table_stats = stats_match[0] if stats_match else None
table_stats = SparkColumn.convert_table_stats(raw_table_stats)
for match_num, match in enumerate(matches):
column_name, column_type, nullable = match.groups()
table_stats = SparkColumn.convert_table_stats(relation.properties.get(KEY_TABLE_STATISTICS))
for match_num, (column_name, column_type) in enumerate(relation.columns):
column = SparkColumn(
table_database=None,
table_schema=relation.schema,
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import Optional, TypeVar, List, Tuple, Dict
from dataclasses import dataclass, field

from dbt.adapters.base.relation import BaseRelation, Policy
Expand Down Expand Up @@ -33,8 +33,8 @@ class SparkRelation(BaseRelation):
is_delta: Optional[bool] = None
is_hudi: Optional[bool] = None
is_iceberg: Optional[bool] = None
# TODO: make this a dict everywhere
information: Optional[str] = None
columns: List[Tuple[str, str]] = field(default_factory=list)
properties: Dict[str, str] = field(default_factory=dict)

def __post_init__(self):
if self.database != self.schema and self.database:
Expand Down

0 comments on commit 17ad88f

Please sign in to comment.