diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py index eb7741c82929b..f952bf9d0421b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py @@ -27,11 +27,16 @@ ) from datahub.ingestion.api.source import MetadataWorkUnitProcessor from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.cassandra.cassandra_api import CassandraAPIInterface +from datahub.ingestion.source.cassandra.cassandra_api import ( + CassandraAPI, + CassandraColumn, + CassandraKeyspace, + CassandraTable, + CassandraView, +) from datahub.ingestion.source.cassandra.cassandra_config import CassandraSourceConfig from datahub.ingestion.source.cassandra.cassandra_profiling import CassandraProfiler from datahub.ingestion.source.cassandra.cassandra_utils import ( - COL_NAMES, SYSTEM_KEYSPACE_LIST, CassandraToSchemaFieldConverter, ) @@ -107,6 +112,9 @@ class CassandraEntities: tables: Dict[str, List[str]] = field( default_factory=dict ) # Maps keyspace -> tables + columns: Dict[str, List[CassandraColumn]] = field( + default_factory=dict + ) # Maps tables -> columns @platform_name("Cassandra") @@ -141,7 +149,7 @@ def __init__(self, ctx: PipelineContext, config: CassandraSourceConfig): self.platform = PLATFORM_NAME_IN_DATAHUB self.config = config self.report = CassandraSourceReport() - self.cassandra_api = CassandraAPIInterface(config, self.report) + self.cassandra_api = CassandraAPI(config, self.report) self.cassandra_data = CassandraEntities() # For profiling self.profiler = CassandraProfiler(config, self.report, self.cassandra_api) @@ -165,9 +173,9 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: def get_workunits_internal( self, ) -> Iterable[MetadataWorkUnit]: - keyspaces = self.cassandra_api.get_keyspaces() + keyspaces: List[CassandraKeyspace] = self.cassandra_api.get_keyspaces() for keyspace in keyspaces: - keyspace_name: str = getattr(keyspace, COL_NAMES["keyspace_name"]) + keyspace_name: str = keyspace.keyspace_name if keyspace_name in SYSTEM_KEYSPACE_LIST: continue @@ -175,7 +183,7 @@ def get_workunits_internal( self.report.report_dropped(keyspace_name) continue - yield from self._generate_keyspace_container(keyspace_name) + yield from self._generate_keyspace_container(keyspace) try: yield from self._extract_tables_from_keyspace(keyspace_name) @@ -186,7 +194,6 @@ def get_workunits_internal( context=keyspace_name, exc=e, ) - try: yield from self._extract_views_from_keyspace(keyspace_name) except Exception as e: @@ -199,15 +206,15 @@ def get_workunits_internal( # Profiling if self.config.is_profiling_enabled(): - for keyspace in self.cassandra_data.keyspaces: - tables = self.cassandra_data.tables.get(keyspace, []) - self.report.set_ingestion_stage(keyspace, PROFILING) + for keyspace_name in self.cassandra_data.keyspaces: + tables = self.cassandra_data.tables.get(keyspace_name, []) + self.report.set_ingestion_stage(keyspace_name, PROFILING) with ThreadPoolExecutor( max_workers=self.config.profiling.max_workers ) as executor: future_to_dataset = { executor.submit( - self.generate_profiles, keyspace, table_name + self.generate_profiles, keyspace_name, table_name ): table_name for table_name in tables } @@ -219,17 +226,24 @@ def get_workunits_internal( self.report.profiling_skipped_other[table_name] += 1 self.report.report_failure( message="Failed to profile for table", - context=f"{keyspace}.{table_name}", + context=f"{keyspace_name}.{table_name}", exc=exc, ) def _generate_keyspace_container( - self, keyspace_name: str + self, keyspace: CassandraKeyspace ) -> Iterable[MetadataWorkUnit]: - keyspace_container_key = self._generate_keyspace_container_key(keyspace_name) + keyspace_container_key = self._generate_keyspace_container_key( + keyspace.keyspace_name + ) yield from gen_containers( container_key=keyspace_container_key, - name=keyspace_name, + name=keyspace.keyspace_name, + qualified_name=keyspace.keyspace_name, + extra_properties={ + "durable_writes": str(keyspace.durable_writes), + "replication": json.dumps(keyspace.replication), + }, sub_types=[DatasetContainerSubTypes.KEYSPACE], ) @@ -246,10 +260,10 @@ def _extract_tables_from_keyspace( self, keyspace_name: str ) -> Iterable[MetadataWorkUnit]: self.cassandra_data.keyspaces.append(keyspace_name) - tables = self.cassandra_api.get_tables(keyspace_name) + tables: List[CassandraTable] = self.cassandra_api.get_tables(keyspace_name) for table in tables: # define the dataset urn for this table to be used downstream - table_name: str = getattr(table, COL_NAMES["table_name"]) + table_name: str = table.table_name dataset_name: str = f"{keyspace_name}.{table_name}" if not self.config.table_pattern.allowed(dataset_name): @@ -299,13 +313,25 @@ def _extract_tables_from_keyspace( qualifiedName=f"{keyspace_name}.{table_name}", description=table.comment, customProperties={ + "id": str(table.id), "bloom_filter_fp_chance": str(table.bloom_filter_fp_chance), - "caching": str(table.caching), - "cdc": str(table.cdc), - "compaction": str(table.compaction), - "compression": str(table.compression), + "caching": json.dumps(table.caching), + "compaction": json.dumps(table.compaction), + "compression": json.dumps(table.compression), + "crc_check_chance": str(table.crc_check_chance), + "dclocal_read_repair_chance": str( + table.dclocal_read_repair_chance + ), + "default_time_to_live": str(table.default_time_to_live), + "extensions": json.dumps(table.extensions), + "gc_grace_seconds": str(table.gc_grace_seconds), "max_index_interval": str(table.max_index_interval), "min_index_interval": str(table.min_index_interval), + "memtable_flush_period_in_ms": str( + table.memtable_flush_period_in_ms + ), + "read_repair_chance": str(table.read_repair_chance), + "speculative_retry": str(table.speculative_retry), }, ), ).as_workunit() @@ -330,7 +356,9 @@ def _extract_tables_from_keyspace( def _extract_columns_from_table( self, keyspace_name: str, table_name: str, dataset_urn: str ) -> Iterable[MetadataWorkUnit]: - column_infos = self.cassandra_api.get_columns(keyspace_name, table_name) + column_infos: List[CassandraColumn] = self.cassandra_api.get_columns( + keyspace_name, table_name + ) schema_fields: List[SchemaField] = list( CassandraToSchemaFieldConverter.get_schema_fields(column_infos) ) @@ -340,15 +368,11 @@ def _extract_columns_from_table( ) return - # remove any value that is type bytes, so it can be converted to json jsonable_column_infos: List[Dict[str, Any]] = [] for column in column_infos: + self.cassandra_data.columns.setdefault(table_name, []).append(column) column_dict = column._asdict() - jsonable_column_dict = column_dict.copy() - for key, value in column_dict.items(): - if isinstance(value, bytes): - jsonable_column_dict.pop(key) - jsonable_column_infos.append(jsonable_column_dict) + jsonable_column_infos.append(column_dict) schema_metadata: SchemaMetadata = SchemaMetadata( schemaName=table_name, @@ -370,9 +394,9 @@ def _extract_views_from_keyspace( self, keyspace_name: str ) -> Iterable[MetadataWorkUnit]: - views = self.cassandra_api.get_views(keyspace_name) + views: List[CassandraView] = self.cassandra_api.get_views(keyspace_name) for view in views: - view_name: str = getattr(view, COL_NAMES["view_name"]) + view_name: str = view.view_name dataset_name: str = f"{keyspace_name}.{view_name}" self.report.report_entity_scanned(dataset_name) dataset_urn: str = make_dataset_urn_with_platform_instance( @@ -412,14 +436,26 @@ def _extract_views_from_keyspace( qualifiedName=f"{keyspace_name}.{view_name}", description=view.comment, customProperties={ + "base_table_id": str(view.id), "bloom_filter_fp_chance": str(view.bloom_filter_fp_chance), - "caching": str(view.caching), - "cdc": str(view.cdc), - "compaction": str(view.compaction), - "compression": str(view.compression), + "caching": json.dumps(view.caching), + "compaction": json.dumps(view.compaction), + "compression": json.dumps(view.compression), + "crc_check_chance": str(view.crc_check_chance), + "include_all_columns": str(view.include_all_columns), + "dclocal_read_repair_chance": str( + view.dclocal_read_repair_chance + ), + "default_time_to_live": str(view.default_time_to_live), + "extensions": json.dumps(view.extensions), + "gc_grace_seconds": str(view.gc_grace_seconds), "max_index_interval": str(view.max_index_interval), "min_index_interval": str(view.min_index_interval), - "include_all_columns": str(view.include_all_columns), + "memtable_flush_period_in_ms": str( + view.memtable_flush_period_in_ms + ), + "read_repair_chance": str(view.read_repair_chance), + "speculative_retry": str(view.speculative_retry), }, ), ).as_workunit() @@ -439,12 +475,12 @@ def _extract_views_from_keyspace( # NOTE: we don't need to use 'base_table_id' since table is always in same keyspace, see https://docs.datastax.com/en/cql-oss/3.3/cql/cql_reference/cqlCreateMaterializedView.html#cqlCreateMaterializedView__keyspace-name upstream_urn: str = make_dataset_urn_with_platform_instance( platform=self.platform, - name=f"{keyspace_name}.{getattr(view, COL_NAMES['base_table_name'])}", + name=f"{keyspace_name}.{view.table_name}", env=self.config.env, platform_instance=self.config.platform_instance, ) fineGrainedLineages = self.get_upstream_fields_of_field_in_datasource( - keyspace_name, view_name, dataset_urn, upstream_urn + view_name, dataset_urn, upstream_urn ) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, @@ -485,12 +521,17 @@ def generate_profiles( env=self.config.env, platform_instance=self.config.platform_instance, ) - yield from self.profiler.get_workunits(dataset_urn, keyspace, table_name) + yield from self.profiler.get_workunits( + dataset_urn, + keyspace, + table_name, + self.cassandra_data.columns.get(table_name, []), + ) def get_upstream_fields_of_field_in_datasource( - self, keyspace_name: str, table_name: str, dataset_urn: str, upstream_urn: str + self, table_name: str, dataset_urn: str, upstream_urn: str ) -> List[FineGrainedLineageClass]: - column_infos = self.cassandra_api.get_columns(keyspace_name, table_name) + column_infos = self.cassandra_data.columns.get(table_name, []) # Collect column-level lineage fine_grained_lineages = [] for column_info in column_infos: diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py index 1e6be546131ed..08013b9fe4ecb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_api.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from ssl import CERT_NONE, PROTOCOL_TLSv1_2, SSLContext -from typing import List, Optional +from typing import Any, Dict, List, Optional from cassandra import DriverException, OperationTimedOut from cassandra.auth import PlainTextAuthProvider @@ -13,13 +14,86 @@ from datahub.ingestion.api.source import SourceReport from datahub.ingestion.source.cassandra.cassandra_config import CassandraSourceConfig -from datahub.ingestion.source.cassandra.cassandra_utils import ( - COL_NAMES, - CassandraQueries, -) -class CassandraAPIInterface: +@dataclass +class CassandraKeyspace: + keyspace_name: str + durable_writes: bool + replication: Dict + + +@dataclass +class CassandraTable: + keyspace_name: str + table_name: str + bloom_filter_fp_chance: Optional[float] + caching: Optional[Dict[str, str]] + comment: Optional[str] + compaction: Optional[Dict[str, Any]] + compression: Optional[Dict[str, Any]] + crc_check_chance: Optional[float] + dclocal_read_repair_chance: Optional[float] + default_time_to_live: Optional[int] + extensions: Optional[Dict[str, Any]] + gc_grace_seconds: Optional[int] + id: Optional[str] + max_index_interval: Optional[int] + memtable_flush_period_in_ms: Optional[int] + min_index_interval: Optional[int] + read_repair_chance: Optional[float] + speculative_retry: Optional[str] + + +@dataclass +class CassandraColumn: + keyspace_name: str + table_name: str + column_name: str + type: str + clustering_order: Optional[str] + kind: Optional[str] + position: Optional[int] + + def _asdict(self): + return { + "keyspace_name": self.keyspace_name, + "table_name": self.table_name, + "column_name": self.column_name, + "clustering_order": self.clustering_order, + "kind": self.kind, + "position": self.position, + "type": self.type, + } + + +@dataclass +class CassandraView(CassandraTable): + view_name: str + include_all_columns: Optional[bool] + where_clause: str = "" + + +# - Referencing system_schema: https://docs.datastax.com/en/cql-oss/3.x/cql/cql_using/useQuerySystem.html#Table3.ColumnsinSystem_SchemaTables-Cassandra3.0 - # +# this keyspace contains details about the cassandra cluster's keyspaces, tables, and columns + + +class CassandraQueries: + # get all keyspaces + GET_KEYSPACES_QUERY = "SELECT * FROM system_schema.keyspaces" + # get all tables for a keyspace + GET_TABLES_QUERY = "SELECT * FROM system_schema.tables WHERE keyspace_name = %s" + # get all columns for a table + GET_COLUMNS_QUERY = "SELECT * FROM system_schema.columns WHERE keyspace_name = %s AND table_name = %s" + # get all views for a keyspace + GET_VIEWS_QUERY = "SELECT * FROM system_schema.views WHERE keyspace_name = %s" + # Row Count + ROW_COUNT = 'SELECT COUNT(*) AS row_count FROM {}."{}"' + # Column Count + COLUMN_COUNT = "SELECT COUNT(*) AS column_count FROM system_schema.columns WHERE keyspace_name = '{}' AND table_name = '{}'" + + +class CassandraAPI: def __init__(self, config: CassandraSourceConfig, report: SourceReport): self.config = config self.report = report @@ -87,17 +161,21 @@ def authenticate(self) -> Session: ) raise - def get_keyspaces(self) -> List: + def get_keyspaces(self) -> List[CassandraKeyspace]: """Fetch all keyspaces.""" try: keyspaces = self.cassandra_session.execute( CassandraQueries.GET_KEYSPACES_QUERY ) - keyspaces = sorted( - keyspaces, - key=lambda k: getattr(k, COL_NAMES["keyspace_name"]), - ) - return keyspaces + keyspace_list = [ + CassandraKeyspace( + keyspace_name=row.keyspace_name, + durable_writes=row.durable_writes, + replication=dict(row.replication), + ) + for row in keyspaces + ] + return keyspace_list except DriverException as e: self.report.warning( message="Failed to fetch keyspaces", context=f"{str(e)}", exc=e @@ -107,17 +185,36 @@ def get_keyspaces(self) -> List: self.report.warning(message="Failed to fetch keyspaces", exc=e) return [] - def get_tables(self, keyspace_name: str) -> List: + def get_tables(self, keyspace_name: str) -> List[CassandraTable]: """Fetch all tables for a given keyspace.""" try: tables = self.cassandra_session.execute( CassandraQueries.GET_TABLES_QUERY, [keyspace_name] ) - tables = sorted( - tables, - key=lambda t: getattr(t, COL_NAMES["table_name"]), - ) - return tables + table_list = [ + CassandraTable( + keyspace_name=row.keyspace_name, + table_name=row.table_name, + bloom_filter_fp_chance=row.bloom_filter_fp_chance, + caching=dict(row.caching), + comment=row.comment, + compaction=dict(row.compaction), + compression=dict(row.compression), + crc_check_chance=row.crc_check_chance, + dclocal_read_repair_chance=row.dclocal_read_repair_chance, + default_time_to_live=row.default_time_to_live, + extensions=dict(row.extensions), + gc_grace_seconds=row.gc_grace_seconds, + id=str(row.id) if row.id else None, + max_index_interval=row.max_index_interval, + memtable_flush_period_in_ms=row.memtable_flush_period_in_ms, + min_index_interval=row.min_index_interval, + read_repair_chance=row.read_repair_chance, + speculative_retry=row.speculative_retry, + ) + for row in tables + ] + return table_list except DriverException as e: self.report.warning( message="Failed to fetch tables for keyspace", @@ -133,14 +230,25 @@ def get_tables(self, keyspace_name: str) -> List: ) return [] - def get_columns(self, keyspace_name: str, table_name: str) -> List: + def get_columns(self, keyspace_name: str, table_name: str) -> List[CassandraColumn]: """Fetch all columns for a given table.""" try: column_infos = self.cassandra_session.execute( CassandraQueries.GET_COLUMNS_QUERY, [keyspace_name, table_name] ) - column_infos = sorted(column_infos, key=lambda c: c.column_name) - return column_infos + column_list = [ + CassandraColumn( + keyspace_name=row.keyspace_name, + table_name=row.table_name, + column_name=row.column_name, + clustering_order=row.clustering_order, + kind=row.kind, + position=row.position, + type=row.type, + ) + for row in column_infos + ] + return column_list except DriverException as e: self.report.warning( message="Failed to fetch columns for table", context=f"{str(e)}", exc=e @@ -154,17 +262,39 @@ def get_columns(self, keyspace_name: str, table_name: str) -> List: ) return [] - def get_views(self, keyspace_name: str) -> List: + def get_views(self, keyspace_name: str) -> List[CassandraView]: """Fetch all views for a given keyspace.""" try: views = self.cassandra_session.execute( CassandraQueries.GET_VIEWS_QUERY, [keyspace_name] ) - views = sorted( - views, - key=lambda v: getattr(v, COL_NAMES["view_name"]), - ) - return views + view_list = [ + CassandraView( + id=row.base_table_id, + table_name=row.base_table_name, + keyspace_name=row.keyspace_name, + view_name=row.view_name, + bloom_filter_fp_chance=row.bloom_filter_fp_chance, + caching=dict(row.caching), + comment=row.comment, + compaction=dict(row.compaction), + compression=dict(row.compression), + crc_check_chance=row.crc_check_chance, + dclocal_read_repair_chance=row.dclocal_read_repair_chance, + default_time_to_live=row.default_time_to_live, + extensions=dict(row.extensions), + gc_grace_seconds=row.gc_grace_seconds, + include_all_columns=row.include_all_columns, + max_index_interval=row.max_index_interval, + memtable_flush_period_in_ms=row.memtable_flush_period_in_ms, + min_index_interval=row.min_index_interval, + read_repair_chance=row.read_repair_chance, + speculative_retry=row.speculative_retry, + where_clause=row.where_clause, + ) + for row in views + ] + return view_list except DriverException as e: self.report.warning( message="Failed to fetch views for keyspace", context=f"{str(e)}", exc=e diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py index 6749c756bd7fa..70a313b7ccb2c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_profiling.py @@ -1,15 +1,19 @@ import logging import time -from typing import Any, Dict, Iterable, List, Tuple +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional import numpy as np from cassandra.util import OrderedMapSerializedKey, SortedSet from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.cassandra.cassandra_api import CassandraAPIInterface +from datahub.ingestion.source.cassandra.cassandra_api import ( + CassandraAPI, + CassandraColumn, + CassandraQueries, +) from datahub.ingestion.source.cassandra.cassandra_config import CassandraSourceConfig -from datahub.ingestion.source.cassandra.cassandra_utils import CassandraQueries from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport from datahub.metadata.schema_classes import ( DatasetFieldProfileClass, @@ -20,6 +24,29 @@ logger = logging.getLogger(__name__) +@dataclass +class ColumnMetric: + col_type: str = "" + values: List[Any] = field(default_factory=list) + null_count: int = 0 + total_count: int = 0 + distinct_count: Optional[int] = None + min: Optional[Any] = None + max: Optional[Any] = None + mean: Optional[float] = None + stdev: Optional[float] = None + median: Optional[float] = None + quantiles: Optional[List[float]] = None + sample_values: Optional[Any] = None + + +@dataclass +class ProfileData: + row_count: Optional[int] = None + column_count: Optional[int] = None + column_metrics: Dict[str, ColumnMetric] = field(default_factory=dict) + + class CassandraProfiler: config: CassandraSourceConfig report: ProfilingSqlReport @@ -28,17 +55,19 @@ def __init__( self, config: CassandraSourceConfig, report: ProfilingSqlReport, - api: CassandraAPIInterface, + api: CassandraAPI, ) -> None: self.api = api self.config = config self.report = report def get_workunits( - self, dataset_urn: str, keyspace_name: str, table_name: str + self, + dataset_urn: str, + keyspace_name: str, + table_name: str, + columns: List[CassandraColumn], ) -> Iterable[MetadataWorkUnit]: - columns = self.api.get_columns(keyspace_name, table_name) - if not columns: self.report.warning( message="Skipping profiling as no columns found for table", @@ -47,8 +76,6 @@ def get_workunits( self.report.profiling_skipped_other[table_name] += 1 return - columns = [(col.column_name, col.type) for col in columns] - if not self.config.profile_pattern.allowed(f"{keyspace_name}.{table_name}"): self.report.profiling_skipped_table_profile_pattern[table_name] += 1 self.report.warning( @@ -67,130 +94,102 @@ def get_workunits( ) yield mcp.as_workunit() - def populate_profile_aspect(self, profile_data: Dict) -> DatasetProfileClass: + def populate_profile_aspect(self, profile_data: ProfileData) -> DatasetProfileClass: field_profiles = [ - self._create_field_profile(field_name, field_stats) - for field_name, field_stats in profile_data.get("column_stats", {}).items() + self._create_field_profile(column_name, column_metrics) + for column_name, column_metrics in profile_data.column_metrics.items() ] return DatasetProfileClass( timestampMillis=round(time.time() * 1000), - rowCount=profile_data.get("row_count"), - columnCount=profile_data.get("column_count"), + rowCount=profile_data.row_count, + columnCount=profile_data.column_count, fieldProfiles=field_profiles, ) def _create_field_profile( - self, field_name: str, field_stats: Dict + self, field_name: str, field_stats: ColumnMetric ) -> DatasetFieldProfileClass: - quantiles = field_stats.get("quantiles") + quantiles = field_stats.quantiles return DatasetFieldProfileClass( fieldPath=field_name, - uniqueCount=field_stats.get("distinct_count"), - nullCount=field_stats.get("null_count"), - min=str(field_stats.get("min")) if field_stats.get("min") else None, - max=str(field_stats.get("max")) if field_stats.get("max") else None, - mean=str(field_stats.get("mean")) if field_stats.get("mean") else None, - median=str(field_stats.get("median")) - if field_stats.get("median") - else None, - stdev=str(field_stats.get("stdev")) if field_stats.get("stdev") else None, + uniqueCount=field_stats.distinct_count, + nullCount=field_stats.null_count, + min=str(field_stats.min) if field_stats.min else None, + max=str(field_stats.max) if field_stats.max else None, + mean=str(field_stats.mean) if field_stats.mean else None, + median=str(field_stats.median) if field_stats.median else None, + stdev=str(field_stats.stdev) if field_stats.stdev else None, quantiles=[ QuantileClass(quantile=str(0.25), value=str(quantiles[0])), QuantileClass(quantile=str(0.75), value=str(quantiles[1])), ] if quantiles else None, - sampleValues=field_stats.get("sample_values"), + sampleValues=field_stats.sample_values + if field_stats.sample_values + else None, ) def profile_table( - self, keyspace_name: str, table_name: str, columns: List[Tuple[str, str]] - ) -> Dict: - - results: Dict[str, Any] = {} - - limit = None - if self.config.profiling.limit: - limit = self.config.profiling.limit + self, keyspace_name: str, table_name: str, columns: List[CassandraColumn] + ) -> ProfileData: + profile_data = ProfileData() if self.config.profiling.row_count: resp = self.api.execute( CassandraQueries.ROW_COUNT.format(keyspace_name, table_name) ) if resp: - results["row_count"] = resp[0].row_count + profile_data.row_count = resp[0].row_count if self.config.profiling.column_count: resp = self.api.execute( - CassandraQueries.COLUMN_COUNT.format(keyspace_name, table_name), limit + CassandraQueries.COLUMN_COUNT.format(keyspace_name, table_name) ) if resp: - results["column_count"] = resp[0].column_count + profile_data.column_count = resp[0].column_count if not self.config.profiling.profile_table_level_only: resp = self.api.execute( - f'SELECT {", ".join([col[0] for col in columns])} FROM {keyspace_name}."{table_name}"', - limit, + f'SELECT {", ".join([col.column_name for col in columns])} FROM {keyspace_name}."{table_name}"' ) - results["column_metrics"] = resp - - return self._parse_profile_results(results, columns) + profile_data.column_metrics = self._collect_column_data(resp, columns) - def _parse_profile_results( - self, results: Dict[str, Any], columns: List[Tuple[str, str]] - ) -> Dict: - profile: Dict[str, Any] = {"column_stats": {}} + return self._parse_profile_results(profile_data) - # Step 1: Parse overall profile metrics - self._parse_overall_metrics(results, profile) + def _parse_profile_results(self, profile_data: ProfileData) -> ProfileData: + for _, column_metrics in profile_data.column_metrics.items(): + if column_metrics.values: + self._compute_field_statistics(column_metrics) - # Step 2: Process and parse each column - if results.get("column_metrics"): - metrics = self._initialize_metrics(columns) - self._collect_column_data(results, metrics, columns) - self._calculate_statistics(metrics, columns, profile) + return profile_data - return profile - - def _parse_overall_metrics( - self, results: Dict[str, Any], profile: Dict[str, Any] - ) -> None: - if self.config.profiling.row_count: - profile["row_count"] = int(results.get("row_count", 0)) + def _collect_column_data( + self, rows: List[Any], columns: List[CassandraColumn] + ) -> Dict[str, ColumnMetric]: + metrics = {column.column_name: ColumnMetric() for column in columns} - if self.config.profiling.column_count: - profile["column_count"] = int(results.get("column_count", 0)) + for row in rows: + for column in columns: + if self._is_skippable_type(column.type): + continue - def _initialize_metrics( - self, columns: List[Tuple[str, str]] - ) -> Dict[str, Dict[str, Any]]: - return { - column: {"values": [], "null_count": 0, "total_count": 0} - for column, _ in columns - } + value: Any = getattr(row, column.column_name, None) + metric = metrics[column.column_name] + metric.col_type = column.type - def _collect_column_data( - self, - results: Dict[str, Any], - metrics: Dict[str, Dict[str, Any]], - columns: List[Tuple[str, str]], - ) -> None: - for row in results.get("column_metrics", []): - for cl_name, col_type in columns: - if self._is_skippable_type(col_type): - continue - value: Any = getattr(row, cl_name, None) - metrics[cl_name]["total_count"] += 1 - if not value: - metrics[cl_name]["null_count"] += 1 + metric.total_count += 1 + if value is None: + metric.null_count += 1 else: - metrics[cl_name]["values"].extend(self._parse_value(value)) + metric.values.extend(self._parse_value(value)) + + return metrics def _is_skippable_type(self, data_type: str) -> bool: return data_type.lower() in ["timeuuid", "blob", "frozen>"] def _parse_value(self, value: Any) -> List[Any]: - # NOTE for astra db column need to check type if isinstance(value, SortedSet): return list(value) elif isinstance(value, OrderedMapSerializedKey): @@ -199,56 +198,40 @@ def _parse_value(self, value: Any) -> List[Any]: return value return [value] - def _calculate_statistics( - self, - metrics: Dict[str, Dict[str, Any]], - columns: List[Tuple[str, str]], - profile: Dict[str, Any], - ) -> None: - for column_name, data_type in columns: - if column_name not in metrics: - continue - - data = metrics[column_name] - if not data: - continue - - values: List[Any] = data.get("values", []) - column_stats: Dict[str, Any] = {} - - if self.config.profiling.include_field_null_count: - column_stats["null_count"] = data.get("null_count", 0) - - if values: - self._compute_field_statistics(values, data_type, column_stats) + def _compute_field_statistics(self, column_metrics: ColumnMetric) -> None: + values = column_metrics.values + if not values: + return - profile["column_stats"][column_name] = column_stats + # ByDefault Null count is added + if not self.config.profiling.include_field_null_count: + column_metrics.null_count = 0 - def _compute_field_statistics( - self, values: List[Any], data_type: str, column_stats: Dict[str, Any] - ) -> None: if self.config.profiling.include_field_distinct_count: - column_stats["distinct_count"] = len(set(values)) + column_metrics.distinct_count = len(set(values)) if self.config.profiling.include_field_min_value: - column_stats["min"] = min(values) + column_metrics.min = min(values) if self.config.profiling.include_field_max_value: - column_stats["max"] = max(values) + column_metrics.max = max(values) - if self._is_numeric_type(data_type): + if values and self._is_numeric_type(column_metrics.col_type): if self.config.profiling.include_field_mean_value: - column_stats["mean"] = str(np.mean(values)) + column_metrics.mean = round(float(np.mean(values)), 2) if self.config.profiling.include_field_stddev_value: - column_stats["stdev"] = str(np.std(values)) + column_metrics.stdev = round(float(np.std(values)), 2) if self.config.profiling.include_field_median_value: - column_stats["median"] = str(np.median(values)) + column_metrics.median = round(float(np.median(values)), 2) if self.config.profiling.include_field_quantiles: - column_stats["quantiles"] = [ - str(np.percentile(values, 25)), - str(np.percentile(values, 75)), + column_metrics.quantiles = [ + float(np.percentile(values, 25)), + float(np.percentile(values, 75)), ] + if values and self.config.profiling.include_field_sample_values: + column_metrics.sample_values = [str(v) for v in values[:5]] + def _is_numeric_type(self, data_type: str) -> bool: return data_type.lower() in [ "int", diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py index c2331e7627881..4ea8a3fc1112a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra_utils.py @@ -1,6 +1,7 @@ import logging from typing import Dict, Generator, List, Optional, Type +from datahub.ingestion.source.cassandra.cassandra_api import CassandraColumn from datahub.metadata.com.linkedin.pegasus2avro.schema import ( SchemaField, SchemaFieldDataType, @@ -26,37 +27,6 @@ ) -# these column names are present on the system_schema tables -COL_NAMES = { - "keyspace_name": "keyspace_name", # present on all tables - "table_name": "table_name", # present on tables table - "column_name": "column_name", # present on columns table - "column_type": "type", # present on columns table - "view_name": "view_name", # present on views table - "base_table_name": "base_table_name", # present on views table - "where_clause": "where_clause", # present on views table -} - - -# - Referencing system_schema: https://docs.datastax.com/en/cql-oss/3.x/cql/cql_using/useQuerySystem.html#Table3.ColumnsinSystem_SchemaTables-Cassandra3.0 - # -# this keyspace contains details about the cassandra cluster's keyspaces, tables, and columns - - -class CassandraQueries: - # get all keyspaces - GET_KEYSPACES_QUERY = "SELECT * FROM system_schema.keyspaces" - # get all tables for a keyspace - GET_TABLES_QUERY = "SELECT * FROM system_schema.tables WHERE keyspace_name = %s" - # get all columns for a table - GET_COLUMNS_QUERY = "SELECT * FROM system_schema.columns WHERE keyspace_name = %s AND table_name = %s" - # get all views for a keyspace - GET_VIEWS_QUERY = "SELECT * FROM system_schema.views WHERE keyspace_name = %s" - # Row Count - ROW_COUNT = 'SELECT COUNT(*) AS row_count FROM {}."{}"' - # Column Count - COLUMN_COUNT = "SELECT COUNT(*) AS column_count FROM system_schema.columns WHERE keyspace_name = '{}' AND table_name = '{}'" - - # This class helps convert cassandra column types to SchemaFieldDataType for use by the datahaub metadata schema class CassandraToSchemaFieldConverter: # Mapping from cassandra field types to SchemaFieldDataType. @@ -111,18 +81,12 @@ def get_column_type(cassandra_column_type: str) -> SchemaFieldDataType: return SchemaFieldDataType(type=type_class()) def _get_schema_fields( - self, cassandra_column_infos: List + self, cassandra_column_infos: List[CassandraColumn] ) -> Generator[SchemaField, None, None]: # append each schema field (sort so output is consistent) for column_info in cassandra_column_infos: - # convert namedtuple to dictionary if it isn't already one - column_info = ( - column_info._asdict() - if hasattr(column_info, "_asdict") - else column_info - ) - column_name: str = column_info[COL_NAMES["column_name"]] - cassandra_type: str = column_info[COL_NAMES["column_type"]] + column_name: str = column_info.column_name + cassandra_type: str = column_info.type schema_field_data_type: SchemaFieldDataType = self.get_column_type( cassandra_type @@ -139,7 +103,7 @@ def _get_schema_fields( @classmethod def get_schema_fields( - cls, cassandra_column_infos: List + cls, cassandra_column_infos: List[CassandraColumn] ) -> Generator[SchemaField, None, None]: converter = cls() yield from converter._get_schema_fields(cassandra_column_infos)