Skip to content

Commit

Permalink
Merge pull request #549 from pygeek/feat-db-opts
Browse files Browse the repository at this point in the history
#548: Added support for additional db connect options.
  • Loading branch information
zainhoda authored Jul 25, 2024
2 parents b7604e2 + c1a4275 commit 64164cf
Showing 1 changed file with 79 additions and 51 deletions.
130 changes: 79 additions & 51 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def connect_to_snowflake(
database: str,
role: Union[str, None] = None,
warehouse: Union[str, None] = None,
**kwargs
):
try:
snowflake = __import__("snowflake.connector")
Expand Down Expand Up @@ -765,7 +766,8 @@ def connect_to_snowflake(
password=password,
account=account,
database=database,
client_session_keep_alive=True
client_session_keep_alive=True,
**kwargs
)

def run_sql_snowflake(sql: str) -> pd.DataFrame:
Expand All @@ -791,13 +793,13 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:
self.run_sql = run_sql_snowflake
self.run_sql_is_set = True

def connect_to_sqlite(self, url: str):
def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
"""
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
url (str): The URL of the database to connect to.
check_same_thread (str): Allow the connection may be accessed in multiple threads.
Returns:
None
"""
Expand All @@ -816,7 +818,11 @@ def connect_to_sqlite(self, url: str):
url = path

# Connect to the database
conn = sqlite3.connect(url, check_same_thread=False)
conn = sqlite3.connect(
url,
check_same_thread=check_same_thread,
**kwargs
)

def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)
Expand All @@ -832,6 +838,7 @@ def connect_to_postgres(
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
"""
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -901,6 +908,7 @@ def connect_to_postgres(
user=user,
password=password,
port=port,
**kwargs
)
except psycopg2.Error as e:
raise ValidationError(e)
Expand Down Expand Up @@ -932,12 +940,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:


def connect_to_mysql(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):

try:
Expand Down Expand Up @@ -981,12 +990,15 @@ def connect_to_mysql(
conn = None

try:
conn = pymysql.connect(host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor)
conn = pymysql.connect(
host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor,
**kwargs
)
except pymysql.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1016,12 +1028,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql = run_sql_mysql

def connect_to_clickhouse(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):

try:
Expand Down Expand Up @@ -1071,6 +1084,7 @@ def connect_to_clickhouse(
username=user,
password=password,
database=dbname,
**kwargs
)
print(conn)
except Exception as e:
Expand All @@ -1093,10 +1107,11 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql = run_sql_clickhouse

def connect_to_oracle(
self,
user: str = None,
password: str = None,
dsn: str = None,
self,
user: str = None,
password: str = None,
dsn: str = None,
**kwargs
):

"""
Expand Down Expand Up @@ -1149,7 +1164,8 @@ def connect_to_oracle(
user=user,
password=password,
dsn=dsn,
)
**kwargs
)
except oracledb.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1181,7 +1197,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql_is_set = True
self.run_sql = run_sql_oracle

def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
def connect_to_bigquery(
self,
cred_file_path: str = None,
project_id: str = None,
**kwargs
):
"""
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
Expand Down Expand Up @@ -1243,7 +1264,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None
)

try:
conn = bigquery.Client(project=project_id, credentials=credentials)
conn = bigquery.Client(
project=project_id,
credentials=credentials,
**kwargs
)
except:
raise ImproperlyConfigured(
"Could not connect to bigquery please correct credentials"
Expand All @@ -1266,7 +1291,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

def connect_to_duckdb(self, url: str, init_sql: str = None):
def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -1304,7 +1329,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
f.write(response.content)

# Connect to the database
conn = duckdb.connect(path)
conn = duckdb.connect(path, **kwargs)
if init_sql:
conn.query(init_sql)

Expand All @@ -1315,7 +1340,7 @@ def run_sql_duckdb(sql: str):
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

def connect_to_mssql(self, odbc_conn_str: str):
def connect_to_mssql(self, odbc_conn_str: str, **kwargs):
"""
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -1348,7 +1373,7 @@ def connect_to_mssql(self, odbc_conn_str: str):

from sqlalchemy import create_engine

engine = create_engine(connection_url)
engine = create_engine(connection_url, **kwargs)

def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
Expand All @@ -1363,16 +1388,17 @@ def run_sql_mssql(sql: str):
self.run_sql = run_sql_mssql
self.run_sql_is_set = True
def connect_to_presto(
self,
host: str,
catalog: str = 'hive',
schema: str = 'default',
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = 'https',
requests_kwargs: dict = None
self,
host: str,
catalog: str = 'hive',
schema: str = 'default',
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = 'https',
requests_kwargs: dict = None,
**kwargs
):
"""
Connect to a Presto database using the specified parameters.
Expand Down Expand Up @@ -1445,7 +1471,8 @@ def connect_to_presto(
schema=schema,
port=port,
protocol=protocol,
requests_kwargs=requests_kwargs)
requests_kwargs=requests_kwargs,
**kwargs)
except presto.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1478,13 +1505,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql = run_sql_presto

def connect_to_hive(
self,
host: str = None,
dbname: str = 'default',
user: str = None,
password: str = None,
port: int = None,
auth: str = 'CUSTOM'
self,
host: str = None,
dbname: str = 'default',
user: str = None,
password: str = None,
port: int = None,
auth: str = 'CUSTOM',
**kwargs
):
"""
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down

0 comments on commit 64164cf

Please sign in to comment.