Skip to content

Commit

Permalink
Add parent typing to datamodel
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Aug 20, 2024
1 parent 55b5ee8 commit a19c918
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
19 changes: 18 additions & 1 deletion libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pydantic import BaseModel, computed_field, Field
from pydantic import BaseModel, computed_field, Field, field_validator
from typing import Optional
from pathlib import Path
from typing import Type, TypeVar
from abc import ABCMeta, abstractmethod
import uuid
from builtins import classmethod


# ID is a 10 digit hex string
Expand Down Expand Up @@ -68,6 +69,22 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
def relationship_name(self) -> str:
pass

@classmethod
@abstractmethod
def parent_type(cls) -> Type[KilnBaseModel]:
pass

@field_validator("parent")
@classmethod
def check_parent_type(cls, v: Optional[KilnBaseModel]) -> Optional[KilnBaseModel]:
if v is not None:
expected_parent_type = cls.parent_type()
if not isinstance(v, expected_parent_type):
raise ValueError(
f"Parent must be of type {expected_parent_type}, but was {type(v)}"
)
return v

def build_child_filename(self) -> Path:
# Default implementation for readable filenames.
# Can be overridden, but probably shouldn't be.
Expand Down
25 changes: 21 additions & 4 deletions libs/core/kiln_ai/datamodel/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def relationship_name(self) -> str:
def build_child_filename(self) -> Path:
return Path("child.kiln")

def parent_type():
return KilnBaseModel


def test_parented_model_path_gen(tmp_path):
parent = KilnBaseModel(path=tmp_path)
Expand All @@ -77,16 +80,23 @@ def test_parented_model_path_gen(tmp_path):
assert child_path.parent.parent == tmp_path.parent


class BaseParentExample(KilnBaseModel):
pass


# Instance of the parented model for abstract methods, with default name builder
class DefaultParentedModel(KilnParentedModel):
name: Optional[str] = None

def relationship_name(self) -> str:
return "children"

def parent_type():
return BaseParentExample


def test_build_default_child_filename(tmp_path):
parent = KilnBaseModel(path=tmp_path)
parent = BaseParentExample(path=tmp_path)
child = DefaultParentedModel(parent=parent)
child_path = child.build_path()
child_path_without_id = child_path.name[10:]
Expand All @@ -103,7 +113,7 @@ def test_build_default_child_filename(tmp_path):


def test_serialize_child(tmp_path):
parent = KilnBaseModel(path=tmp_path)
parent = BaseParentExample(path=tmp_path)
child = DefaultParentedModel(parent=parent, name="Name")

expected_path = child.build_path()
Expand Down Expand Up @@ -136,7 +146,7 @@ def test_serialize_child(tmp_path):

def test_save_to_set_location(tmp_path):
# Keeps the OG path if parent and path are both set
parent = KilnBaseModel(path=tmp_path)
parent = BaseParentExample(path=tmp_path)
child_path = tmp_path.parent / "child.kiln"
child = DefaultParentedModel(path=child_path, parent=parent, name="Name")
assert child.build_path() == child_path
Expand All @@ -154,7 +164,14 @@ def test_save_to_set_location(tmp_path):

def test_parent_without_path():
# no path from parent or direct path
parent = KilnBaseModel()
parent = BaseParentExample()
child = DefaultParentedModel(parent=parent, name="Name")
with pytest.raises(ValueError):
child.save_to_file()


def test_parent_wrong_type():
# DefaultParentedModel is parented to BaseParentExample, not KilnBaseModel
parent = KilnBaseModel()
with pytest.raises(ValueError):
DefaultParentedModel(parent=parent, name="Name")

0 comments on commit a19c918

Please sign in to comment.