Skip to content

Commit

Permalink
Autosave and examples that get run.
Browse files Browse the repository at this point in the history
WIP as we want to save more information about them (source, model, etc)
  • Loading branch information
scosman committed Sep 30, 2024
1 parent 9db2207 commit 5552341
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 5 deletions.
38 changes: 37 additions & 1 deletion libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from abc import ABCMeta, abstractmethod
from typing import Dict

from kiln_ai.datamodel import Task
from kiln_ai.datamodel import (
Example,
ExampleOutput,
ExampleOutputSource,
ExampleSource,
Task,
)
from kiln_ai.datamodel.json_schema import validate_schema
from kiln_ai.utils.config import Config


class BaseAdapter(metaclass=ABCMeta):
Expand Down Expand Up @@ -34,6 +41,11 @@ async def invoke(self, input: Dict | str) -> Dict | str:
raise RuntimeError(
f"response is not a string for non-structured task: {result}"
)

# Save the example and output
if Config.shared().autosave_examples:
self.save_example(input, result)

return result

def has_strctured_output(self) -> bool:
Expand All @@ -47,6 +59,30 @@ async def _run(self, input: Dict | str) -> Dict | str:
def adapter_specific_instructions(self) -> str | None:
return None

# create an example and example output
def save_example(self, input: Dict | str, output: Dict | str) -> Example:
# Convert input and output to JSON strings if they are dictionaries
input_str = json.dumps(input) if isinstance(input, dict) else input
output_str = json.dumps(output) if isinstance(output, dict) else output

# TODO P2: check for existing example with this input, and use it instead of creating a new one
example = Example(
parent=self.kiln_task,
input=input_str,
# TODO P1: this isn't necessarily synthetic. Should pass this in.
source=ExampleSource.synthetic,
)
example.save_to_file()

example_output = ExampleOutput(
parent=example,
output=output_str,
source=ExampleOutputSource.synthetic,
source_properties={"creator": Config.shared().user_id},
)
example_output.save_to_file()
return example


class BasePromptBuilder(metaclass=ABCMeta):
def __init__(self, task: Task, adapter: BaseAdapter | None = None):
Expand Down
103 changes: 103 additions & 0 deletions libs/core/kiln_ai/adapters/test_saving_adapter_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from unittest.mock import AsyncMock, patch

import pytest
from kiln_ai.adapters.base_adapter import BaseAdapter
from kiln_ai.datamodel import (
Example,
ExampleOutput,
ExampleOutputSource,
ExampleSource,
Project,
Task,
)
from kiln_ai.utils.config import Config


class TestAdapter(BaseAdapter):
async def _run(self, input: dict | str) -> dict | str:
return "Test output"


@pytest.fixture
def test_task(tmp_path):
project = Project(name="test_project", path=tmp_path / "test_project.kiln")
project.save_to_file()
task = Task(parent=project, name="test_task")
task.save_to_file()
return task


def test_save_example_isolation(test_task):
adapter = TestAdapter(test_task)
input_data = "Test input"
output_data = "Test output"

example = adapter.save_example(input_data, output_data)

# Check that the example was saved correctly
assert example.parent == test_task
assert example.input == input_data
assert example.source == ExampleSource.synthetic

# Check that the example output was saved correctly
saved_outputs = example.outputs()
assert len(saved_outputs) == 1
saved_output = saved_outputs[0]
assert saved_output.parent.id == example.id
assert saved_output.output == output_data
assert saved_output.source == ExampleOutputSource.synthetic
assert saved_output.source_properties == {"creator": Config.shared().user_id}
assert saved_output.requirement_ratings == {}

# Verify that the data can be read back from disk
reloaded_task = Task.load_from_file(test_task.path)
reloaded_examples = reloaded_task.examples()
assert len(reloaded_examples) == 1
reloaded_example = reloaded_examples[0]
assert reloaded_example.input == input_data
assert reloaded_example.source == ExampleSource.synthetic

reloaded_outputs = reloaded_example.outputs()
assert len(reloaded_outputs) == 1
reloaded_output = reloaded_outputs[0]
assert reloaded_output.parent.id == reloaded_example.id
assert reloaded_output.output == output_data
assert reloaded_output.source == ExampleOutputSource.synthetic
assert reloaded_output.source_properties == {"creator": Config.shared().user_id}
assert reloaded_output.requirement_ratings == {}


