Skip to content

Commit

Permalink
Update with new VertexAIModelGarden output format
Browse files Browse the repository at this point in the history
  • Loading branch information
shikanime committed Dec 11, 2023
1 parent c0f4b95 commit cf42493
Showing 1 changed file with 74 additions and 11 deletions.
85 changes: 74 additions & 11 deletions libs/langchain/langchain/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."
strip_prefix: bool = False
"Whether to strip the prompt from the generated text."

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -441,19 +443,76 @@ def _generate(
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)

def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
return self._parse_response(
prompts,
response,
run_manager=run_manager,
)

def _parse_response(
self,
prompts: List[str],
predictions: "Prediction",
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> List[List[GenerationChunk]]:
generations: List[List[GenerationChunk]] = []
for prompt, result in zip(prompts, predictions.predictions):
chunks = [
GenerationChunk(text=self._parse_prediction(prediction))
for prediction in result
]
if self.strip_prefix:
chunks = self._strip_generation_context(prompt, chunks)
generation = self._aggregate_response(
chunks,
run_manager=run_manager,
verbose=self.verbose,
)
generations.append([generation])
return LLMResult(generations=generations)

def _aggregate_response(
self,
chunks: List[Generation],
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for chunk in chunks:
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("Malformed response from VertexAIModelGarden")
return final_chunk

def _strip_generation_context(
self,
prompt: str,
chunks: List[GenerationChunk],
) -> List[GenerationChunk]:
context = self._format_generation_context(prompt)
chunk_cursor = 0
context_cursor = 0
while chunk_cursor < len(chunks) and context_cursor < len(context):
chunk = chunks[chunk_cursor]
for c in chunk.text:
if c == context[context_cursor]:
context_cursor += 1
else:
break
chunk_cursor += 1
return chunks[chunk_cursor:] if chunk_cursor == context_cursor else chunks

def _format_generation_context(self, prompt: str) -> str:
return "\n".join(["Prompt:", prompt, "Output:", prompt])

def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
Expand Down Expand Up @@ -488,4 +547,8 @@ async def _agenerate(
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)
return self._parse_response(
prompts,
response,
run_manager=run_manager,
)

0 comments on commit cf42493

Please sign in to comment.