Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sqlite duckdb vector support #555

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

aulring
Copy link

@aulring aulring commented Jul 14, 2024

Added support for Duckdb and SQLite as vector stores.

Summary of changes:

  • new interface implementations for duckdb and sqlite
  • added corresponding tests
  • added option to run single tests from the test suite for easier test suite/code contribution
  • updated connect_to_sqlite and connect_to_duckdb with option to pass conn
  • updated precommit to ruff

Sorry about the large number of formatting changes in base.py. I ran precommit and it changed like 40 files, so I reverted formatting changes on every file I didn't touch. I ran into a few issues with the current pre-commit file, so I swapped black and isort out with ruff as a suggestion.

In addition to the test suite, here is a test I ran from a fresh install:

import os
import sqlite3
from tempfile import TemporaryDirectory

import duckdb
from vanna.openai import OpenAI_Chat
from vanna.sqlite import SQLite_VectorStore
from vanna.sqlite.sqlite_vector import sqlite_information_schema
from vanna.duckdb import DuckDB_VectorStore


class MyVanna(SQLite_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        SQLite_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)


with TemporaryDirectory() as temp_dir:
    database_path = os.path.join(temp_dir, "vanna.sqlite")
    vn = MyVanna(
        config={
            "api_key": os.environ["OPENAI_API_KEY"],
            "model": "gpt-4-turbo",
            "database": database_path,
        }
    )
    conn = sqlite3.connect(database_path)
    employee_ddl = """
    CREATE TABLE employee (
        employee_id INTEGER,
        name TEXT,
        occupation TEXT
    );
    """
    conn.execute(employee_ddl)
    conn.execute("""
    INSERT INTO employee VALUES
    (1, 'Alice Johnson', 'Software Engineer'),
    (2, 'Bob Smith', 'Data Scientist'),
    (3, 'Charlie Brown', 'Product Manager'),
    (4, 'Diana Prince', 'UX Designer'),
    (5, 'Ethan Hunt', 'DevOps Engineer');
    """)
    results = conn.execute("SELECT * FROM employee").fetchall()
    for row in results:
        print(row)
    conn.close()
    print(f"Temporary SQLite file created at: {database_path}")
    df_information_schema = sqlite_information_schema(database_path)
    print(df_information_schema)
    plan = vn.get_training_plan_generic(df_information_schema)
    print(plan)
    vn.train(plan=plan)
    vn.train(ddl=employee_ddl)
    training_data = vn.get_training_data()
    print(training_data)
    similar_query = vn.query_similar_embeddings("employee id", 3)
    print(similar_query)
    vn.ask(question="which employee is software engineer?")
    sql = vn.generate_sql(
        question="write a query to get all software engineers from the employees table",
        allow_llm_to_see_data=True,
    )
    print(sql)

################################


class MyVannaDuck(DuckDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        DuckDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)


with TemporaryDirectory() as temp_dir:
    # Define the path for the DuckDB file within the temporary directory
    database_path = os.path.join(temp_dir, "vanna.duckdb")
    vn_duck = MyVannaDuck(
        config={
            "api_key": os.environ["OPENAI_API_KEY"],
            "model": "gpt-4-turbo",
            "database": database_path,
        }
    )
    # Connect to the DuckDB database file
    conn = duckdb.connect(database=database_path)
    # Create the employee table
    employee_ddl = """
    CREATE TABLE employee (
        employee_id INTEGER,
        name VARCHAR,
        occupation VARCHAR
    );
    """
    conn.execute(employee_ddl)
    conn.execute("""
    INSERT INTO employee VALUES
    (1, 'Alice Johnson', 'Software Engineer'),
    (2, 'Bob Smith', 'Data Scientist'),
    (3, 'Charlie Brown', 'Product Manager'),
    (4, 'Diana Prince', 'UX Designer'),
    (5, 'Ethan Hunt', 'DevOps Engineer');
    """)
    conn.commit()
    results = conn.execute("SELECT * FROM employee").fetchall()
    for row in results:
        print(row)
    # Close the connection
    conn.close()
    print(f"Temporary DuckDB file created at: {database_path}")
    vn_duck.connect_to_duckdb(database_path)
    df_information_schema = vn_duck.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")
    print(df_information_schema)
    plan = vn_duck.get_training_plan_generic(df_information_schema)
    print(plan)
    vn_duck.train(plan=plan)
    vn_duck.train(ddl=employee_ddl)
    training_data = vn_duck.get_training_data()
    print(training_data)
    similar_query = vn_duck.query_similar_embeddings("employee id", 3)
    print("similar query: ", similar_query)
    # vn_duck.ask(question="which employee is software engineer?")
    sql = vn_duck.generate_sql(
        question="write a query to get all software engineers from the employee table",
        allow_llm_to_see_data=True,
    )
    print(sql)
    df = vn_duck.run_sql(sql)
    print(df.name[0] == "Alice Johnson")
    print(df.name[0])