@pytest.mark.asyncio
async def test_autosave_false(test_task):
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
mock_config = mock_shared.return_value
mock_config.autosave_examples = False

adapter = TestAdapter(test_task)
input_data = "Test input"

await adapter.invoke(input_data)

# Check that no examples were saved
assert len(test_task.examples()) == 0


@pytest.mark.asyncio
async def test_autosave_true(test_task):
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
mock_config = mock_shared.return_value
mock_config.autosave_examples = True
mock_config.user_id = "test_user"

adapter = TestAdapter(test_task)
input_data = "Test input"

await adapter.invoke(input_data)

# Check that an example was saved
examples = test_task.examples()
assert len(examples) == 1
assert examples[0].input == input_data
assert len(examples[0].outputs()) == 1
assert examples[0].outputs()[0].output == "Test output"
assert examples[0].outputs()[0].source_properties["creator"] == "test_user"
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ExampleOutput(KilnParentedModel):
# TODO: add structure/validation to this. For human creator_id. Model ID and verion and provider for models
source_properties: Dict[str, str] = Field(
default={},
description="Additional properties of the source, e.g. the name of the human who provided the output or the model that generated the output.",
description="Additional properties of the source, e.g. the user name of the human who provided the output or the model that generated the output.",
)
rating: ReasonRating | None = Field(
default=None,
Expand Down
6 changes: 5 additions & 1 deletion libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from builtins import classmethod
from datetime import datetime
from pathlib import Path
from typing import Optional, Self, Type, TypeVar
from typing import TYPE_CHECKING, Optional, Self, Type, TypeVar

from kiln_ai.utils.config import Config
from pydantic import (
Expand Down Expand Up @@ -108,6 +108,10 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
id: ID_TYPE = ID_FIELD
_parent: KilnBaseModel | None = None

# workaround to tell typechecker that we support the parent property, even though it's not a stock property
if TYPE_CHECKING:
parent: KilnBaseModel # type: ignore

def __init__(self, **data):
super().__init__(**data)
if "parent" in data:
Expand Down
14 changes: 13 additions & 1 deletion libs/core/kiln_ai/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def __init__(self, properties: Dict[str, ConfigProperty] | None = None):
"user_id": ConfigProperty(
str,
env_var="KILN_USER_ID",
default_lambda=lambda: pwd.getpwuid(os.getuid()).pw_name,
default_lambda=_get_user_id,
),
"autosave_examples": ConfigProperty(
bool,
env_var="KILN_AUTOSAVE_EXAMPLES",
default=True,
),
}

Expand Down Expand Up @@ -60,3 +65,10 @@ def __setattr__(self, name: str, value: Any) -> None:
self._values[name] = self._properties[name].type(value)
else:
raise AttributeError(f"Config has no attribute '{name}'")


def _get_user_id():
try:
return pwd.getpwuid(os.getuid()).pw_name or "unknown_user"
except Exception:
return "unknown_user"
38 changes: 37 additions & 1 deletion libs/core/kiln_ai/utils/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import pwd

import pytest

from libs.core.kiln_ai.utils.config import Config, ConfigProperty
from libs.core.kiln_ai.utils.config import Config, ConfigProperty, _get_user_id


def TestConfig():
Expand Down Expand Up @@ -107,7 +108,42 @@ def default_lambda():
assert config.lambda_property == "lambda_value"


def test_get_user_id_none(monkeypatch):
def mock_getpwuid(_):
class MockPwnam:
pw_name = None

return MockPwnam()

monkeypatch.setattr(pwd, "getpwuid", mock_getpwuid)
assert _get_user_id() == "unknown_user"


def test_get_user_id_exception(monkeypatch):
def mock_getpwuid(_):
raise Exception("Test exception")

monkeypatch.setattr(pwd, "getpwuid", mock_getpwuid)
assert _get_user_id() == "unknown_user"


def test_get_user_id_valid(monkeypatch):
def mock_getpwuid(_):
class MockPwnam:
pw_name = "test_user"

return MockPwnam()

monkeypatch.setattr(pwd, "getpwuid", mock_getpwuid)
assert _get_user_id() == "test_user"


def test_user_id_default(reset_config):
config = Config()
# assert config.user_id == "scosman"
assert len(config.user_id) > 0


def test_autosave_examples_default(reset_config):
config = Config()
assert config.autosave_examples

0 comments on commit 5552341

Please sign in to comment.