From 809ac310957f838af6df866e294b18702069b3a3 Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:23:06 -0400 Subject: [PATCH] DH5669/store db dialect in database connection collection (#447) * DH5669/store db dialect in database connection collection * modify the code for the tests * test the response * DH5669/removing the validation on object for dialect * DH5669/adding the script to update dialect * DH5669/reformat with black * DH5669/rename the function to make more sense --- .../scripts/populate_dialect_db_connection.py | 25 ++++++++++++++ dataherald/sql_database/models/types.py | 33 ++++++++++++++++--- docs/api.create_database_connection.rst | 1 + docs/api.list_database_connections.rst | 2 ++ docs/api.update_database_connection.rst | 1 + 5 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 dataherald/scripts/populate_dialect_db_connection.py diff --git a/dataherald/scripts/populate_dialect_db_connection.py b/dataherald/scripts/populate_dialect_db_connection.py new file mode 100644 index 00000000..64fea056 --- /dev/null +++ b/dataherald/scripts/populate_dialect_db_connection.py @@ -0,0 +1,25 @@ +import dataherald.config +from dataherald.config import System +from dataherald.db import DB +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.utils.encrypt import FernetEncrypt + +if __name__ == "__main__": + settings = dataherald.config.Settings() + system = System(settings) + system.start() + storage = system.instance(DB) + fernet_encrypt = FernetEncrypt() + database_connections = storage.find_all("database_connections") + for database_connection in database_connections: + if not database_connection.get("dialect"): + decrypted_uri = fernet_encrypt.decrypt( + database_connection["connection_uri"] + ) + dialect_prefix = DatabaseConnection.get_dialect(decrypted_uri) + dialect = DatabaseConnection.set_dialect(dialect_prefix) + storage.update_or_create( + "database_connections", + {"_id": database_connection["_id"]}, + {"dialect": dialect}, + ) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index c4ed5b5c..941a4390 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -1,6 +1,7 @@ import os import re from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel, BaseSettings, Extra, Field, validator @@ -75,9 +76,23 @@ class InvalidURIFormatError(Exception): pass +class SupportedDialects(Enum): + POSTGRES = "postgresql" + MYSQL = "mysql" + MSSQL = "mssql" + DATABRICKS = "databricks" + SNOWFLAKE = "snowflake" + CLICKHOUSE = "clickhouse" + AWSATHENA = "awsathena" + DUCKDB = "duckdb" + BIGQUERY = "bigquery" + SQLITE = "sqlite" + + class DatabaseConnection(BaseModel): id: str | None alias: str + dialect: SupportedDialects | None use_ssh: bool = False connection_uri: str | None path_to_credentials_file: str | None @@ -88,21 +103,29 @@ class DatabaseConnection(BaseModel): created_at: datetime = Field(default_factory=datetime.now) @classmethod - def validate_uri(cls, input_string): + def get_dialect(cls, input_string): pattern = r"([^:/]+):/+([^/]+)/([^/]+)" match = re.match(pattern, input_string) if not match: raise InvalidURIFormatError(f"Invalid URI format: {input_string}") + return match.group(1) + + @classmethod + def set_dialect(cls, input_string): + for dialect in SupportedDialects: + if dialect.value in input_string: + return dialect.value + return None @validator("connection_uri", pre=True, always=True) - def connection_uri_format(cls, value: str): + def connection_uri_format(cls, value: str, values): fernet_encrypt = FernetEncrypt() try: fernet_encrypt.decrypt(value) - return value except Exception: - cls.validate_uri(value) - return fernet_encrypt.encrypt(value) + dialect_prefix = cls.get_dialect(value) + values["dialect"] = cls.set_dialect(dialect_prefix) + value = fernet_encrypt.encrypt(value) return value @validator("llm_api_key", pre=True, always=True) diff --git a/docs/api.create_database_connection.rst b/docs/api.create_database_connection.rst index 79f06a92..e7ca37be 100644 --- a/docs/api.create_database_connection.rst +++ b/docs/api.create_database_connection.rst @@ -85,6 +85,7 @@ HTTP 201 code response { "id": "64f251ce9614e0e94b0520bc", "alias": "string_999", + dialect: "postgresql", "use_ssh": true, "connection_uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string", diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst index 02c58823..1396ee23 100644 --- a/docs/api.list_database_connections.rst +++ b/docs/api.list_database_connections.rst @@ -18,6 +18,7 @@ HTTP 200 code response { "id": "64dfa0e103f5134086f7090c", "alias": "databricks", + "dialect": "databricks", "use_ssh": false, "connection_uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", "path_to_credentials_file": null, @@ -27,6 +28,7 @@ HTTP 200 code response { "id": "64e52c5f7d6dc4bc510d6d28", "alias": "postgres", + "dialect": "postgres", "use_ssh": true, "connection_uri": null, "path_to_credentials_file": "bar-LWxPdFcjQw9lU7CeK_2ELR3jGBq0G_uQ7E2rfPLk2RcFR4aDO9e2HmeAQtVpdvtrsQ_0zjsy9q7asdsadXExYJ0g==", diff --git a/docs/api.update_database_connection.rst b/docs/api.update_database_connection.rst index 05fe2e32..f3aaa1bd 100644 --- a/docs/api.update_database_connection.rst +++ b/docs/api.update_database_connection.rst @@ -42,6 +42,7 @@ HTTP 200 code response { "id": "64f251ce9614e0e94b0520bc", "alias": "string_999", + "dialect": "sqlite", "use_ssh": false, "connection_uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string",