From a131137a0aef1731798cc17fb58412305b9ac9c3 Mon Sep 17 00:00:00 2001 From: adam ulring Date: Sun, 14 Jul 2024 14:28:05 -0500 Subject: [PATCH 1/5] Added: - sqlite vector support - duckdb vector support - added corresponding tests to vanna_test - updated import test - updated precommit config to use ruff - updated pyproject toml with new duckdb dependency --- .pre-commit-config.yaml | 8 +- pyproject.toml | 3 +- src/vanna/base/base.py | 482 +++++++++++++++--------------- src/vanna/duckdb/__init__.py | 1 + src/vanna/duckdb/duckdb_vector.py | 195 ++++++++++++ src/vanna/sqlite/__init__.py | 1 + src/vanna/sqlite/sqlite_vector.py | 245 +++++++++++++++ tests/test_imports.py | 7 +- tests/test_vanna.py | 419 +++++++++++++++++++++++--- 9 files changed, 1074 insertions(+), 287 deletions(-) create mode 100644 src/vanna/duckdb/__init__.py create mode 100644 src/vanna/duckdb/duckdb_vector.py create mode 100644 src/vanna/sqlite/__init__.py create mode 100644 src/vanna/sqlite/sqlite_vector.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c64ebe70..690264fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,8 @@ repos: - id: debug-statements - id: mixed-line-ending - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.3 hooks: - - id: isort - args: [ "--profile", "black", "--filter-files" ] + #- id: ruff + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index 9b9b56f8..0a7829d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ mysql = ["PyMySQL"] clickhouse = ["clickhouse_connect"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] -duckdb = ["duckdb"] +duckdb = ["duckdb", "fastembed"] +sqlite = ["fastembed"] google = ["google-generativeai", "google-cloud-aiplatform"] all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"] test = ["tox"] diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 492516ea..7525e1f7 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -136,7 +136,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> llm_response = self.submit_prompt(prompt, **kwargs) self.log(title="LLM Response", message=llm_response) - if 'intermediate_sql' in llm_response: + if "intermediate_sql" in llm_response: if not allow_llm_to_see_data: return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this." @@ -152,7 +152,11 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> question=question, question_sql_list=question_sql_list, ddl_list=ddl_list, - doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()], + doc_list=doc_list + + [ + f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + + df.to_markdown() + ], **kwargs, ) self.log(title="Final SQL Prompt", message=prompt) @@ -161,7 +165,6 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> except Exception as e: return f"Error running intermediate SQL: {e}" - return self.extract_sql(llm_response) def extract_sql(self, llm_response: str) -> str: @@ -229,7 +232,7 @@ def is_sql_valid(self, sql: str) -> bool: parsed = sqlparse.parse(sql) for statement in parsed: - if statement.get_type() == 'SELECT': + if statement.get_type() == "SELECT": return True return False @@ -251,7 +254,7 @@ def should_generate_chart(self, df: pd.DataFrame) -> bool: bool: True if a chart should be generated, False otherwise. """ - if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0: + if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0: return True return False @@ -282,8 +285,8 @@ def generate_followup_questions( f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" ), self.user_message( - f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + - self._response_language() + f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + + self._response_language() ), ] @@ -327,8 +330,8 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" ), self.user_message( - "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + - self._response_language() + "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + + self._response_language() ), ] @@ -524,7 +527,7 @@ def add_sql_to_prompt( def get_sql_prompt( self, - initial_prompt : str, + initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, @@ -556,8 +559,10 @@ def get_sql_prompt( """ if initial_prompt is None: - initial_prompt = f"You are a {self.dialect} expert. " + \ - "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " + initial_prompt = ( + f"You are a {self.dialect} expert. " + + "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " + ) initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_list, max_tokens=self.max_tokens @@ -764,7 +769,7 @@ def connect_to_snowflake( password=password, account=account, database=database, - client_session_keep_alive=True + client_session_keep_alive=True, ) def run_sql_snowflake(sql: str) -> pd.DataFrame: @@ -790,7 +795,7 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame: self.run_sql = run_sql_snowflake self.run_sql_is_set = True - def connect_to_sqlite(self, url: str): + def connect_to_sqlite(self, url: str = ":memory:", conn=None): """ Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -803,19 +808,24 @@ def connect_to_sqlite(self, url: str): # URL of the database to download - # Path to save the downloaded database - path = os.path.basename(urlparse(url).path) + if not conn: + path = os.path.basename(urlparse(url).path) + + if path == ":memory:" or path == "": + url = ":memory:" - # Download the database if it doesn't exist - if not os.path.exists(url): - response = requests.get(url) - response.raise_for_status() # Check that the request was successful - with open(path, "wb") as f: - f.write(response.content) - url = path + elif not os.path.exists(url): + response = requests.get(url) + response.raise_for_status() # Check that the request was successful + with open(path, "wb") as f: + f.write(response.content) + url = path + else: + raise ValidationError( + f"Invalid connection settings. Pass a valid url or sqlite conn." + ) - # Connect to the database - conn = sqlite3.connect(url, check_same_thread=False) + conn = sqlite3.connect(url, check_same_thread=False) def run_sql_sqlite(sql: str): return pd.read_sql_query(sql, conn) @@ -929,16 +939,14 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_postgres - def connect_to_mysql( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, ): - try: import pymysql.cursors except ImportError: @@ -980,12 +988,14 @@ def connect_to_mysql( conn = None try: - conn = pymysql.connect(host=host, - user=user, - password=password, - database=dbname, - port=port, - cursorclass=pymysql.cursors.DictCursor) + conn = pymysql.connect( + host=host, + user=user, + password=password, + database=dbname, + port=port, + cursorclass=pymysql.cursors.DictCursor, + ) except pymysql.Error as e: raise ValidationError(e) @@ -1015,14 +1025,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_mysql def connect_to_clickhouse( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, ): - try: import clickhouse_connect except ImportError: @@ -1087,17 +1096,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]: except Exception as e: raise e - + self.run_sql_is_set = True self.run_sql = run_sql_clickhouse def connect_to_oracle( - self, - user: str = None, - password: str = None, - dsn: str = None, + self, + user: str = None, + password: str = None, + dsn: str = None, ): - """ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -1117,7 +1125,6 @@ def connect_to_oracle( try: import oracledb except ImportError: - raise DependencyError( "You need to install required dependencies to execute this method," " run command: \npip install oracledb" @@ -1127,7 +1134,9 @@ def connect_to_oracle( dsn = os.getenv("DSN") if not dsn: - raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid") + raise ImproperlyConfigured( + "Please set your Oracle dsn which should include host:port/sid" + ) if not user: user = os.getenv("USER") @@ -1148,7 +1157,7 @@ def connect_to_oracle( user=user, password=password, dsn=dsn, - ) + ) except oracledb.Error as e: raise ValidationError(e) @@ -1156,7 +1165,9 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: if conn: try: sql = sql.rstrip() - if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error. + if sql.endswith( + ";" + ): # fix for a known problem with Oracle db where an extra ; will cause an error. sql = sql[:-1] cs = conn.cursor() @@ -1265,7 +1276,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_bigquery - def connect_to_duckdb(self, url: str, init_sql: str = None): + def connect_to_duckdb(self, url: str = ":memory:", init_sql: str = None, conn=None): """ Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1276,6 +1287,8 @@ def connect_to_duckdb(self, url: str, init_sql: str = None): Returns: None """ + # TODO consider passing an actual duckdb conn + try: import duckdb except ImportError: @@ -1283,27 +1296,23 @@ def connect_to_duckdb(self, url: str, init_sql: str = None): "You need to install required dependencies to execute this method," " run command: \npip install vanna[duckdb]" ) - # URL of the database to download - if url == ":memory:" or url == "": - path = ":memory:" - else: - # Path to save the downloaded database - print(os.path.exists(url)) - if os.path.exists(url): - path = url - elif url.startswith("md") or url.startswith("motherduck"): - path = url + if not conn: + if url == ":memory:" or url == "": + path = ":memory:" else: - path = os.path.basename(urlparse(url).path) - # Download the database if it doesn't exist - if not os.path.exists(path): - response = requests.get(url) - response.raise_for_status() # Check that the request was successful - with open(path, "wb") as f: - f.write(response.content) - - # Connect to the database - conn = duckdb.connect(path) + if os.path.exists(url): + path = url + elif url.startswith("md") or url.startswith("motherduck"): + path = url + else: + path = os.path.basename(urlparse(url).path) + if not os.path.exists(path): + response = requests.get(url) + response.raise_for_status() + with open(path, "wb") as f: + f.write(response.content) + + conn = duckdb.connect(path) if init_sql: conn.query(init_sql) @@ -1361,19 +1370,20 @@ def run_sql_mssql(sql: str): self.dialect = "T-SQL / Microsoft SQL Server" self.run_sql = run_sql_mssql self.run_sql_is_set = True + def connect_to_presto( - self, - host: str, - catalog: str = 'hive', - schema: str = 'default', - user: str = None, - password: str = None, - port: int = None, - combined_pem_path: str = None, - protocol: str = 'https', - requests_kwargs: dict = None + self, + host: str, + catalog: str = "hive", + schema: str = "default", + user: str = None, + password: str = None, + port: int = None, + combined_pem_path: str = None, + protocol: str = "https", + requests_kwargs: dict = None, ): - """ + """ Connect to a Presto database using the specified parameters. Args: @@ -1393,99 +1403,101 @@ def connect_to_presto( Returns: None - """ - try: - from pyhive import presto - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install pyhive" - ) + """ + try: + from pyhive import presto + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) - if not host: - host = os.getenv("PRESTO_HOST") - - if not host: - raise ImproperlyConfigured("Please set your presto host") - - if not catalog: - catalog = os.getenv("PRESTO_CATALOG") - - if not catalog: - raise ImproperlyConfigured("Please set your presto catalog") - - if not user: - user = os.getenv("PRESTO_USER") - - if not user: - raise ImproperlyConfigured("Please set your presto user") - - if not password: - password = os.getenv("PRESTO_PASSWORD") - - if not port: - port = os.getenv("PRESTO_PORT") - - if not port: - raise ImproperlyConfigured("Please set your presto port") - - conn = None - - try: - if requests_kwargs is None and combined_pem_path is not None: - # use the combined pem file to verify the SSL connection - requests_kwargs = { - 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证 - } - conn = presto.Connection(host=host, - username=user, - password=password, - catalog=catalog, - schema=schema, - port=port, - protocol=protocol, - requests_kwargs=requests_kwargs) - except presto.Error as e: - raise ValidationError(e) - - def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - sql = sql.rstrip() - # fix for a known problem with presto db where an extra ; will cause an error. - if sql.endswith(';'): - sql = sql[:-1] - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + if not host: + host = os.getenv("PRESTO_HOST") - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] - ) - return df + if not host: + raise ImproperlyConfigured("Please set your presto host") - except presto.Error as e: - print(e) + if not catalog: + catalog = os.getenv("PRESTO_CATALOG") + + if not catalog: + raise ImproperlyConfigured("Please set your presto catalog") + + if not user: + user = os.getenv("PRESTO_USER") + + if not user: + raise ImproperlyConfigured("Please set your presto user") + + if not password: + password = os.getenv("PRESTO_PASSWORD") + + if not port: + port = os.getenv("PRESTO_PORT") + + if not port: + raise ImproperlyConfigured("Please set your presto port") + + conn = None + + try: + if requests_kwargs is None and combined_pem_path is not None: + # use the combined pem file to verify the SSL connection + requests_kwargs = { + "verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证 + } + conn = presto.Connection( + host=host, + username=user, + password=password, + catalog=catalog, + schema=schema, + port=port, + protocol=protocol, + requests_kwargs=requests_kwargs, + ) + except presto.Error as e: raise ValidationError(e) - except Exception as e: - print(e) - raise e + def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + sql = sql.rstrip() + # fix for a known problem with presto db where an extra ; will cause an error. + if sql.endswith(";"): + sql = sql[:-1] + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() + + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df - self.run_sql_is_set = True - self.run_sql = run_sql_presto + except presto.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_presto def connect_to_hive( - self, - host: str = None, - dbname: str = 'default', - user: str = None, - password: str = None, - port: int = None, - auth: str = 'CUSTOM' + self, + host: str = None, + dbname: str = "default", + user: str = None, + password: str = None, + port: int = None, + auth: str = "CUSTOM", ): - """ + """ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1499,78 +1511,80 @@ def connect_to_hive( Returns: None - """ - - try: - from pyhive import hive - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install pyhive" - ) - - if not host: - host = os.getenv("HIVE_HOST") + """ - if not host: - raise ImproperlyConfigured("Please set your hive host") + try: + from pyhive import hive + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) - if not dbname: - dbname = os.getenv("HIVE_DATABASE") + if not host: + host = os.getenv("HIVE_HOST") - if not dbname: - raise ImproperlyConfigured("Please set your hive database") + if not host: + raise ImproperlyConfigured("Please set your hive host") - if not user: - user = os.getenv("HIVE_USER") + if not dbname: + dbname = os.getenv("HIVE_DATABASE") - if not user: - raise ImproperlyConfigured("Please set your hive user") + if not dbname: + raise ImproperlyConfigured("Please set your hive database") - if not password: - password = os.getenv("HIVE_PASSWORD") + if not user: + user = os.getenv("HIVE_USER") - if not port: - port = os.getenv("HIVE_PORT") + if not user: + raise ImproperlyConfigured("Please set your hive user") - if not port: - raise ImproperlyConfigured("Please set your hive port") + if not password: + password = os.getenv("HIVE_PASSWORD") - conn = None + if not port: + port = os.getenv("HIVE_PORT") - try: - conn = hive.Connection(host=host, - username=user, - password=password, - database=dbname, - port=port, - auth=auth) - except hive.Error as e: - raise ValidationError(e) + if not port: + raise ImproperlyConfigured("Please set your hive port") - def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + conn = None - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] + try: + conn = hive.Connection( + host=host, + username=user, + password=password, + database=dbname, + port=port, + auth=auth, ) - return df - - except hive.Error as e: - print(e) + except hive.Error as e: raise ValidationError(e) - except Exception as e: - print(e) - raise e + def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() - self.run_sql_is_set = True - self.run_sql = run_sql_hive + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df + + except hive.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_hive def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: """ @@ -1628,7 +1642,9 @@ def ask( question = input("Enter a question: ") try: - sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data) + sql = self.generate_sql( + question=question, allow_llm_to_see_data=allow_llm_to_see_data + ) except Exception as e: print(e) return None, None, None @@ -1802,12 +1818,8 @@ def get_training_plan_generic(self, df) -> TrainingPlan: table_column = df.columns[ df.columns.str.lower().str.contains("table_name") ].to_list()[0] - columns = [database_column, - schema_column, - table_column] - candidates = ["column_name", - "data_type", - "comment"] + columns = [database_column, schema_column, table_column] + candidates = ["column_name", "data_type", "comment"] matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True) columns += df.columns[matches].to_list() diff --git a/src/vanna/duckdb/__init__.py b/src/vanna/duckdb/__init__.py new file mode 100644 index 00000000..3a514092 --- /dev/null +++ b/src/vanna/duckdb/__init__.py @@ -0,0 +1 @@ +from .duckdb_vector import DuckDB_VectorStore diff --git a/src/vanna/duckdb/duckdb_vector.py b/src/vanna/duckdb/duckdb_vector.py new file mode 100644 index 00000000..13021727 --- /dev/null +++ b/src/vanna/duckdb/duckdb_vector.py @@ -0,0 +1,195 @@ +import json +import os +from typing import List + +import duckdb +import numpy as np +import pandas as pd +from fastembed import TextEmbedding + +from vanna.base import VannaBase +from vanna.utils import deterministic_uuid + + +class DuckDB_VectorStore(VannaBase): + def __init__(self, config=None): + super().__init__(config=config) + if config is None: + config = {} + + path = config.get("path", ".") + database_name = config.get("database_name", "vanna.duckdb") + self.database_path = os.path.join(path, database_name) + self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) + self.n_results_documentation = config.get( + "n_results_documentation", config.get("n_results", 10) + ) + self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) + + self.model_name = self.config.get("model_name", "BAAI/bge-small-en-v1.5") + self.embedding_model = TextEmbedding(model_name=self.model_name) + + conn = duckdb.connect(database=self.database_path) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS embeddings ( + id VARCHAR, + text VARCHAR, + model VARCHAR, + vec FLOAT[384] + ); + """ + ) + conn.close() + + def generate_embedding(self, data: str) -> List[float]: + embeddings = list(self.embedding_model.embed([data])) + return embeddings[0] + + def write_embedding_to_table(self, text, id, embedding): + con = duckdb.connect(database=self.database_path) + embedding_array = np.array(embedding, dtype=np.float32).tolist() + con.execute( + "INSERT INTO embeddings (id, text, model, vec) VALUES (?, ?, ?, ?)", + [id, text, self.model_name, embedding_array], + ) + con.close() + + def add_question_sql(self, question: str, sql: str) -> str: + question_sql_json = json.dumps( + { + "question": question, + "sql": sql, + }, + ensure_ascii=False, + ) + id = deterministic_uuid(question_sql_json) + "-sql" + self.write_embedding_to_table( + question_sql_json, id, self.generate_embedding(question_sql_json) + ) + return id + + def add_ddl(self, ddl: str) -> str: + id = deterministic_uuid(ddl) + "-ddl" + self.write_embedding_to_table(ddl, id, self.generate_embedding(ddl)) + return id + + def add_documentation(self, documentation: str) -> str: + id = deterministic_uuid(documentation) + "-doc" + self.write_embedding_to_table( + documentation, id, self.generate_embedding(documentation) + ) + return id + + def get_training_data(self) -> pd.DataFrame: + con = duckdb.connect(database=self.database_path) + sql_data = con.execute("SELECT * FROM embeddings").fetchdf() + con.close() + + df = pd.DataFrame() + + if not sql_data.empty: + df_sql = sql_data[sql_data["id"].str.endswith("-sql")] + df_sql = pd.DataFrame( + { + "id": df_sql["id"], + "question": [json.loads(doc)["question"] for doc in df_sql["text"]], + "content": [json.loads(doc)["sql"] for doc in df_sql["text"]], + "training_data_type": "sql", + } + ) + df = pd.concat([df, df_sql]) + + df_ddl = sql_data[sql_data["id"].str.endswith("-ddl")] + df_ddl = pd.DataFrame( + { + "id": df_ddl["id"], + "question": None, + "content": df_ddl["text"], + "training_data_type": "ddl", + } + ) + df = pd.concat([df, df_ddl]) + + df_doc = sql_data[sql_data["id"].str.endswith("-doc")] + df_doc = pd.DataFrame( + { + "id": df_doc["id"], + "question": None, + "content": df_doc["text"], + "training_data_type": "documentation", + } + ) + df = pd.concat([df, df_doc]) + + return df + + def remove_training_data(self, id: str) -> bool: + con = duckdb.connect(database=self.database_path) + con.execute("DELETE FROM embeddings WHERE id = ?", [id]) + con.close() + return True + + def remove_collection(self, collection_name: str) -> bool: + suffix = {"sql": "-sql", "ddl": "-ddl", "documentation": "-doc"}.get( + collection_name, None + ) + if suffix: + con = duckdb.connect(database=self.database_path) + con.execute("DELETE FROM embeddings WHERE id LIKE ?", ["%" + suffix]) + con.close() + return True + return False + + def query_similar_embeddings(self, query_text: str, top_n: int) -> pd.DataFrame: + query_embedding = self.generate_embedding(query_text) + query_embedding_array = np.array(query_embedding, dtype=np.float32).tolist() + + con = duckdb.connect(database=self.database_path) + results = con.execute( + """ + SELECT text, array_cosine_similarity(vec, ?::FLOAT[384]) AS similarity_score + FROM embeddings + ORDER BY similarity_score DESC + LIMIT ?; + """, + [query_embedding_array, top_n], + ).fetchdf() + con.close() + return results + + def get_similar_question_sql(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_sql) + similar_questions = [] + for doc in results["text"]: + try: + parsed_doc = json.loads(doc) + similar_questions.append( + {"question": parsed_doc["question"], "sql": parsed_doc["sql"]} + ) + except json.JSONDecodeError as e: + similar_questions.append(doc) + continue + return similar_questions + + def get_related_ddl(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_ddl) + related_ddls = [] + for doc in results["text"]: + try: + related_ddls.append(json.loads(doc)) + except json.JSONDecodeError as e: + related_ddls.append(doc) + continue + return related_ddls + + def get_related_documentation(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_documentation) + related_docs = [] + for doc in results["text"]: + try: + related_docs.append(json.loads(doc)) + except json.JSONDecodeError as e: + related_docs.append(doc) + continue + return related_docs diff --git a/src/vanna/sqlite/__init__.py b/src/vanna/sqlite/__init__.py new file mode 100644 index 00000000..edb9c1d0 --- /dev/null +++ b/src/vanna/sqlite/__init__.py @@ -0,0 +1 @@ +from .sqlite_vector import SQLite_VectorStore diff --git a/src/vanna/sqlite/sqlite_vector.py b/src/vanna/sqlite/sqlite_vector.py new file mode 100644 index 00000000..03e94f3e --- /dev/null +++ b/src/vanna/sqlite/sqlite_vector.py @@ -0,0 +1,245 @@ +import json +import os +import sqlite3 +from typing import List + +import numpy as np +import pandas as pd +from fastembed import TextEmbedding + +from ..base import VannaBase +from ..utils import deterministic_uuid + + +def cosine_similarity(a, b): + a = np.array(a) + b = np.array(b) + dot_product = np.dot(a, b) + magnitude_a = np.linalg.norm(a) + magnitude_b = np.linalg.norm(b) + if magnitude_a == 0 or magnitude_b == 0: + return 0.0 + return dot_product / (magnitude_a * magnitude_b) + + +def sqlite_information_schema(database_path): + conn = sqlite3.connect(database_path) + tables = pd.read_sql_query( + "SELECT name FROM sqlite_master WHERE type='table';", conn + ) + schema_info = [] + for table in tables["name"]: + table_info = pd.read_sql_query(f"PRAGMA table_info({table});", conn) + for row in table_info.itertuples(): + schema_info.append( + { + "table_catalog": "main", # SQLite does not have catalogs, using a placeholder + "table_schema": "main", # SQLite does not have schemas, using a placeholder + "table_name": table, + "column_name": row.name, + "ordinal_position": row.cid + 1, + "column_default": row.dflt_value, + "is_nullable": not row.notnull, + "data_type": row.type, + "identity_generation": None, # Placeholder as SQLite does not have this + "identity_start": None, # Placeholder as SQLite does not have this + "identity_increment": None, # Placeholder as SQLite does not have this + "identity_maximum": None, # Placeholder as SQLite does not have this + "identity_minimum": None, # Placeholder as SQLite does not have this + "identity_cycle": None, # Placeholder as SQLite does not have this + "is_generated": None, # Placeholder as SQLite does not have this + "generation_expression": None, # Placeholder as SQLite does not have this + "is_updatable": None, # Placeholder as SQLite does not have this + "COLUMN_COMMENT": None, # Placeholder for comments + } + ) + conn.close() + return pd.DataFrame(schema_info) + + +class SQLite_VectorStore(VannaBase): + def __init__(self, config=None): + super().__init__(config=config) + if config is None: + config = {} + + path = config.get("path", ".") + database_name = config.get("database_name", "vanna.sqlite") + self.database_path = os.path.join(path, database_name) + self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) + self.n_results_documentation = config.get( + "n_results_documentation", config.get("n_results", 10) + ) + self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) + self.model_name = self.config.get("model_name", "BAAI/bge-small-en-v1.5") + self.embedding_model = TextEmbedding(model_name=self.model_name) + + conn = sqlite3.connect(self.database_path) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS embeddings ( + id TEXT, + text TEXT, + model TEXT, + vec BLOB + ); + """ + ) + conn.close() + + def generate_embedding(self, data: str) -> List[float]: + embeddings = list(self.embedding_model.embed([data])) + return embeddings[0] + + def write_embedding_to_table(self, text, id, embedding): + con = sqlite3.connect(self.database_path) + embedding_array = np.array(embedding, dtype=np.float32).tobytes() + con.execute( + "INSERT INTO embeddings (id, text, model, vec) VALUES (?, ?, ?, ?)", + (id, text, self.model_name, embedding_array), + ) + con.commit() + con.close() + + def add_question_sql(self, question: str, sql: str) -> str: + question_sql_json = json.dumps( + { + "question": question, + "sql": sql, + }, + ensure_ascii=False, + ) + id = deterministic_uuid(question_sql_json) + "-sql" + self.write_embedding_to_table( + question_sql_json, id, self.generate_embedding(question_sql_json) + ) + return id + + def add_ddl(self, ddl: str) -> str: + id = deterministic_uuid(ddl) + "-ddl" + self.write_embedding_to_table(ddl, id, self.generate_embedding(ddl)) + return id + + def add_documentation(self, documentation: str) -> str: + id = deterministic_uuid(documentation) + "-doc" + self.write_embedding_to_table( + documentation, id, self.generate_embedding(documentation) + ) + return id + + def get_training_data(self) -> pd.DataFrame: + con = sqlite3.connect(self.database_path) + sql_data = pd.read_sql_query("SELECT * FROM embeddings", con) + con.close() + + df = pd.DataFrame() + + if not sql_data.empty: + df_sql = sql_data[sql_data["id"].str.endswith("-sql")] + df_sql = pd.DataFrame( + { + "id": df_sql["id"], + "question": [json.loads(doc)["question"] for doc in df_sql["text"]], + "content": [json.loads(doc)["sql"] for doc in df_sql["text"]], + "training_data_type": "sql", + } + ) + df = pd.concat([df, df_sql]) + + df_ddl = sql_data[sql_data["id"].str.endswith("-ddl")] + df_ddl = pd.DataFrame( + { + "id": df_ddl["id"], + "question": None, + "content": df_ddl["text"], + "training_data_type": "ddl", + } + ) + df = pd.concat([df, df_ddl]) + + df_doc = sql_data[sql_data["id"].str.endswith("-doc")] + df_doc = pd.DataFrame( + { + "id": df_doc["id"], + "question": None, + "content": df_doc["text"], + "training_data_type": "documentation", + } + ) + df = pd.concat([df, df_doc]) + + return df + + def remove_training_data(self, id: str) -> bool: + con = sqlite3.connect(self.database_path) + con.execute("DELETE FROM embeddings WHERE id = ?", (id,)) + con.commit() + con.close() + return True + + def remove_collection(self, collection_name: str) -> bool: + suffix = {"sql": "%-sql", "ddl": "%-ddl", "documentation": "%-doc"}.get( + collection_name, None + ) + if suffix: + con = sqlite3.connect(self.database_path) + con.execute("DELETE FROM embeddings WHERE id LIKE ?", (suffix,)) + con.commit() + con.close() + return True + return False + + def query_similar_embeddings(self, query_text: str, top_n: int) -> pd.DataFrame: + query_embedding = self.generate_embedding(query_text) + + con = sqlite3.connect(self.database_path) + embeddings_data = pd.read_sql_query("SELECT id, text, vec FROM embeddings", con) + con.close() + + embeddings_data["vec"] = embeddings_data["vec"].apply( + lambda x: np.frombuffer(x, dtype=np.float32) + ) + embeddings_data["similarity_score"] = embeddings_data["vec"].apply( + lambda vec: cosine_similarity(query_embedding, vec) + ) + + sorted_results = embeddings_data.sort_values( + by="similarity_score", ascending=False + ).head(top_n) + return sorted_results[["text", "similarity_score"]] + + def get_similar_question_sql(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_sql) + similar_questions = [] + for doc in results["text"]: + try: + if self._is_json(doc): + parsed_doc = json.loads(doc) + similar_questions.append( + {"question": parsed_doc["question"], "sql": parsed_doc["sql"]} + ) + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e} - Document: {doc}") + continue + return similar_questions + + def get_related_ddl(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_ddl) + related_ddls = [] + for doc in results["text"]: + related_ddls.append(doc) + return related_ddls + + def get_related_documentation(self, question: str) -> list: + results = self.query_similar_embeddings(question, self.n_results_documentation) + related_docs = [] + for doc in results["text"]: + related_docs.append(doc) + return related_docs + + def _is_json(self, text: str) -> bool: + try: + json.loads(text) + except ValueError: + return False + return True diff --git a/tests/test_imports.py b/tests/test_imports.py index ee240f79..373d6d8a 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,9 +1,8 @@ - - def test_regular_imports(): from vanna.anthropic.anthropic_chat import Anthropic_Chat from vanna.base.base import VannaBase from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore + from vanna.duckdb.duckdb_vector import DuckDB_VectorStore from vanna.hf.hf import Hf from vanna.local import LocalContext_OpenAI from vanna.marqo.marqo import Marqo_VectorStore @@ -15,15 +14,18 @@ def test_regular_imports(): from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore from vanna.remote import VannaDefault + from vanna.sqlite.sqlite_vector import SQLite_VectorStore from vanna.vannadb.vannadb_vector import VannaDB_VectorStore from vanna.weaviate.weaviate_vector import WeaviateDatabase from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings + def test_shortcut_imports(): from vanna.anthropic import Anthropic_Chat from vanna.base import VannaBase from vanna.chromadb import ChromaDB_VectorStore + from vanna.duckdb import DuckDB_VectorStore from vanna.hf import Hf from vanna.marqo import Marqo_VectorStore from vanna.milvus import Milvus_VectorStore @@ -32,6 +34,7 @@ def test_shortcut_imports(): from vanna.openai import OpenAI_Chat, OpenAI_Embeddings from vanna.opensearch import OpenSearch_VectorStore from vanna.pinecone import PineconeDB_VectorStore + from vanna.sqlite import SQLite_VectorStore from vanna.vannadb import VannaDB_VectorStore from vanna.vllm import Vllm from vanna.weaviate import WeaviateDatabase diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 67e34d7b..703bb4b2 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -1,4 +1,9 @@ import os +import sys +from tempfile import TemporaryDirectory, NamedTemporaryFile + +import requests +import pytest from vanna.anthropic.anthropic_chat import Anthropic_Chat from vanna.google import GoogleGeminiChat @@ -10,55 +15,95 @@ try: print("Trying to load .env") from dotenv import load_dotenv + load_dotenv() except Exception as e: print(f"Failed to load .env {e}") pass -MY_VANNA_MODEL = 'chinook' -ANTHROPIC_Model = 'claude-3-sonnet-20240229' -MY_VANNA_API_KEY = os.environ['VANNA_API_KEY'] -OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] -MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY'] -ANTHROPIC_API_KEY = os.environ['ANTHROPIC_API_KEY'] -SNOWFLAKE_ACCOUNT = os.environ['SNOWFLAKE_ACCOUNT'] -SNOWFLAKE_USERNAME = os.environ['SNOWFLAKE_USERNAME'] -SNOWFLAKE_PASSWORD = os.environ['SNOWFLAKE_PASSWORD'] +MY_VANNA_MODEL = "chinook" +ANTHROPIC_Model = "claude-3-sonnet-20240229" +MY_VANNA_API_KEY = os.environ.get("VANNA_API_KEY") +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") +MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") +ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") +SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT") +SNOWFLAKE_USERNAME = os.environ.get("SNOWFLAKE_USERNAME") +SNOWFLAKE_PASSWORD = os.environ.get("SNOWFLAKE_PASSWORD") +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") + + +def is_k_option_passed(): + return "-k" in sys.argv + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session): + if not is_k_option_passed(): + assert MY_VANNA_API_KEY is not None, "VANNA_API_KEY is not set" + assert OPENAI_API_KEY is not None, "OPENAI_API_KEY is not set" + assert MISTRAL_API_KEY is not None, "MISTRAL_API_KEY is not set" + assert ANTHROPIC_API_KEY is not None, "ANTHROPIC_API_KEY is not set" + assert SNOWFLAKE_ACCOUNT is not None, "SNOWFLAKE_ACCOUNT is not set" + assert SNOWFLAKE_USERNAME is not None, "SNOWFLAKE_USERNAME is not set" + assert SNOWFLAKE_PASSWORD is not None, "SNOWFLAKE_PASSWORD is not set" + assert GEMINI_API_KEY is not None, "GEMINI_API_KEY is not set" + class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat): def __init__(self, config=None): - VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + VannaDB_VectorStore.__init__( + self, + vanna_model=MY_VANNA_MODEL, + vanna_api_key=MY_VANNA_API_KEY, + config=config, + ) OpenAI_Chat.__init__(self, config=config) -vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) -vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +vn_openai = VannaOpenAI(config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo"}) +vn_openai.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_openai(): sql = vn_openai.generate_sql("What are the top 4 customers by sales?") df = vn_openai.run_sql(sql) assert len(df) == 4 + class VannaMistral(VannaDB_VectorStore, Mistral): def __init__(self, config=None): - VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) - Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'}) + VannaDB_VectorStore.__init__( + self, + vanna_model=MY_VANNA_MODEL, + vanna_api_key=MY_VANNA_API_KEY, + config=config, + ) + Mistral.__init__( + self, config={"api_key": MISTRAL_API_KEY, "model": "mistral-tiny"} + ) + vn_mistral = VannaMistral() -vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') +vn_mistral.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_mistral(): sql = vn_mistral.generate_sql("What are the top 5 customers by sales?") df = vn_mistral.run_sql(sql) assert len(df) == 5 + vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY) -vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') +vn_default.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_default(): sql = vn_default.generate_sql("What are the top 6 customers by sales?") df = vn_default.run_sql(sql) assert len(df) == 6 + from vanna.qdrant import Qdrant_VectorStore @@ -67,23 +112,34 @@ def __init__(self, config=None): Qdrant_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) + from qdrant_client import QdrantClient qdrant_memory_client = QdrantClient(":memory:") -vn_qdrant = VannaQdrant(config={'client': qdrant_memory_client, 'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) -vn_qdrant.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') +vn_qdrant = VannaQdrant( + config={ + "client": qdrant_memory_client, + "api_key": OPENAI_API_KEY, + "model": "gpt-3.5-turbo", + } +) +vn_qdrant.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_qdrant(): - df_ddl = vn_qdrant.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + df_ddl = vn_qdrant.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) - for ddl in df_ddl['sql'].to_list(): + for ddl in df_ddl["sql"].to_list(): vn_qdrant.train(ddl=ddl) sql = vn_qdrant.generate_sql("What are the top 7 customers by sales?") df = vn_qdrant.run_sql(sql) assert len(df) == 7 + from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore from vanna.openai.openai_chat import OpenAI_Chat @@ -93,18 +149,22 @@ def __init__(self, config=None): ChromaDB_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) -vn_chroma = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) -vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +vn_chroma = MyVanna(config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo"}) +vn_chroma.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_chroma(): existing_training_data = vn_chroma.get_training_data() if len(existing_training_data) > 0: for _, training_data in existing_training_data.iterrows(): - vn_chroma.remove_training_data(training_data['id']) + vn_chroma.remove_training_data(training_data["id"]) - df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + df_ddl = vn_chroma.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) - for ddl in df_ddl['sql'].to_list(): + for ddl in df_ddl["sql"].to_list(): vn_chroma.train(ddl=ddl) sql = vn_chroma.generate_sql("What are the top 7 customers by sales?") @@ -112,6 +172,231 @@ def test_vn_chroma(): assert len(df) == 7 +from vanna.duckdb.duckdb_vector import DuckDB_VectorStore + + +class MyVannaDuckDb(DuckDB_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + DuckDB_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + + +def test_vn_duckdb(): + import duckdb + + with TemporaryDirectory() as temp_dir: + vn_duckdb = MyVannaDuckDb( + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "path": temp_dir} + ) + database_path = os.path.join(temp_dir, "vanna.duckdb") + conn = duckdb.connect(database=database_path) + conn.close() + vn_duckdb.connect_to_duckdb(database_path) + _ = vn_duckdb.get_training_data() + + conn = duckdb.connect(database=database_path) + + employee_ddl = """ + CREATE TABLE employee ( + employee_id INTEGER, + name VARCHAR, + occupation VARCHAR + ); + """ + + conn.execute(employee_ddl) + + conn.execute( + """ + INSERT INTO employee VALUES + (1, 'Alice Johnson', 'Software Engineer'), + (2, 'Bob Smith', 'Data Scientist'), + (3, 'Charlie Brown', 'Product Manager'), + (4, 'Diana Prince', 'UX Designer'), + (5, 'Ethan Hunt', 'DevOps Engineer'); + """ + ) + conn.commit() + df_information_schema = vn_duckdb.run_sql( + "SELECT * FROM INFORMATION_SCHEMA.COLUMNS" + ) + plan = vn_duckdb.get_training_plan_generic(df_information_schema) + vn_duckdb.train(plan=plan) + + vn_duckdb.train(ddl=employee_ddl) + training_data = vn_duckdb.get_training_data() + assert not training_data.empty + + similar_query = vn_duckdb.query_similar_embeddings("employee id", 3) + assert not similar_query.empty + + sql = vn_duckdb.generate_sql( + question="write a query to get all software engineers from the employees table", + allow_llm_to_see_data=True, + ) + df = vn_duckdb.run_sql(sql) + assert df.name[0] == "Alice Johnson" + del vn_duckdb + ################################################################################ + + with NamedTemporaryFile( + suffix=".sqlite", delete=False + ) as temp_sqlite_file, TemporaryDirectory() as temp_duckdb_dir: + database_path = os.path.join(temp_duckdb_dir, "vanna.duckdb") + response = requests.get("https://vanna.ai/Chinook.sqlite") + response.raise_for_status() + temp_sqlite_file.write(response.content) + temp_sqlite_file.flush() + print(f"Downloaded SQLite database to {temp_sqlite_file.name}") + + duckdb_conn = duckdb.connect(database_path) + duckdb_conn.execute("INSTALL sqlite;") + duckdb_conn.execute("LOAD sqlite;") + duckdb_conn.execute( + f"ATTACH '{temp_sqlite_file.name}' AS sqlite_db (TYPE sqlite);" + ) + duckdb_conn.execute("USE sqlite_db;") + print("Fetching list of tables from attached SQLite database...") + tables = duckdb_conn.execute("SHOW TABLES;").fetchall() + print(f"Tables found: {tables}") + + ddls = [] + for table in tables: + table_name = table[0] + duckdb_conn.execute( + f"CREATE TABLE IF NOT EXISTS main.{table_name} AS SELECT * FROM sqlite_db.{table_name}" + ) + print(f"Copied table {table_name} to DuckDB.") + ddl_query = f"DESCRIBE {table_name};" + columns_info = duckdb_conn.execute(ddl_query).fetchall() + ddl = f"CREATE TABLE {table_name} (\n" + ddl += ",\n".join([f" {col[0]} {col[1]}" for col in columns_info]) + ddl += "\n);" + print(f"DDL for table {table_name}:\n{ddl}") + ddls.append(ddl) + + duckdb_tables = duckdb_conn.execute("SHOW TABLES").fetchall() + duckdb_conn.commit() + print("Tables in DuckDB:", duckdb_tables) + + vn_duckdb = MyVannaDuckDb( + config={ + "api_key": OPENAI_API_KEY, + "model": "gpt-4-turbo", + "path": temp_duckdb_dir, + } + ) + vn_duckdb.connect_to_duckdb(conn=duckdb_conn) + df_information_schema = vn_duckdb.run_sql( + "select * from vanna.information_schema.tables" + ) + print(df_information_schema) + df_information_columns = vn_duckdb.run_sql( + "select * from vanna.information_schema.columns" + ) + plan_schema = vn_duckdb.get_training_plan_generic(df_information_schema) + plan_columns = vn_duckdb.get_training_plan_generic(df_information_columns) + vn_duckdb.train(plan=plan_schema) + vn_duckdb.train(plan=plan_columns) + + for ddl in ddls: + vn_duckdb.train(ddl=ddl) + + sql = vn_duckdb.generate_sql("What are the top 7 customers by sales?") + df = vn_duckdb.run_sql(sql) + assert len(df) == 7 + + duckdb_conn.close() + + +from vanna.sqlite.sqlite_vector import ( + SQLite_VectorStore, + sqlite_information_schema, +) + + +class MyVannaSqlite(SQLite_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + SQLite_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + + +def test_vn_sqlite(): + import sqlite3 + + with TemporaryDirectory() as temp_dir: + vn_sqlite = MyVannaSqlite( + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "path": temp_dir} + ) + database_path = os.path.join(temp_dir, "vanna.sqlite") + conn = sqlite3.connect(database_path) + + employee_ddl = """ + CREATE TABLE employee ( + employee_id INTEGER, + name VARCHAR, + occupation VARCHAR + ); + """ + + conn.execute(employee_ddl) + + conn.execute( + """ + INSERT INTO employee VALUES + (1, 'Alice Johnson', 'Software Engineer'), + (2, 'Bob Smith', 'Data Scientist'), + (3, 'Charlie Brown', 'Product Manager'), + (4, 'Diana Prince', 'UX Designer'), + (5, 'Ethan Hunt', 'DevOps Engineer'); + """ + ) + conn.commit() + conn.close() + + vn_sqlite.connect_to_sqlite(database_path) + df_information_schema = sqlite_information_schema(database_path) + plan = vn_sqlite.get_training_plan_generic(df_information_schema) + vn_sqlite.train(plan=plan) + + vn_sqlite.train(ddl=employee_ddl) + training_data = vn_sqlite.get_training_data() + assert not training_data.empty + + similar_query = vn_sqlite.query_similar_embeddings("employee id", 3) + assert not similar_query.empty + + sql = vn_sqlite.generate_sql( + question="write a query to get all software engineers from the employees table", + allow_llm_to_see_data=True, + ) + df = vn_sqlite.run_sql(sql) + assert df.name[0] == "Alice Johnson" + + del vn_sqlite + ############################################################# + + vn_sqlite = MyVannaSqlite( + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo"} + ) + vn_sqlite.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + existing_training_data = vn_sqlite.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn_sqlite.remove_training_data(training_data["id"]) + + df_ddl = vn_sqlite.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) + + for ddl in df_ddl["sql"].to_list(): + vn_sqlite.train(ddl=ddl) + + sql = vn_sqlite.generate_sql("What are the top 7 customers by sales?") + df = vn_sqlite.run_sql(sql) + assert len(df) == 7 + + from vanna.milvus import Milvus_VectorStore @@ -120,18 +405,22 @@ def __init__(self, config=None): Milvus_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) -vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) -vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +vn_milvus = VannaMilvus(config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo"}) +vn_milvus.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_milvus(): existing_training_data = vn_milvus.get_training_data() if len(existing_training_data) > 0: for _, training_data in existing_training_data.iterrows(): - vn_milvus.remove_training_data(training_data['id']) + vn_milvus.remove_training_data(training_data["id"]) - df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + df_ddl = vn_milvus.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) - for ddl in df_ddl['sql'].to_list(): + for ddl in df_ddl["sql"].to_list(): vn_milvus.train(ddl=ddl) sql = vn_milvus.generate_sql("What are the top 7 customers by sales?") @@ -144,14 +433,31 @@ def __init__(self, config=None): ChromaDB_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, config=config) -vn_chroma_n_results = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results': 1}) -vn_chroma_n_results_ddl = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_ddl': 2}) -vn_chroma_n_results_sql = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_sql': 3}) -vn_chroma_n_results_documentation = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo', 'n_results_documentation': 4}) + +vn_chroma_n_results = MyVanna( + config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo", "n_results": 1} +) +vn_chroma_n_results_ddl = MyVanna( + config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo", "n_results_ddl": 2} +) +vn_chroma_n_results_sql = MyVanna( + config={"api_key": OPENAI_API_KEY, "model": "gpt-3.5-turbo", "n_results_sql": 3} +) +vn_chroma_n_results_documentation = MyVanna( + config={ + "api_key": OPENAI_API_KEY, + "model": "gpt-3.5-turbo", + "n_results_documentation": 4, + } +) + def test_n_results(): for i in range(1, 10): - vn_chroma.train(question=f"What are the total sales for customer {i}?", sql=f"SELECT SUM(sales) FROM example_sales WHERE customer_id = {i}") + vn_chroma.train( + question=f"What are the total sales for customer {i}?", + sql=f"SELECT SUM(sales) FROM example_sales WHERE customer_id = {i}", + ) for i in range(1, 10): vn_chroma.train(documentation=f"Sample documentation {i}") @@ -170,17 +476,29 @@ def test_n_results(): assert len(vn_chroma_n_results_sql.get_similar_question_sql(question)) == 3 assert len(vn_chroma_n_results_documentation.get_related_ddl(question)) != 4 - assert len(vn_chroma_n_results_documentation.get_related_documentation(question)) == 4 - assert len(vn_chroma_n_results_documentation.get_similar_question_sql(question)) != 4 + assert ( + len(vn_chroma_n_results_documentation.get_related_documentation(question)) == 4 + ) + assert ( + len(vn_chroma_n_results_documentation.get_similar_question_sql(question)) != 4 + ) + class VannaClaude(VannaDB_VectorStore, Anthropic_Chat): def __init__(self, config=None): - VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) - Anthropic_Chat.__init__(self, config={'api_key': ANTHROPIC_API_KEY, 'model': ANTHROPIC_Model}) + VannaDB_VectorStore.__init__( + self, + vanna_model=MY_VANNA_MODEL, + vanna_api_key=MY_VANNA_API_KEY, + config=config, + ) + Anthropic_Chat.__init__( + self, config={"api_key": ANTHROPIC_API_KEY, "model": ANTHROPIC_Model} + ) vn_claude = VannaClaude() -vn_claude.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') +vn_claude.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") def test_vn_claude(): @@ -188,19 +506,28 @@ def test_vn_claude(): df = vn_claude.run_sql(sql) assert len(df) == 8 + class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat): def __init__(self, config=None): - VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + VannaDB_VectorStore.__init__( + self, + vanna_model=MY_VANNA_MODEL, + vanna_api_key=MY_VANNA_API_KEY, + config=config, + ) GoogleGeminiChat.__init__(self, config=config) -vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']}) -vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +vn_gemini = VannaGemini(config={"api_key": GEMINI_API_KEY}) +vn_gemini.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + def test_vn_gemini(): sql = vn_gemini.generate_sql("What are the top 9 customers by sales?") df = vn_gemini.run_sql(sql) assert len(df) == 9 + def test_training_plan(): vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY) @@ -208,10 +535,12 @@ def test_training_plan(): account=SNOWFLAKE_ACCOUNT, username=SNOWFLAKE_USERNAME, password=SNOWFLAKE_PASSWORD, - database='SNOWFLAKE_SAMPLE_DATA', + database="SNOWFLAKE_SAMPLE_DATA", ) - df_information_schema = vn_dummy.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = 'TPCH_SF1' ") + df_information_schema = vn_dummy.run_sql( + "SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = 'TPCH_SF1' " + ) plan = vn_dummy.get_training_plan_generic(df_information_schema) assert len(plan._plan) == 8 From 24496afad71149c2082a619d3d388edcfecff8b8 Mon Sep 17 00:00:00 2001 From: adam ulring Date: Sun, 14 Jul 2024 15:12:20 -0500 Subject: [PATCH 2/5] improve config. add missing url=url logic. update corresponding code for corresponding config changes. --- src/vanna/base/base.py | 3 +- src/vanna/duckdb/duckdb_vector.py | 16 ++++----- src/vanna/sqlite/sqlite_vector.py | 22 ++++++------- tests/test_vanna.py | 54 +++++++++++++++---------------- 4 files changed, 45 insertions(+), 50 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 7525e1f7..7c0cb152 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -813,7 +813,8 @@ def connect_to_sqlite(self, url: str = ":memory:", conn=None): if path == ":memory:" or path == "": url = ":memory:" - + elif os.path.exists(url): + pass elif not os.path.exists(url): response = requests.get(url) response.raise_for_status() # Check that the request was successful diff --git a/src/vanna/duckdb/duckdb_vector.py b/src/vanna/duckdb/duckdb_vector.py index 13021727..ebce0bb8 100644 --- a/src/vanna/duckdb/duckdb_vector.py +++ b/src/vanna/duckdb/duckdb_vector.py @@ -17,9 +17,7 @@ def __init__(self, config=None): if config is None: config = {} - path = config.get("path", ".") - database_name = config.get("database_name", "vanna.duckdb") - self.database_path = os.path.join(path, database_name) + self.database = config.get("database", ".") self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) self.n_results_documentation = config.get( "n_results_documentation", config.get("n_results", 10) @@ -29,7 +27,7 @@ def __init__(self, config=None): self.model_name = self.config.get("model_name", "BAAI/bge-small-en-v1.5") self.embedding_model = TextEmbedding(model_name=self.model_name) - conn = duckdb.connect(database=self.database_path) + conn = duckdb.connect(database=self.database) conn.execute( """ CREATE TABLE IF NOT EXISTS embeddings ( @@ -47,7 +45,7 @@ def generate_embedding(self, data: str) -> List[float]: return embeddings[0] def write_embedding_to_table(self, text, id, embedding): - con = duckdb.connect(database=self.database_path) + con = duckdb.connect(database=self.database) embedding_array = np.array(embedding, dtype=np.float32).tolist() con.execute( "INSERT INTO embeddings (id, text, model, vec) VALUES (?, ?, ?, ?)", @@ -82,7 +80,7 @@ def add_documentation(self, documentation: str) -> str: return id def get_training_data(self) -> pd.DataFrame: - con = duckdb.connect(database=self.database_path) + con = duckdb.connect(database=self.database) sql_data = con.execute("SELECT * FROM embeddings").fetchdf() con.close() @@ -125,7 +123,7 @@ def get_training_data(self) -> pd.DataFrame: return df def remove_training_data(self, id: str) -> bool: - con = duckdb.connect(database=self.database_path) + con = duckdb.connect(database=self.database) con.execute("DELETE FROM embeddings WHERE id = ?", [id]) con.close() return True @@ -135,7 +133,7 @@ def remove_collection(self, collection_name: str) -> bool: collection_name, None ) if suffix: - con = duckdb.connect(database=self.database_path) + con = duckdb.connect(database=self.database) con.execute("DELETE FROM embeddings WHERE id LIKE ?", ["%" + suffix]) con.close() return True @@ -145,7 +143,7 @@ def query_similar_embeddings(self, query_text: str, top_n: int) -> pd.DataFrame: query_embedding = self.generate_embedding(query_text) query_embedding_array = np.array(query_embedding, dtype=np.float32).tolist() - con = duckdb.connect(database=self.database_path) + con = duckdb.connect(database=self.database) results = con.execute( """ SELECT text, array_cosine_similarity(vec, ?::FLOAT[384]) AS similarity_score diff --git a/src/vanna/sqlite/sqlite_vector.py b/src/vanna/sqlite/sqlite_vector.py index 03e94f3e..eeff87cf 100644 --- a/src/vanna/sqlite/sqlite_vector.py +++ b/src/vanna/sqlite/sqlite_vector.py @@ -1,5 +1,4 @@ import json -import os import sqlite3 from typing import List @@ -22,8 +21,8 @@ def cosine_similarity(a, b): return dot_product / (magnitude_a * magnitude_b) -def sqlite_information_schema(database_path): - conn = sqlite3.connect(database_path) +def sqlite_information_schema(database): + conn = sqlite3.connect(database=database) tables = pd.read_sql_query( "SELECT name FROM sqlite_master WHERE type='table';", conn ) @@ -63,9 +62,7 @@ def __init__(self, config=None): if config is None: config = {} - path = config.get("path", ".") - database_name = config.get("database_name", "vanna.sqlite") - self.database_path = os.path.join(path, database_name) + self.database = config.get("database") self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) self.n_results_documentation = config.get( "n_results_documentation", config.get("n_results", 10) @@ -73,8 +70,7 @@ def __init__(self, config=None): self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) self.model_name = self.config.get("model_name", "BAAI/bge-small-en-v1.5") self.embedding_model = TextEmbedding(model_name=self.model_name) - - conn = sqlite3.connect(self.database_path) + conn = sqlite3.connect(database=self.database) conn.execute( """ CREATE TABLE IF NOT EXISTS embeddings ( @@ -92,7 +88,7 @@ def generate_embedding(self, data: str) -> List[float]: return embeddings[0] def write_embedding_to_table(self, text, id, embedding): - con = sqlite3.connect(self.database_path) + con = sqlite3.connect(database=self.database) embedding_array = np.array(embedding, dtype=np.float32).tobytes() con.execute( "INSERT INTO embeddings (id, text, model, vec) VALUES (?, ?, ?, ?)", @@ -128,7 +124,7 @@ def add_documentation(self, documentation: str) -> str: return id def get_training_data(self) -> pd.DataFrame: - con = sqlite3.connect(self.database_path) + con = sqlite3.connect(database=self.database) sql_data = pd.read_sql_query("SELECT * FROM embeddings", con) con.close() @@ -171,7 +167,7 @@ def get_training_data(self) -> pd.DataFrame: return df def remove_training_data(self, id: str) -> bool: - con = sqlite3.connect(self.database_path) + con = sqlite3.connect(database=self.database) con.execute("DELETE FROM embeddings WHERE id = ?", (id,)) con.commit() con.close() @@ -182,7 +178,7 @@ def remove_collection(self, collection_name: str) -> bool: collection_name, None ) if suffix: - con = sqlite3.connect(self.database_path) + con = sqlite3.connect(database=self.database) con.execute("DELETE FROM embeddings WHERE id LIKE ?", (suffix,)) con.commit() con.close() @@ -192,7 +188,7 @@ def remove_collection(self, collection_name: str) -> bool: def query_similar_embeddings(self, query_text: str, top_n: int) -> pd.DataFrame: query_embedding = self.generate_embedding(query_text) - con = sqlite3.connect(self.database_path) + con = sqlite3.connect(database=self.database) embeddings_data = pd.read_sql_query("SELECT id, text, vec FROM embeddings", con) con.close() diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 703bb4b2..ef0105d9 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -185,12 +185,10 @@ def test_vn_duckdb(): import duckdb with TemporaryDirectory() as temp_dir: + database_path = f"{temp_dir}/vanna.duckdb" vn_duckdb = MyVannaDuckDb( - config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "path": temp_dir} + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "database": database_path} ) - database_path = os.path.join(temp_dir, "vanna.duckdb") - conn = duckdb.connect(database=database_path) - conn.close() vn_duckdb.connect_to_duckdb(database_path) _ = vn_duckdb.get_training_data() @@ -242,7 +240,7 @@ def test_vn_duckdb(): with NamedTemporaryFile( suffix=".sqlite", delete=False ) as temp_sqlite_file, TemporaryDirectory() as temp_duckdb_dir: - database_path = os.path.join(temp_duckdb_dir, "vanna.duckdb") + database_path = f"{temp_duckdb_dir}/vanna.duckdb" response = requests.get("https://vanna.ai/Chinook.sqlite") response.raise_for_status() temp_sqlite_file.write(response.content) @@ -283,7 +281,7 @@ def test_vn_duckdb(): config={ "api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", - "path": temp_duckdb_dir, + "database": database_path, } ) vn_duckdb.connect_to_duckdb(conn=duckdb_conn) @@ -325,11 +323,11 @@ def test_vn_sqlite(): import sqlite3 with TemporaryDirectory() as temp_dir: + database_path = f"{temp_dir}/vanna.sqlite" vn_sqlite = MyVannaSqlite( - config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "path": temp_dir} + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "database": database_path} ) - database_path = os.path.join(temp_dir, "vanna.sqlite") - conn = sqlite3.connect(database_path) + conn = sqlite3.connect(database=database_path) employee_ddl = """ CREATE TABLE employee ( @@ -375,26 +373,28 @@ def test_vn_sqlite(): del vn_sqlite ############################################################# + with NamedTemporaryFile( + suffix=".sqlite", delete=False + ) as temp_sqlite_file: + vn_sqlite = MyVannaSqlite( + config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "database": temp_sqlite_file.name} + ) + vn_sqlite.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") + existing_training_data = vn_sqlite.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn_sqlite.remove_training_data(training_data["id"]) + + df_ddl = vn_sqlite.run_sql( + "SELECT type, sql FROM sqlite_master WHERE sql is not null" + ) - vn_sqlite = MyVannaSqlite( - config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo"} - ) - vn_sqlite.connect_to_sqlite("https://vanna.ai/Chinook.sqlite") - existing_training_data = vn_sqlite.get_training_data() - if len(existing_training_data) > 0: - for _, training_data in existing_training_data.iterrows(): - vn_sqlite.remove_training_data(training_data["id"]) - - df_ddl = vn_sqlite.run_sql( - "SELECT type, sql FROM sqlite_master WHERE sql is not null" - ) - - for ddl in df_ddl["sql"].to_list(): - vn_sqlite.train(ddl=ddl) + for ddl in df_ddl["sql"].to_list(): + vn_sqlite.train(ddl=ddl) - sql = vn_sqlite.generate_sql("What are the top 7 customers by sales?") - df = vn_sqlite.run_sql(sql) - assert len(df) == 7 + sql = vn_sqlite.generate_sql("What are the top 7 customers by sales?") + df = vn_sqlite.run_sql(sql) + assert len(df) == 7 from vanna.milvus import Milvus_VectorStore From 93a82097d47aa325781a20195dc87613099e114a Mon Sep 17 00:00:00 2001 From: adam ulring Date: Sun, 28 Jul 2024 13:14:00 -0500 Subject: [PATCH 3/5] make float size configurable so different embedding sizes work. --- src/vanna/duckdb/duckdb_vector.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/vanna/duckdb/duckdb_vector.py b/src/vanna/duckdb/duckdb_vector.py index ebce0bb8..651ad9bf 100644 --- a/src/vanna/duckdb/duckdb_vector.py +++ b/src/vanna/duckdb/duckdb_vector.py @@ -26,15 +26,18 @@ def __init__(self, config=None): self.model_name = self.config.get("model_name", "BAAI/bge-small-en-v1.5") self.embedding_model = TextEmbedding(model_name=self.model_name) + self.embedding_size = self.config.get( + "embedding_size", 384 + ) # default is size of BAAI/bge-small-en-v1.5 conn = duckdb.connect(database=self.database) conn.execute( - """ + f""" CREATE TABLE IF NOT EXISTS embeddings ( id VARCHAR, text VARCHAR, model VARCHAR, - vec FLOAT[384] + vec FLOAT[{self.embedding_size}] ); """ ) From eeeb9e74f57f192c470ba225ecea401ac5daa85d Mon Sep 17 00:00:00 2001 From: adam ulring Date: Sun, 28 Jul 2024 14:13:05 -0500 Subject: [PATCH 4/5] fix scoping of conn creation in duckdb + sqlite connection methods. If passed... do nothing. --- src/vanna/base/base.py | 6 +- tester_scripts.py | 130 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 4 deletions(-) create mode 100644 tester_scripts.py diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index f47e1c34..d9f40e40 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -835,8 +835,7 @@ def connect_to_sqlite( f"Invalid connection settings. Pass a valid url or sqlite conn." ) - # Connect to the database - conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs) + conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs) def run_sql_sqlite(sql: str): return pd.read_sql_query(sql, conn) @@ -1328,8 +1327,7 @@ def connect_to_duckdb( with open(path, "wb") as f: f.write(response.content) - # Connect to the database - conn = duckdb.connect(path, **kwargs) + conn = duckdb.connect(path, **kwargs) if init_sql: conn.query(init_sql) diff --git a/tester_scripts.py b/tester_scripts.py new file mode 100644 index 00000000..119bb226 --- /dev/null +++ b/tester_scripts.py @@ -0,0 +1,130 @@ +import os +import sqlite3 +from tempfile import TemporaryDirectory + +import duckdb +from vanna.openai import OpenAI_Chat +from vanna.sqlite import SQLite_VectorStore +from vanna.sqlite.sqlite_vector import sqlite_information_schema +from vanna.duckdb import DuckDB_VectorStore + + +class MyVanna(SQLite_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + SQLite_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + + +with TemporaryDirectory() as temp_dir: + database_path = os.path.join(temp_dir, "vanna.sqlite") + vn = MyVanna( + config={ + "api_key": os.environ["OPENAI_API_KEY"], + "model": "gpt-4-turbo", + "database": database_path, + } + ) + conn = sqlite3.connect(database_path) + employee_ddl = """ + CREATE TABLE employee ( + employee_id INTEGER, + name TEXT, + occupation TEXT + ); + """ + conn.execute(employee_ddl) + conn.execute(""" + INSERT INTO employee VALUES + (1, 'Alice Johnson', 'Software Engineer'), + (2, 'Bob Smith', 'Data Scientist'), + (3, 'Charlie Brown', 'Product Manager'), + (4, 'Diana Prince', 'UX Designer'), + (5, 'Ethan Hunt', 'DevOps Engineer'); + """) + results = conn.execute("SELECT * FROM employee").fetchall() + for row in results: + print(row) + conn.close() + print(f"Temporary SQLite file created at: {database_path}") + df_information_schema = sqlite_information_schema(database_path) + print(df_information_schema) + plan = vn.get_training_plan_generic(df_information_schema) + print(plan) + vn.train(plan=plan) + vn.train(ddl=employee_ddl) + training_data = vn.get_training_data() + print(training_data) + similar_query = vn.query_similar_embeddings("employee id", 3) + print(similar_query) + vn.ask(question="which employee is software engineer?") + sql = vn.generate_sql( + question="write a query to get all software engineers from the employees table", + allow_llm_to_see_data=True, + ) + print(sql) + +################################ + + +class MyVannaDuck(DuckDB_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + DuckDB_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + + +with TemporaryDirectory() as temp_dir: + # Define the path for the DuckDB file within the temporary directory + database_path = os.path.join(temp_dir, "vanna.duckdb") + vn_duck = MyVannaDuck( + config={ + "api_key": os.environ["OPENAI_API_KEY"], + "model": "gpt-4-turbo", + "database": database_path, + } + ) + # Connect to the DuckDB database file + conn = duckdb.connect(database=database_path) + # Create the employee table + employee_ddl = """ + CREATE TABLE employee ( + employee_id INTEGER, + name VARCHAR, + occupation VARCHAR + ); + """ + conn.execute(employee_ddl) + conn.execute(""" + INSERT INTO employee VALUES + (1, 'Alice Johnson', 'Software Engineer'), + (2, 'Bob Smith', 'Data Scientist'), + (3, 'Charlie Brown', 'Product Manager'), + (4, 'Diana Prince', 'UX Designer'), + (5, 'Ethan Hunt', 'DevOps Engineer'); + """) + conn.commit() + results = conn.execute("SELECT * FROM employee").fetchall() + for row in results: + print(row) + # Close the connection + conn.close() + print(f"Temporary DuckDB file created at: {database_path}") + vn_duck.connect_to_duckdb(database_path) + df_information_schema = vn_duck.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS") + print(df_information_schema) + plan = vn_duck.get_training_plan_generic(df_information_schema) + print(plan) + vn_duck.train(plan=plan) + vn_duck.train(ddl=employee_ddl) + training_data = vn_duck.get_training_data() + print(training_data) + similar_query = vn_duck.query_similar_embeddings("employee id", 3) + print("similar query: ", similar_query) + # vn_duck.ask(question="which employee is software engineer?") + sql = vn_duck.generate_sql( + question="write a query to get all software engineers from the employee table", + allow_llm_to_see_data=True, + ) + print(sql) + df = vn_duck.run_sql(sql) + print(df.name[0] == "Alice Johnson") + print(df.name[0]) From 88a62d5e7e2781174736d94a666f1655247bf764 Mon Sep 17 00:00:00 2001 From: adam ulring Date: Sun, 28 Jul 2024 14:15:27 -0500 Subject: [PATCH 5/5] remove unintended test file addition. --- tester_scripts.py | 130 ---------------------------------------------- 1 file changed, 130 deletions(-) delete mode 100644 tester_scripts.py diff --git a/tester_scripts.py b/tester_scripts.py deleted file mode 100644 index 119bb226..00000000 --- a/tester_scripts.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -import sqlite3 -from tempfile import TemporaryDirectory - -import duckdb -from vanna.openai import OpenAI_Chat -from vanna.sqlite import SQLite_VectorStore -from vanna.sqlite.sqlite_vector import sqlite_information_schema -from vanna.duckdb import DuckDB_VectorStore - - -class MyVanna(SQLite_VectorStore, OpenAI_Chat): - def __init__(self, config=None): - SQLite_VectorStore.__init__(self, config=config) - OpenAI_Chat.__init__(self, config=config) - - -with TemporaryDirectory() as temp_dir: - database_path = os.path.join(temp_dir, "vanna.sqlite") - vn = MyVanna( - config={ - "api_key": os.environ["OPENAI_API_KEY"], - "model": "gpt-4-turbo", - "database": database_path, - } - ) - conn = sqlite3.connect(database_path) - employee_ddl = """ - CREATE TABLE employee ( - employee_id INTEGER, - name TEXT, - occupation TEXT - ); - """ - conn.execute(employee_ddl) - conn.execute(""" - INSERT INTO employee VALUES - (1, 'Alice Johnson', 'Software Engineer'), - (2, 'Bob Smith', 'Data Scientist'), - (3, 'Charlie Brown', 'Product Manager'), - (4, 'Diana Prince', 'UX Designer'), - (5, 'Ethan Hunt', 'DevOps Engineer'); - """) - results = conn.execute("SELECT * FROM employee").fetchall() - for row in results: - print(row) - conn.close() - print(f"Temporary SQLite file created at: {database_path}") - df_information_schema = sqlite_information_schema(database_path) - print(df_information_schema) - plan = vn.get_training_plan_generic(df_information_schema) - print(plan) - vn.train(plan=plan) - vn.train(ddl=employee_ddl) - training_data = vn.get_training_data() - print(training_data) - similar_query = vn.query_similar_embeddings("employee id", 3) - print(similar_query) - vn.ask(question="which employee is software engineer?") - sql = vn.generate_sql( - question="write a query to get all software engineers from the employees table", - allow_llm_to_see_data=True, - ) - print(sql) - -################################ - - -class MyVannaDuck(DuckDB_VectorStore, OpenAI_Chat): - def __init__(self, config=None): - DuckDB_VectorStore.__init__(self, config=config) - OpenAI_Chat.__init__(self, config=config) - - -with TemporaryDirectory() as temp_dir: - # Define the path for the DuckDB file within the temporary directory - database_path = os.path.join(temp_dir, "vanna.duckdb") - vn_duck = MyVannaDuck( - config={ - "api_key": os.environ["OPENAI_API_KEY"], - "model": "gpt-4-turbo", - "database": database_path, - } - ) - # Connect to the DuckDB database file - conn = duckdb.connect(database=database_path) - # Create the employee table - employee_ddl = """ - CREATE TABLE employee ( - employee_id INTEGER, - name VARCHAR, - occupation VARCHAR - ); - """ - conn.execute(employee_ddl) - conn.execute(""" - INSERT INTO employee VALUES - (1, 'Alice Johnson', 'Software Engineer'), - (2, 'Bob Smith', 'Data Scientist'), - (3, 'Charlie Brown', 'Product Manager'), - (4, 'Diana Prince', 'UX Designer'), - (5, 'Ethan Hunt', 'DevOps Engineer'); - """) - conn.commit() - results = conn.execute("SELECT * FROM employee").fetchall() - for row in results: - print(row) - # Close the connection - conn.close() - print(f"Temporary DuckDB file created at: {database_path}") - vn_duck.connect_to_duckdb(database_path) - df_information_schema = vn_duck.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS") - print(df_information_schema) - plan = vn_duck.get_training_plan_generic(df_information_schema) - print(plan) - vn_duck.train(plan=plan) - vn_duck.train(ddl=employee_ddl) - training_data = vn_duck.get_training_data() - print(training_data) - similar_query = vn_duck.query_similar_embeddings("employee id", 3) - print("similar query: ", similar_query) - # vn_duck.ask(question="which employee is software engineer?") - sql = vn_duck.generate_sql( - question="write a query to get all software engineers from the employee table", - allow_llm_to_see_data=True, - ) - print(sql) - df = vn_duck.run_sql(sql) - print(df.name[0] == "Alice Johnson") - print(df.name[0])