adam ulring added 2 commits July 14, 2024 14:28
    - sqlite vector support
    - duckdb vector support
    - added corresponding tests to vanna_test
    - updated import test
    - updated precommit config to use ruff
    - updated pyproject toml with new duckdb dependency
@aulring
Copy link
Author

aulring commented Jul 16, 2024

Need to make FLOAT size configurable in table DDL.

@zainhoda
Copy link
Contributor

I really like this PR! Would you be able to resolve the conflicts in base.py and then I can merge?

@aulring
Copy link
Author

aulring commented Jul 28, 2024

I really like this PR! Would you be able to resolve the conflicts in base.py and then I can merge?

Sure thing! I made a commit to make the embedding size configurable in duckdb and resolved the merge conflicts in base.py. There were some also formatting issues in the method connect_to_presto I fixed.

@zainhoda
Copy link
Contributor

zainhoda commented Aug 6, 2024

Thanks @aulring -- this is very close. The test is failing for me though:

tests/test_vanna.py::test_vn_duckdb FAILED                                                                                                                                                                                       [ 57%]

=============================================================================================================== FAILURES ===============================================================================================================
____________________________________________________________________________________________________________ test_vn_duckdb ____________________________________________________________________________________________________________

self = RangeIndex(start=0, stop=0, step=1), key = 0

    @doc(Index.get_loc)
    def get_loc(self, key) -> int:
        if is_integer(key) or (is_float(key) and key.is_integer()):
            new_key = int(key)
            try:
>               return self._range.index(new_key)
E               ValueError: 0 is not in range

.tox/mac/lib/python3.12/site-packages/pandas/core/indexes/range.py:413: ValueError

The above exception was the direct cause of the following exception:

    def test_vn_duckdb():
        import duckdb
    
        with TemporaryDirectory() as temp_dir:
            database_path = f"{temp_dir}/vanna.duckdb"
            vn_duckdb = MyVannaDuckDb(
                config={"api_key": OPENAI_API_KEY, "model": "gpt-4-turbo", "database": database_path}
            )
            vn_duckdb.connect_to_duckdb(database_path)
            _ = vn_duckdb.get_training_data()
    
            conn = duckdb.connect(database=database_path)
    
            employee_ddl = """
            CREATE TABLE employee (
                employee_id INTEGER,
                name VARCHAR,
                occupation VARCHAR
            );
            """
    
            conn.execute(employee_ddl)
    
            conn.execute(
                """
            INSERT INTO employee VALUES
            (1, 'Alice Johnson', 'Software Engineer'),
            (2, 'Bob Smith', 'Data Scientist'),
            (3, 'Charlie Brown', 'Product Manager'),
            (4, 'Diana Prince', 'UX Designer'),
            (5, 'Ethan Hunt', 'DevOps Engineer');
            """
            )
            conn.commit()
            df_information_schema = vn_duckdb.run_sql(
                "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
            )
            plan = vn_duckdb.get_training_plan_generic(df_information_schema)
            vn_duckdb.train(plan=plan)
    
            vn_duckdb.train(ddl=employee_ddl)
            training_data = vn_duckdb.get_training_data()
            assert not training_data.empty
    
            similar_query = vn_duckdb.query_similar_embeddings("employee id", 3)
            assert not similar_query.empty
    
            sql = vn_duckdb.generate_sql(
                question="write a query to get all software engineers from the employees table",
                allow_llm_to_see_data=True,
            )
            df = vn_duckdb.run_sql(sql)
>           assert df.name[0] == "Alice Johnson"

