From 817d956dba33287bb73cbb1a3216d90ecb3c1d34 Mon Sep 17 00:00:00 2001 From: Edward Date: Tue, 1 Oct 2024 10:51:52 +0100 Subject: [PATCH 1/7] pgvector collection fix --- src/vanna/pgvector/pgvector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index cf0c2a23..8cbc6a46 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -31,17 +31,17 @@ def __init__(self, config=None): from sentence_transformers import SentenceTransformer self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2") - self.sql_vectorstore = PGVector( + self.sql_collection = PGVector( embeddings=self.embedding_function, collection_name="sql", connection=self.connection_string, ) - self.ddl_vectorstore = PGVector( + self.ddl_collection = PGVector( embeddings=self.embedding_function, collection_name="ddl", connection=self.connection_string, ) - self.documentation_vectorstore = PGVector( + self.documentation_collection = PGVector( embeddings=self.embedding_function, collection_name="documentation", connection=self.connection_string, From d4f20503f45a2602ceb27511e9330ebe5d911315 Mon Sep 17 00:00:00 2001 From: Edward Date: Tue, 1 Oct 2024 10:52:17 +0100 Subject: [PATCH 2/7] pgvector async fix --- src/vanna/pgvector/pgvector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index 8cbc6a46..c1ec5f06 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -94,16 +94,16 @@ def get_collection(self, collection_name): case _: raise ValueError("Specified collection does not exist.") - async def get_similar_question_sql(self, question: str) -> list: + def get_similar_question_sql(self, question: str) -> list: documents = self.sql_collection.similarity_search(query=question, k=self.n_results) return [ast.literal_eval(document.page_content) for document in documents] - async def get_related_ddl(self, question: str, **kwargs) -> list: - documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results) + def get_related_ddl(self, question: str, **kwargs) -> list: + documents = self.ddl_collection.similarity_search(query=question, k=self.n_results) return [document.page_content for document in documents] - async def get_related_documentation(self, question: str, **kwargs) -> list: - documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results) + def get_related_documentation(self, question: str, **kwargs) -> list: + documents = self.documentation_collection.similarity_search(query=question, k=self.n_results) return [document.page_content for document in documents] def train( From 5b0e61cbc177134f6b7d27f51a7452ff3239597b Mon Sep 17 00:00:00 2001 From: Edward Date: Tue, 1 Oct 2024 11:02:11 +0100 Subject: [PATCH 3/7] pgvector remove submit_prompt --- src/vanna/pgvector/pgvector.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index c1ec5f06..d3714a17 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool: def generate_embedding(self, *args, **kwargs): pass - - def submit_prompt(self, *args, **kwargs): - pass - - def system_message(self, message: str) -> any: - return {"role": "system", "content": message} - - def user_message(self, message: str) -> any: - return {"role": "user", "content": message} - - def assistant_message(self, message: str) -> any: - return {"role": "assistant", "content": message} From 269d5542b0b23f63bb95ee03b81df7e1baea72a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 1 Oct 2024 12:33:38 +0200 Subject: [PATCH 4/7] Fix default embedding model used with PGVector --- src/vanna/pgvector/pgvector.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index d3714a17..5d196b67 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -28,8 +28,8 @@ def __init__(self, config=None): if config and "embedding_function" in config: self.embedding_function = config.get("embedding_function") else: - from sentence_transformers import SentenceTransformer - self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2") + from langchain_huggingface import HuggingFaceEmbeddings + self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") self.sql_collection = PGVector( embeddings=self.embedding_function, @@ -251,3 +251,15 @@ def remove_collection(self, collection_name: str) -> bool: def generate_embedding(self, *args, **kwargs): pass + + def submit_prompt(self, *args, **kwargs): + pass + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} From a10461705f5a164c69b6ff976e1966e719e8c66e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 1 Oct 2024 13:24:37 +0200 Subject: [PATCH 5/7] Remove redundant declaration of methods --- src/vanna/pgvector/pgvector.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index 5d196b67..3cddeb46 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool: def generate_embedding(self, *args, **kwargs): pass - - def submit_prompt(self, *args, **kwargs): - pass - - def system_message(self, message: str) -> any: - return {"role": "system", "content": message} - - def user_message(self, message: str) -> any: - return {"role": "user", "content": message} - - def assistant_message(self, message: str) -> any: - return {"role": "assistant", "content": message} From 26d647fd98bd62594a81114e9a194cf071f9095c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 1 Oct 2024 13:25:48 +0200 Subject: [PATCH 6/7] Updated pgvector test --- tests/test_pgvector.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py index 4bc1dea9..84d92091 100644 --- a/tests/test_pgvector.py +++ b/tests/test_pgvector.py @@ -3,7 +3,10 @@ from dotenv import load_dotenv from vanna.pgvector import PG_VectorStore +from vanna.openai import OpenAI_Chat + +# assume .env file placed next to file with provided env vars load_dotenv() @@ -18,7 +21,33 @@ def get_vanna_connection_string(): return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}" -def test_pgvector(): - connection_string = get_vanna_connection_string() - pgclient = PG_VectorStore(config={"connection_string": connection_string}) - assert pgclient is not None +def test_pgvector_e2e(): + # configure Vanna to use OpenAI and PGVector + class VannaCustom(PG_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + PG_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + + vn = VannaCustom(config={ + 'api_key': os.environ['OPENAI_API_KEY'], + 'model': 'gpt-3.5-turbo', + "connection_string": get_vanna_connection_string(), + }) + + # connect to SQLite database + vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + + # train Vanna on DDLs + df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + for ddl in df_ddl['sql'].to_list(): + vn.train(ddl=ddl) + assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default + + question = "What are the top 7 customers by sales?" + sql = vn.generate_sql(question) + df = vn.run_sql(sql) + assert len(df) == 7 + + # test if Vanna can generate an answer + answer = vn.ask(question) + assert answer is not None From 85586ac18b526ee9e89b842593efda8db5bf3f6c Mon Sep 17 00:00:00 2001 From: Edward Date: Tue, 1 Oct 2024 15:01:02 +0100 Subject: [PATCH 7/7] Fix for test database port --- tests/test_pgvector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py index 84d92091..04443257 100644 --- a/tests/test_pgvector.py +++ b/tests/test_pgvector.py @@ -13,7 +13,7 @@ def get_vanna_connection_string(): server = os.environ.get("PG_SERVER") driver = "psycopg" - port = 5434 + port = os.environ.get("PG_PORT", 5432) database = os.environ.get("PG_DATABASE") username = os.environ.get("PG_USERNAME") password = os.environ.get("PG_PASSWORD")