From 713b227818b12a2854a74bde8ccb4548adefb6fb Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Sun, 29 Sep 2024 05:06:15 -0400 Subject: [PATCH] feat: upload and send images in chat --- app/client/api.ts | 9 +++++- app/client/mlcllm.ts | 5 +++- app/client/webllm.ts | 4 +-- app/components/chat.module.scss | 2 ++ app/components/chat.tsx | 49 +++++++++++++++++-------------- app/components/exporter.tsx | 4 +-- app/components/home.tsx | 2 +- app/components/settings.tsx | 2 +- app/components/template.tsx | 2 +- app/components/ui-lib.module.scss | 5 +++- app/components/ui-lib.tsx | 6 ++-- app/constant.ts | 33 +++++++++++++++++++-- app/icons/attachment.svg | 35 ++++++++++++++++++++++ app/icons/eye-off.svg | 2 +- app/icons/eye.svg | 2 +- app/store/chat.ts | 16 +++++++--- app/store/config.ts | 6 ++-- app/typing.ts | 6 ++++ app/utils.ts | 36 ++++++++++++----------- app/worker/service-worker.ts | 2 +- app/worker/web-worker.ts | 2 +- package.json | 2 +- yarn.lock | 6 ++-- 23 files changed, 169 insertions(+), 69 deletions(-) create mode 100644 app/icons/attachment.svg diff --git a/app/client/api.ts b/app/client/api.ts index 151a0bd5..6812baa9 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -1,4 +1,7 @@ -import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm"; +import { + ChatCompletionFinishReason, + CompletionUsage, +} from "@neet-nestor/web-llm"; import { CacheType, Model } from "../store"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -12,6 +15,10 @@ export interface MultimodalContent { image_url?: { url: string; }; + dimension?: { + width: number; + height: number; + }; } export interface RequestMessage { diff --git a/app/client/mlcllm.ts b/app/client/mlcllm.ts index f9fb04a3..77018822 100644 --- a/app/client/mlcllm.ts +++ b/app/client/mlcllm.ts @@ -1,6 +1,9 @@ import log from "loglevel"; import { ChatOptions, LLMApi } from "./api"; -import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm"; +import { + ChatCompletionFinishReason, + CompletionUsage, +} from "@neet-nestor/web-llm"; export class MlcLLMApi implements LLMApi { private endpoint: string; diff --git a/app/client/webllm.ts b/app/client/webllm.ts index 8473196e..235223dc 100644 --- a/app/client/webllm.ts +++ b/app/client/webllm.ts @@ -12,10 +12,10 @@ import { WebWorkerMLCEngine, CompletionUsage, ChatCompletionFinishReason, -} from "@mlc-ai/web-llm"; +} from "@neet-nestor/web-llm"; import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api"; -import { LogLevel } from "@mlc-ai/web-llm"; +import { LogLevel } from "@neet-nestor/web-llm"; import { fixMessage } from "../utils"; import { DEFAULT_MODELS } from "../constant"; diff --git a/app/components/chat.module.scss b/app/components/chat.module.scss index 41b31540..d9bcc00b 100644 --- a/app/components/chat.module.scss +++ b/app/components/chat.module.scss @@ -477,6 +477,8 @@ box-sizing: border-box; border-radius: 10px; border: rgba($color: #888, $alpha: 0.2) 1px solid; + height: auto; + display: block; } diff --git a/app/components/chat.tsx b/app/components/chat.tsx index c59491ae..0e938d93 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -8,7 +8,6 @@ import React, { Fragment, RefObject, useContext, - ReactElement, } from "react"; import ShareIcon from "../icons/share.svg"; @@ -21,16 +20,12 @@ import LoadingIcon from "../icons/three-dots.svg"; import LoadingButtonIcon from "../icons/loading.svg"; import PromptIcon from "../icons/prompt.svg"; import MaxIcon from "../icons/max.svg"; -import BrainIcon from "../icons/brain.svg"; import MinIcon from "../icons/min.svg"; import ResetIcon from "../icons/reload.svg"; import BreakIcon from "../icons/break.svg"; import DeleteIcon from "../icons/clear.svg"; -import PinIcon from "../icons/pin.svg"; import EditIcon from "../icons/rename.svg"; import ConfirmIcon from "../icons/confirm.svg"; -import InfoIcon from "../icons/info.svg"; -import CancelIcon from "../icons/cancel.svg"; import ImageIcon from "../icons/image.svg"; import BottomIcon from "../icons/bottom.svg"; @@ -74,7 +69,6 @@ import { Modal, Popover, Selector, - Tooltip, showConfirm, showPrompt, showToast, @@ -96,6 +90,8 @@ import { MultimodalContent } from "../client/api"; import { Template, useTemplateStore } from "../store/template"; import Image from "next/image"; import { MLCLLMContext, WebLLMContext } from "../context"; +import EyeIcon from "../icons/eye.svg"; +import { ChatImage } from "../typing"; export function ScrollDownToast(prop: { show: boolean; onclick: () => void }) { return ( @@ -481,7 +477,7 @@ function useScrollToBottom( export function ChatActions(props: { uploadImage: () => void; - setAttachImages: (images: string[]) => void; + setAttachImages: (images: ChatImage[]) => void; setUploading: (uploading: boolean) => void; scrollToBottom: () => void; showPromptSetting: () => void; @@ -553,6 +549,7 @@ export function ChatActions(props: { title: m.name, value: m.name, family: m.family, + icon: isVisionModel(m.name) ? : undefined, }))} onClose={() => setShowModelSelector(false)} onSelection={(s) => { @@ -603,7 +600,7 @@ function _Chat() { const [hitBottom, setHitBottom] = useState(true); const isMobileScreen = useMobileScreen(); const navigate = useNavigate(); - const [attachImages, setAttachImages] = useState([]); + const [attachImages, setAttachImages] = useState([]); const [uploading, setUploading] = useState(false); const [showEditPromptModal, setShowEditPromptModal] = useState(false); const webllm = useContext(WebLLMContext)!; @@ -971,15 +968,15 @@ function _Chat() { event.preventDefault(); const file = item.getAsFile(); if (file) { - const images: string[] = []; + const images: ChatImage[] = []; images.push(...attachImages); images.push( - ...(await new Promise((res, rej) => { + ...(await new Promise((res, rej) => { setUploading(true); - const imagesData: string[] = []; + const imagesData: ChatImage[] = []; compressImage(file, 256 * 1024) - .then((dataUrl) => { - imagesData.push(dataUrl); + .then((imageData) => { + imagesData.push(imageData); setUploading(false); res(imagesData); }) @@ -1003,11 +1000,11 @@ function _Chat() { ); async function uploadImage() { - const images: string[] = []; + const images: ChatImage[] = []; images.push(...attachImages); images.push( - ...(await new Promise((res, rej) => { + ...(await new Promise((res, rej) => { const fileInput = document.createElement("input"); fileInput.type = "file"; fileInput.accept = @@ -1016,12 +1013,12 @@ function _Chat() { fileInput.onchange = (event: any) => { setUploading(true); const files = event.target.files; - const imagesData: string[] = []; + const imagesData: ChatImage[] = []; for (let i = 0; i < files.length; i++) { const file = event.target.files[i]; compressImage(file, 256 * 1024) - .then((dataUrl) => { - imagesData.push(dataUrl); + .then((imageData) => { + imagesData.push(imageData); if ( imagesData.length === 3 || imagesData.length === files.length @@ -1225,7 +1222,11 @@ function _Chat() { newContent.push({ type: "image_url", image_url: { - url: images[i], + url: images[i].url, + }, + dimension: { + width: images[i].width, + height: images[i].height, }, }); } @@ -1301,7 +1302,9 @@ function _Chat() { {getMessageImages(message).length == 1 && ( )} @@ -1321,7 +1324,9 @@ function _Chat() { styles["chat-message-item-image-multi"] } key={index} - src={image} + src={image.url} + width={image.width} + height={image.height} alt="" /> ); @@ -1413,7 +1418,7 @@ function _Chat() {
@@ -451,7 +451,7 @@ export function ImagePreviewer(props: { } as React.CSSProperties } > - {getMessageImages(m).map((src, i) => ( + {getMessageImages(m).map(({ url: src }, i) => ( void }) { diff --git a/app/components/template.tsx b/app/components/template.tsx index f9652ec7..a7b47922 100644 --- a/app/components/template.tsx +++ b/app/components/template.tsx @@ -238,7 +238,7 @@ export function ContextPrompts(props: { const text = getMessageTextContent(context[i]); const newContext: MultimodalContent[] = [{ type: "text", text }]; for (const img of images) { - newContext.push({ type: "image_url", image_url: { url: img } }); + newContext.push({ type: "image_url", image_url: { url: img.url } }); } context[i].content = newContext; } diff --git a/app/components/ui-lib.module.scss b/app/components/ui-lib.module.scss index e944c4a6..906176fc 100644 --- a/app/components/ui-lib.module.scss +++ b/app/components/ui-lib.module.scss @@ -112,7 +112,10 @@ align-items: center; .list-icon { - margin-right: 10px; + margin-left: 10px; + display: flex; + align-items: center; + justify-content: center; } .list-item-title { diff --git a/app/components/ui-lib.tsx b/app/components/ui-lib.tsx index b0588b28..cb3bd90f 100644 --- a/app/components/ui-lib.tsx +++ b/app/components/ui-lib.tsx @@ -13,7 +13,7 @@ import MinIcon from "../icons/min.svg"; import Locale from "../locales"; import { createRoot } from "react-dom/client"; -import React, { HTMLProps, useEffect, useState } from "react"; +import React, { HTMLProps, ReactNode, useEffect, useState } from "react"; import { IconButton } from "./button"; export function Popover(props: { @@ -95,7 +95,6 @@ export function ListItem(props: { onClick={props.onClick} >
- {props.icon &&
{props.icon}
}
{props.title}
{props.subTitle && ( @@ -104,6 +103,7 @@ export function ListItem(props: {
)}
+ {props.icon &&
{props.icon}
}
{props.children}
@@ -489,6 +489,7 @@ export function Selector(props: { subTitle?: string; value: T; family?: string; + icon?: JSX.Element; }>; defaultSelectedValue?: T; onSelection?: (selection: T[]) => void; @@ -517,6 +518,7 @@ export function Selector(props: { props.onSelection?.([item.value]); props.onClose?.(); }} + icon={item.icon} > {selected ? (
+ + + attachment + + + + + + + diff --git a/app/icons/eye-off.svg b/app/icons/eye-off.svg index dd7e0b80..e45b6dd6 100644 --- a/app/icons/eye-off.svg +++ b/app/icons/eye-off.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/app/icons/eye.svg b/app/icons/eye.svg index aab43d43..ca42bd10 100644 --- a/app/icons/eye.svg +++ b/app/icons/eye.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/app/store/chat.ts b/app/store/chat.ts index 3bb19559..15ef4514 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -15,7 +15,11 @@ import { RequestMessage, MultimodalContent, LLMApi } from "../client/api"; import { estimateTokenLength } from "../utils/token"; import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; -import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm"; +import { + ChatCompletionFinishReason, + CompletionUsage, +} from "@neet-nestor/web-llm"; +import { ChatImage } from "../typing"; export type ChatMessage = RequestMessage & { date: string; @@ -279,7 +283,7 @@ export const useChatStore = createPersistStore( get().summarizeSession(llm); }, - onUserInput(content: string, llm: LLMApi, attachImages?: string[]) { + onUserInput(content: string, llm: LLMApi, attachImages?: ChatImage[]) { const modelConfig = useAppConfig.getState().modelConfig; const userContent = fillTemplateWith(content, useAppConfig.getState()); @@ -295,11 +299,15 @@ export const useChatStore = createPersistStore( }, ]; mContent = mContent.concat( - attachImages.map((url) => { + attachImages.map((imageData) => { return { type: "image_url", image_url: { - url: url, + url: imageData.url, + }, + dimension: { + width: imageData.width, + height: imageData.height, }, }; }), diff --git a/app/store/config.ts b/app/store/config.ts index eac6ce7a..312a2462 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,4 @@ -import { LogLevel } from "@mlc-ai/web-llm"; +import { LogLevel } from "@neet-nestor/web-llm"; import { ModelRecord } from "../client/api"; import { DEFAULT_INPUT_TEMPLATE, @@ -208,9 +208,9 @@ export const useAppConfig = createPersistStore( }), { name: StoreKey.Config, - version: 0.53, + version: 0.54, migrate: (persistedState, version) => { - if (version < 0.53) { + if (version < 0.54) { return { ...DEFAULT_CONFIG, ...(persistedState as any), diff --git a/app/typing.ts b/app/typing.ts index b09722ab..e624a014 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -7,3 +7,9 @@ export interface RequestMessage { role: MessageRole; content: string; } + +export interface ChatImage { + url: string; + width: number; + height: number; +} diff --git a/app/utils.ts b/app/utils.ts index e001b442..ae69464e 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -2,6 +2,9 @@ import { useEffect, useState } from "react"; import { showToast } from "./components/ui-lib"; import Locale from "./locales"; import { RequestMessage } from "./client/api"; +import { Model } from "./store"; +import { ModelType, prebuiltAppConfig } from "@neet-nestor/web-llm"; +import { ChatImage } from "./typing"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -51,7 +54,7 @@ export async function downloadAs(text: string, filename: string) { document.body.removeChild(element); } -export function compressImage(file: File, maxSize: number): Promise { +export function compressImage(file: File, maxSize: number): Promise { return new Promise((resolve, reject) => { const reader = new FileReader(); reader.onload = (readerEvent: any) => { @@ -83,7 +86,11 @@ export function compressImage(file: File, maxSize: number): Promise { } } while (dataUrl.length > maxSize); - resolve(dataUrl); + resolve({ + url: dataUrl, + width: width, + height: height, + }); }; image.onerror = reject; image.src = readerEvent.target.result; @@ -244,31 +251,26 @@ export function getMessageTextContent(message: RequestMessage) { return ""; } -export function getMessageImages(message: RequestMessage): string[] { +export function getMessageImages(message: RequestMessage): ChatImage[] { if (typeof message.content === "string") { return []; } - const urls: string[] = []; + const urls: ChatImage[] = []; for (const c of message.content) { if (c.type === "image_url") { - urls.push(c.image_url?.url ?? ""); + urls.push({ + url: c.image_url?.url ?? "", + width: c.dimension?.width ?? 0, + height: c.dimension?.height ?? 0, + }); } } return urls; } -export function isVisionModel(model: string) { - // Note: This is a better way using the TypeScript feature instead of `&&` or `||` (ts v5.5.0-dev.20240314 I've been using) - - const visionKeywords = ["vision", "claude-3", "gemini-1.5-pro"]; - - const isGpt4Turbo = - model.includes("gpt-4-turbo") && !model.includes("preview"); - - return ( - visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo - ); -} +export const isVisionModel = (model: Model) => + prebuiltAppConfig.model_list.find((m) => m.model_id === model)?.model_type === + ModelType.VLM; // Fix various problems in webllm generation export function fixMessage(message: string) { diff --git a/app/worker/service-worker.ts b/app/worker/service-worker.ts index 8e520e16..50cd4688 100644 --- a/app/worker/service-worker.ts +++ b/app/worker/service-worker.ts @@ -1,4 +1,4 @@ -import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; +import { ServiceWorkerMLCEngineHandler } from "@neet-nestor/web-llm"; import { defaultCache } from "@serwist/next/worker"; import type { PrecacheEntry, SerwistGlobalConfig } from "serwist"; import { CacheFirst, ExpirationPlugin, Serwist } from "serwist"; diff --git a/app/worker/web-worker.ts b/app/worker/web-worker.ts index 12fccbb2..84955c5d 100644 --- a/app/worker/web-worker.ts +++ b/app/worker/web-worker.ts @@ -1,5 +1,5 @@ import log from "loglevel"; -import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; +import { WebWorkerMLCEngineHandler } from "@neet-nestor/web-llm"; let handler: WebWorkerMLCEngineHandler; diff --git a/package.json b/package.json index 1064001e..6862d6df 100644 --- a/package.json +++ b/package.json @@ -17,7 +17,7 @@ "dependencies": { "@fortaine/fetch-event-source": "^3.0.6", "@hello-pangea/dnd": "^16.5.0", - "@mlc-ai/web-llm": "^0.2.71", + "@neet-nestor/web-llm": "^0.2.71", "@serwist/next": "^9.0.2", "@svgr/webpack": "^6.5.1", "emoji-picker-react": "^4.9.2", diff --git a/yarn.lock b/yarn.lock index 6e0babaa..8256707f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1180,10 +1180,10 @@ "@jridgewell/resolve-uri" "^3.1.0" "@jridgewell/sourcemap-codec" "^1.4.14" -"@mlc-ai/web-llm@^0.2.71": +"@neet-nestor/web-llm@^0.2.71": version "0.2.71" - resolved "https://registry.yarnpkg.com/@mlc-ai/web-llm/-/web-llm-0.2.71.tgz#6e0533fc6ad643a852f9c79d650506ea96061924" - integrity sha512-VRo+yZFnHsQeI/a0DLNnnUdWtDGiIwL7vYN4YdYXMZMKXGT+rcXRRUIVlH3/AVUbQx/TKHyKl1wxenXg6Bpx1w== + resolved "https://registry.yarnpkg.com/@neet-nestor/web-llm/-/web-llm-0.2.71.tgz#8dd96fec43ad3ec1c890c60c8999985764262b92" + integrity sha512-cvdFLoCXazp2WvGR3H9OYARAPlz7WrPr2ii1AHGY+FHJZUd0Db7OiSRy+hBxWEtdyeWZyj1dTW8hYhAIiO+Deg== dependencies: loglevel "^1.9.1"