diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 94f8b7d5..03597f70 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -64,7 +64,13 @@ def remove_markdown(self, query: str) -> str: def format_sql_query_intermediate_steps(self, step: str) -> str: pattern = r"```sql(.*?)```" - return re.sub(pattern, self.format_sql_query, step) + + def formatter(match): + original_sql = match.group(1) + formatted_sql = self.format_sql_query(original_sql) + return "```sql\n" + formatted_sql + "\n```" + + return re.sub(pattern, formatter, step, flags=re.DOTALL) @classmethod def get_upper_bound_limit(cls) -> int: diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 3a1ecb28..899e6e87 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "SqlDbQuery" description = """ - Input: A SQL query between ```sql and ``` tags. + Input: A well-formed multi-line SQL query between ```sql and ``` tags. Output: Result from the database or an error message if the query is incorrect. If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries.