From 34d59441f13307f07a66e988a6f1ef075234d22f Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 17 Apr 2024 09:12:12 -0400 Subject: [PATCH] DH-5738/fixing the malformed sql queries in intermediate steps --- dataherald/sql_generator/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index b270d5a7..94f8b7d5 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -62,6 +62,10 @@ def remove_markdown(self, query: str) -> str: return matches[0].strip() return query + def format_sql_query_intermediate_steps(self, step: str) -> str: + pattern = r"```sql(.*?)```" + return re.sub(pattern, self.format_sql_query, step) + @classmethod def get_upper_bound_limit(cls) -> int: top_k = os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", None) @@ -170,12 +174,19 @@ def stream_agent_steps( # noqa: C901 ): if "actions" in chunk: for message in chunk["messages"]: - queue.put(message.content + "\n") + queue.put( + self.format_sql_query_intermediate_steps( + message.content + ) + + "\n" + ) elif "steps" in chunk: for step in chunk["steps"]: queue.put(f"\n**Observation:**\n {step.observation}\n") elif "output" in chunk: - queue.put(f'\n**Final Answer:**\n {chunk["output"]}') + queue.put( + f'\n**Final Answer:**\n {self.format_sql_query_intermediate_steps(chunk["output"])}' + ) if "```sql" in chunk["output"]: response.sql = replace_unprocessable_characters( self.remove_markdown(chunk["output"])