From aae5a65375eade557016912233f9c43d3f7665d0 Mon Sep 17 00:00:00 2001 From: William Phetsinorath <william.phetsinorath@shikanime.studio> Date: Fri, 8 Dec 2023 15:00:47 +0100 Subject: [PATCH] Update with new VertexAIModelGarden output format --- libs/langchain/langchain/llms/vertexai.py | 84 +++++++++++++++++++---- 1 file changed, 71 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 0dedeaf0df853..6ccbf5418891b 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -359,6 +359,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): result_arg: Optional[str] = "generated_text" "Set result_arg to None if output of the model is expected to be a string." "Otherwise, if it's a dict, provided an argument that contains the result." + strip_prefix: bool = False + "Whether to strip the prompt from the generated text." @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -429,26 +431,79 @@ def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]: def _generate( self, prompts: List[str], - stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" instances = self._prepare_request(prompts, **kwargs) response = self.client.predict(endpoint=self.endpoint_path, instances=instances) - return self._parse_response(response) - - def _parse_response(self, predictions: "Prediction") -> LLMResult: - generations: List[List[Generation]] = [] - for result in predictions.predictions: - generations.append( - [ - Generation(text=self._parse_prediction(prediction)) - for prediction in result - ] + return self._parse_response( + prompts, + response, + run_manager=run_manager, + ) + + def _parse_response( + self, + prompts: List[str], + predictions: "Prediction", + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> List[List[GenerationChunk]]: + generations: List[List[GenerationChunk]] = [] + for prompt, result in zip(prompts, predictions.predictions): + chunks = [ + GenerationChunk(text=self._parse_prediction(prediction)) + for prediction in result + ] + if self.strip_prefix: + chunks = self._strip_generation_context(prompt, chunks) + generation = self._aggregate_response( + chunks, + run_manager=run_manager, + verbose=self.verbose, ) + generations.append([generation]) return LLMResult(generations=generations) + def _aggregate_response( + self, + chunks: List[Generation], + run_manager: Optional[CallbackManagerForLLMRun] = None, + verbose: bool = False, + ) -> GenerationChunk: + final_chunk: Optional[GenerationChunk] = None + for chunk in chunks: + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=verbose, + ) + if final_chunk is None: + raise ValueError("Malformed response from VertexAIModelGarden") + return final_chunk + + def _strip_generation_context( + self, + prompt: str, + chunks: List[GenerationChunk], + ) -> List[GenerationChunk]: + source_context = self._format_generation_context(prompt) + generation_context = "" + for chunk in chunks: + if len(generation_context) >= len(source_context): + break + generation_context += chunk.text + if source_context == generation_context: + return chunks[len(generation_context):] + return chunks + + def _format_generation_context(self, prompt: str) -> str: + return "\n".join(["Prompt:", prompt, "Output:", ""]) + def _parse_prediction(self, prediction: Any) -> str: if isinstance(prediction, str): return prediction @@ -474,7 +529,6 @@ def _parse_prediction(self, prediction: Any) -> str: async def _agenerate( self, prompts: List[str], - stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -483,4 +537,8 @@ async def _agenerate( response = await self.async_client.predict( endpoint=self.endpoint_path, instances=instances ) - return self._parse_response(response) + return self._parse_response( + prompts, + response, + run_manager=run_manager, + )