diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index e3038812fd9be..7747877dc91e2 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -364,6 +364,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): result_arg: Optional[str] = "generated_text" "Set result_arg to None if output of the model is expected to be a string." "Otherwise, if it's a dict, provided an argument that contains the result." + strip_prefix: bool = False + "Whether to strip the prompt from the generated text." @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -441,19 +443,76 @@ def _generate( """Run the LLM on the given prompt and input.""" instances = self._prepare_request(prompts, **kwargs) response = self.client.predict(endpoint=self.endpoint_path, instances=instances) - return self._parse_response(response) - - def _parse_response(self, predictions: "Prediction") -> LLMResult: - generations: List[List[Generation]] = [] - for result in predictions.predictions: - generations.append( - [ - Generation(text=self._parse_prediction(prediction)) - for prediction in result - ] + return self._parse_response( + prompts, + response, + run_manager=run_manager, + ) + + def _parse_response( + self, + prompts: List[str], + predictions: "Prediction", + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> List[List[GenerationChunk]]: + generations: List[List[GenerationChunk]] = [] + for prompt, result in zip(prompts, predictions.predictions): + chunks = [ + GenerationChunk(text=self._parse_prediction(prediction)) + for prediction in result + ] + if self.strip_prefix: + chunks = self._strip_generation_context(prompt, chunks) + generation = self._aggregate_response( + chunks, + run_manager=run_manager, + verbose=self.verbose, ) + generations.append([generation]) return LLMResult(generations=generations) + def _aggregate_response( + self, + chunks: List[Generation], + run_manager: Optional[CallbackManagerForLLMRun] = None, + verbose: bool = False, + ) -> GenerationChunk: + final_chunk: Optional[GenerationChunk] = None + for chunk in chunks: + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=verbose, + ) + if final_chunk is None: + raise ValueError("Malformed response from VertexAIModelGarden") + return final_chunk + + def _strip_generation_context( + self, + prompt: str, + chunks: List[GenerationChunk], + ) -> List[GenerationChunk]: + context = self._format_generation_context(prompt) + chunk_cursor = 0 + context_cursor = 0 + while chunk_cursor < len(chunks) and context_cursor < len(context): + chunk = chunks[chunk_cursor] + for c in chunk.text: + if c == context[context_cursor]: + context_cursor += 1 + else: + break + chunk_cursor += 1 + return chunks[chunk_cursor:] if chunk_cursor == context_cursor else chunks + + def _format_generation_context(self, prompt: str) -> str: + return "\n".join(["Prompt:", prompt, "Output:", prompt]) + def _parse_prediction(self, prediction: Any) -> str: if isinstance(prediction, str): return prediction @@ -488,4 +547,8 @@ async def _agenerate( response = await self.async_client.predict( endpoint=self.endpoint_path, instances=instances ) - return self._parse_response(response) + return self._parse_response( + prompts, + response, + run_manager=run_manager, + )