Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【feat】opensearch supports document data update, query by table, embedding, and etc #589

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2baf59e
feat:add support ddl to be updated by engine and table
zyclove May 21, 2024
ece218b
Merge branch 'opensearch_fix' of https://github.com/zyclove/vanna int…
zyclove May 21, 2024
6f48369
feat:add get_similar_tables_metadata
zyclove May 30, 2024
6bbb14c
fix: query modify
zyclove May 31, 2024
abf84bb
feat: add TableMetadata manage
zyclove May 31, 2024
cf603f5
feat: optimize search_tables_metadata
zyclove May 31, 2024
d9bd555
fix: return size
zyclove Jun 4, 2024
6dcff9b
fix: get_full_table_name
zyclove Jun 4, 2024
205ba0a
Merge branch 'vanna-ai:main' into opensearch_fix
zyclove Jun 12, 2024
4f644e3
Merge branch 'opensearch_fix' into get_ddl_by_table_name
zyclove Jun 12, 2024
f72d916
feat: add table_name accurate query
zyclove Jun 13, 2024
23c1f09
feat: generate_sql by table_name
zyclove Jun 13, 2024
9cff51f
feat: support table_name_list accurate query
zyclove Jun 13, 2024
670d829
fix: remove blank str
zyclove Jun 14, 2024
456ce78
fix: opensearch get_similar_question_sql
zyclove Jun 23, 2024
6c61ccd
fix: opensearch wildcard_table_name = f"*{table_name}*"
zyclove Jun 23, 2024
d2ee067
Merge branch 'get_ddl_by_table_name' of https://github.com/zyclove/va…
zyclove Jun 25, 2024
e0b7e1c
feat: add min_score=0.5
zyclove Jun 25, 2024
d4537fb
Merge branch 'vanna-ai:main' into get_ddl_by_table_name
zyclove Jun 26, 2024
150e89b
fix: search_tables_metadata empty impl
zyclove Jul 2, 2024
ba15416
feat:增加embedding
zyclove Jul 17, 2024
5d17810
feat: enable knn
zyclove Jul 17, 2024
834bd33
feat: support knn query
zyclove Jul 17, 2024
7eb0105
feat: adapt dimensions
zyclove Jul 17, 2024
cf8af6e
feat:add biz_type
zyclove Jul 19, 2024
8ba0afc
feat: add biz_type
zyclove Jul 19, 2024
0aef01b
Merge branch 'vanna-ai:main' into get_ddl_by_table_name
zyclove Aug 6, 2024
8dc03f1
Merge branch 'main' into get_ddl_by_table_name
zainhoda Aug 22, 2024
3bfc268
get tests passing
zainhoda Aug 22, 2024
bb8b417
Merge branch 'vanna-ai:main' into get_ddl_by_table_name
zyclove Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 91 additions & 7 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
from ..types import TrainingPlan, TrainingPlanItem, TableMetadata
from ..utils import validate_config_path


Expand All @@ -90,7 +90,7 @@ def _response_language(self) -> str:

return f"Respond in the {self.language} language."

def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
def generate_sql(self, question: str, table_name_list: List[str] = None, allow_llm_to_see_data=False, **kwargs) -> str:
"""
Example:
```python
Expand All @@ -112,6 +112,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->

Args:
question (str): The question to generate a SQL query for.
table_name_list (List[str], optional): A list of table names to use in the SQL query. Defaults to None.
allow_llm_to_see_data (bool): Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).

Returns:
Expand All @@ -122,7 +123,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) ->
else:
initial_prompt = None
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
ddl_list = self.get_related_ddl(question=question, table_name_list=table_name_list, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
Expand Down Expand Up @@ -210,6 +211,54 @@ def extract_sql(self, llm_response: str) -> str:

return llm_response

def extract_table_metadata(ddl: str) -> TableMetadata:
"""
Example:
```python
vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)")
```

Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata.
Override this function if your DDL statements need custom extraction logic.

Args:
ddl (str): The DDL statement.

Returns:
TableMetadata: The extracted table metadata.
"""
pattern_with_catalog_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_table = re.compile(
r'CREATE TABLE\s+(\w+)\s*\(',
re.IGNORECASE
)

match_with_catalog_schema = pattern_with_catalog_schema.search(ddl)
match_with_schema = pattern_with_schema.search(ddl)
match_with_table = pattern_with_table.search(ddl)

if match_with_catalog_schema:
catalog = match_with_catalog_schema.group(1)
schema = match_with_catalog_schema.group(2)
table_name = match_with_catalog_schema.group(3)
return TableMetadata(catalog, schema, table_name)
elif match_with_schema:
schema = match_with_schema.group(1)
table_name = match_with_schema.group(2)
return TableMetadata(None, schema, table_name)
elif match_with_table:
table_name = match_with_table.group(1)
return TableMetadata(None, None, table_name)
else:
return TableMetadata()

def is_sql_valid(self, sql: str) -> bool:
"""
Example:
Expand Down Expand Up @@ -383,18 +432,46 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:
pass

