diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 7462444..25c966a 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -3,7 +3,7 @@ from typing import Dict, List import httpx -from langchain_aws import ChatBedrock +from langchain_aws import ChatBedrockConverse from langchain_core.language_models.chat_models import BaseChatModel from langchain_groq import ChatGroq from langchain_ollama import ChatOllama @@ -11,7 +11,7 @@ from pydantic import BaseModel -class ModelProviders(str, Enum): +class ModelProviderName(str, Enum): openai = "openai" groq = "groq" amazon_bedrock = "amazon_bedrock" @@ -34,10 +34,17 @@ class ModelName(str, Enum): mistral_large = "mistral_large" +class KilnModelProvider(BaseModel): + name: ModelProviderName + # Allow overriding the model level setting + supports_structured_output: bool = True + provider_options: Dict = {} + + class KilnModel(BaseModel): model_family: str model_name: str - provider_config: Dict[ModelProviders, Dict] + providers: List[KilnModelProvider] supports_structured_output: bool = True @@ -46,79 +53,103 @@ class KilnModel(BaseModel): KilnModel( model_family=ModelFamily.gpt, model_name=ModelName.gpt_4o_mini, - provider_config={ - ModelProviders.openai: { - "model": "gpt-4o-mini", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.openai, + provider_options={"model": "gpt-4o-mini"}, + ), + ], ), # GPT 4o KilnModel( model_family=ModelFamily.gpt, model_name=ModelName.gpt_4o, - provider_config={ - ModelProviders.openai: { - "model": "gpt-4o", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.openai, + provider_options={"model": "gpt-4o"}, + ), + ], ), # Llama 3.1-8b KilnModel( model_family=ModelFamily.llama, model_name=ModelName.llama_3_1_8b, - provider_config={ - ModelProviders.groq: { - "model": "llama-3.1-8b-instant", - }, - # Doesn't reliably work with tool calling / structured output - # https://www.reddit.com/r/LocalLLaMA/comments/1ece00h/llama_31_8b_instruct_functiontool_calling_seems/ - # ModelProviders.amazon_bedrock: { - # "model_id": "meta.llama3-1-8b-instruct-v1:0", - # "region_name": "us-west-2", # Llama 3.1 only in west-2 - # }, - ModelProviders.ollama: { - "model": "llama3.1", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.groq, + provider_options={"model": "llama-3.1-8b-instant"}, + ), + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + # bedrock llama doesn't support structured output, should check again latet + supports_structured_output=False, + provider_options={ + "model": "meta.llama3-1-8b-instruct-v1:0", + "region_name": "us-west-2", # Llama 3.1 only in west-2 + }, + ), + KilnModelProvider( + name=ModelProviderName.ollama, + provider_options={"model": "llama3.1"}, + ), + ], ), # Llama 3.1 70b KilnModel( model_family=ModelFamily.llama, model_name=ModelName.llama_3_1_70b, - provider_config={ - ModelProviders.groq: { - "model": "llama-3.1-70b-versatile", - }, - ModelProviders.amazon_bedrock: { - "model_id": "meta.llama3-1-70b-instruct-v1:0", - "region_name": "us-west-2", # Llama 3.1 only in west-2 - }, - # ModelProviders.ollama: { - # "model": "llama3.1:70b", - # }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.groq, + provider_options={"model": "llama-3.1-70b-versatile"}, + ), + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + # bedrock llama doesn't support structured output, should check again latet + supports_structured_output=False, + provider_options={ + "model": "meta.llama3-1-70b-instruct-v1:0", + "region_name": "us-west-2", # Llama 3.1 only in west-2 + }, + ), + # TODO: enable once tests update to check if model is available + # KilnModelProvider( + # provider=ModelProviders.ollama, + # provider_options={"model": "llama3.1:70b"}, + # ), + ], ), # Mistral Large KilnModel( model_family=ModelFamily.mistral, model_name=ModelName.mistral_large, - provider_config={ - ModelProviders.amazon_bedrock: { - "model_id": "mistral.mistral-large-2407-v1:0", - "region_name": "us-west-2", # only in west-2 - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.amazon_bedrock, + provider_options={ + "model": "mistral.mistral-large-2407-v1:0", + "region_name": "us-west-2", # only in west-2 + }, + ), + # TODO: enable once tests update to check if model is available + # KilnModelProvider( + # provider=ModelProviders.ollama, + # provider_options={"model": "mistral-large"}, + # ), + ], ), # Phi 3.5 KilnModel( model_family=ModelFamily.phi, model_name=ModelName.phi_3_5, supports_structured_output=False, - provider_config={ - ModelProviders.ollama: { - "model": "phi3.5", - }, - }, + providers=[ + KilnModelProvider( + name=ModelProviderName.ollama, + provider_options={"model": "phi3.5"}, + ), + ], ), ] @@ -135,27 +166,29 @@ def langchain_model_from( raise ValueError(f"Model {model_name} not found") # If a provider is provided, select the provider from the model's provider_config - provider: ModelProviders | None = None - if model.provider_config is None or len(model.provider_config) == 0: + provider: KilnModelProvider | None = None + if model.providers is None or len(model.providers) == 0: raise ValueError(f"Model {model_name} has no providers") - if provider_name is None: + elif provider_name is None: # TODO: priority order - provider_name = list(model.provider_config.keys())[0] - if provider_name not in ModelProviders.__members__: - raise ValueError(f"Invalid provider: {provider_name}") - if provider_name not in model.provider_config: + provider = model.providers[0] + else: + provider = next( + filter(lambda p: p.name == provider_name, model.providers), None + ) + if provider is None: raise ValueError(f"Provider {provider_name} not found for model {model_name}") - model_provider_props = model.provider_config[provider_name] - provider = ModelProviders(provider_name) - - if provider == ModelProviders.openai: - return ChatOpenAI(**model_provider_props) - elif provider == ModelProviders.groq: - return ChatGroq(**model_provider_props) - elif provider == ModelProviders.amazon_bedrock: - return ChatBedrock(**model_provider_props) - elif provider == ModelProviders.ollama: - return ChatOllama(**model_provider_props, base_url=ollama_base_url()) + + if provider.name == ModelProviderName.openai: + return ChatOpenAI(**provider.provider_options) + elif provider.name == ModelProviderName.groq: + return ChatGroq(**provider.provider_options) + elif provider.name == ModelProviderName.amazon_bedrock: + return ChatBedrockConverse(**provider.provider_options) + elif provider.name == ModelProviderName.ollama: + return ChatOllama(**provider.provider_options, base_url=ollama_base_url()) + else: + raise ValueError(f"Invalid model or provider: {model_name} - {provider_name}") def ollama_base_url(): diff --git a/libs/core/kiln_ai/adapters/prompt_adapters.py b/libs/core/kiln_ai/adapters/prompt_adapters.py index ac1b84f..5068d8b 100644 --- a/libs/core/kiln_ai/adapters/prompt_adapters.py +++ b/libs/core/kiln_ai/adapters/prompt_adapters.py @@ -3,7 +3,7 @@ import kiln_ai.datamodel.models as models from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages.base import BaseMessageChunk +from langchain_core.messages.base import BaseMessage from .base_adapter import BaseAdapter from .ml_model_list import langchain_model_from @@ -56,24 +56,20 @@ async def run(self, input: str) -> str: # TODO cleanup prompt = self.build_prompt() prompt += f"\n\n{input}" + response = self.model.invoke(prompt) if self.__is_structured: - response = self.model.invoke(prompt) if not isinstance(response, dict) or "parsed" not in response: raise RuntimeError(f"structured response not returned: {response}") structured_response = response["parsed"] # TODO: not JSON, use a dict here return json.dumps(structured_response) else: - answer = "" - async for chunk in self.model.astream(prompt): - if not isinstance(chunk, BaseMessageChunk) or not isinstance( - chunk.content, str - ): - raise RuntimeError(f"chunk is not a string: {chunk}") - - print(chunk.content, end="", flush=True) - answer += chunk.content - return answer + if not isinstance(response, BaseMessage): + raise RuntimeError(f"response is not a BaseMessage: {response}") + text_content = response.content + if not isinstance(text_content, str): + raise RuntimeError(f"response is not a string: {text_content}") + return text_content class SimplePromptAdapter(BaseLangChainPromptAdapter): diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index d7244da..4f2eab3 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -56,7 +56,7 @@ async def test_amazon_bedrock(tmp_path): or os.getenv("AWS_ACCESS_KEY_ID") is None ): pytest.skip("AWS keys not set") - await run_simple_test(tmp_path, "llama_3_1_70b", "amazon_bedrock") + await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock") async def test_mock(tmp_path): @@ -72,8 +72,14 @@ async def test_mock(tmp_path): async def test_all_built_in_models(tmp_path): task = build_test_task(tmp_path) for model in built_in_models: - for provider in model.provider_config: - await run_simple_task(task, model.model_name, provider) + for provider in model.providers: + try: + print(f"Running {model.model_name} {provider.name}") + await run_simple_task(task, model.model_name, provider.name) + except Exception as e: + raise RuntimeError( + f"Error running {model.model_name} {provider}" + ) from e def build_test_task(tmp_path: Path): diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index c84bbe3..0c436eb 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -4,8 +4,6 @@ import kiln_ai.datamodel.models as models import pytest from kiln_ai.adapters.ml_model_list import ( - ModelName, - ModelProviders, built_in_models, ollama_online, ) @@ -20,10 +18,7 @@ async def test_structured_output_groq(tmp_path): @pytest.mark.paid async def test_structured_output_bedrock(tmp_path): - pytest.skip( - "bedrock not working with structured output. New API, hopefully fixed soon by langchain" - ) - await run_structured_output_test(tmp_path, "llama_3_1_8b", "amazon_bedrock") + await run_structured_output_test(tmp_path, "mistral_large", "amazon_bedrock") @pytest.mark.ollama @@ -35,6 +30,11 @@ async def test_structured_output_ollama_phi(tmp_path): await run_structured_output_test(tmp_path, "phi_3_5", "ollama") +@pytest.mark.ollama +async def test_structured_output_gpt_4o_mini(tmp_path): + await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai") + + @pytest.mark.ollama async def test_structured_output_ollama_llama(tmp_path): if not await ollama_online(): @@ -51,13 +51,17 @@ async def test_all_built_in_models_structured_output(tmp_path): f"Skipping {model.model_name} because it does not support structured output" ) continue - for provider in model.provider_config: - if provider == ModelProviders.amazon_bedrock: - # TODO: bedrock not working, should investigate and fix + for provider in model.providers: + if not provider.supports_structured_output: + print( + f"Skipping {model.model_name} {provider.name} because it does not support structured output" + ) continue try: print(f"Running {model.model_name} {provider}") - await run_structured_output_test(tmp_path, model.model_name, provider) + await run_structured_output_test( + tmp_path, model.model_name, provider.name + ) except Exception as e: raise RuntimeError( f"Error running {model.model_name} {provider}"