Skip to content

Commit

Permalink
Fixing the multiline sql queries
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 18, 2024
1 parent 47b957f commit 8f6044e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def remove_markdown(self, query: str) -> str:

def format_sql_query_intermediate_steps(self, step: str) -> str:
pattern = r"```sql(.*?)```"
return re.sub(pattern, self.format_sql_query, step)

def formatter(match):
original_sql = match.group(1)
formatted_sql = self.format_sql_query(original_sql)
return "```sql\n" + formatted_sql + "\n```"

return re.sub(pattern, formatter, step, flags=re.DOTALL)

@classmethod
def get_upper_bound_limit(cls) -> int:
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: A SQL query between ```sql and ``` tags.
Input: A well-formed multi-line SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
If an error occurs, rewrite the query and retry.
Use this tool to execute SQL queries.
Expand Down

0 comments on commit 8f6044e

Please sign in to comment.