Skip to content

Commit

Permalink
Make AWS Bedrock work for structured data by upgrading to new ChatPro…
Browse files Browse the repository at this point in the history
…vider (Converse)

Also improve the provider data structure, as individual providers may or may not support structured output

Improve tests.

Note: all tests now pass!
  • Loading branch information
scosman committed Sep 2, 2024
1 parent a91d4e6 commit b148042
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 93 deletions.
169 changes: 101 additions & 68 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Dict, List

import httpx
from langchain_aws import ChatBedrock
from langchain_aws import ChatBedrockConverse
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from pydantic import BaseModel


class ModelProviders(str, Enum):
class ModelProviderName(str, Enum):
openai = "openai"
groq = "groq"
amazon_bedrock = "amazon_bedrock"
Expand All @@ -34,10 +34,17 @@ class ModelName(str, Enum):
mistral_large = "mistral_large"


class KilnModelProvider(BaseModel):
name: ModelProviderName
# Allow overriding the model level setting
supports_structured_output: bool = True
provider_options: Dict = {}


class KilnModel(BaseModel):
model_family: str
model_name: str
provider_config: Dict[ModelProviders, Dict]
providers: List[KilnModelProvider]
supports_structured_output: bool = True


Expand All @@ -46,79 +53,103 @@ class KilnModel(BaseModel):
KilnModel(
model_family=ModelFamily.gpt,
model_name=ModelName.gpt_4o_mini,
provider_config={
ModelProviders.openai: {
"model": "gpt-4o-mini",
},
},
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
provider_options={"model": "gpt-4o-mini"},
),
],
),
# GPT 4o
KilnModel(
model_family=ModelFamily.gpt,
model_name=ModelName.gpt_4o,
provider_config={
ModelProviders.openai: {
"model": "gpt-4o",
},
},
providers=[
KilnModelProvider(
name=ModelProviderName.openai,
provider_options={"model": "gpt-4o"},
),
],
),
# Llama 3.1-8b
KilnModel(
model_family=ModelFamily.llama,
model_name=ModelName.llama_3_1_8b,
provider_config={
ModelProviders.groq: {
"model": "llama-3.1-8b-instant",
},
# Doesn't reliably work with tool calling / structured output
# https://www.reddit.com/r/LocalLLaMA/comments/1ece00h/llama_31_8b_instruct_functiontool_calling_seems/
# ModelProviders.amazon_bedrock: {
# "model_id": "meta.llama3-1-8b-instruct-v1:0",
# "region_name": "us-west-2", # Llama 3.1 only in west-2
# },
ModelProviders.ollama: {
"model": "llama3.1",
},
},
providers=[
KilnModelProvider(
name=ModelProviderName.groq,
provider_options={"model": "llama-3.1-8b-instant"},
),
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
# bedrock llama doesn't support structured output, should check again latet
supports_structured_output=False,
provider_options={
"model": "meta.llama3-1-8b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
},
),
KilnModelProvider(
name=ModelProviderName.ollama,
provider_options={"model": "llama3.1"},
),
],
),
# Llama 3.1 70b
KilnModel(
model_family=ModelFamily.llama,
model_name=ModelName.llama_3_1_70b,
provider_config={
ModelProviders.groq: {
"model": "llama-3.1-70b-versatile",
},
ModelProviders.amazon_bedrock: {
"model_id": "meta.llama3-1-70b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
},
# ModelProviders.ollama: {
# "model": "llama3.1:70b",
# },
},
providers=[
KilnModelProvider(
name=ModelProviderName.groq,
provider_options={"model": "llama-3.1-70b-versatile"},
),
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
# bedrock llama doesn't support structured output, should check again latet
supports_structured_output=False,
provider_options={
"model": "meta.llama3-1-70b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
},
),
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# provider=ModelProviders.ollama,
# provider_options={"model": "llama3.1:70b"},
# ),
],
),
# Mistral Large
KilnModel(
model_family=ModelFamily.mistral,
model_name=ModelName.mistral_large,
provider_config={
ModelProviders.amazon_bedrock: {
"model_id": "mistral.mistral-large-2407-v1:0",
"region_name": "us-west-2", # only in west-2
},
},
providers=[
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
provider_options={
"model": "mistral.mistral-large-2407-v1:0",
"region_name": "us-west-2", # only in west-2
},
),
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# provider=ModelProviders.ollama,
# provider_options={"model": "mistral-large"},
# ),
],
),
# Phi 3.5
KilnModel(
model_family=ModelFamily.phi,
model_name=ModelName.phi_3_5,
supports_structured_output=False,
provider_config={
ModelProviders.ollama: {
"model": "phi3.5",
},
},
providers=[
KilnModelProvider(
name=ModelProviderName.ollama,
provider_options={"model": "phi3.5"},
),
],
),
]

