Skip to content

Commit

Permalink
Add ids in sync-schemas endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Apr 19, 2024
1 parent f392cbd commit 27a5bfc
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 55 deletions.
61 changes: 26 additions & 35 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@
MAX_ROWS_TO_CREATE_CSV_FILE = 50


def async_scanning(scanner, database, scanner_request, storage):
def async_scanning(scanner, database, table_descriptions, storage):
scanner.scan(
database,
scanner_request.db_connection_id,
scanner_request.table_names,
table_descriptions,
TableDescriptionRepository(storage),
QueryHistoryRepository(storage),
)
Expand Down Expand Up @@ -133,43 +132,35 @@ def scan_db(
self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks
) -> list[TableDescriptionResponse]:
"""Takes a db_connection_id and scan all the tables columns"""
try:
db_connection_repository = DatabaseConnectionRepository(self.storage)
scanner_repository = TableDescriptionRepository(self.storage)
data = {}
for id in scanner_request.ids:
table_description = scanner_repository.find_by_id(id)
if not table_description:
raise Exception("Table description not found")
if table_description.schema_name not in data.keys():
data[table_description.schema_name] = []
data[table_description.schema_name].append(table_description)

db_connection_repository = DatabaseConnectionRepository(self.storage)
scanner = self.system.instance(Scanner)
database_connection_service = DatabaseConnectionService(scanner, self.storage)
for schema, table_descriptions in data.items():
db_connection = db_connection_repository.find_by_id(
scanner_request.db_connection_id
table_descriptions[0].db_connection_id
)

if not db_connection:
raise DatabaseConnectionNotFoundError(
f"Database connection {scanner_request.db_connection_id} not found"
)

database = SQLDatabase.get_sql_engine(db_connection, True)
all_tables = database.get_tables_and_views()

if scanner_request.table_names:
for table in scanner_request.table_names:
if table not in all_tables:
raise HTTPException(
status_code=404,
detail=f"Table named: {table} doesn't exist",
) # noqa: B904
else:
scanner_request.table_names = all_tables

scanner = self.system.instance(Scanner)
rows = scanner.synchronizing(
scanner_request,
TableDescriptionRepository(self.storage),
database = database_connection_service.get_sql_database(
db_connection, schema
)
except Exception as e:
return error_response(e, scanner_request.dict(), "invalid_database_sync")

background_tasks.add_task(
async_scanning, scanner, database, scanner_request, self.storage
)
return [TableDescriptionResponse(**row.dict()) for row in rows]
background_tasks.add_task(
async_scanning, scanner, database, table_descriptions, self.storage
)
return [
TableDescriptionResponse(**row.dict())
for _, table_descriptions in data.items()
for row in table_descriptions
]

@override
def create_database_connection(
Expand Down
3 changes: 1 addition & 2 deletions dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class Scanner(Component, ABC):
def scan(
self,
db_engine: SQLDatabase,
db_connection_id: str,
table_names: list[str] | None,
table_descriptions: list[TableDescription],
repository: TableDescriptionRepository,
query_history_repository: QueryHistoryRepository,
) -> None:
Expand Down
23 changes: 7 additions & 16 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ def scan_single_table(
def scan(
self,
db_engine: SQLDatabase,
db_connection_id: str,
table_names: list[str] | None,
table_descriptions: list[TableDescription],
repository: TableDescriptionRepository,
query_history_repository: QueryHistoryRepository,
) -> None:
Expand All @@ -295,32 +294,24 @@ def scan(
if db_engine.engine.dialect.name in services.keys():
scanner_service = services[db_engine.engine.dialect.name]()

inspector = inspect(db_engine.engine)
inspect(db_engine.engine)
meta = MetaData(bind=db_engine.engine)
MetaData.reflect(meta, views=True)
tables = inspector.get_table_names() + inspector.get_view_names()
if table_names:
table_names = [table.lower() for table in table_names]
tables = [
table for table in tables if table and table.lower() in table_names
]
if len(tables) == 0:
raise ValueError("No table found")

for table in tables:
for table in table_descriptions:
try:
self.scan_single_table(
meta=meta,
table=table,
table=table.table_name,
db_engine=db_engine,
db_connection_id=db_connection_id,
db_connection_id=table.db_connection_id,
repository=repository,
scanner_service=scanner_service,
)
except Exception as e:
repository.save_table_info(
TableDescription(
db_connection_id=db_connection_id,
db_connection_id=table.db_connection_id,
table_name=table,
status=TableDescriptionStatus.FAILED.value,
error_message=f"{e}",
Expand All @@ -329,7 +320,7 @@ def scan(
try:
logger.info(f"Get logs table: {table}")
query_history = scanner_service.get_logs(
table, db_engine, db_connection_id
table.table_name, db_engine, table.db_connection_id
)
if len(query_history) > 0:
for query in query_history:
Expand Down
14 changes: 14 additions & 0 deletions dataherald/sql_database/services/database_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ def __init__(self, scanner: Scanner, storage: DB):
self.scanner = scanner
self.storage = storage

def get_sql_database(
self, database_connection: DatabaseConnection, schema: str = None
) -> SQLDatabase:
fernet_encrypt = FernetEncrypt()
if schema:
database_connection.connection_uri = fernet_encrypt.encrypt(
self.add_schema_in_uri(
fernet_encrypt.decrypt(database_connection.connection_uri),
schema,
database_connection.dialect.value,
)
)
return SQLDatabase.get_sql_engine(database_connection, True)

def get_current_schema(
self, database_connection: DatabaseConnection
) -> list[str] | None:
Expand Down
13 changes: 11 additions & 2 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,19 @@ class SupportedDatabase(Enum):
BIGQUERY = "BIGQUERY"


class ScannerRequest(DBConnectionValidation):
table_names: list[str] | None
class ScannerRequest(BaseModel):
ids: list[str] | None
metadata: dict | None

@validator("ids")
def ids_validation(cls, ids: list = None):
try:
for id in ids:
ObjectId(id)
except InvalidId:
raise ValueError("Must be a valid ObjectId") # noqa: B904
return ids


class DatabaseConnectionRequest(BaseModel):
alias: str
Expand Down

0 comments on commit 27a5bfc

Please sign in to comment.