Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DH-5688/fixing the observations code blocks
Browse files Browse the repository at this point in the history
MohammadrezaPourreza committed Apr 4, 2024
1 parent 39bef01 commit 19786a0
Showing 2 changed files with 12 additions and 8 deletions.
7 changes: 5 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
@@ -250,7 +250,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
Use this tool to execute the SQL query on the database, and return the results.
"""
@@ -335,7 +335,8 @@ def _run(
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content
returned_sql = response.choices[0].message.content
return f"```sql\n{returned_sql}```"

async def _arun(
self,
@@ -372,6 +373,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
@@ -385,6 +387,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
13 changes: 7 additions & 6 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
@@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
If an error occurs, rewrite the query and retry.
Use this tool to execute SQL queries.
@@ -204,8 +204,8 @@ def _run(
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
response = "Admin: All of the generated SQL queries must follow the below instructions:\n"
for instruction in self.instructions:
response += f"{instruction['instruction']}\n"
for index, instruction in enumerate(self.instructions):
response += f"{index + 1}) {instruction['instruction']}\n"
return response

async def _arun(
@@ -407,6 +407,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
@@ -420,6 +421,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
@@ -516,9 +518,8 @@ def _run(
return "Action input for the fewshot_examples_retriever tool should be an integer"
returned_output = ""
for example in self.few_shot_examples[:number_of_samples]:
returned_output += (
f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n"
)
returned_output += f"Question: {example['prompt_text']} \n"
returned_output += f"```sql\n{example['sql']}\n```\n"
if returned_output == "":
returned_output = "No previously asked Question/SQL pairs are available"
return returned_output

0 comments on commit 19786a0

Please sign in to comment.