Expand All @@ -135,27 +166,29 @@ def langchain_model_from(
raise ValueError(f"Model {model_name} not found")

# If a provider is provided, select the provider from the model's provider_config
provider: ModelProviders | None = None
if model.provider_config is None or len(model.provider_config) == 0:
provider: KilnModelProvider | None = None
if model.providers is None or len(model.providers) == 0:
raise ValueError(f"Model {model_name} has no providers")
if provider_name is None:
elif provider_name is None:
# TODO: priority order
provider_name = list(model.provider_config.keys())[0]
if provider_name not in ModelProviders.__members__:
raise ValueError(f"Invalid provider: {provider_name}")
if provider_name not in model.provider_config:
provider = model.providers[0]
else:
provider = next(
filter(lambda p: p.name == provider_name, model.providers), None
)
if provider is None:
raise ValueError(f"Provider {provider_name} not found for model {model_name}")
model_provider_props = model.provider_config[provider_name]
provider = ModelProviders(provider_name)

if provider == ModelProviders.openai:
return ChatOpenAI(**model_provider_props)
elif provider == ModelProviders.groq:
return ChatGroq(**model_provider_props)
elif provider == ModelProviders.amazon_bedrock:
return ChatBedrock(**model_provider_props)
elif provider == ModelProviders.ollama:
return ChatOllama(**model_provider_props, base_url=ollama_base_url())

if provider.name == ModelProviderName.openai:
return ChatOpenAI(**provider.provider_options)
elif provider.name == ModelProviderName.groq:
return ChatGroq(**provider.provider_options)
elif provider.name == ModelProviderName.amazon_bedrock:
return ChatBedrockConverse(**provider.provider_options)
elif provider.name == ModelProviderName.ollama:
return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
else:
raise ValueError(f"Invalid model or provider: {model_name} - {provider_name}")


def ollama_base_url():
Expand Down
20 changes: 8 additions & 12 deletions libs/core/kiln_ai/adapters/prompt_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import kiln_ai.datamodel.models as models
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages.base import BaseMessageChunk
from langchain_core.messages.base import BaseMessage

from .base_adapter import BaseAdapter
from .ml_model_list import langchain_model_from
Expand Down Expand Up @@ -56,24 +56,20 @@ async def run(self, input: str) -> str:
# TODO cleanup
prompt = self.build_prompt()
prompt += f"\n\n{input}"
response = self.model.invoke(prompt)
if self.__is_structured:
response = self.model.invoke(prompt)
if not isinstance(response, dict) or "parsed" not in response:
raise RuntimeError(f"structured response not returned: {response}")
structured_response = response["parsed"]
# TODO: not JSON, use a dict here
return json.dumps(structured_response)
else:
answer = ""
async for chunk in self.model.astream(prompt):
if not isinstance(chunk, BaseMessageChunk) or not isinstance(
chunk.content, str
):
raise RuntimeError(f"chunk is not a string: {chunk}")

print(chunk.content, end="", flush=True)
answer += chunk.content
return answer
if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")
text_content = response.content
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")
return text_content


class SimplePromptAdapter(BaseLangChainPromptAdapter):
Expand Down
12 changes: 9 additions & 3 deletions libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test_amazon_bedrock(tmp_path):
or os.getenv("AWS_ACCESS_KEY_ID") is None
):
pytest.skip("AWS keys not set")
await run_simple_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")
await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")


async def test_mock(tmp_path):
Expand All @@ -72,8 +72,14 @@ async def test_mock(tmp_path):
async def test_all_built_in_models(tmp_path):
task = build_test_task(tmp_path)
for model in built_in_models:
for provider in model.provider_config:
await run_simple_task(task, model.model_name, provider)
for provider in model.providers:
try:
print(f"Running {model.model_name} {provider.name}")
await run_simple_task(task, model.model_name, provider.name)
except Exception as e:
raise RuntimeError(
f"Error running {model.model_name} {provider}"
) from e


def build_test_task(tmp_path: Path):
Expand Down
24 changes: 14 additions & 10 deletions libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import kiln_ai.datamodel.models as models
import pytest
from kiln_ai.adapters.ml_model_list import (
ModelName,
ModelProviders,
built_in_models,
ollama_online,
)
Expand All @@ -20,10 +18,7 @@ async def test_structured_output_groq(tmp_path):

@pytest.mark.paid
async def test_structured_output_bedrock(tmp_path):
pytest.skip(
"bedrock not working with structured output. New API, hopefully fixed soon by langchain"
)
await run_structured_output_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")
await run_structured_output_test(tmp_path, "mistral_large", "amazon_bedrock")


@pytest.mark.ollama
Expand All @@ -35,6 +30,11 @@ async def test_structured_output_ollama_phi(tmp_path):
await run_structured_output_test(tmp_path, "phi_3_5", "ollama")


@pytest.mark.ollama
async def test_structured_output_gpt_4o_mini(tmp_path):
await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")


@pytest.mark.ollama
async def test_structured_output_ollama_llama(tmp_path):
if not await ollama_online():
Expand All @@ -51,13 +51,17 @@ async def test_all_built_in_models_structured_output(tmp_path):
f"Skipping {model.model_name} because it does not support structured output"
)
continue
for provider in model.provider_config:
if provider == ModelProviders.amazon_bedrock:
# TODO: bedrock not working, should investigate and fix
for provider in model.providers:
if not provider.supports_structured_output:
print(
f"Skipping {model.model_name} {provider.name} because it does not support structured output"
)
continue
try:
print(f"Running {model.model_name} {provider}")
await run_structured_output_test(tmp_path, model.model_name, provider)
await run_structured_output_test(
tmp_path, model.model_name, provider.name
)
except Exception as e:
raise RuntimeError(
f"Error running {model.model_name} {provider}"
Expand Down

0 comments on commit b148042

Please sign in to comment.