diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index d17a2932bc8f2..730b7f8908f95 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -328,13 +328,7 @@ class GetPopulation(BaseModel): def _init_chat_model_helper( model: str, *, model_provider: Optional[str] = None, **kwargs: Any ) -> BaseChatModel: - model_provider = model_provider or _attempt_infer_model_provider(model) - if not model_provider: - raise ValueError( - f"Unable to infer model provider for {model=}, please specify " - f"model_provider directly." - ) - model_provider = model_provider.replace("-", "_").lower() + model, model_provider = _parse_model(model, model_provider) if model_provider == "openai": _check_pkg("langchain_openai") from langchain_openai import ChatOpenAI @@ -461,6 +455,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]: return None +def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]: + if ( + not model_provider + and ":" in model + and model.split(":")[0] in _SUPPORTED_PROVIDERS + ): + model_provider = model.split(":")[0] + model = ":".join(model.split(":")[1:]) + model_provider = model_provider or _attempt_infer_model_provider(model) + if not model_provider: + raise ValueError( + f"Unable to infer model provider for {model=}, please specify " + f"model_provider directly." + ) + model_provider = model_provider.replace("-", "_").lower() + return model, model_provider + + def _check_pkg(pkg: str) -> None: if not util.find_spec(pkg): pkg_kebab = pkg.replace("_", "-") diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index fd5e3e80f21fd..c06e844073150 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -1,4 +1,5 @@ import os +from typing import Optional from unittest import mock import pytest @@ -26,7 +27,6 @@ def test_all_imports() -> None: "langchain_openai", "langchain_anthropic", "langchain_fireworks", - "langchain_mistralai", "langchain_groq", ) @pytest.mark.parametrize( @@ -38,10 +38,14 @@ def test_all_imports() -> None: ("mixtral-8x7b-32768", "groq"), ], ) -def test_init_chat_model(model_name: str, model_provider: str) -> None: - _: BaseChatModel = init_chat_model( +def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None: + llm1: BaseChatModel = init_chat_model( model_name, model_provider=model_provider, api_key="foo" ) + llm2: BaseChatModel = init_chat_model( + f"{model_provider}:{model_name}", api_key="foo" + ) + assert llm1.dict() == llm2.dict() def test_init_missing_dep() -> None: