diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index dca8cd5f8fd08..90187100d1587 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json import warnings from abc import ABC -from typing import Any, Dict, Iterator, List, Mapping, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -15,6 +17,9 @@ get_token_ids_anthropic, ) +if TYPE_CHECKING: + from botocore.config import Config + HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" ALTERNATION_ERROR = ( @@ -163,6 +168,9 @@ class BedrockBase(BaseModel, ABC): See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html """ + config: Optional[Config] = None + """An optional botocore.config.Config instance to pass to the client.""" + model_id: str """Id of the model to call, e.g., amazon.titan-text-express-v1, this is equivalent to the modelId property in the list-foundation-models api""" @@ -212,6 +220,8 @@ def validate_environment(cls, values: Dict) -> Dict: client_params["region_name"] = values["region_name"] if values["endpoint_url"]: client_params["endpoint_url"] = values["endpoint_url"] + if values["config"]: + client_params["config"] = values["config"] values["client"] = session.client("bedrock-runtime", **client_params)