Skip to content

Commit

Permalink
Openrouter support: every model from one provider!
Browse files Browse the repository at this point in the history
Also add mistral_nemo
  • Loading branch information
scosman committed Sep 9, 2024
1 parent 737aba6 commit f31f447
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 10 deletions.
12 changes: 11 additions & 1 deletion libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,21 @@ async def _run(self, input: str) -> Dict | str:
):
raise RuntimeError(f"structured response not returned: {response}")
structured_response = response["parsed"]
return structured_response
return self._munge_response(structured_response)
else:
if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")
text_content = response.content
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")
return text_content

def _munge_response(self, response: Dict) -> Dict:
# Mistral Large tool calling format is a bit different. Convert to standard format.
if (
"name" in response
and response["name"] == "task_response"
and "arguments" in response
):
return response["arguments"]
return response
62 changes: 59 additions & 3 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from enum import Enum
from os import getenv
from typing import Dict, List

import httpx
Expand All @@ -16,6 +17,7 @@ class ModelProviderName(str, Enum):
groq = "groq"
amazon_bedrock = "amazon_bedrock"
ollama = "ollama"
openrouter = "openrouter"


class ModelFamily(str, Enum):
Expand All @@ -32,6 +34,7 @@ class ModelName(str, Enum):
gpt_4o = "gpt_4o"
phi_3_5 = "phi_3_5"
mistral_large = "mistral_large"
mistral_nemo = "mistral_nemo"


class KilnModelProvider(BaseModel):
Expand All @@ -58,6 +61,10 @@ class KilnModel(BaseModel):
name=ModelProviderName.openai,
provider_options={"model": "gpt-4o-mini"},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "openai/gpt-4o-mini"},
),
],
),
# GPT 4o
Expand All @@ -69,6 +76,10 @@ class KilnModel(BaseModel):
name=ModelProviderName.openai,
provider_options={"model": "gpt-4o"},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "openai/gpt-4o-2024-08-06"},
),
],
),
# Llama 3.1-8b
Expand All @@ -82,8 +93,6 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
# bedrock llama doesn't support structured output, should check again latet
supports_structured_output=False,
provider_options={
"model": "meta.llama3-1-8b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
Expand All @@ -93,6 +102,10 @@ class KilnModel(BaseModel):
name=ModelProviderName.ollama,
provider_options={"model": "llama3.1"},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
),
],
),
# Llama 3.1 70b
Expand All @@ -106,20 +119,35 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
# bedrock llama doesn't support structured output, should check again latet
# TODO: this should work but a bug in the bedrock response schema
supports_structured_output=False,
provider_options={
"model": "meta.llama3-1-70b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
),
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# provider=ModelProviders.ollama,
# provider_options={"model": "llama3.1:70b"},
# ),
],
),
# Mistral Nemo
KilnModel(
family=ModelFamily.mistral,
name=ModelName.mistral_nemo,
providers=[
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "mistralai/mistral-nemo"},
),
],
),
# Mistral Large
KilnModel(
family=ModelFamily.mistral,
Expand All @@ -132,6 +160,10 @@ class KilnModel(BaseModel):
"region_name": "us-west-2", # only in west-2
},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "mistralai/mistral-large"},
),
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# provider=ModelProviders.ollama,
Expand All @@ -149,6 +181,10 @@ class KilnModel(BaseModel):
name=ModelProviderName.ollama,
provider_options={"model": "phi3.5"},
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
),
],
),
]
Expand Down Expand Up @@ -185,6 +221,26 @@ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseCha
return ChatBedrockConverse(**provider.provider_options)
elif provider.name == ModelProviderName.ollama:
return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
elif provider.name == ModelProviderName.openrouter:
api_key = getenv("OPENROUTER_API_KEY")
if api_key is None:
raise ValueError(
"OPENROUTER_API_KEY environment variable must be set to use OpenRouter. "
"Get your API key from https://openrouter.ai/settings/keys"
)
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
return ChatOpenAI(
**provider.provider_options,
openai_api_key=api_key, # type: ignore[arg-type]
openai_api_base=base_url, # type: ignore[arg-type]
# TODO: should pass these
# model_kwargs={
# "headers": {
# "HTTP-Referer": "https://kiln-ai.com",
# "X-Title": "KilnAI",
# }
# },
)
else:
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")

Expand Down
13 changes: 8 additions & 5 deletions libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ async def test_groq(tmp_path):
await run_simple_test(tmp_path, "llama_3_1_8b", "groq")


@pytest.mark.paid
async def test_openrouter(tmp_path):
await run_simple_test(tmp_path, "llama_3_1_8b", "openrouter")


@pytest.mark.ollama
async def test_ollama_phi(tmp_path):
# Check if Ollama API is running
Expand Down Expand Up @@ -74,12 +79,10 @@ async def test_all_built_in_models(tmp_path):
for model in built_in_models:
for provider in model.providers:
try:
print(f"Running {model.model_name} {provider.name}")
await run_simple_task(task, model.model_name, provider.name)
print(f"Running {model.name} {provider.name}")
await run_simple_task(task, model.name, provider.name)
except Exception as e:
raise RuntimeError(
f"Error running {model.model_name} {provider}"
) from e
raise RuntimeError(f"Error running {model.name} {provider}") from e


def build_test_task(tmp_path: Path):
Expand Down
7 changes: 6 additions & 1 deletion libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ async def test_structured_output_groq(tmp_path):
await run_structured_output_test(tmp_path, "llama_3_1_8b", "groq")


@pytest.mark.paid
async def test_structured_output_openrouter(tmp_path):
await run_structured_output_test(tmp_path, "mistral_nemo", "openrouter")


@pytest.mark.paid
async def test_structured_output_bedrock(tmp_path):
await run_structured_output_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")
Expand Down Expand Up @@ -98,7 +103,7 @@ async def test_all_built_in_models_structured_output(tmp_path):
)
continue
try:
print(f"Running {model.name} {provider}")
print(f"Running {model.name} {provider.name}")
await run_structured_output_test(tmp_path, model.name, provider.name)
except Exception as e:
raise RuntimeError(f"Error running {model.name} {provider}") from e
Expand Down

0 comments on commit f31f447

Please sign in to comment.