From aae5a65375eade557016912233f9c43d3f7665d0 Mon Sep 17 00:00:00 2001
From: William Phetsinorath <william.phetsinorath@shikanime.studio>
Date: Fri, 8 Dec 2023 15:00:47 +0100
Subject: [PATCH] Update with new VertexAIModelGarden output format

---
 libs/langchain/langchain/llms/vertexai.py | 84 +++++++++++++++++++----
 1 file changed, 71 insertions(+), 13 deletions(-)

diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py
index 0dedeaf0df853..6ccbf5418891b 100644
--- a/libs/langchain/langchain/llms/vertexai.py
+++ b/libs/langchain/langchain/llms/vertexai.py
@@ -359,6 +359,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
     result_arg: Optional[str] = "generated_text"
     "Set result_arg to None if output of the model is expected to be a string."
     "Otherwise, if it's a dict, provided an argument that contains the result."
+    strip_prefix: bool = False
+    "Whether to strip the prompt from the generated text."
 
     @root_validator()
     def validate_environment(cls, values: Dict) -> Dict:
@@ -429,26 +431,79 @@ def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
     def _generate(
         self,
         prompts: List[str],
-        stop: Optional[List[str]] = None,
         run_manager: Optional[CallbackManagerForLLMRun] = None,
         **kwargs: Any,
     ) -> LLMResult:
         """Run the LLM on the given prompt and input."""
         instances = self._prepare_request(prompts, **kwargs)
         response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
-        return self._parse_response(response)
-
-    def _parse_response(self, predictions: "Prediction") -> LLMResult:
-        generations: List[List[Generation]] = []
-        for result in predictions.predictions:
-            generations.append(
-                [
-                    Generation(text=self._parse_prediction(prediction))
-                    for prediction in result
-                ]
+        return self._parse_response(
+            prompts,
+            response,
+            run_manager=run_manager,
+        )
+
+    def _parse_response(
+        self,
+        prompts: List[str],
+        predictions: "Prediction",
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+    ) -> List[List[GenerationChunk]]:
+        generations: List[List[GenerationChunk]] = []
+        for prompt, result in zip(prompts, predictions.predictions):
+            chunks = [
+                GenerationChunk(text=self._parse_prediction(prediction))
+                for prediction in result
+            ]
+            if self.strip_prefix:
+                chunks = self._strip_generation_context(prompt, chunks)
+            generation = self._aggregate_response(
+                chunks,
+                run_manager=run_manager,
+                verbose=self.verbose,
             )
+            generations.append([generation])
         return LLMResult(generations=generations)
 
+    def _aggregate_response(
+        self,
+        chunks: List[Generation],
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        verbose: bool = False,
+    ) -> GenerationChunk:
+        final_chunk: Optional[GenerationChunk] = None
+        for chunk in chunks:
+          if final_chunk is None:
+              final_chunk = chunk
+          else:
+              final_chunk += chunk
+          if run_manager:
+              run_manager.on_llm_new_token(
+                  chunk.text,
+                  verbose=verbose,
+              )
+        if final_chunk is None:
+            raise ValueError("Malformed response from VertexAIModelGarden")
+        return final_chunk
+
+    def _strip_generation_context(
+        self,
+        prompt: str,
+        chunks: List[GenerationChunk],
+    ) -> List[GenerationChunk]:
+        source_context = self._format_generation_context(prompt)
+        generation_context = ""
+        for chunk in chunks:
+          if len(generation_context) >= len(source_context):
+            break
+          generation_context += chunk.text
+        if source_context == generation_context:
+            return chunks[len(generation_context):]
+        return chunks
+
+    def _format_generation_context(self, prompt: str) -> str:
+        return "\n".join(["Prompt:", prompt, "Output:", ""])
+
     def _parse_prediction(self, prediction: Any) -> str:
         if isinstance(prediction, str):
             return prediction
@@ -474,7 +529,6 @@ def _parse_prediction(self, prediction: Any) -> str:
     async def _agenerate(
         self,
         prompts: List[str],
-        stop: Optional[List[str]] = None,
         run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
         **kwargs: Any,
     ) -> LLMResult:
@@ -483,4 +537,8 @@ async def _agenerate(
         response = await self.async_client.predict(
             endpoint=self.endpoint_path, instances=instances
         )
-        return self._parse_response(response)
+        return self._parse_response(
+            prompts,
+            response,
+            run_manager=run_manager,
+        )