Skip to content

Commit

Permalink
Add the ability to save ratings!
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Oct 16, 2024
1 parent febe11f commit a08d52d
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 59 deletions.
6 changes: 3 additions & 3 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ export interface components {
/** @default five_star */
type: components["schemas"]["TaskOutputRatingType"];
/**
* Rating
* @description The rating value (typically 1-5 stars).
* Value
* @description The overall rating value (typically 1-5 stars).
*/
rating: number;
value?: number | null;
/**
* Requirement Ratings
* @description The ratings of the requirements of the task. The keys are the ids of the requirements. The values are the ratings (typically 1-5 stars).
Expand Down
16 changes: 10 additions & 6 deletions app/web_ui/src/routes/(app)/run/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
},
},
}
$: output = response?.output
$: subtitle = $current_task ? "Task: " + $current_task.name : ""
Expand Down Expand Up @@ -175,9 +174,14 @@
/>
</FormContainer>
</div>
<div
class="mt-10 max-w-[1400px] {submitting || output == null ? 'hidden' : ''}"
>
<Output {response} json_schema={$current_task?.output_json_schema} />
</div>
{#if $current_task && !submitting && response != null && $current_project}
<div class="mt-10 max-w-[1400px]">
<Output
{response}
json_schema={$current_task?.output_json_schema}
task={$current_task}
project_id={$current_project.id}
/>
</div>
{/if}
</AppPage>
111 changes: 81 additions & 30 deletions app/web_ui/src/routes/(app)/run/output.svelte
Original file line number Diff line number Diff line change
@@ -1,52 +1,101 @@
<script lang="ts">
import type { Task, TaskRequirement } from "$lib/stores"
import { current_task } from "$lib/stores"
import type { Task } from "$lib/stores"
import FormContainer from "$lib/utils/form_container.svelte"
import FormElement from "$lib/utils/form_element.svelte"
import Rating from "./rating.svelte"
export let json_schema: string | null = null
let repair_instructions: string | null = null
import { type components } from "$lib/api_schema.d"
import createClient from "openapi-fetch"
import { type components, type paths } from "$lib/api_schema.d"
export let response: components["schemas"]["RunTaskResponse"] | null = null
export let project_id: string
export let task: Task
export let response: components["schemas"]["RunTaskResponse"]
$: output = response?.output
$: output_valid =
output &&
((json_schema && output.structured_output) ||
(!json_schema && output.plaintext_output))
// TODO warn_before_onload
let updated_run: components["schemas"]["TaskRun"] | null = null
$: run = updated_run || response?.run
// TODO warn_before_unload
// TODO: we aren't loading existing ratings from the server
let overall_rating: 1 | 2 | 3 | 4 | 5 | null = null
let requirement_ratings: (1 | 2 | 3 | 4 | 5 | null)[] = []
let prior_task: Task | null = null
let requirement_ratings: (1 | 2 | 3 | 4 | 5 | null)[] = Array(
task.requirements.length,
).fill(null)
current_task.subscribe((task) => {
if (task) {
let original_ratings = requirement_ratings
requirement_ratings = []
for (const requirement of task.requirements) {
// Look up prior rating, if any
let prior_index = prior_task?.requirements.findIndex(
(req: TaskRequirement) => req.id === requirement.id,
)
let value =
prior_index && prior_index >= 0 ? original_ratings[prior_index] : null
requirement_ratings.push(value)
async function save_ratings() {
try {
let requirement_ratings_obj: Record<string, 1 | 2 | 3 | 4 | 5 | null> = {}
task.requirements.forEach((req, index) => {
requirement_ratings_obj[req.id] = requirement_ratings[index]
})
let patch_body = {
output: {
rating: {
value: overall_rating,
type: "five_star",
requirement_ratings: requirement_ratings_obj,
},
},
}
const client = createClient<paths>({
baseUrl: "http://localhost:8757",
})
const {
data, // only present if 2XX response
error: fetch_error, // only present if 4XX or 5XX response
} = await client.PATCH(
"/api/projects/{project_id}/task/{task_id}/run/{run_id}",
{
params: {
path: {
project_id: project_id,
task_id: task.id || "",
run_id: run?.id || "",
},
},
// @ts-expect-error type checking and PATCH don't mix
body: patch_body,
},
)
if (fetch_error) {
// TODO: check error message extraction
throw new Error("Failed to run task: " + fetch_error)
}
prior_task = task
updated_run = data
} catch (err) {
// TODO: better error handling
console.error("Failed to save ratings", err)
}
})
function save_ratings() {
console.log("Overall rating", overall_rating)
$current_task?.requirements.forEach((req, index) => {
console.log("Requirement", req.name, requirement_ratings[index])
})
}
function attempt_repair() {
console.log("Attempting repair")
}
// Watch for changes to ratings and save them if they change
let prior_overall_rating: 1 | 2 | 3 | 4 | 5 | null = overall_rating
let prior_requirement_ratings: (1 | 2 | 3 | 4 | 5 | null)[] =
requirement_ratings
$: {
if (
overall_rating !== prior_overall_rating ||
!areArraysEqual(requirement_ratings, prior_requirement_ratings)
) {
save_ratings()
}
prior_overall_rating = overall_rating
prior_requirement_ratings = [...requirement_ratings]
}
function areArraysEqual(arr1: unknown[], arr2: unknown[]): boolean {
if (arr1.length !== arr2.length) return false
return arr1.every((value, index) => value === arr2[index])
}
</script>

<div>
Expand Down Expand Up @@ -98,8 +147,8 @@
<div class="flex items-center">
<Rating bind:rating={overall_rating} size={7} />
</div>
{#if $current_task?.requirements}
{#each $current_task.requirements as requirement, index}
{#if task.requirements}
{#each task.requirements as requirement, index}
<div class="flex items-center">
{requirement.name}:
</div>
Expand All @@ -109,7 +158,6 @@
{/each}
{/if}
</div>
<button class="mt-4 link" on:click={save_ratings}>Save Ratings</button>
</div>
</div>

Expand All @@ -133,4 +181,7 @@
</FormContainer>
{/if}
{/if}

<h1>Raw Data</h1>
<pre>{JSON.stringify(run, null, 2)}</pre>
</div>
10 changes: 7 additions & 3 deletions libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class TaskOutputRating(KilnBaseModel):
"""

type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
value: float = Field(description="The overall rating value (typically 1-5 stars).")
value: float | None = Field(
description="The overall rating value (typically 1-5 stars).",
default=None,
)
requirement_ratings: Dict[ID_TYPE, float] = Field(
default={},
description="The ratings of the requirements of the task. The keys are the ids of the requirements. The values are the ratings (typically 1-5 stars).",
Expand All @@ -60,7 +63,7 @@ class TaskOutputRating(KilnBaseModel):
# Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc)
def is_high_quality(self) -> bool:
if self.type == TaskOutputRatingType.five_star:
return self.value >= 4
return self.value is not None and self.value >= 4
return False

@model_validator(mode="after")
Expand All @@ -69,7 +72,8 @@ def validate_rating(self) -> Self:
raise ValueError(f"Invalid rating type: {self.type}")

if self.type == TaskOutputRatingType.five_star:
self._validate_five_star(self.value, "overall rating")
if self.value is not None:
self._validate_five_star(self.value, "overall rating")
for req_id, req_rating in self.requirement_ratings.items():
self._validate_five_star(req_rating, f"requirement rating for {req_id}")

Expand Down
17 changes: 11 additions & 6 deletions libs/studio/kiln_studio/task_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,26 @@ class RunTaskResponse(BaseModel):
run: TaskRun | None = None


def deep_update(source, update):
def deep_update(
source: Dict[str, Any] | None, update: Dict[str, Any | None]
) -> Dict[str, Any]:
if source is None:
return update
return {k: v for k, v in update.items() if v is not None}
for key, value in update.items():
if isinstance(value, dict):
source[key] = deep_update(source.get(key, {}), value)
if value is None:
source.pop(key, None)
elif isinstance(value, dict):
if key not in source or not isinstance(source[key], dict):
source[key] = {}
source[key] = deep_update(source[key], value)
else:
source[key] = value
return source
return {k: v for k, v in source.items() if v is not None}


def connect_task_management(app: FastAPI):
@app.post("/api/projects/{project_id}/task")
async def create_task(project_id: str, task_data: Dict[str, Any]):
print(f"Creating task for project {project_id} with data {task_data}")
parent_project = project_from_id(project_id)

task = Task.validate_and_save_with_subrelations(
Expand Down
81 changes: 70 additions & 11 deletions libs/studio/kiln_studio/test_task_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,6 @@ async def test_run_task_structured_input(client, tmp_path):
assert res["run"] is None


def test_deep_update_with_none_source():
source = None
update = {"a": 1, "b": {"c": 2}}
result = deep_update(source, update)
assert result == {"a": 1, "b": {"c": 2}}


def test_deep_update_with_empty_source():
source = {}
update = {"a": 1, "b": {"c": 2}}
Expand Down Expand Up @@ -443,6 +436,72 @@ def test_deep_update_with_mixed_types():
assert result == {"a": "new", "b": {"c": 4, "d": {"e": 5}}}


def test_deep_update_with_none_values():
# Test case 1: Basic removal of keys
source = {"a": 1, "b": 2, "c": 3}
update = {"a": None, "b": 4}
result = deep_update(source, update)
assert result == {"b": 4, "c": 3}

# Test case 2: Nested dictionaries
source = {"x": 1, "y": {"y1": 10, "y2": 20, "y3": {"y3a": 100, "y3b": 200}}, "z": 3}
update = {"y": {"y2": None, "y3": {"y3b": None, "y3c": 300}}, "z": None}
result = deep_update(source, update)
assert result == {"x": 1, "y": {"y1": 10, "y3": {"y3a": 100, "y3c": 300}}}

# Test case 3: Update with empty dictionary
source = {"a": 1, "b": 2}
update = {}
result = deep_update(source, update)
assert result == {"a": 1, "b": 2}

# Test case 4: Update missing with none elements
source = {"a": 1, "b": {"d": 1}}
update = {"b": {"e": {"f": {"h": 1, "j": None}, "g": None}}}
result = deep_update(source, update)
assert result == {"a": 1, "b": {"d": 1, "e": {"f": {"h": 1}}}}

# Test case 5: Mixed types
source = {"a": 1, "b": {"x": 10, "y": 20}, "c": [1, 2, 3]}
update = {"b": {"y": None, "z": 30}, "c": None, "d": 4}
result = deep_update(source, update)
assert result == {"a": 1, "b": {"x": 10, "z": 30}, "d": 4}

# Test case 6: Update with
source = {}
update = {"a": {"b": None, "c": None}}
result = deep_update(source, update)
assert result == {"a": {}}

# Test case 7: Update with
source = {
"output": {
"rating": None,
},
}
update = {
"output": {
"rating": {
"value": 2,
"type": "five_star",
"requirement_ratings": {
"148753630565": None,
"988847661375": 3,
"474350686960": None,
},
}
}
}
result = deep_update(source, update)
assert result["output"]["rating"]["value"] == 2
assert result["output"]["rating"]["type"] == "five_star"
assert result["output"]["rating"]["requirement_ratings"] == {
# "148753630565": None,
"988847661375": 3,
# "474350686960": None,
}


def test_update_run_method():
run = TaskRun(
input="Test input",
Expand All @@ -463,12 +522,12 @@ def test_update_run_method():
assert updated_run.input == "Updated input"

update = {
"output": {"rating": {"rating": 4, "type": TaskOutputRatingType.five_star}}
"output": {"rating": {"value": 4, "type": TaskOutputRatingType.five_star}}
}
dumped = run.model_dump()
merged = deep_update(dumped, update)
updated_run = TaskRun.model_validate(merged)
assert updated_run.output.rating.rating == 4
assert updated_run.output.rating.value == 4
assert updated_run.output.rating.type == TaskOutputRatingType.five_star


Expand Down Expand Up @@ -506,12 +565,12 @@ async def test_update_run(client, tmp_path):
"name": "Update output rating",
"patch": {
"output": {
"rating": {"rating": 4, "type": TaskOutputRatingType.five_star},
"rating": {"value": 4, "type": TaskOutputRatingType.five_star},
}
},
"expected": {
"output": {
"rating": {"rating": 4, "type": TaskOutputRatingType.five_star},
"rating": {"value": 4, "type": TaskOutputRatingType.five_star},
}
},
},
Expand Down

0 comments on commit a08d52d

Please sign in to comment.