From 06cfedab164a586cbeb947a25ea3b308af80da74 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 10 May 2023 15:39:38 +0200 Subject: [PATCH] Cleanup --- dbt/adapters/spark/impl.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 26b9ed1dd..72b95acca 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -24,6 +24,7 @@ from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger +from dbt.flags import get_flags from dbt.utils import executor, AttrDict logger = AdapterLogger("Spark") @@ -33,6 +34,8 @@ LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching" DESCRIBE_TABLE_EXTENDED_MACRO_NAME = "describe_table_extended_without_caching" +DROP_RELATION_MACRO_NAME = "drop_relation" +FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" KEY_TABLE_OWNER = "Owner" KEY_TABLE_STATISTICS = "Statistics" @@ -103,29 +106,41 @@ def date_function(cls) -> str: return "current_timestamp()" @classmethod - def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_text_type(cls, agate_table, col_idx): return "string" @classmethod - def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_number_type(cls, agate_table, col_idx): decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @classmethod - def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_date_type(cls, agate_table, col_idx): return "date" @classmethod - def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_time_type(cls, agate_table, col_idx): return "time" @classmethod - def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_datetime_type(cls, agate_table, col_idx): return "timestamp" def quote(self, identifier): return "`{}`".format(identifier) + def add_schema_to_cache(self, schema) -> str: + """Cache a new schema in dbt. It will show up in `list relations`.""" + if schema is None: + name = self.nice_connection_name() + raise dbt.exceptions.CompilationError( + "Attempted to cache a null schema for {}".format(name) + ) + if get_flags().USE_CACHE: + self.cache.add_schema(None, schema) + # so jinja doesn't render things + return "" + def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: @@ -279,9 +294,11 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: assert isinstance(relation, SparkRelation) if relation.columns is not None and len(relation.columns) > 0: + print(f"HAZ: {relation.columns}") columns = relation.columns properties = relation.properties else: + print(f"NOT HAZ: {relation.columns}") try: describe_extended_result = self.execute_macro( GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} @@ -349,6 +366,12 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, as_dict["table_database"] = None yield as_dict + def get_properties(self, relation: Relation) -> Dict[str, str]: + properties = self.execute_macro( + FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} + ) + return dict(properties) + def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: