From a08d52d222301b50df355437ba5e2b3fa21f255f Mon Sep 17 00:00:00 2001 From: scosman Date: Tue, 15 Oct 2024 22:41:53 -0400 Subject: [PATCH] Add the ability to save ratings! --- app/web_ui/src/lib/api_schema.d.ts | 6 +- app/web_ui/src/routes/(app)/run/+page.svelte | 16 ++- app/web_ui/src/routes/(app)/run/output.svelte | 111 +++++++++++++----- libs/core/kiln_ai/datamodel/__init__.py | 10 +- libs/studio/kiln_studio/task_management.py | 17 ++- .../kiln_studio/test_task_management.py | 81 +++++++++++-- 6 files changed, 182 insertions(+), 59 deletions(-) diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index d82dba7..11f7970 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -377,10 +377,10 @@ export interface components { /** @default five_star */ type: components["schemas"]["TaskOutputRatingType"]; /** - * Rating - * @description The rating value (typically 1-5 stars). + * Value + * @description The overall rating value (typically 1-5 stars). */ - rating: number; + value?: number | null; /** * Requirement Ratings * @description The ratings of the requirements of the task. The keys are the ids of the requirements. The values are the ratings (typically 1-5 stars). diff --git a/app/web_ui/src/routes/(app)/run/+page.svelte b/app/web_ui/src/routes/(app)/run/+page.svelte index 16a24ff..0393bc7 100644 --- a/app/web_ui/src/routes/(app)/run/+page.svelte +++ b/app/web_ui/src/routes/(app)/run/+page.svelte @@ -59,7 +59,6 @@ }, }, } - $: output = response?.output $: subtitle = $current_task ? "Task: " + $current_task.name : "" @@ -175,9 +174,14 @@ /> -
- -
+ {#if $current_task && !submitting && response != null && $current_project} +
+ +
+ {/if} diff --git a/app/web_ui/src/routes/(app)/run/output.svelte b/app/web_ui/src/routes/(app)/run/output.svelte index a52c81d..8ec8b40 100644 --- a/app/web_ui/src/routes/(app)/run/output.svelte +++ b/app/web_ui/src/routes/(app)/run/output.svelte @@ -1,52 +1,101 @@
@@ -98,8 +147,8 @@
- {#if $current_task?.requirements} - {#each $current_task.requirements as requirement, index} + {#if task.requirements} + {#each task.requirements as requirement, index}
{requirement.name}:
@@ -109,7 +158,6 @@ {/each} {/if}
- @@ -133,4 +181,7 @@ {/if} {/if} + +

Raw Data

+
{JSON.stringify(run, null, 2)}
diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 5d61b2b..4721033 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -51,7 +51,10 @@ class TaskOutputRating(KilnBaseModel): """ type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star) - value: float = Field(description="The overall rating value (typically 1-5 stars).") + value: float | None = Field( + description="The overall rating value (typically 1-5 stars).", + default=None, + ) requirement_ratings: Dict[ID_TYPE, float] = Field( default={}, description="The ratings of the requirements of the task. The keys are the ids of the requirements. The values are the ratings (typically 1-5 stars).", @@ -60,7 +63,7 @@ class TaskOutputRating(KilnBaseModel): # Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc) def is_high_quality(self) -> bool: if self.type == TaskOutputRatingType.five_star: - return self.value >= 4 + return self.value is not None and self.value >= 4 return False @model_validator(mode="after") @@ -69,7 +72,8 @@ def validate_rating(self) -> Self: raise ValueError(f"Invalid rating type: {self.type}") if self.type == TaskOutputRatingType.five_star: - self._validate_five_star(self.value, "overall rating") + if self.value is not None: + self._validate_five_star(self.value, "overall rating") for req_id, req_rating in self.requirement_ratings.items(): self._validate_five_star(req_rating, f"requirement rating for {req_id}") diff --git a/libs/studio/kiln_studio/task_management.py b/libs/studio/kiln_studio/task_management.py index aebb702..ad9eba5 100644 --- a/libs/studio/kiln_studio/task_management.py +++ b/libs/studio/kiln_studio/task_management.py @@ -29,21 +29,26 @@ class RunTaskResponse(BaseModel): run: TaskRun | None = None -def deep_update(source, update): +def deep_update( + source: Dict[str, Any] | None, update: Dict[str, Any | None] +) -> Dict[str, Any]: if source is None: - return update + return {k: v for k, v in update.items() if v is not None} for key, value in update.items(): - if isinstance(value, dict): - source[key] = deep_update(source.get(key, {}), value) + if value is None: + source.pop(key, None) + elif isinstance(value, dict): + if key not in source or not isinstance(source[key], dict): + source[key] = {} + source[key] = deep_update(source[key], value) else: source[key] = value - return source + return {k: v for k, v in source.items() if v is not None} def connect_task_management(app: FastAPI): @app.post("/api/projects/{project_id}/task") async def create_task(project_id: str, task_data: Dict[str, Any]): - print(f"Creating task for project {project_id} with data {task_data}") parent_project = project_from_id(project_id) task = Task.validate_and_save_with_subrelations( diff --git a/libs/studio/kiln_studio/test_task_management.py b/libs/studio/kiln_studio/test_task_management.py index 27c8be5..720a20f 100644 --- a/libs/studio/kiln_studio/test_task_management.py +++ b/libs/studio/kiln_studio/test_task_management.py @@ -401,13 +401,6 @@ async def test_run_task_structured_input(client, tmp_path): assert res["run"] is None -def test_deep_update_with_none_source(): - source = None - update = {"a": 1, "b": {"c": 2}} - result = deep_update(source, update) - assert result == {"a": 1, "b": {"c": 2}} - - def test_deep_update_with_empty_source(): source = {} update = {"a": 1, "b": {"c": 2}} @@ -443,6 +436,72 @@ def test_deep_update_with_mixed_types(): assert result == {"a": "new", "b": {"c": 4, "d": {"e": 5}}} +def test_deep_update_with_none_values(): + # Test case 1: Basic removal of keys + source = {"a": 1, "b": 2, "c": 3} + update = {"a": None, "b": 4} + result = deep_update(source, update) + assert result == {"b": 4, "c": 3} + + # Test case 2: Nested dictionaries + source = {"x": 1, "y": {"y1": 10, "y2": 20, "y3": {"y3a": 100, "y3b": 200}}, "z": 3} + update = {"y": {"y2": None, "y3": {"y3b": None, "y3c": 300}}, "z": None} + result = deep_update(source, update) + assert result == {"x": 1, "y": {"y1": 10, "y3": {"y3a": 100, "y3c": 300}}} + + # Test case 3: Update with empty dictionary + source = {"a": 1, "b": 2} + update = {} + result = deep_update(source, update) + assert result == {"a": 1, "b": 2} + + # Test case 4: Update missing with none elements + source = {"a": 1, "b": {"d": 1}} + update = {"b": {"e": {"f": {"h": 1, "j": None}, "g": None}}} + result = deep_update(source, update) + assert result == {"a": 1, "b": {"d": 1, "e": {"f": {"h": 1}}}} + + # Test case 5: Mixed types + source = {"a": 1, "b": {"x": 10, "y": 20}, "c": [1, 2, 3]} + update = {"b": {"y": None, "z": 30}, "c": None, "d": 4} + result = deep_update(source, update) + assert result == {"a": 1, "b": {"x": 10, "z": 30}, "d": 4} + + # Test case 6: Update with + source = {} + update = {"a": {"b": None, "c": None}} + result = deep_update(source, update) + assert result == {"a": {}} + + # Test case 7: Update with + source = { + "output": { + "rating": None, + }, + } + update = { + "output": { + "rating": { + "value": 2, + "type": "five_star", + "requirement_ratings": { + "148753630565": None, + "988847661375": 3, + "474350686960": None, + }, + } + } + } + result = deep_update(source, update) + assert result["output"]["rating"]["value"] == 2 + assert result["output"]["rating"]["type"] == "five_star" + assert result["output"]["rating"]["requirement_ratings"] == { + # "148753630565": None, + "988847661375": 3, + # "474350686960": None, + } + + def test_update_run_method(): run = TaskRun( input="Test input", @@ -463,12 +522,12 @@ def test_update_run_method(): assert updated_run.input == "Updated input" update = { - "output": {"rating": {"rating": 4, "type": TaskOutputRatingType.five_star}} + "output": {"rating": {"value": 4, "type": TaskOutputRatingType.five_star}} } dumped = run.model_dump() merged = deep_update(dumped, update) updated_run = TaskRun.model_validate(merged) - assert updated_run.output.rating.rating == 4 + assert updated_run.output.rating.value == 4 assert updated_run.output.rating.type == TaskOutputRatingType.five_star @@ -506,12 +565,12 @@ async def test_update_run(client, tmp_path): "name": "Update output rating", "patch": { "output": { - "rating": {"rating": 4, "type": TaskOutputRatingType.five_star}, + "rating": {"value": 4, "type": TaskOutputRatingType.five_star}, } }, "expected": { "output": { - "rating": {"rating": 4, "type": TaskOutputRatingType.five_star}, + "rating": {"value": 4, "type": TaskOutputRatingType.five_star}, } }, },