diff --git a/src/backend/routers/chat.py b/src/backend/routers/chat.py index dd88fb1cb8..85ae6e0a26 100644 --- a/src/backend/routers/chat.py +++ b/src/backend/routers/chat.py @@ -1,6 +1,6 @@ from typing import Any, Generator -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends from sse_starlette.sse import EventSourceResponse from backend.chat.custom.custom import CustomChat @@ -31,7 +31,6 @@ async def chat_stream( session: DBSessionDep, chat_request: CohereChatRequest, - request: Request, ctx: Context = Depends(get_context), ) -> Generator[ChatResponseEvent, Any, None]: """ @@ -58,7 +57,7 @@ async def chat_stream( managed_tools, next_message_position, ctx, - ) = process_chat(session, chat_request, request, ctx) + ) = process_chat(session, chat_request, ctx) return EventSourceResponse( generate_chat_stream( @@ -86,7 +85,6 @@ async def chat_stream( async def regenerate_chat_stream( session: DBSessionDep, chat_request: CohereChatRequest, - request: Request, ctx: Context = Depends(get_context), ) -> EventSourceResponse: """ @@ -127,7 +125,7 @@ async def regenerate_chat_stream( previous_response_message_ids, managed_tools, ctx, - ) = process_message_regeneration(session, chat_request, request, ctx) + ) = process_message_regeneration(session, chat_request, ctx) return EventSourceResponse( generate_chat_stream( @@ -155,7 +153,6 @@ async def regenerate_chat_stream( async def chat( session: DBSessionDep, chat_request: CohereChatRequest, - request: Request, ctx: Context = Depends(get_context), ) -> NonStreamedChatResponse: """ @@ -197,7 +194,7 @@ async def chat( managed_tools, next_message_position, ctx, - ) = process_chat(session, chat_request, request, ctx) + ) = process_chat(session, chat_request, ctx) response = await generate_chat_response( session, diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 8f28bae183..06e149a1a6 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -4,7 +4,7 @@ import nltk from cohere.types import StreamedChatResponse -from fastapi import HTTPException, Request +from fastapi import HTTPException from fastapi.encoders import jsonable_encoder from backend.chat.collate import to_dict @@ -74,19 +74,17 @@ def generate_tools_preamble(chat_request: CohereChatRequest) -> str: def process_chat( session: DBSessionDep, - chat_request: BaseChatRequest, - request: Request, + chat_request: CohereChatRequest, ctx: Context, ) -> tuple[ - DBSessionDep, BaseChatRequest, Union[list[str], None], Message, str, str, dict + DBSessionDep, CohereChatRequest, Union[list[str], None], Message, str, str, Context ]: """ Process a chat request. Args: - chat_request (BaseChatRequest): Chat request data. + chat_request (CohereChatRequest): Chat request data. session (DBSessionDep): Database session. - request (Request): Request object. ctx (Context): Context object. Returns: @@ -124,6 +122,10 @@ def process_chat( chat_request.model = agent.model chat_request.preamble = agent.preamble + # If temperature is not defined in the chat request, use the temperature from the agent + if not chat_request.temperature: + chat_request.temperature = agent.temperature + should_store = chat_request.chat_history is None and not is_custom_tool_call( chat_request ) @@ -193,7 +195,6 @@ def process_chat( def process_message_regeneration( session: DBSessionDep, chat_request: CohereChatRequest, - request: Request, ctx: Context, ) -> tuple[Any, CohereChatRequest, Message, list[str], bool, Context]: """ @@ -202,7 +203,6 @@ def process_message_regeneration( Args: session (DBSessionDep): Database session. chat_request (CohereChatRequest): Chat request data. - request (Request): Request object. ctx (Context): Context object. Returns: @@ -224,6 +224,10 @@ def process_message_regeneration( # Set the agent settings in the chat request chat_request.preamble = agent.preamble + # If temperature is not defined in the chat request, use the temperature from the agent + if not chat_request.temperature: + chat_request.temperature = agent.temperature + conversation_id = chat_request.conversation_id ctx.with_conversation_id(conversation_id) diff --git a/src/backend/tests/unit/factories/agent.py b/src/backend/tests/unit/factories/agent.py index 14e74e67e2..71c0d256e9 100644 --- a/src/backend/tests/unit/factories/agent.py +++ b/src/backend/tests/unit/factories/agent.py @@ -19,7 +19,7 @@ class Meta: description = factory.Faker("sentence") preamble = factory.Faker("sentence") version = factory.Faker("random_int") - temperature = factory.Faker("pyfloat") + temperature = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) created_at = factory.Faker("date_time") updated_at = factory.Faker("date_time") tools = factory.List( diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index ef8b020345..3753e63622 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -311,7 +311,6 @@ def test_streaming_fail_chat_missing_message( "loc": ["body", "message"], "msg": "Field required", "input": {}, - "url": "https://errors.pydantic.dev/2.10/v/missing", } ] } diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx index 9d0a5aa71b..a17ae38716 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx @@ -49,6 +49,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ const fileIds = conversation?.files.map((file) => file.id); setParams({ + temperature: agent?.temperature, tools: agentTools, fileIds, }); diff --git a/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/UpdateAgent.tsx b/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/UpdateAgent.tsx index 76a66930b8..0f17fcec0d 100644 --- a/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/UpdateAgent.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/UpdateAgent.tsx @@ -8,7 +8,11 @@ import { AgentSettingsFields, AgentSettingsForm } from '@/components/AgentSettin import { MobileHeader } from '@/components/Global'; import { DeleteAgent } from '@/components/Modals/DeleteAgent'; import { Button, Icon, Spinner, Text } from '@/components/UI'; -import { DEFAULT_AGENT_MODEL, DEPLOYMENT_COHERE_PLATFORM } from '@/constants'; +import { + DEFAULT_AGENT_MODEL, + DEFAULT_AGENT_TEMPERATURE, + DEPLOYMENT_COHERE_PLATFORM, +} from '@/constants'; import { useContextStore } from '@/context'; import { useIsAgentNameUnique, useNotify, useUpdateAgent } from '@/hooks'; @@ -28,6 +32,7 @@ export const UpdateAgent: React.FC = ({ agent }) => { description: agent.description, deployment: agent.deployment ?? DEPLOYMENT_COHERE_PLATFORM, model: agent.model ?? DEFAULT_AGENT_MODEL, + temperature: agent.temperature ?? DEFAULT_AGENT_TEMPERATURE, tools: agent.tools, preamble: agent.preamble, tools_metadata: agent.tools_metadata, diff --git a/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx b/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx index fddd3035f6..6bb4a8c059 100644 --- a/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx @@ -11,6 +11,7 @@ import { Button, Icon, Text } from '@/components/UI'; import { BACKGROUND_TOOLS, DEFAULT_AGENT_MODEL, + DEFAULT_AGENT_TEMPERATURE, DEFAULT_PREAMBLE, DEPLOYMENT_COHERE_PLATFORM, } from '@/constants'; @@ -23,6 +24,7 @@ const DEFAULT_FIELD_VALUES = { preamble: DEFAULT_PREAMBLE, deployment: DEPLOYMENT_COHERE_PLATFORM, model: DEFAULT_AGENT_MODEL, + temperature: DEFAULT_AGENT_TEMPERATURE, tools: BACKGROUND_TOOLS, is_private: false, }; diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ConfigStep.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ConfigStep.tsx index 8c5ca8ade9..1581cca22c 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ConfigStep.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ConfigStep.tsx @@ -3,7 +3,7 @@ import { useState } from 'react'; import { AgentSettingsFields } from '@/components/AgentSettingsForm'; -import { Dropdown } from '@/components/UI'; +import { Dropdown, Slider } from '@/components/UI'; import { useListAllDeployments } from '@/hooks'; type Props = { @@ -13,7 +13,10 @@ type Props = { }; export const ConfigStep: React.FC = ({ fields, setFields }) => { - const [selectedValue, setSelectedValue] = useState(fields.model); + const [selectedModelValue, setSelectedModelValue] = useState(fields.model); + const [selectedTemperatureValue, setSelectedTemperatureValue] = useState( + fields.temperature + ); const { data: deployments } = useListAllDeployments(); @@ -27,12 +30,23 @@ export const ConfigStep: React.FC = ({ fields, setFields }) => { { setFields({ ...fields, model: model }); - setSelectedValue(model); + setSelectedModelValue(model); }} /> + { + setFields({ ...fields, temperature: temperature }); + setSelectedTemperatureValue(temperature); + }} + > ); }; diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx index c945b958a1..0c235c56a1 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/index.tsx @@ -24,13 +24,13 @@ type RequiredAndNotNull = { type RequireAndNotNullSome = RequiredAndNotNull> & Omit; type CreateAgentSettingsFields = RequireAndNotNullSome< - Omit, - 'name' | 'model' | 'deployment' + Omit, + 'name' | 'model' | 'deployment' | 'temperature' >; type UpdateAgentSettingsFields = RequireAndNotNullSome< - Omit, - 'name' | 'model' | 'deployment' + Omit, + 'name' | 'model' | 'deployment' | 'temperature' > & { is_private?: boolean }; export type AgentSettingsFields = CreateAgentSettingsFields | UpdateAgentSettingsFields; diff --git a/src/interfaces/assistants_web/src/components/UI/Slider.tsx b/src/interfaces/assistants_web/src/components/UI/Slider.tsx new file mode 100644 index 0000000000..6a852cd91a --- /dev/null +++ b/src/interfaces/assistants_web/src/components/UI/Slider.tsx @@ -0,0 +1,82 @@ +'use client'; + +import { ChangeEvent, useEffect, useMemo } from 'react'; + +import { InputLabel, Text } from '@/components/UI'; +import { cn } from '@/utils'; + +type Props = { + label: string; + min: number; + max: number; + step: number; + value: number; + onChange: (value: number) => void; + sublabel?: string; + className?: string; + tooltipLabel?: React.ReactNode; + formatValue?: (value: number) => string; +}; + +/** + * + * Renders a slider with a label, a minimum, maximum and step value, and optional subLabel and tooltip. + * Styling for the thumb is located in main.css + */ +export const Slider: React.FC = ({ + label, + sublabel, + min, + max, + step, + value, + onChange, + tooltipLabel, + formatValue, + className = '', +}) => { + // if `max` is changed dynamically don't allow the value to surpass it + useEffect(() => { + if (value > max) onChange(Math.min(value, max)); + }, [max, onChange, value]); + + // if `min` is changed dynamically don't allow the value to go below it + useEffect(() => { + if (value < min) onChange(Math.max(value, min)); + }, [min, onChange, value]); + + const handleChange = (e: ChangeEvent) => { + const value = Number(e.target.value); + + onChange(value); + }; + + const ticks = useMemo(() => { + return Array.from({ length: (max - min) / step + 1 }, (_, i) => { + return i * step + min; + }); + }, [max, min, step]); + + return ( +
+
+ + {formatValue ? formatValue(value) : value} +
+
+ +
+
+ ); +}; diff --git a/src/interfaces/assistants_web/src/components/UI/index.ts b/src/interfaces/assistants_web/src/components/UI/index.ts index c2c152b359..2ba8c52615 100644 --- a/src/interfaces/assistants_web/src/components/UI/index.ts +++ b/src/interfaces/assistants_web/src/components/UI/index.ts @@ -24,6 +24,7 @@ export * from './RadioGroup'; export * from './Shortcut'; export * from './ShowStepsToggle'; export * from './Skeleton'; +export * from './Slider'; export * from './Spinner'; export * from './Switch'; export * from './Tabs'; diff --git a/src/interfaces/assistants_web/src/constants/conversation.ts b/src/interfaces/assistants_web/src/constants/conversation.ts index 5947da1062..b1c4c4771e 100644 --- a/src/interfaces/assistants_web/src/constants/conversation.ts +++ b/src/interfaces/assistants_web/src/constants/conversation.ts @@ -3,6 +3,7 @@ import { FileAccept } from '@/components/UI'; export const DEFAULT_CONVERSATION_NAME = 'New Conversation'; export const DEFAULT_AGENT_MODEL = 'command-r-plus'; export const DEFAULT_AGENT_ID = 'default'; +export const DEFAULT_AGENT_TEMPERATURE = 0.3; export const DEFAULT_TYPING_VELOCITY = 35; export const CONVERSATION_HISTORY_OFFSET = 100; diff --git a/src/interfaces/assistants_web/src/stores/slices/paramsSlice.ts b/src/interfaces/assistants_web/src/stores/slices/paramsSlice.ts index 9146bc01f3..df0e8f20ee 100644 --- a/src/interfaces/assistants_web/src/stores/slices/paramsSlice.ts +++ b/src/interfaces/assistants_web/src/stores/slices/paramsSlice.ts @@ -1,12 +1,12 @@ import { StateCreator } from 'zustand'; -import { CohereChatRequest, DEFAULT_CHAT_TEMPERATURE } from '@/cohere-client'; +import { CohereChatRequest } from '@/cohere-client'; import { StoreState } from '..'; const INITIAL_STATE: ConfigurableParams = { model: undefined, - temperature: DEFAULT_CHAT_TEMPERATURE, + temperature: undefined, preamble: '', tools: [], fileIds: [],