Skip to content

Commit

Permalink
DH-5638/add langsmith metadata params (#441)
Browse files Browse the repository at this point in the history
* DH-5638/add langsmith metadata params

* DH-5638/change tests
  • Loading branch information
MohammadrezaPourreza authored Mar 22, 2024
1 parent 3dbd483 commit cf88a1b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 5 deletions.
20 changes: 18 additions & 2 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def update_error(self, sql_generation: SQLGeneration, error: str) -> SQLGenerati
sql_generation.error = error
return self.sql_generation_repository.update(sql_generation)

def generate_response_with_timeout(self, sql_generator, user_prompt, db_connection):
def generate_response_with_timeout(
self, sql_generator, user_prompt, db_connection, metadata=None
):
return sql_generator.generate_response(
user_prompt=user_prompt, database_connection=db_connection
user_prompt=user_prompt,
database_connection=db_connection,
metadata=metadata,
)

def update_the_initial_sql_generation(
Expand All @@ -70,6 +74,11 @@ def create(
else LLMConfig(),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
sql_generation_request.metadata.get("lang_smith", {})
if sql_generation_request.metadata
else {}
)
self.sql_generation_repository.insert(initial_sql_generation)
prompt_repository = PromptRepository(self.storage)
prompt = prompt_repository.find_by_id(prompt_id)
Expand Down Expand Up @@ -134,6 +143,7 @@ def create(
sql_generator,
prompt,
db_connection,
metadata=langsmith_metadata,
)
try:
sql_generation = future.result(
Expand Down Expand Up @@ -179,6 +189,11 @@ def start_streaming(
else LLMConfig(),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
sql_generation_request.metadata.get("lang_smith", {})
if sql_generation_request.metadata
else {}
)
self.sql_generation_repository.insert(initial_sql_generation)
prompt_repository = PromptRepository(self.storage)
prompt = prompt_repository.find_by_id(prompt_id)
Expand Down Expand Up @@ -225,6 +240,7 @@ def start_streaming(
database_connection=db_connection,
response=initial_sql_generation,
queue=queue,
metadata=langsmith_metadata,
)
except Exception as e:
self.update_error(initial_sql_generation, str(e))
Expand Down
7 changes: 6 additions & 1 deletion dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def generate_response(
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None,
metadata: dict = None,
) -> SQLGeneration:
"""Generates a response to a user question."""
pass
Expand All @@ -160,10 +161,13 @@ def stream_agent_steps( # noqa: C901
response: SQLGeneration,
sql_generation_repository: SQLGenerationRepository,
queue: Queue,
metadata: dict = None,
):
try:
with get_openai_callback() as cb:
for chunk in agent_executor.stream({"input": question}):
for chunk in agent_executor.stream(
{"input": question}, {"metadata": metadata}
):
if "actions" in chunk:
for message in chunk["messages"]:
queue.put(message.content + "\n")
Expand Down Expand Up @@ -209,6 +213,7 @@ def stream_response(
database_connection: DatabaseConnection,
response: SQLGeneration,
queue: Queue,
metadata: dict = None,
):
"""Streams a response to a user question."""
pass
7 changes: 6 additions & 1 deletion dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def generate_response(
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None, # noqa: ARG002
metadata: dict = None,
) -> SQLGeneration:
"""
generate_response generates a response to a user question using a Finetuning model.
Expand Down Expand Up @@ -564,7 +565,9 @@ def generate_response(
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
with get_openai_callback() as cb:
try:
result = agent_executor.invoke({"input": user_prompt.text})
result = agent_executor.invoke(
{"input": user_prompt.text}, {"metadata": metadata}
)
result = self.check_for_time_out_or_tool_limit(result)
except SQLInjectionError as e:
raise SQLInjectionError(e) from e
Expand Down Expand Up @@ -607,6 +610,7 @@ def stream_response(
database_connection: DatabaseConnection,
response: SQLGeneration,
queue: Queue,
metadata: dict = None,
):
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
Expand Down Expand Up @@ -669,6 +673,7 @@ def stream_response(
response,
sql_generation_repository,
queue,
metadata,
),
)
thread.start()
7 changes: 6 additions & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def generate_response(
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None,
metadata: dict = None,
) -> SQLGeneration:
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
Expand Down Expand Up @@ -710,7 +711,9 @@ def generate_response(
agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE
with get_openai_callback() as cb:
try:
result = agent_executor.invoke({"input": user_prompt.text})
result = agent_executor.invoke(
{"input": user_prompt.text}, {"metadata": metadata}
)
result = self.check_for_time_out_or_tool_limit(result)
except SQLInjectionError as e:
raise SQLInjectionError(e) from e
Expand Down Expand Up @@ -756,6 +759,7 @@ def stream_response(
database_connection: DatabaseConnection,
response: SQLGeneration,
queue: Queue,
metadata: dict = None,
):
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
Expand Down Expand Up @@ -815,6 +819,7 @@ def stream_response(
response,
sql_generation_repository,
queue,
metadata,
),
)
thread.start()
1 change: 1 addition & 0 deletions dataherald/tests/sql_generator/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_response(
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None, # noqa: ARG002
metadata: dict = None, # noqa: ARG002
) -> SQLGeneration:
return SQLGeneration(
question_id="651f2d76275132d5b65175eb",
Expand Down

0 comments on commit cf88a1b

Please sign in to comment.