tests/test_vanna.py:236: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.tox/mac/lib/python3.12/site-packages/pandas/core/series.py:1121: in __getitem__
    return self._get_value(key)
.tox/mac/lib/python3.12/site-packages/pandas/core/series.py:1237: in _get_value
    loc = self.index.get_loc(label)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = RangeIndex(start=0, stop=0, step=1), key = 0

    @doc(Index.get_loc)
    def get_loc(self, key) -> int:
        if is_integer(key) or (is_float(key) and key.is_integer()):
            new_key = int(key)
            try:
                return self._range.index(new_key)
            except ValueError as err:
>               raise KeyError(key) from err
E               KeyError: 0

.tox/mac/lib/python3.12/site-packages/pandas/core/indexes/range.py:415: KeyError
--------------------------------------------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------------------------------------------
Adding ddl: 
        CREATE TABLE employee (
            employee_id INTEGER,
            name VARCHAR,
            occupation VARCHAR
        );
        
SQL Prompt: [{'role': 'system', 'content': "You are a DuckDB SQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\n        CREATE TABLE employee (\n            employee_id INTEGER,\n            name VARCHAR,\n            occupation VARCHAR\n        );\n        \n\nThe following columns are in the employee table in the vanna database:\n\n|    | table_catalog   | table_schema   | table_name   | column_name   | data_type   | COLUMN_COMMENT   |\n|---:|:----------------|:---------------|:-------------|:--------------|:------------|:-----------------|\n|  4 | vanna           | main           | employee     | employee_id   | INTEGER     |                  |\n|  5 | vanna           | main           | employee     | name          | VARCHAR     |                  |\n|  6 | vanna           | main           | employee     | occupation    | VARCHAR     |                  |\n\nThe following columns are in the embeddings table in the vanna database:\n\n|    | table_catalog   | table_schema   | table_name   | column_name   | data_type   | COLUMN_COMMENT   |\n|---:|:----------------|:---------------|:-------------|:--------------|:------------|:-----------------|\n|  0 | vanna           | main           | embeddings   | id            | VARCHAR     |                  |\n|  1 | vanna           | main           | embeddings   | text          | VARCHAR     |                  |\n|  2 | vanna           | main           | embeddings   | model         | VARCHAR     |                  |\n|  3 | vanna           | main           | embeddings   | vec           | FLOAT[384]  |                  |\n\n\n===Additional Context \n\n\n        CREATE TABLE employee (\n            employee_id INTEGER,\n            name VARCHAR,\n            occupation VARCHAR\n        );\n        \n\nThe following columns are in the employee table in the vanna database:\n\n|    | table_catalog   | table_schema   | table_name   | column_name   | data_type   | COLUMN_COMMENT   |\n|---:|:----------------|:---------------|:-------------|:--------------|:------------|:-----------------|\n|  4 | vanna           | main           | employee     | employee_id   | INTEGER     |                  |\n|  5 | vanna           | main           | employee     | name          | VARCHAR     |                  |\n|  6 | vanna           | main           | employee     | occupation    | VARCHAR     |                  |\n\nThe following columns are in the embeddings table in the vanna database:\n\n|    | table_catalog   | table_schema   | table_name   | column_name   | data_type   | COLUMN_COMMENT   |\n|---:|:----------------|:---------------|:-------------|:--------------|:------------|:-----------------|\n|  0 | vanna           | main           | embeddings   | id            | VARCHAR     |                  |\n|  1 | vanna           | main           | embeddings   | text          | VARCHAR     |                  |\n|  2 | vanna           | main           | embeddings   | model         | VARCHAR     |                  |\n|  3 | vanna           | main           | embeddings   | vec           | FLOAT[384]  |                  |\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n3. If the provided context is insufficient, please explain why it can't be generated. \n4. Please use the most relevant table(s). \n5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n6. Ensure that the output SQL is DuckDB SQL-compliant and executable, and free of syntax errors. \n"}, {'role': 'user', 'content': 'write a query to get all software engineers from the employees table'}]
Using model gpt-4-turbo for 993.0 tokens (approx)
LLM Response: ```sql
SELECT * FROM employee WHERE occupation = 'software engineer';

Extracted SQL: SELECT * FROM employee WHERE occupation = 'software engineer';
--------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants