Skip to content

Commit

Permalink
premature abstraction!
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Aug 31, 2024
1 parent af08fd4 commit 893792d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
7 changes: 7 additions & 0 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABCMeta, abstractmethod


class BaseAdapter(metaclass=ABCMeta):
@abstractmethod
async def run(self, input: str) -> str:
pass
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ModelProviders(str, Enum):
}


def model_from(model_name: str, provider: str) -> BaseChatModel:
def langchain_model_from(model_name: str, provider: str) -> BaseChatModel:
if model_name not in ModelName.__members__:
raise ValueError(f"Invalid model_name: {model_name}")
model_name = ModelName(model_name)
Expand Down
9 changes: 5 additions & 4 deletions libs/core/kiln_ai/adapters/prompt_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import kiln_ai.datamodel.models as models
from langchain_core.language_models.chat_models import BaseChatModel

from .ml_model_list import model_from
from .base_adapter import BaseAdapter
from .ml_model_list import langchain_model_from


class BasePromptAdapter(metaclass=ABCMeta):
class BaseLangChainPromptAdapter(BaseAdapter, metaclass=ABCMeta):
def __init__(
self,
kiln_task: models.Task,
Expand All @@ -18,7 +19,7 @@ def __init__(
if custom_model is not None:
self.model = custom_model
elif model_name is not None and provider is not None:
self.model = model_from(model_name, provider)
self.model = langchain_model_from(model_name, provider)
else:
raise ValueError(
"model_name and provider must be provided if custom_model is not provided"
Expand All @@ -43,7 +44,7 @@ async def run(self, input: str) -> str:
return answer


class SimplePromptAdapter(BasePromptAdapter):
class SimplePromptAdapter(BaseLangChainPromptAdapter):
def build_prompt(self) -> str:
base_prompt = self.kiln_task.instruction

Expand Down

0 comments on commit 893792d

Please sign in to comment.