From b1480425183d9993d61df2b9b5fc20aa6c7a4285 Mon Sep 17 00:00:00 2001 From: scosman Date: Sun, 1 Sep 2024 22:32:52 -0400 Subject: [PATCH] Make AWS Bedrock work for structured data by upgrading to new ChatProvider (Converse) Also improve the provider data structure, as individual providers may or may not support structured output Improve tests. Note: all tests now pass! --- libs/core/kiln_ai/adapters/ml_model_list.py | 169 +++++++++++------- libs/core/kiln_ai/adapters/prompt_adapters.py | 20 +-- .../kiln_ai/adapters/test_prompt_adaptors.py | 12 +- .../adapters/test_structured_output.py | 24 +-- 4 files changed, 132 insertions(+), 93 deletions(-) diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 7462444..25c966a 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -3,7 +3,7 @@ from typing import Dict, List import httpx -from langchain_aws import ChatBedrock +from langchain_aws import ChatBedrockConverse from langchain_core.language_models.chat_models import BaseChatModel from langchain_groq import ChatGroq from langchain_ollama import ChatOllama @@ -11,7 +11,7 @@ from pydantic import BaseModel -class ModelProviders(str, Enum): +class ModelProviderName(str, Enum): openai = "openai" groq = "groq" amazon_bedrock = "amazon_bedrock" @@ -34,10 +34,17 @@ class ModelName(str, Enum): mistral_large = "mistral_large" +class KilnModelProvider(BaseModel): + name: ModelProviderName + # Allow overriding the model level setting + supports_structured_output: bool = True + provider_options: Dict = {} + + class KilnModel(BaseModel): model_family: str model_name: str - provider_config: Dict[ModelProviders, Dict] + providers: List[KilnModelProvider] supports_structured_output: bool = True @@ -46,79 +53,103 @@ class KilnModel(BaseModel): KilnModel( model_family=ModelFamily.gpt, model_name=ModelName.gpt_4o_mini, - provider_config={ - ModelProviders.openai: { - "model": "gpt-4o-mini", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.openai, + provider_options={"model": "gpt-4o-mini"}, + ), + ], ), # GPT 4o KilnModel( model_family=ModelFamily.gpt, model_name=ModelName.gpt_4o, - provider_config={ - ModelProviders.openai: { - "model": "gpt-4o", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.openai, + provider_options={"model": "gpt-4o"}, + ), + ], ), # Llama 3.1-8b KilnModel( model_family=ModelFamily.llama, model_name=ModelName.llama_3_1_8b, - provider_config={ - ModelProviders.groq: { - "model": "llama-3.1-8b-instant", - }, - # Doesn't reliably work with tool calling / structured output - # https://www.reddit.com/r/LocalLLaMA/comments/1ece00h/llama_31_8b_instruct_functiontool_calling_seems/ - # ModelProviders.amazon_bedrock: { - # "model_id": "meta.llama3-1-8b-instruct-v1:0", - # "region_name": "us-west-2", # Llama 3.1 only in west-2 - # }, - ModelProviders.ollama: { - "model": "llama3.1", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.groq, + provider_options={"model": "llama-3.1-8b-instant"}, + ), + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + # bedrock llama doesn't support structured output, should check again latet + supports_structured_output=False, + provider_options={ + "model": "meta.llama3-1-8b-instruct-v1:0", + "region_name": "us-west-2", # Llama 3.1 only in west-2 + }, + ), + KilnModelProvider( + name=ModelProviderName.ollama, + provider_options={"model": "llama3.1"}, + ), + ], ), # Llama 3.1 70b KilnModel( model_family=ModelFamily.llama, model_name=ModelName.llama_3_1_70b, - provider_config={ - ModelProviders.groq: { - "model": "llama-3.1-70b-versatile", - }, - ModelProviders.amazon_bedrock: { - "model_id": "meta.llama3-1-70b-instruct-v1:0", - "region_name": "us-west-2", # Llama 3.1 only in west-2 - }, - # ModelProviders.ollama: { - # "model": "llama3.1:70b", - # }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.groq, + provider_options={"model": "llama-3.1-70b-versatile"}, + ), + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + # bedrock llama doesn't support structured output, should check again latet + supports_structured_output=False, + provider_options={ + "model": "meta.llama3-1-70b-instruct-v1:0", + "region_name": "us-west-2", # Llama 3.1 only in west-2 + }, + ), + # TODO: enable once tests update to check if model is available + # KilnModelProvider( + # provider=ModelProviders.ollama, + # provider_options={"model": "llama3.1:70b"}, + # ), + ], ), # Mistral Large KilnModel( model_family=ModelFamily.mistral, model_name=ModelName.mistral_large, - provider_config={ - ModelProviders.amazon_bedrock: { - "model_id": "mistral.mistral-large-2407-v1:0", - "region_name": "us-west-2", # only in west-2 - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + provider_options={ + "model": "mistral.mistral-large-2407-v1:0", + "region_name": "us-west-2", # only in west-2 + }, + ), + # TODO: enable once tests update to check if model is available + # KilnModelProvider( + # provider=ModelProviders.ollama, + # provider_options={"model": "mistral-large"}, + # ), + ], ), # Phi 3.5 KilnModel( model_family=ModelFamily.phi, model_name=ModelName.phi_3_5, supports_structured_output=False, - provider_config={ - ModelProviders.ollama: { - "model": "phi3.5", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.ollama, + provider_options={"model": "phi3.5"}, + ), + ], ), ] @@ -135,27 +166,29 @@ def langchain_model_from( raise ValueError(f"Model {model_name} not found") # If a provider is provided, select the provider from the model's provider_config - provider: ModelProviders | None = None - if model.provider_config is None or len(model.provider_config) == 0: + provider: KilnModelProvider | None = None + if model.providers is None or len(model.providers) == 0: raise ValueError(f"Model {model_name} has no providers") - if provider_name is None: + elif provider_name is None: # TODO: priority order - provider_name = list(model.provider_config.keys())[0] - if provider_name not in ModelProviders.__members__: - raise ValueError(f"Invalid provider: {provider_name}") - if provider_name not in model.provider_config: + provider = model.providers[0] + else: + provider = next( + filter(lambda p: p.name == provider_name, model.providers), None + ) + if provider is None: raise ValueError(f"Provider {provider_name} not found for model {model_name}") - model_provider_props = model.provider_config[provider_name] - provider = ModelProviders(provider_name) - - if provider == ModelProviders.openai: - return ChatOpenAI(**model_provider_props) - elif provider == ModelProviders.groq: - return ChatGroq(**model_provider_props) - elif provider == ModelProviders.amazon_bedrock: - return ChatBedrock(**model_provider_props) - elif provider == ModelProviders.ollama: - return ChatOllama(**model_provider_props, base_url=ollama_base_url()) + + if provider.name == ModelProviderName.openai: + return ChatOpenAI(**provider.provider_options) + elif provider.name == ModelProviderName.groq: + return ChatGroq(**provider.provider_options) + elif provider.name == ModelProviderName.amazon_bedrock: + return ChatBedrockConverse(**provider.provider_options) + elif provider.name == ModelProviderName.ollama: + return ChatOllama(**provider.provider_options, base_url=ollama_base_url()) + else: + raise ValueError(f"Invalid model or provider: {model_name} - {provider_name}") def ollama_base_url(): diff --git a/libs/core/kiln_ai/adapters/prompt_adapters.py b/libs/core/kiln_ai/adapters/prompt_adapters.py index ac1b84f..5068d8b 100644 --- a/libs/core/kiln_ai/adapters/prompt_adapters.py +++ b/libs/core/kiln_ai/adapters/prompt_adapters.py @@ -3,7 +3,7 @@ import kiln_ai.datamodel.models as models from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages.base import BaseMessageChunk +from langchain_core.messages.base import BaseMessage from .base_adapter import BaseAdapter from .ml_model_list import langchain_model_from @@ -56,24 +56,20 @@ async def run(self, input: str) -> str: # TODO cleanup prompt = self.build_prompt() prompt += f"\n\n{input}" + response = self.model.invoke(prompt) if self.__is_structured: - response = self.model.invoke(prompt) if not isinstance(response, dict) or "parsed" not in response: raise RuntimeError(f"structured response not returned: {response}") structured_response = response["parsed"] # TODO: not JSON, use a dict here return json.dumps(structured_response) else: - answer = "" - async for chunk in self.model.astream(prompt): - if not isinstance(chunk, BaseMessageChunk) or not isinstance( - chunk.content, str - ): - raise RuntimeError(f"chunk is not a string: {chunk}") - - print(chunk.content, end="", flush=True) - answer += chunk.content - return answer + if not isinstance(response, BaseMessage): + raise RuntimeError(f"response is not a BaseMessage: {response}") + text_content = response.content + if not isinstance(text_content, str): + raise RuntimeError(f"response is not a string: {text_content}") + return text_content class SimplePromptAdapter(BaseLangChainPromptAdapter): diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index d7244da..4f2eab3 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -56,7 +56,7 @@ async def test_amazon_bedrock(tmp_path): or os.getenv("AWS_ACCESS_KEY_ID") is None ): pytest.skip("AWS keys not set") - await run_simple_test(tmp_path, "llama_3_1_70b", "amazon_bedrock") + await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock") async def test_mock(tmp_path): @@ -72,8 +72,14 @@ async def test_mock(tmp_path): async def test_all_built_in_models(tmp_path): task = build_test_task(tmp_path) for model in built_in_models: - for provider in model.provider_config: - await run_simple_task(task, model.model_name, provider) + for provider in model.providers: + try: + print(f"Running {model.model_name} {provider.name}") + await run_simple_task(task, model.model_name, provider.name) + except Exception as e: + raise RuntimeError( + f"Error running {model.model_name} {provider}" + ) from e def build_test_task(tmp_path: Path): diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index c84bbe3..0c436eb 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -4,8 +4,6 @@ import kiln_ai.datamodel.models as models import pytest from kiln_ai.adapters.ml_model_list import ( - ModelName, - ModelProviders, built_in_models, ollama_online, ) @@ -20,10 +18,7 @@ async def test_structured_output_groq(tmp_path): @pytest.mark.paid async def test_structured_output_bedrock(tmp_path): - pytest.skip( - "bedrock not working with structured output. New API, hopefully fixed soon by langchain" - ) - await run_structured_output_test(tmp_path, "llama_3_1_8b", "amazon_bedrock") + await run_structured_output_test(tmp_path, "mistral_large", "amazon_bedrock") @pytest.mark.ollama @@ -35,6 +30,11 @@ async def test_structured_output_ollama_phi(tmp_path): await run_structured_output_test(tmp_path, "phi_3_5", "ollama") +@pytest.mark.ollama +async def test_structured_output_gpt_4o_mini(tmp_path): + await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai") + + @pytest.mark.ollama async def test_structured_output_ollama_llama(tmp_path): if not await ollama_online(): @@ -51,13 +51,17 @@ async def test_all_built_in_models_structured_output(tmp_path): f"Skipping {model.model_name} because it does not support structured output" ) continue - for provider in model.provider_config: - if provider == ModelProviders.amazon_bedrock: - # TODO: bedrock not working, should investigate and fix + for provider in model.providers: + if not provider.supports_structured_output: + print( + f"Skipping {model.model_name} {provider.name} because it does not support structured output" + ) continue try: print(f"Running {model.model_name} {provider}") - await run_structured_output_test(tmp_path, model.model_name, provider) + await run_structured_output_test( + tmp_path, model.model_name, provider.name + ) except Exception as e: raise RuntimeError( f"Error running {model.model_name} {provider}"