Skip to content

Commit

Permalink
DH-5688/fixing the observations code blocks (#454)
Browse files Browse the repository at this point in the history
* DH-5688/fixing the observations code blocks

* DATA-5688/fix the inline comments

* DH-5688/reformat with black

* Fixing the backticks of the observations

* Add newlines after the observations and final answer

* removing multi DDL command in create tables

* adding newline for excute sql query
  • Loading branch information
MohammadrezaPourreza authored Apr 5, 2024
1 parent 30f5226 commit a11d1f1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
4 changes: 2 additions & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def stream_agent_steps( # noqa: C901
queue.put(message.content + "\n")
elif "steps" in chunk:
for step in chunk["steps"]:
queue.put(f"**Observation:**\n `{step.observation}`\n")
queue.put(f"\n**Observation:**\n {step.observation}\n")
elif "output" in chunk:
queue.put(f'**Final Answer:**\n {chunk["output"]}')
queue.put(f'\n**Final Answer:**\n {chunk["output"]}')
if "```sql" in chunk["output"]:
response.sql = replace_unprocessable_characters(
self.remove_markdown(chunk["output"])
Expand Down
14 changes: 8 additions & 6 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ def _run(
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

Expand All @@ -250,9 +248,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
Use this tool to execute the SQL query on the database, and return the results.
Add newline after both ```sql and ``` tags.
"""
args_schema: Type[BaseModel] = SQLInput

Expand Down Expand Up @@ -335,7 +334,8 @@ def _run(
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content
returned_sql = response.choices[0].message.content
return f"```sql\n{returned_sql}```"

async def _arun(
self,
Expand Down Expand Up @@ -372,6 +372,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
Expand All @@ -385,6 +386,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down
21 changes: 10 additions & 11 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
If an error occurs, rewrite the query and retry.
Use this tool to execute SQL queries.
Add newline after both ```sql and ``` tags.
"""

@catch_exceptions()
Expand Down Expand Up @@ -204,8 +205,8 @@ def _run(
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
response = "Admin: All of the generated SQL queries must follow the below instructions:\n"
for instruction in self.instructions:
response += f"{instruction['instruction']}\n"
for index, instruction in enumerate(self.instructions):
response += f"{index + 1}) {instruction['instruction']}\n"
return response

async def _arun(
Expand Down Expand Up @@ -290,13 +291,11 @@ def _run(
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

Expand Down Expand Up @@ -404,7 +403,7 @@ def _run(
replace_unprocessable_characters(table_name)
for table_name in table_names_list
]
tables_schema = ""
tables_schema = "```sql\n"
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += table.table_schema + "\n"
Expand All @@ -420,6 +419,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down Expand Up @@ -516,9 +516,8 @@ def _run(
return "Action input for the fewshot_examples_retriever tool should be an integer"
returned_output = ""
for example in self.few_shot_examples[:number_of_samples]:
returned_output += (
f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n"
)
returned_output += f"Question: {example['prompt_text']} \n"
returned_output += f"```sql\n{example['sql']}\n```\n"
if returned_output == "":
returned_output = "No previously asked Question/SQL pairs are available"
return returned_output
Expand Down

0 comments on commit a11d1f1

Please sign in to comment.