From a11d1f1e92a5e172f1c69bc89e16919b0a66fa9e Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:37:04 -0400 Subject: [PATCH] DH-5688/fixing the observations code blocks (#454) * DH-5688/fixing the observations code blocks * DATA-5688/fix the inline comments * DH-5688/reformat with black * Fixing the backticks of the observations * Add newlines after the observations and final answer * removing multi DDL command in create tables * adding newline for excute sql query --- dataherald/sql_generator/__init__.py | 4 ++-- .../dataherald_finetuning_agent.py | 14 +++++++------ .../sql_generator/dataherald_sqlagent.py | 21 +++++++++---------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index dd318bf6..b270d5a7 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -173,9 +173,9 @@ def stream_agent_steps( # noqa: C901 queue.put(message.content + "\n") elif "steps" in chunk: for step in chunk["steps"]: - queue.put(f"**Observation:**\n `{step.observation}`\n") + queue.put(f"\n**Observation:**\n {step.observation}\n") elif "output" in chunk: - queue.put(f'**Final Answer:**\n {chunk["output"]}') + queue.put(f'\n**Final Answer:**\n {chunk["output"]}') if "```sql" in chunk["output"]: response.sql = replace_unprocessable_characters( self.remove_markdown(chunk["output"]) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index e0ff4367..5fe0ed4d 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -227,13 +227,11 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += ( - f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' - ) + table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( - f"Table: {table}, relevance score: {max(df['similarities'])}\n" + f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" ) return table_relevance @@ -250,9 +248,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. Use this tool to execute the SQL query on the database, and return the results. + Add newline after both ```sql and ``` tags. """ args_schema: Type[BaseModel] = SQLInput @@ -335,7 +334,8 @@ def _run( {"role": "user", "content": user_prompt}, ], ) - return response.choices[0].message.content + returned_sql = response.choices[0].message.content + return f"```sql\n{returned_sql}```" async def _arun( self, @@ -372,6 +372,7 @@ def _run( tables_schema = "" for table in self.db_scan: if table.table_name in table_names_list: + tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: @@ -385,6 +386,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 857cdd08..3a1ecb28 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -150,10 +150,11 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: SQL query. + Input: A SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries. + Add newline after both ```sql and ``` tags. """ @catch_exceptions() @@ -204,8 +205,8 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: response = "Admin: All of the generated SQL queries must follow the below instructions:\n" - for instruction in self.instructions: - response += f"{instruction['instruction']}\n" + for index, instruction in enumerate(self.instructions): + response += f"{index + 1}) {instruction['instruction']}\n" return response async def _arun( @@ -290,13 +291,11 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += ( - f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' - ) + table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' if len(most_similar_tables) > 0: for table in most_similar_tables: table_relevance += ( - f"Table: {table}, relevance score: {max(df['similarities'])}\n" + f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" ) return table_relevance @@ -404,7 +403,7 @@ def _run( replace_unprocessable_characters(table_name) for table_name in table_names_list ] - tables_schema = "" + tables_schema = "```sql\n" for table in self.db_scan: if table.table_name in table_names_list: tables_schema += table.table_schema + "\n" @@ -420,6 +419,7 @@ def _run( ) if len(descriptions) > 0: tables_schema += f"/*\n{''.join(descriptions)}*/\n" + tables_schema += "```\n" if tables_schema == "": tables_schema += "Tables not found in the database" return tables_schema @@ -516,9 +516,8 @@ def _run( return "Action input for the fewshot_examples_retriever tool should be an integer" returned_output = "" for example in self.few_shot_examples[:number_of_samples]: - returned_output += ( - f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n" - ) + returned_output += f"Question: {example['prompt_text']} \n" + returned_output += f"```sql\n{example['sql']}\n```\n" if returned_output == "": returned_output = "No previously asked Question/SQL pairs are available" return returned_output