@abstractmethod
def get_related_ddl(self, question: str, **kwargs) -> list:
def get_related_ddl(self, question: str, table_name_list: List[str] = None, **kwargs) -> list:
"""
This method is used to get related DDL statements to a question.

Args:
question (str): The question to get related DDL statements for.
table_name_list (list): A list of table names to get related DDL statements for.

Returns:
list: A list of related DDL statements.
"""
pass

@abstractmethod
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
biz_type: str = None,
size: int = 10,
**kwargs) -> list:
"""
This method is used to get similar tables metadata.

Args:
engine (str): The database engine.
catalog (str): The catalog.
schema (str): The schema.
table_name (str): The table name.
ddl (str): The DDL statement.
biz_type (str): The business type.
size (int): The number of tables to return.

Returns:
list: A list of tables metadata.
"""
pass

@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
Expand Down Expand Up @@ -423,12 +500,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, biz_type: str = None, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.

Args:
ddl (str): The DDL statement to add.
engine (str): The database engine that the DDL statement applies to.

Returns:
str: The ID of the training data that was added.
Expand Down Expand Up @@ -1778,6 +1856,8 @@ def train(
question: str = None,
sql: str = None,
ddl: str = None,
engine: str = None,
biz_type: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
Expand All @@ -1798,8 +1878,12 @@ def train(
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
engine (str): The database engine.
biz_type (str): The business type.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
Returns:
str: The training pl
"""

if question and not sql:
Expand All @@ -1817,12 +1901,12 @@ def train(

if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
return self.add_ddl(ddl=ddl, engine=engine, biz_type=biz_type)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
self.add_ddl(ddl=item.item_value, engine=engine, biz_type=biz_type)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
Expand Down
15 changes: 13 additions & 2 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:

return id

def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, biz_type: str = None, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl"
self.ddl_collection.add(
documents=ddl,
Expand Down Expand Up @@ -240,7 +240,18 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:
)
)

def get_related_ddl(self, question: str, **kwargs) -> list:
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
biz_type: str = None,
size: int = 10,
**kwargs) -> list:
return []

def get_related_ddl(self, question: str, table_name_list: List[str] = None, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.ddl_collection.query(
query_texts=[question],
Expand Down
3 changes: 2 additions & 1 deletion src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,11 +816,12 @@ def add_training_data(user: any):
question = flask.request.json.get("question")
sql = flask.request.json.get("sql")
ddl = flask.request.json.get("ddl")
biz_type = flask.request.json.get("biz_type")
documentation = flask.request.json.get("documentation")

try:
id = vn.train(
question=question, sql=sql, ddl=ddl, documentation=documentation
question=question, sql=sql, ddl=ddl, biz_type=biz_type, documentation=documentation
)

return jsonify({"id": id})
Expand Down
16 changes: 14 additions & 2 deletions src/vanna/marqo/marqo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from typing import List

import marqo
import pandas as pd
Expand Down Expand Up @@ -49,7 +50,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:

return id

def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, biz_type: str = None, **kwargs) -> str:
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"ddl": ddl,
Expand Down Expand Up @@ -152,12 +153,23 @@ def _extract_documents(data) -> list:
# Return an empty list if 'hits' is not found or not a list
return []

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
biz_type: str = None,
size: int = 10,
**kwargs) -> list:
return []

def get_similar_question_sql(self, question: str, **kwargs) -> list:
return Marqo_VectorStore._extract_documents(
self.mq.index("vanna-sql").search(question)
)

def get_related_ddl(self, question: str, **kwargs) -> list:
def get_related_ddl(self, question: str, table_name_list: List[str] = None, **kwargs) -> list:
return Marqo_VectorStore._extract_documents(
self.mq.index("vanna-ddl").search(question)
)
Expand Down
13 changes: 12 additions & 1 deletion src/vanna/milvus/milvus_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
)
return _id

def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, biz_type: str = None, **kwargs) -> str:
if len(ddl) == 0:
raise Exception("ddl can not be null")
_id = str(uuid.uuid4()) + "-ddl"
Expand Down Expand Up @@ -225,6 +225,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.concat([df, df_doc])
return df

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
biz_type: str = None,
size: int = 10,
**kwargs) -> list:
return []

def get_similar_question_sql(self, question: str, **kwargs) -> list:
search_params = {
"metric_type": "L2",
Expand Down
17 changes: 15 additions & 2 deletions src/vanna/mock/vectordb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pandas as pd

from ..base import VannaBase
Expand All @@ -11,7 +13,7 @@ def _get_id(self, value: str, **kwargs) -> str:
# Hash the value and return the ID
return str(hash(value))

def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, biz_type: str = None, **kwargs) -> str:
return self._get_id(ddl)

def add_documentation(self, doc: str, **kwargs) -> str:
Expand All @@ -20,12 +22,23 @@ def add_documentation(self, doc: str, **kwargs) -> str:
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return self._get_id(question)

def get_related_ddl(self, question: str, **kwargs) -> list:
def get_related_ddl(self, question: str, table_name_list: List[str] = None, **kwargs) -> list:
return []

def get_related_documentation(self, question: str, **kwargs) -> list:
return []

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
biz_type: str = None,
size: int = 10,
**kwargs) -> list:
return []

def get_similar_question_sql(self, question: str, **kwargs) -> list:
return []

Expand Down
Loading