diff --git a/dataherald/services/sql_generations.py b/dataherald/services/sql_generations.py index e766720c..b1890443 100644 --- a/dataherald/services/sql_generations.py +++ b/dataherald/services/sql_generations.py @@ -43,9 +43,13 @@ def update_error(self, sql_generation: SQLGeneration, error: str) -> SQLGenerati sql_generation.error = error return self.sql_generation_repository.update(sql_generation) - def generate_response_with_timeout(self, sql_generator, user_prompt, db_connection): + def generate_response_with_timeout( + self, sql_generator, user_prompt, db_connection, metadata=None + ): return sql_generator.generate_response( - user_prompt=user_prompt, database_connection=db_connection + user_prompt=user_prompt, + database_connection=db_connection, + metadata=metadata, ) def update_the_initial_sql_generation( @@ -70,6 +74,11 @@ def create( else LLMConfig(), metadata=sql_generation_request.metadata, ) + langsmith_metadata = ( + sql_generation_request.metadata.get("lang_smith", {}) + if sql_generation_request.metadata + else {} + ) self.sql_generation_repository.insert(initial_sql_generation) prompt_repository = PromptRepository(self.storage) prompt = prompt_repository.find_by_id(prompt_id) @@ -134,6 +143,7 @@ def create( sql_generator, prompt, db_connection, + metadata=langsmith_metadata, ) try: sql_generation = future.result( @@ -179,6 +189,11 @@ def start_streaming( else LLMConfig(), metadata=sql_generation_request.metadata, ) + langsmith_metadata = ( + sql_generation_request.metadata.get("lang_smith", {}) + if sql_generation_request.metadata + else {} + ) self.sql_generation_repository.insert(initial_sql_generation) prompt_repository = PromptRepository(self.storage) prompt = prompt_repository.find_by_id(prompt_id) @@ -225,6 +240,7 @@ def start_streaming( database_connection=db_connection, response=initial_sql_generation, queue=queue, + metadata=langsmith_metadata, ) except Exception as e: self.update_error(initial_sql_generation, str(e)) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 2f86e37d..1377272a 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -149,6 +149,7 @@ def generate_response( user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, + metadata: dict = None, ) -> SQLGeneration: """Generates a response to a user question.""" pass @@ -160,10 +161,13 @@ def stream_agent_steps( # noqa: C901 response: SQLGeneration, sql_generation_repository: SQLGenerationRepository, queue: Queue, + metadata: dict = None, ): try: with get_openai_callback() as cb: - for chunk in agent_executor.stream({"input": question}): + for chunk in agent_executor.stream( + {"input": question}, {"metadata": metadata} + ): if "actions" in chunk: for message in chunk["messages"]: queue.put(message.content + "\n") @@ -209,6 +213,7 @@ def stream_response( database_connection: DatabaseConnection, response: SQLGeneration, queue: Queue, + metadata: dict = None, ): """Streams a response to a user question.""" pass diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 3de833fa..b3f00160 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -492,6 +492,7 @@ def generate_response( user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 + metadata: dict = None, ) -> SQLGeneration: """ generate_response generates a response to a user question using a Finetuning model. @@ -564,7 +565,9 @@ def generate_response( agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE with get_openai_callback() as cb: try: - result = agent_executor.invoke({"input": user_prompt.text}) + result = agent_executor.invoke( + {"input": user_prompt.text}, {"metadata": metadata} + ) result = self.check_for_time_out_or_tool_limit(result) except SQLInjectionError as e: raise SQLInjectionError(e) from e @@ -607,6 +610,7 @@ def stream_response( database_connection: DatabaseConnection, response: SQLGeneration, queue: Queue, + metadata: dict = None, ): context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) @@ -669,6 +673,7 @@ def stream_response( response, sql_generation_repository, queue, + metadata, ), ) thread.start() diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 735fb52b..0bf499c4 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -654,6 +654,7 @@ def generate_response( user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, + metadata: dict = None, ) -> SQLGeneration: context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) @@ -710,7 +711,9 @@ def generate_response( agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE with get_openai_callback() as cb: try: - result = agent_executor.invoke({"input": user_prompt.text}) + result = agent_executor.invoke( + {"input": user_prompt.text}, {"metadata": metadata} + ) result = self.check_for_time_out_or_tool_limit(result) except SQLInjectionError as e: raise SQLInjectionError(e) from e @@ -756,6 +759,7 @@ def stream_response( database_connection: DatabaseConnection, response: SQLGeneration, queue: Queue, + metadata: dict = None, ): context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) @@ -815,6 +819,7 @@ def stream_response( response, sql_generation_repository, queue, + metadata, ), ) thread.start() diff --git a/dataherald/tests/sql_generator/test_generator.py b/dataherald/tests/sql_generator/test_generator.py index 56c91695..f3e4ce6a 100644 --- a/dataherald/tests/sql_generator/test_generator.py +++ b/dataherald/tests/sql_generator/test_generator.py @@ -18,6 +18,7 @@ def generate_response( user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 + metadata: dict = None, # noqa: ARG002 ) -> SQLGeneration: return SQLGeneration( question_id="651f2d76275132d5b65175eb",