Skip to content

Commit

Permalink
fix: PR comments regarding Dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware committed Nov 14, 2024
1 parent cafe877 commit bb251aa
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 229 deletions.
121 changes: 81 additions & 40 deletions metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.cassandra.cassandra_api import CassandraAPIInterface
from datahub.ingestion.source.cassandra.cassandra_api import (
CassandraAPI,
CassandraColumn,
CassandraKeyspace,
CassandraTable,
CassandraView,
)
from datahub.ingestion.source.cassandra.cassandra_config import CassandraSourceConfig
from datahub.ingestion.source.cassandra.cassandra_profiling import CassandraProfiler
from datahub.ingestion.source.cassandra.cassandra_utils import (
COL_NAMES,
SYSTEM_KEYSPACE_LIST,
CassandraToSchemaFieldConverter,
)
Expand Down Expand Up @@ -107,6 +112,9 @@ class CassandraEntities:
tables: Dict[str, List[str]] = field(
default_factory=dict
) # Maps keyspace -> tables
columns: Dict[str, List[CassandraColumn]] = field(
default_factory=dict
) # Maps tables -> columns


@platform_name("Cassandra")
Expand Down Expand Up @@ -141,7 +149,7 @@ def __init__(self, ctx: PipelineContext, config: CassandraSourceConfig):
self.platform = PLATFORM_NAME_IN_DATAHUB
self.config = config
self.report = CassandraSourceReport()
self.cassandra_api = CassandraAPIInterface(config, self.report)
self.cassandra_api = CassandraAPI(config, self.report)
self.cassandra_data = CassandraEntities()
# For profiling
self.profiler = CassandraProfiler(config, self.report, self.cassandra_api)
Expand All @@ -165,17 +173,17 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
def get_workunits_internal(
self,
) -> Iterable[MetadataWorkUnit]:
keyspaces = self.cassandra_api.get_keyspaces()
keyspaces: List[CassandraKeyspace] = self.cassandra_api.get_keyspaces()
for keyspace in keyspaces:
keyspace_name: str = getattr(keyspace, COL_NAMES["keyspace_name"])
keyspace_name: str = keyspace.keyspace_name
if keyspace_name in SYSTEM_KEYSPACE_LIST:
continue

if not self.config.keyspace_pattern.allowed(keyspace_name):
self.report.report_dropped(keyspace_name)
continue

yield from self._generate_keyspace_container(keyspace_name)
yield from self._generate_keyspace_container(keyspace)

try:
yield from self._extract_tables_from_keyspace(keyspace_name)
Expand All @@ -186,7 +194,6 @@ def get_workunits_internal(
context=keyspace_name,
exc=e,
)

