diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py new file mode 100644 index 0000000..3814bae --- /dev/null +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -0,0 +1,7 @@ +from abc import ABCMeta, abstractmethod + + +class BaseAdapter(metaclass=ABCMeta): + @abstractmethod + async def run(self, input: str) -> str: + pass diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 2529b9e..d6a3f66 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -39,7 +39,7 @@ class ModelProviders(str, Enum): } -def model_from(model_name: str, provider: str) -> BaseChatModel: +def langchain_model_from(model_name: str, provider: str) -> BaseChatModel: if model_name not in ModelName.__members__: raise ValueError(f"Invalid model_name: {model_name}") model_name = ModelName(model_name) diff --git a/libs/core/kiln_ai/adapters/prompt_adapters.py b/libs/core/kiln_ai/adapters/prompt_adapters.py index d9ff5be..c1531c2 100644 --- a/libs/core/kiln_ai/adapters/prompt_adapters.py +++ b/libs/core/kiln_ai/adapters/prompt_adapters.py @@ -3,10 +3,11 @@ import kiln_ai.datamodel.models as models from langchain_core.language_models.chat_models import BaseChatModel -from .ml_model_list import model_from +from .base_adapter import BaseAdapter +from .ml_model_list import langchain_model_from -class BasePromptAdapter(metaclass=ABCMeta): +class BaseLangChainPromptAdapter(BaseAdapter, metaclass=ABCMeta): def __init__( self, kiln_task: models.Task, @@ -18,7 +19,7 @@ def __init__( if custom_model is not None: self.model = custom_model elif model_name is not None and provider is not None: - self.model = model_from(model_name, provider) + self.model = langchain_model_from(model_name, provider) else: raise ValueError( "model_name and provider must be provided if custom_model is not provided" @@ -43,7 +44,7 @@ async def run(self, input: str) -> str: return answer -class SimplePromptAdapter(BasePromptAdapter): +class SimplePromptAdapter(BaseLangChainPromptAdapter): def build_prompt(self) -> str: base_prompt = self.kiln_task.instruction