Skip to content

Commit

Permalink
fixes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Jan 3, 2024
1 parent 69905a3 commit a0659af
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 266 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse

import requests
Expand All @@ -26,6 +26,20 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import ( # type: ignore
Content,
GenerativeModel,
Image,
Part,
)

from langchain_google_vertexai.llms import (
_VertexAICommon,
Expand All @@ -34,26 +48,16 @@
)
from langchain_google_vertexai.utils import (
load_image_from_gcs,
raise_vertex_import_error,
)

if TYPE_CHECKING:
from vertexai.language_models import (
ChatMessage,
ChatSession,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import Content

logger = logging.getLogger(__name__)


@dataclass
class _ChatHistory:
"""Represents a context and a history of messages."""

history: List["ChatMessage"] = field(default_factory=list)
history: List[ChatMessage] = field(default_factory=list)
context: Optional[str] = None


Expand All @@ -68,7 +72,6 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
ValueError: If a sequence of message has a SystemMessage not at the
first place.
"""
from vertexai.language_models import ChatMessage

vertex_messages, context = [], None
for i, message in enumerate(history):
Expand Down Expand Up @@ -100,9 +103,7 @@ def _is_url(s: str) -> bool:

def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
) -> List["Content"]:
from vertexai.preview.generative_models import Content, Image, Part

) -> List[Content]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
Expand All @@ -120,8 +121,9 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part:
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group(
1
regexp = r"data:image/\w{2,4};base64,(.*)"
encoded = (
re.search(regexp, path).group(1) # type: ignore
)
except AttributeError:
raise ValueError(
Expand Down Expand Up @@ -161,9 +163,7 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part:
return vertex_messages


def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
from vertexai.language_models import InputOutputTextPair

def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
if len(examples) % 2 != 0:
raise ValueError(
f"Expect examples to have an even amount of messages, got {len(examples)}."
Expand Down Expand Up @@ -223,16 +223,7 @@ def get_lc_namespace(cls) -> List[str]:
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values)
try:
from vertexai.language_models import ChatModel, CodeChatModel

if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
except ImportError:
raise_vertex_import_error()
cls._init_vertexai(values)
if is_gemini:
values["client"] = GenerativeModel(model_name=values["model_name"])
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, Type

from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
InvalidArgument,
ResourceExhausted,
ServiceUnavailable,
)
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
)

from langchain_google_vertexai.llms import _VertexAICommon
from langchain_google_vertexai.utils import raise_vertex_import_error

logger = logging.getLogger(__name__)

Expand All @@ -28,18 +38,14 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._try_init_vertexai(values)
cls._init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
try:
from vertexai.language_models import TextEmbeddingModel
except ImportError:
raise_vertex_import_error()
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

Expand Down Expand Up @@ -141,14 +147,8 @@ def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
ResourceExhausted,
ServiceUnavailable,
)

errors = [
errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
Expand All @@ -161,8 +161,6 @@ def _get_embeddings_with_retry(
@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
from vertexai.language_models import TextEmbeddingInput

requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
Expand All @@ -182,7 +180,6 @@ def _prepare_and_validate_batches(
# Returns embeddings of the first text batch that went through,
# and text batches for the rest of the texts.
"""
from google.api_core.exceptions import InvalidArgument

batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
Expand Down
Loading

0 comments on commit a0659af

Please sign in to comment.