Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DATA-2038/fixing the fallback and confidence score #453

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
You are a {dialect} expert.
Given a question, a SQL query, and the database schema, analyze the correctness of the SQL query and provide a score.
Score indicates how correctly and accurately SQL query answers the question.
Note that the score should be between 0 and 100. Higher scores means the SQL Query is more accurate.
Note that the score should be between 0 and {MAX_CONFIDENCE}. Higher scores means the SQL Query is more accurate.
Double check the SQL query for the common mistakes, including:
- For columns that can contain NULL values, NULL values should be filtered out by using the IS NOT NULL operator in the WHERE condition
- when intention of the question is to include all rows from both sets, including duplicates, using UNION ALL is better than UNION
Expand Down Expand Up @@ -85,13 +85,34 @@ def answer_parser(self, answer: str) -> int:
output = int(numbers[-1])
return output

def create_sql_results(self, result: Any) -> list:
rows = []
if result:
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)
return rows

@override
def evaluate(
self,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> Evaluation:
max_confidence = 100
database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
f"(Simple evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}"
Expand Down Expand Up @@ -142,33 +163,24 @@ def evaluate(
with database._engine.connect() as connection:
execution = connection.execute(text(query))
result = execution.fetchmany(TOP_K)
rows = []
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)
rows = self.create_sql_results(result)

except SQLInjectionError as e:
raise SQLInjectionError(
"Sensitive SQL keyword detected in the query."
) from e
if not rows:
logger.info(
f"(Simple evaluator) SQL query: {sql} returned no results. max confidence is 70"
)
max_confidence = 70
answer = chain.invoke(
{
"dialect": dialect,
"question": user_question,
"SQL": sql,
"SQL_result": "\n".join([str(row) for row in rows]),
"MAX_CONFIDENCE": str(max_confidence),
"schema": schema,
}
)["text"]
Expand Down
6 changes: 3 additions & 3 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def extract_query_from_intermediate_steps(
sql_query = self.remove_markdown(action.tool_input)
if sql_query == "":
for step in intermediate_steps:
action = step[0]
if "SELECT" in action.tool_input.upper():
sql_query = self.remove_markdown(action.tool_input)
thought = str(step[0].log).split("Action:")[0]
if "```sql" in thought:
sql_query = self.remove_markdown(thought)
if not sql_query.upper().strip().startswith("SELECT"):
sql_query = ""
return sql_query
Expand Down
Loading