Skip to content

Commit

Permalink
langchain[patch]: init_chat_model provider in model string (#28367)
Browse files Browse the repository at this point in the history
```python
llm = init_chat_model("openai:gpt-4o")
```
  • Loading branch information
baskaryan authored Nov 27, 2024
1 parent 8adc4a5 commit ffe7bd4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
26 changes: 19 additions & 7 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,7 @@ class GetPopulation(BaseModel):
def _init_chat_model_helper(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel:
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
Expand Down Expand Up @@ -461,6 +455,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None


def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
if (
not model_provider
and ":" in model
and model.split(":")[0] in _SUPPORTED_PROVIDERS
):
model_provider = model.split(":")[0]
model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider


def _check_pkg(pkg: str) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
Expand Down
10 changes: 7 additions & 3 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional
from unittest import mock

import pytest
Expand Down Expand Up @@ -26,7 +27,6 @@ def test_all_imports() -> None:
"langchain_openai",
"langchain_anthropic",
"langchain_fireworks",
"langchain_mistralai",
"langchain_groq",
)
@pytest.mark.parametrize(
Expand All @@ -38,10 +38,14 @@ def test_all_imports() -> None:
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: str) -> None:
_: BaseChatModel = init_chat_model(
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
llm1: BaseChatModel = init_chat_model(
model_name, model_provider=model_provider, api_key="foo"
)
llm2: BaseChatModel = init_chat_model(
f"{model_provider}:{model_name}", api_key="foo"
)
assert llm1.dict() == llm2.dict()


def test_init_missing_dep() -> None:
Expand Down

0 comments on commit ffe7bd4

Please sign in to comment.