diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 01b151686..bbd1a140b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -28,6 +28,7 @@ logger = AdapterLogger("Spark") +GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME = "get_columns_in_relation_raw" LIST_SCHEMAS_MACRO_NAME = "list_schemas" LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" LIST_RELATIONS_SHOW_TABLES_MACRO_NAME = "list_relations_show_tables_without_caching" @@ -155,7 +156,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: return RelationInfo(_schema, name, columns, table_properties) - def _parse_describe_table( + def _parse_describe_table_extended( self, table_results: agate.Table ) -> Tuple[List[Tuple[str, str]], Dict[str, str]]: # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns @@ -195,7 +196,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - columns, table_properties = self._parse_describe_table(table_results) + columns, table_properties = self._parse_describe_table_extended(table_results) return RelationInfo(_schema, name, columns, table_properties) def _build_spark_relation_list( @@ -276,9 +277,29 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: - assert isinstance(relation, SparkRelation), f"Expected SparkRelation, got: {relation}" + assert isinstance(relation, SparkRelation) + if rel_columns := relation.columns: + columns = rel_columns + properties = relation.properties + else: + try: + describe_extended_result = self.execute_macro( + GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} + ) + columns, properties = self._parse_describe_table_extended(describe_extended_result) + except dbt.exceptions.DbtRuntimeError as e: + # spark would throw error when table doesn't exist, where other + # CDW would just return and empty list, normalizing the behavior here + errmsg = getattr(e, "msg", "") + found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES) + if any(found_msgs): + columns = [] + properties = {} + else: + raise e + # Convert the Row to a dict - raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS) + raw_table_stats = properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) return [ SparkColumn( @@ -286,13 +307,13 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: table_schema=relation.schema, table_name=relation.name, table_type=relation.type, - table_owner=relation.properties.get(KEY_TABLE_OWNER, ""), + table_owner=properties.get(KEY_TABLE_OWNER, ""), table_stats=table_stats, column=column_name, column_index=idx, dtype=column_type, ) - for idx, (column_name, column_type) in enumerate(relation.columns) + for idx, (column_name, column_type) in enumerate(columns) if column_name not in self.HUDI_METADATA_COLUMNS ] @@ -385,19 +406,21 @@ def get_rows_different_sql( column_names: Optional[List[str]] = None, except_operator: str = "EXCEPT", ) -> str: - """Generate SQL for a query that returns a single row with a two + """Generate SQL for a query that returns a single row with two columns: the number of rows that are different between the two relations and the number of mismatched rows. """ # This method only really exists for test reasons. names: List[str] - if column_names is None: + if not column_names: columns = self.get_columns_in_relation(relation_a) names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) columns_csv = ", ".join(names) + assert columns_csv, f"Could not determine columns for: {relation_a}" + sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, relation_a=str(relation_a), diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3bec3df33..aa972913e 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -323,7 +323,7 @@ def test_parse_relation(self): config = self._get_target_http(self.project_cfg) adapter = SparkAdapter(config) - columns, properties = adapter._parse_describe_table(input_cols) + columns, properties = adapter._parse_describe_table_extended(input_cols) relation_info = adapter._build_spark_relation_list( columns, lambda a: RelationInfo(relation.schema, relation.name, columns, properties) ) @@ -414,7 +414,7 @@ def test_parse_relation_with_integer_owner(self): ] config = self._get_target_http(self.project_cfg) - _, properties = SparkAdapter(config)._parse_describe_table(plain_rows) + _, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) self.assertEqual(properties.get(KEY_TABLE_OWNER), "1234") @@ -448,7 +448,7 @@ def test_parse_relation_with_statistics(self): ] config = self._get_target_http(self.project_cfg) - columns, properties = SparkAdapter(config)._parse_describe_table(plain_rows) + columns, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) spark_relation = SparkRelation.create( schema=relation.schema, identifier=relation.name,