try:
yield from self._extract_views_from_keyspace(keyspace_name)
except Exception as e:
Expand All @@ -199,15 +206,15 @@ def get_workunits_internal(

# Profiling
if self.config.is_profiling_enabled():
for keyspace in self.cassandra_data.keyspaces:
tables = self.cassandra_data.tables.get(keyspace, [])
self.report.set_ingestion_stage(keyspace, PROFILING)
for keyspace_name in self.cassandra_data.keyspaces:
tables = self.cassandra_data.tables.get(keyspace_name, [])
self.report.set_ingestion_stage(keyspace_name, PROFILING)
with ThreadPoolExecutor(
max_workers=self.config.profiling.max_workers
) as executor:
future_to_dataset = {
executor.submit(
self.generate_profiles, keyspace, table_name
self.generate_profiles, keyspace_name, table_name
): table_name
for table_name in tables
}
Expand All @@ -219,17 +226,24 @@ def get_workunits_internal(
self.report.profiling_skipped_other[table_name] += 1
self.report.report_failure(
message="Failed to profile for table",
context=f"{keyspace}.{table_name}",
context=f"{keyspace_name}.{table_name}",
exc=exc,
)

def _generate_keyspace_container(
self, keyspace_name: str
self, keyspace: CassandraKeyspace
) -> Iterable[MetadataWorkUnit]:
keyspace_container_key = self._generate_keyspace_container_key(keyspace_name)
keyspace_container_key = self._generate_keyspace_container_key(
keyspace.keyspace_name
)
yield from gen_containers(
container_key=keyspace_container_key,
name=keyspace_name,
name=keyspace.keyspace_name,
qualified_name=keyspace.keyspace_name,
extra_properties={
"durable_writes": str(keyspace.durable_writes),
"replication": json.dumps(keyspace.replication),
},
sub_types=[DatasetContainerSubTypes.KEYSPACE],
)

Expand All @@ -246,10 +260,10 @@ def _extract_tables_from_keyspace(
self, keyspace_name: str
) -> Iterable[MetadataWorkUnit]:
self.cassandra_data.keyspaces.append(keyspace_name)
tables = self.cassandra_api.get_tables(keyspace_name)
tables: List[CassandraTable] = self.cassandra_api.get_tables(keyspace_name)
for table in tables:
# define the dataset urn for this table to be used downstream
table_name: str = getattr(table, COL_NAMES["table_name"])
table_name: str = table.table_name
dataset_name: str = f"{keyspace_name}.{table_name}"

if not self.config.table_pattern.allowed(dataset_name):
Expand Down Expand Up @@ -299,13 +313,25 @@ def _extract_tables_from_keyspace(
qualifiedName=f"{keyspace_name}.{table_name}",
description=table.comment,
customProperties={
"id": str(table.id),
"bloom_filter_fp_chance": str(table.bloom_filter_fp_chance),
"caching": str(table.caching),
"cdc": str(table.cdc),
"compaction": str(table.compaction),
"compression": str(table.compression),
"caching": json.dumps(table.caching),
"compaction": json.dumps(table.compaction),
"compression": json.dumps(table.compression),
"crc_check_chance": str(table.crc_check_chance),
"dclocal_read_repair_chance": str(
table.dclocal_read_repair_chance
),
"default_time_to_live": str(table.default_time_to_live),
"extensions": json.dumps(table.extensions),
"gc_grace_seconds": str(table.gc_grace_seconds),
"max_index_interval": str(table.max_index_interval),
"min_index_interval": str(table.min_index_interval),
"memtable_flush_period_in_ms": str(
table.memtable_flush_period_in_ms
),
"read_repair_chance": str(table.read_repair_chance),
"speculative_retry": str(table.speculative_retry),
},
),
).as_workunit()
Expand All @@ -330,7 +356,9 @@ def _extract_tables_from_keyspace(
def _extract_columns_from_table(
self, keyspace_name: str, table_name: str, dataset_urn: str
) -> Iterable[MetadataWorkUnit]:
column_infos = self.cassandra_api.get_columns(keyspace_name, table_name)
column_infos: List[CassandraColumn] = self.cassandra_api.get_columns(
keyspace_name, table_name
)
schema_fields: List[SchemaField] = list(
CassandraToSchemaFieldConverter.get_schema_fields(column_infos)
)
Expand All @@ -340,15 +368,11 @@ def _extract_columns_from_table(
)
return

# remove any value that is type bytes, so it can be converted to json
jsonable_column_infos: List[Dict[str, Any]] = []
for column in column_infos:
self.cassandra_data.columns.setdefault(table_name, []).append(column)
column_dict = column._asdict()
jsonable_column_dict = column_dict.copy()
for key, value in column_dict.items():
if isinstance(value, bytes):
jsonable_column_dict.pop(key)
jsonable_column_infos.append(jsonable_column_dict)
jsonable_column_infos.append(column_dict)

schema_metadata: SchemaMetadata = SchemaMetadata(
schemaName=table_name,
Expand All @@ -370,9 +394,9 @@ def _extract_views_from_keyspace(
self, keyspace_name: str
) -> Iterable[MetadataWorkUnit]:

views = self.cassandra_api.get_views(keyspace_name)
views: List[CassandraView] = self.cassandra_api.get_views(keyspace_name)
for view in views:
view_name: str = getattr(view, COL_NAMES["view_name"])
view_name: str = view.view_name
dataset_name: str = f"{keyspace_name}.{view_name}"
self.report.report_entity_scanned(dataset_name)
dataset_urn: str = make_dataset_urn_with_platform_instance(
Expand Down Expand Up @@ -412,14 +436,26 @@ def _extract_views_from_keyspace(
qualifiedName=f"{keyspace_name}.{view_name}",
description=view.comment,
customProperties={
"base_table_id": str(view.id),
"bloom_filter_fp_chance": str(view.bloom_filter_fp_chance),
"caching": str(view.caching),
"cdc": str(view.cdc),
"compaction": str(view.compaction),
"compression": str(view.compression),
"caching": json.dumps(view.caching),
"compaction": json.dumps(view.compaction),
"compression": json.dumps(view.compression),
"crc_check_chance": str(view.crc_check_chance),
"include_all_columns": str(view.include_all_columns),
"dclocal_read_repair_chance": str(
view.dclocal_read_repair_chance
),
"default_time_to_live": str(view.default_time_to_live),
"extensions": json.dumps(view.extensions),
"gc_grace_seconds": str(view.gc_grace_seconds),
"max_index_interval": str(view.max_index_interval),
"min_index_interval": str(view.min_index_interval),
"include_all_columns": str(view.include_all_columns),
"memtable_flush_period_in_ms": str(
view.memtable_flush_period_in_ms
),
"read_repair_chance": str(view.read_repair_chance),
"speculative_retry": str(view.speculative_retry),
},
),
).as_workunit()
Expand All @@ -439,12 +475,12 @@ def _extract_views_from_keyspace(
# NOTE: we don't need to use 'base_table_id' since table is always in same keyspace, see https://docs.datastax.com/en/cql-oss/3.3/cql/cql_reference/cqlCreateMaterializedView.html#cqlCreateMaterializedView__keyspace-name
upstream_urn: str = make_dataset_urn_with_platform_instance(
platform=self.platform,
name=f"{keyspace_name}.{getattr(view, COL_NAMES['base_table_name'])}",
name=f"{keyspace_name}.{view.table_name}",
env=self.config.env,
platform_instance=self.config.platform_instance,
)
fineGrainedLineages = self.get_upstream_fields_of_field_in_datasource(
keyspace_name, view_name, dataset_urn, upstream_urn
view_name, dataset_urn, upstream_urn
)
yield MetadataChangeProposalWrapper(
entityUrn=dataset_urn,
Expand Down Expand Up @@ -485,12 +521,17 @@ def generate_profiles(
env=self.config.env,
platform_instance=self.config.platform_instance,
)
yield from self.profiler.get_workunits(dataset_urn, keyspace, table_name)
yield from self.profiler.get_workunits(
dataset_urn,
keyspace,
table_name,
self.cassandra_data.columns.get(table_name, []),
)

def get_upstream_fields_of_field_in_datasource(
self, keyspace_name: str, table_name: str, dataset_urn: str, upstream_urn: str
self, table_name: str, dataset_urn: str, upstream_urn: str
) -> List[FineGrainedLineageClass]:
column_infos = self.cassandra_api.get_columns(keyspace_name, table_name)
column_infos = self.cassandra_data.columns.get(table_name, [])
# Collect column-level lineage
fine_grained_lineages = []
for column_info in column_infos:
Expand Down
Loading

0 comments on commit bb251aa

Please sign in to comment.