Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frontend: added temperature gauge to assistant form #901

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Generator

from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends
from sse_starlette.sse import EventSourceResponse

from backend.chat.custom.custom import CustomChat
Expand Down Expand Up @@ -31,7 +31,6 @@
async def chat_stream(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> Generator[ChatResponseEvent, Any, None]:
"""
Expand All @@ -58,7 +57,7 @@ async def chat_stream(
managed_tools,
next_message_position,
ctx,
) = process_chat(session, chat_request, request, ctx)
) = process_chat(session, chat_request, ctx)

return EventSourceResponse(
generate_chat_stream(
Expand Down Expand Up @@ -86,7 +85,6 @@ async def chat_stream(
async def regenerate_chat_stream(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> EventSourceResponse:
"""
Expand Down Expand Up @@ -127,7 +125,7 @@ async def regenerate_chat_stream(
previous_response_message_ids,
managed_tools,
ctx,
) = process_message_regeneration(session, chat_request, request, ctx)
) = process_message_regeneration(session, chat_request, ctx)

return EventSourceResponse(
generate_chat_stream(
Expand Down Expand Up @@ -155,7 +153,6 @@ async def regenerate_chat_stream(
async def chat(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context = Depends(get_context),
) -> NonStreamedChatResponse:
"""
Expand Down Expand Up @@ -197,7 +194,7 @@ async def chat(
managed_tools,
next_message_position,
ctx,
) = process_chat(session, chat_request, request, ctx)
) = process_chat(session, chat_request, ctx)

response = await generate_chat_response(
session,
Expand Down
20 changes: 12 additions & 8 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import nltk
from cohere.types import StreamedChatResponse
from fastapi import HTTPException, Request
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder

from backend.chat.collate import to_dict
Expand Down Expand Up @@ -74,19 +74,17 @@ def generate_tools_preamble(chat_request: CohereChatRequest) -> str:

def process_chat(
session: DBSessionDep,
chat_request: BaseChatRequest,
request: Request,
chat_request: CohereChatRequest,
ctx: Context,
) -> tuple[
DBSessionDep, BaseChatRequest, Union[list[str], None], Message, str, str, dict
DBSessionDep, CohereChatRequest, Union[list[str], None], Message, str, str, Context
]:
"""
Process a chat request.

Args:
chat_request (BaseChatRequest): Chat request data.
chat_request (CohereChatRequest): Chat request data.
session (DBSessionDep): Database session.
request (Request): Request object.
ctx (Context): Context object.

Returns:
Expand Down Expand Up @@ -124,6 +122,10 @@ def process_chat(
chat_request.model = agent.model
chat_request.preamble = agent.preamble

# If temperature is not defined in the chat request, use the temperature from the agent
if not chat_request.temperature:
chat_request.temperature = agent.temperature

should_store = chat_request.chat_history is None and not is_custom_tool_call(
chat_request
)
Expand Down Expand Up @@ -193,7 +195,6 @@ def process_chat(
def process_message_regeneration(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
ctx: Context,
) -> tuple[Any, CohereChatRequest, Message, list[str], bool, Context]:
"""
Expand All @@ -202,7 +203,6 @@ def process_message_regeneration(
Args:
session (DBSessionDep): Database session.
chat_request (CohereChatRequest): Chat request data.
request (Request): Request object.
ctx (Context): Context object.

Returns:
Expand All @@ -224,6 +224,10 @@ def process_message_regeneration(
# Set the agent settings in the chat request
chat_request.preamble = agent.preamble

# If temperature is not defined in the chat request, use the temperature from the agent
if not chat_request.temperature:
chat_request.temperature = agent.temperature

conversation_id = chat_request.conversation_id
ctx.with_conversation_id(conversation_id)

Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/factories/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Meta:
description = factory.Faker("sentence")
preamble = factory.Faker("sentence")
version = factory.Faker("random_int")
temperature = factory.Faker("pyfloat")
temperature = factory.Faker("pyfloat", min_value=0.0, max_value=1.0)
created_at = factory.Faker("date_time")
updated_at = factory.Faker("date_time")
tools = factory.List(
Expand Down
1 change: 0 additions & 1 deletion src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def test_streaming_fail_chat_missing_message(
"loc": ["body", "message"],
"msg": "Field required",
"input": {},
"url": "https://errors.pydantic.dev/2.10/v/missing",
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({
const fileIds = conversation?.files.map((file) => file.id);

setParams({
temperature: agent?.temperature,
tools: agentTools,
fileIds,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import { AgentSettingsFields, AgentSettingsForm } from '@/components/AgentSettin
import { MobileHeader } from '@/components/Global';
import { DeleteAgent } from '@/components/Modals/DeleteAgent';
import { Button, Icon, Spinner, Text } from '@/components/UI';
import { DEFAULT_AGENT_MODEL, DEPLOYMENT_COHERE_PLATFORM } from '@/constants';
import {
DEFAULT_AGENT_MODEL,
DEFAULT_AGENT_TEMPERATURE,
DEPLOYMENT_COHERE_PLATFORM,
} from '@/constants';
import { useContextStore } from '@/context';
import { useIsAgentNameUnique, useNotify, useUpdateAgent } from '@/hooks';

Expand All @@ -28,6 +32,7 @@ export const UpdateAgent: React.FC<Props> = ({ agent }) => {
description: agent.description,
deployment: agent.deployment ?? DEPLOYMENT_COHERE_PLATFORM,
model: agent.model ?? DEFAULT_AGENT_MODEL,
temperature: agent.temperature ?? DEFAULT_AGENT_TEMPERATURE,
tools: agent.tools,
preamble: agent.preamble,
tools_metadata: agent.tools_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { Button, Icon, Text } from '@/components/UI';
import {
BACKGROUND_TOOLS,
DEFAULT_AGENT_MODEL,
DEFAULT_AGENT_TEMPERATURE,
DEFAULT_PREAMBLE,
DEPLOYMENT_COHERE_PLATFORM,
} from '@/constants';
Expand All @@ -23,6 +24,7 @@ const DEFAULT_FIELD_VALUES = {
preamble: DEFAULT_PREAMBLE,
deployment: DEPLOYMENT_COHERE_PLATFORM,
model: DEFAULT_AGENT_MODEL,
temperature: DEFAULT_AGENT_TEMPERATURE,
tools: BACKGROUND_TOOLS,
is_private: false,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { useState } from 'react';

import { AgentSettingsFields } from '@/components/AgentSettingsForm';
import { Dropdown } from '@/components/UI';
import { Dropdown, Slider } from '@/components/UI';
import { useListAllDeployments } from '@/hooks';

type Props = {
Expand All @@ -13,7 +13,10 @@ type Props = {
};

export const ConfigStep: React.FC<Props> = ({ fields, setFields }) => {
const [selectedValue, setSelectedValue] = useState<string | undefined>(fields.model);
const [selectedModelValue, setSelectedModelValue] = useState<string | undefined>(fields.model);
const [selectedTemperatureValue, setSelectedTemperatureValue] = useState<number | undefined>(
fields.temperature
);

const { data: deployments } = useListAllDeployments();

Expand All @@ -27,12 +30,23 @@ export const ConfigStep: React.FC<Props> = ({ fields, setFields }) => {
<Dropdown
label="Model"
options={modelOptions ?? []}
value={selectedValue}
value={selectedModelValue}
onChange={(model) => {
setFields({ ...fields, model: model });
setSelectedValue(model);
setSelectedModelValue(model);
}}
/>
<Slider
label="Temperature"
min={0}
max={1.0}
step={0.1}
value={selectedTemperatureValue || 0}
onChange={(temperature) => {
setFields({ ...fields, temperature: temperature });
setSelectedTemperatureValue(temperature);
}}
></Slider>
</div>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ type RequiredAndNotNull<T> = {
type RequireAndNotNullSome<T, K extends keyof T> = RequiredAndNotNull<Pick<T, K>> & Omit<T, K>;

type CreateAgentSettingsFields = RequireAndNotNullSome<
Omit<CreateAgentRequest, 'version' | 'temperature'>,
'name' | 'model' | 'deployment'
Omit<CreateAgentRequest, 'version'>,
'name' | 'model' | 'deployment' | 'temperature'
>;

type UpdateAgentSettingsFields = RequireAndNotNullSome<
Omit<UpdateAgentRequest, 'version' | 'temperature'>,
'name' | 'model' | 'deployment'
Omit<UpdateAgentRequest, 'version'>,
'name' | 'model' | 'deployment' | 'temperature'
> & { is_private?: boolean };

export type AgentSettingsFields = CreateAgentSettingsFields | UpdateAgentSettingsFields;
Expand Down
82 changes: 82 additions & 0 deletions src/interfaces/assistants_web/src/components/UI/Slider.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
'use client';

import { ChangeEvent, useEffect, useMemo } from 'react';

import { InputLabel, Text } from '@/components/UI';
import { cn } from '@/utils';

type Props = {
label: string;
min: number;
max: number;
step: number;
value: number;
onChange: (value: number) => void;
sublabel?: string;
className?: string;
tooltipLabel?: React.ReactNode;
formatValue?: (value: number) => string;
};

/**
*
* Renders a slider with a label, a minimum, maximum and step value, and optional subLabel and tooltip.
* Styling for the thumb is located in main.css
*/
export const Slider: React.FC<Props> = ({
label,
sublabel,
min,
max,
step,
value,
onChange,
tooltipLabel,
formatValue,
className = '',
}) => {
// if `max` is changed dynamically don't allow the value to surpass it
useEffect(() => {
if (value > max) onChange(Math.min(value, max));
}, [max, onChange, value]);

// if `min` is changed dynamically don't allow the value to go below it
useEffect(() => {
if (value < min) onChange(Math.max(value, min));
}, [min, onChange, value]);

const handleChange = (e: ChangeEvent<HTMLInputElement>) => {
const value = Number(e.target.value);

onChange(value);
};

const ticks = useMemo(() => {
return Array.from({ length: (max - min) / step + 1 }, (_, i) => {
return i * step + min;
});
}, [max, min, step]);

return (
<div className={cn('flex flex-col space-y-4', className)}>
<div className="flex w-full items-center justify-between">
<InputLabel label={label} tooltipLabel={tooltipLabel} sublabel={sublabel} />
<Text>{formatValue ? formatValue(value) : value}</Text>
</div>
<div className="flex items-center">
<input
type="range"
value={value}
max={max}
min={min}
step={step}
onChange={handleChange}
className={cn(
'flex w-full cursor-pointer appearance-none items-center rounded-lg border outline-none active:cursor-grabbing',
'focus-visible:outline focus-visible:outline-1 focus-visible:outline-offset-4 focus-visible:outline-volcanic-100'
)}
/>
</div>
</div>
);
};
1 change: 1 addition & 0 deletions src/interfaces/assistants_web/src/components/UI/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export * from './RadioGroup';
export * from './Shortcut';
export * from './ShowStepsToggle';
export * from './Skeleton';
export * from './Slider';
export * from './Spinner';
export * from './Switch';
export * from './Tabs';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { FileAccept } from '@/components/UI';
export const DEFAULT_CONVERSATION_NAME = 'New Conversation';
export const DEFAULT_AGENT_MODEL = 'command-r-plus';
export const DEFAULT_AGENT_ID = 'default';
export const DEFAULT_AGENT_TEMPERATURE = 0.3;
export const DEFAULT_TYPING_VELOCITY = 35;
export const CONVERSATION_HISTORY_OFFSET = 100;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { StateCreator } from 'zustand';

import { CohereChatRequest, DEFAULT_CHAT_TEMPERATURE } from '@/cohere-client';
import { CohereChatRequest } from '@/cohere-client';

import { StoreState } from '..';

const INITIAL_STATE: ConfigurableParams = {
model: undefined,
temperature: DEFAULT_CHAT_TEMPERATURE,
temperature: undefined,
preamble: '',
tools: [],
fileIds: [],
Expand Down
Loading