Skip to content

Commit

Permalink
Big refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Aug 28, 2024
1 parent 4d724e5 commit e3a1b4a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
51 changes: 24 additions & 27 deletions libs/core/kiln_ai/adapters/langchain_adapter.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
import kiln_ai.datamodel.models as models
from .ml_model_list import model_from
from abc import ABCMeta, abstractmethod


class LangChainBaseAdapter:
class LangChainBaseAdapter(metaclass=ABCMeta):
def __init__(self, kiln_task: models.Task, model_name: str, provider: str):
self.kiln_task = kiln_task
self.model = model_from(model_name, provider)


async def test_run_prompt(prompt: str):
model = model_from(model_name="llama_3_1_8b", provider="amazon_bedrock")
model = model_from(model_name="gpt_4o_mini", provider="openai")
model = model_from(model_name="llama_3_1_8b", provider="groq")

chunks = []
answer = ""
async for chunk in model.astream(prompt):
chunks.append(chunk)
print(chunk.content, end="", flush=True)
if isinstance(chunk.content, str):
answer += chunk.content
return answer


class ExperimentalKilnAdapter:
def __init__(self, kiln_task: models.Task):
self.kiln_task = kiln_task

async def run(self):
@abstractmethod
def build_prompt(self) -> str:
pass

# TODO: don't just append input to prompt
async def run(self, input: str) -> str:
# TODO cleanup
prompt = self.build_prompt()
prompt += f"\n\n{input}"
chunks = []
answer = ""
async for chunk in self.model.astream(prompt):
chunks.append(chunk)
print(chunk.content, end="", flush=True)
if isinstance(chunk.content, str):
answer += chunk.content
return answer


class SimplePromptAdapter(LangChainBaseAdapter):
def build_prompt(self) -> str:
base_prompt = self.kiln_task.instruction

if len(self.kiln_task.requirements()) > 0:
Expand All @@ -36,8 +37,4 @@ async def run(self):
for i, requirement in enumerate(self.kiln_task.requirements()):
base_prompt += f"{i+1}) {requirement.instruction}\n"

base_prompt += (
"\n\nYou should answer the following question: four plus six times 10\n"
)

return await test_run_prompt(base_prompt)
return base_prompt
30 changes: 25 additions & 5 deletions libs/core/kiln_ai/adapters/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import kiln_ai.datamodel.models as models
import kiln_ai.adapters.langchain_adapter as ad
from kiln_ai.adapters.langchain_adapter import SimplePromptAdapter
import pytest
import os

from pathlib import Path
from dotenv import load_dotenv


Expand All @@ -11,10 +11,28 @@ def load_env():
load_dotenv()


async def test_langchain(tmp_path):
async def test_groq(tmp_path):
if os.getenv("GROQ_API_KEY") is None:
pytest.skip("GROQ_API_KEY not set")
await run_simple_test(tmp_path, "llama_3_1_8b", "groq")


async def test_openai(tmp_path):
if os.getenv("OPENAI_API_KEY") is None:
pytest.skip("OPENAI_API_KEY not set")
await run_simple_test(tmp_path, "gpt_4o_mini", "openai")


async def test_amazon_bedrock(tmp_path):
if (
os.getenv("AWS_SECRET_ACCESS_KEY") is None
or os.getenv("AWS_ACCESS_KEY_ID") is None
):
pytest.skip("AWS keys not set")
await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")


async def run_simple_test(tmp_path: Path, model_name: str, provider: str):
project = models.Project(name="test", path=tmp_path / "test.kiln")
project.save_to_file()
assert project.name == "test"
Expand All @@ -40,6 +58,8 @@ async def test_langchain(tmp_path):
r2.save_to_file()
assert len(task.requirements()) == 2

adapter = ad.ExperimentalKilnAdapter(task)
answer = await adapter.run()
adapter = SimplePromptAdapter(task, model_name=model_name, provider=provider)
answer = await adapter.run(
"You should answer the following question: four plus six times 10"
)
assert "64" in answer

0 comments on commit e3a1b4a

Please sign in to comment.