Skip to content

Commit

Permalink
DH-5738/fixing the malformed sql queries in intermediate steps
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 17, 2024
1 parent cf1a2e7 commit 34d5944
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def remove_markdown(self, query: str) -> str:
return matches[0].strip()
return query

def format_sql_query_intermediate_steps(self, step: str) -> str:
pattern = r"```sql(.*?)```"
return re.sub(pattern, self.format_sql_query, step)

@classmethod
def get_upper_bound_limit(cls) -> int:
top_k = os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", None)
Expand Down Expand Up @@ -170,12 +174,19 @@ def stream_agent_steps( # noqa: C901
):
if "actions" in chunk:
for message in chunk["messages"]:
queue.put(message.content + "\n")
queue.put(
self.format_sql_query_intermediate_steps(
message.content
)
+ "\n"
)
elif "steps" in chunk:
for step in chunk["steps"]:
queue.put(f"\n**Observation:**\n {step.observation}\n")
elif "output" in chunk:
queue.put(f'\n**Final Answer:**\n {chunk["output"]}')
queue.put(
f'\n**Final Answer:**\n {self.format_sql_query_intermediate_steps(chunk["output"])}'
)
if "```sql" in chunk["output"]:
response.sql = replace_unprocessable_characters(
self.remove_markdown(chunk["output"])
Expand Down

0 comments on commit 34d5944

Please sign in to comment.