diff --git a/libs/core/kiln_ai/adapters/langchain_adapter.py b/libs/core/kiln_ai/adapters/langchain_adapter.py index 2c90201..5e3e8f5 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/langchain_adapter.py @@ -1,33 +1,34 @@ import kiln_ai.datamodel.models as models from .ml_model_list import model_from +from abc import ABCMeta, abstractmethod -class LangChainBaseAdapter: +class LangChainBaseAdapter(metaclass=ABCMeta): def __init__(self, kiln_task: models.Task, model_name: str, provider: str): self.kiln_task = kiln_task self.model = model_from(model_name, provider) - -async def test_run_prompt(prompt: str): - model = model_from(model_name="llama_3_1_8b", provider="amazon_bedrock") - model = model_from(model_name="gpt_4o_mini", provider="openai") - model = model_from(model_name="llama_3_1_8b", provider="groq") - - chunks = [] - answer = "" - async for chunk in model.astream(prompt): - chunks.append(chunk) - print(chunk.content, end="", flush=True) - if isinstance(chunk.content, str): - answer += chunk.content - return answer - - -class ExperimentalKilnAdapter: - def __init__(self, kiln_task: models.Task): - self.kiln_task = kiln_task - - async def run(self): + @abstractmethod + def build_prompt(self) -> str: + pass + + # TODO: don't just append input to prompt + async def run(self, input: str) -> str: + # TODO cleanup + prompt = self.build_prompt() + prompt += f"\n\n{input}" + chunks = [] + answer = "" + async for chunk in self.model.astream(prompt): + chunks.append(chunk) + print(chunk.content, end="", flush=True) + if isinstance(chunk.content, str): + answer += chunk.content + return answer + + +class SimplePromptAdapter(LangChainBaseAdapter): + def build_prompt(self) -> str: base_prompt = self.kiln_task.instruction if len(self.kiln_task.requirements()) > 0: @@ -36,8 +37,4 @@ async def run(self): for i, requirement in enumerate(self.kiln_task.requirements()): base_prompt += f"{i+1}) {requirement.instruction}\n" - base_prompt += ( - "\n\nYou should answer the following question: four plus six times 10\n" - ) - - return await test_run_prompt(base_prompt) + return base_prompt diff --git a/libs/core/kiln_ai/adapters/test_langchain.py b/libs/core/kiln_ai/adapters/test_langchain.py index e9021d4..99a5ff7 100644 --- a/libs/core/kiln_ai/adapters/test_langchain.py +++ b/libs/core/kiln_ai/adapters/test_langchain.py @@ -1,8 +1,8 @@ import kiln_ai.datamodel.models as models -import kiln_ai.adapters.langchain_adapter as ad +from kiln_ai.adapters.langchain_adapter import SimplePromptAdapter import pytest import os - +from pathlib import Path from dotenv import load_dotenv @@ -11,10 +11,28 @@ def load_env(): load_dotenv() -async def test_langchain(tmp_path): +async def test_groq(tmp_path): if os.getenv("GROQ_API_KEY") is None: pytest.skip("GROQ_API_KEY not set") + await run_simple_test(tmp_path, "llama_3_1_8b", "groq") + + +async def test_openai(tmp_path): + if os.getenv("OPENAI_API_KEY") is None: + pytest.skip("OPENAI_API_KEY not set") + await run_simple_test(tmp_path, "gpt_4o_mini", "openai") + +async def test_amazon_bedrock(tmp_path): + if ( + os.getenv("AWS_SECRET_ACCESS_KEY") is None + or os.getenv("AWS_ACCESS_KEY_ID") is None + ): + pytest.skip("AWS keys not set") + await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock") + + +async def run_simple_test(tmp_path: Path, model_name: str, provider: str): project = models.Project(name="test", path=tmp_path / "test.kiln") project.save_to_file() assert project.name == "test" @@ -40,6 +58,8 @@ async def test_langchain(tmp_path): r2.save_to_file() assert len(task.requirements()) == 2 - adapter = ad.ExperimentalKilnAdapter(task) - answer = await adapter.run() + adapter = SimplePromptAdapter(task, model_name=model_name, provider=provider) + answer = await adapter.run( + "You should answer the following question: four plus six times 10" + ) assert "64" in answer