Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed May 10, 2023
1 parent f697b80 commit 06cfeda
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dbt.adapters.base import BaseRelation
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER
from dbt.events import AdapterLogger
from dbt.flags import get_flags
from dbt.utils import executor, AttrDict

logger = AdapterLogger("Spark")
Expand All @@ -33,6 +34,8 @@
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching"
DESCRIBE_TABLE_EXTENDED_MACRO_NAME = "describe_table_extended_without_caching"
DROP_RELATION_MACRO_NAME = "drop_relation"
FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties"

KEY_TABLE_OWNER = "Owner"
KEY_TABLE_STATISTICS = "Statistics"
Expand Down Expand Up @@ -103,29 +106,41 @@ def date_function(cls) -> str:
return "current_timestamp()"

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table, col_idx):
return "string"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "double" if decimals else "bigint"

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table, col_idx):
return "date"

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table, col_idx):
return "time"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table, col_idx):
return "timestamp"

def quote(self, identifier):
return "`{}`".format(identifier)

def add_schema_to_cache(self, schema) -> str:
"""Cache a new schema in dbt. It will show up in `list relations`."""
if schema is None:
name = self.nice_connection_name()
raise dbt.exceptions.CompilationError(
"Attempted to cache a null schema for {}".format(name)
)
if get_flags().USE_CACHE:
self.cache.add_schema(None, schema)
# so jinja doesn't render things
return ""

def _get_relation_information(self, row: agate.Row) -> RelationInfo:
"""relation info was fetched with SHOW TABLES EXTENDED"""
try:
Expand Down Expand Up @@ -279,9 +294,11 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[
def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]:
assert isinstance(relation, SparkRelation)
if relation.columns is not None and len(relation.columns) > 0:
print(f"HAZ: {relation.columns}")
columns = relation.columns
properties = relation.properties
else:
print(f"NOT HAZ: {relation.columns}")
try:
describe_extended_result = self.execute_macro(
GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation}
Expand Down Expand Up @@ -349,6 +366,12 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str,
as_dict["table_database"] = None
yield as_dict

def get_properties(self, relation: Relation) -> Dict[str, str]:
properties = self.execute_macro(
FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation}
)
return dict(properties)

def get_catalog(self, manifest):
schema_map = self._get_catalog_schemas(manifest)
if len(schema_map) > 1:
Expand Down

0 comments on commit 06cfeda

Please sign in to comment.