Skip to content

Commit

Permalink
Update with new VertexAIModelGarden output format
Browse files Browse the repository at this point in the history
  • Loading branch information
shikanime committed Dec 8, 2023
1 parent a05230a commit aae5a65
Showing 1 changed file with 71 additions and 13 deletions.
84 changes: 71 additions & 13 deletions libs/langchain/langchain/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."
strip_prefix: bool = False
"Whether to strip the prompt from the generated text."

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -429,26 +431,79 @@ def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)

def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
return self._parse_response(
prompts,
response,
run_manager=run_manager,
)

def _parse_response(
self,
prompts: List[str],
predictions: "Prediction",
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> List[List[GenerationChunk]]:
generations: List[List[GenerationChunk]] = []
for prompt, result in zip(prompts, predictions.predictions):
chunks = [
GenerationChunk(text=self._parse_prediction(prediction))
for prediction in result
]
if self.strip_prefix:
chunks = self._strip_generation_context(prompt, chunks)
generation = self._aggregate_response(
chunks,
run_manager=run_manager,
verbose=self.verbose,
)
generations.append([generation])
return LLMResult(generations=generations)

def _aggregate_response(
self,
chunks: List[Generation],
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for chunk in chunks:
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("Malformed response from VertexAIModelGarden")
return final_chunk

def _strip_generation_context(
self,
prompt: str,
chunks: List[GenerationChunk],
) -> List[GenerationChunk]:
source_context = self._format_generation_context(prompt)
generation_context = ""
for chunk in chunks:
if len(generation_context) >= len(source_context):
break
generation_context += chunk.text
if source_context == generation_context:
return chunks[len(generation_context):]
return chunks

def _format_generation_context(self, prompt: str) -> str:
return "\n".join(["Prompt:", prompt, "Output:", ""])

def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
Expand All @@ -474,7 +529,6 @@ def _parse_prediction(self, prediction: Any) -> str:
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
Expand All @@ -483,4 +537,8 @@ async def _agenerate(
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)
return self._parse_response(
prompts,
response,
run_manager=run_manager,
)

0 comments on commit aae5a65

Please sign in to comment.