From 0c28da97da97db8c11c350de166f627a9a55b278 Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Thu, 16 May 2024 05:04:59 -0400 Subject: [PATCH] fix: Restart service worker after it is killed --- .env.template | 63 ------------- app/client/api.ts | 2 - app/client/webllm.ts | 181 ++++++++++++++++++++++-------------- app/components/chat.tsx | 9 +- app/components/home.tsx | 53 +++++++---- app/components/settings.tsx | 1 - app/service-worker.ts | 53 +++++++++-- package.json | 4 +- public/ping.txt | 1 + yarn.lock | 18 +--- 10 files changed, 199 insertions(+), 186 deletions(-) delete mode 100644 .env.template create mode 100644 public/ping.txt diff --git a/.env.template b/.env.template deleted file mode 100644 index b2a0438d..00000000 --- a/.env.template +++ /dev/null @@ -1,63 +0,0 @@ - -# Your openai api key. (required) -OPENAI_API_KEY=sk-xxxx - -# Access password, separated by comma. (optional) -CODE=your-password - -# You can start service behind a proxy -PROXY_URL=http://localhost:7890 - -# (optional) -# Default: Empty -# Googel Gemini Pro API key, set if you want to use Google Gemini Pro API. -GOOGLE_API_KEY= - -# (optional) -# Default: https://generativelanguage.googleapis.com/ -# Googel Gemini Pro API url without pathname, set if you want to customize Google Gemini Pro API url. -GOOGLE_URL= - -# Override openai api request base url. (optional) -# Default: https://api.openai.com -# Examples: http://your-openai-proxy.com -BASE_URL= - -# Specify OpenAI organization ID.(optional) -# Default: Empty -OPENAI_ORG_ID= - -# (optional) -# Default: Empty -# If you do not want users to use GPT-4, set this value to 1. -DISABLE_GPT4= - -# (optional) -# Default: Empty -# If you do not want users to input their own API key, set this value to 1. -HIDE_USER_API_KEY= - -# (optional) -# Default: Empty -# If you do want users to query balance, set this value to 1. -ENABLE_BALANCE_QUERY= - -# (optional) -# Default: Empty -# If you want to disable parse settings from url, set this value to 1. -DISABLE_FAST_LINK= - - -# anthropic claude Api Key.(optional) -ANTHROPIC_API_KEY= - -### anthropic claude Api version. (optional) -ANTHROPIC_API_VERSION= - - - -### anthropic claude Api url (optional) -ANTHROPIC_URL= - -### (optional) -WHITE_WEBDEV_ENDPOINTS= \ No newline at end of file diff --git a/app/client/api.ts b/app/client/api.ts index e7ffe0aa..3b1ef64f 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -58,7 +58,5 @@ export interface LLMModelProvider { export abstract class LLMApi { abstract chat(options: ChatOptions): Promise; abstract usage(): Promise; - abstract models(): Promise; abstract abort(): Promise; - abstract clear(): void; } diff --git a/app/client/webllm.ts b/app/client/webllm.ts index cc96a625..34c03303 100644 --- a/app/client/webllm.ts +++ b/app/client/webllm.ts @@ -1,92 +1,143 @@ import { createContext } from "react"; import { - CreateWebServiceWorkerEngine, InitProgressReport, prebuiltAppConfig, ChatCompletionMessageParam, - WebServiceWorkerEngine, + ServiceWorkerEngine, + ServiceWorker, + ChatCompletionChunk, + ChatCompletion, } from "@neet-nestor/web-llm"; -import { ChatOptions, LLMApi, LLMConfig } from "./api"; +import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api"; + +const KEEP_ALIVE_INTERVAL = 10000; export class WebLLMApi implements LLMApi { - private currentModel?: string; - private engine?: WebServiceWorkerEngine; + private llmConfig?: LLMConfig; + engine?: ServiceWorkerEngine; - constructor(onEngineCrash: () => void) { - setInterval(() => { - if ((this.engine?.missedHeatbeat || 0) > 2) { - onEngineCrash?.(); - } - }, 10000); + constructor() { + this.engine = new ServiceWorkerEngine(new ServiceWorker()); + this.engine.keepAlive( + window.location.href + "ping.txt", + KEEP_ALIVE_INTERVAL, + ); + } + + async initModel(onUpdate?: (message: string, chunk: string) => void) { + if (!this.llmConfig) { + throw Error("llmConfig is undefined"); + } + if (!this.engine) { + this.engine = new ServiceWorkerEngine(new ServiceWorker()); + } + let hasResponse = false; + this.engine.setInitProgressCallback((report: InitProgressReport) => { + onUpdate?.(report.text, report.text); + hasResponse = true; + }); + let initRequest = this.engine.init(this.llmConfig.model, this.llmConfig, { + ...prebuiltAppConfig, + useIndexedDBCache: this.llmConfig.cache === "index_db", + }); + // In case the service worker is dead, init will halt indefinitely + // so we manually retry if timeout + let retry = 0; + let engine = this.engine; + let llmConfig = this.llmConfig; + let retryInterval: NodeJS.Timeout; + + await new Promise((resolve, reject) => { + retryInterval = setInterval(() => { + if (hasResponse) { + clearInterval(retryInterval); + initRequest.then(resolve); + return; + } + if (retry >= 5) { + clearInterval(retryInterval); + reject("Model initialization timed out for too many times"); + return; + } + retry += 1; + initRequest = engine.init(llmConfig.model, llmConfig, { + ...prebuiltAppConfig, + useIndexedDBCache: llmConfig.cache === "index_db", + }); + }, 5000); + }); } - clear() { - this.engine = undefined; + isConfigChanged(config: LLMConfig) { + return ( + this.llmConfig?.model !== config.model || + this.llmConfig?.cache !== config.cache || + this.llmConfig?.temperature !== config.temperature || + this.llmConfig?.top_p !== config.top_p || + this.llmConfig?.presence_penalty !== config.presence_penalty || + this.llmConfig?.frequency_penalty !== config.frequency_penalty + ); } - async initModel( - config: LLMConfig, + async chatCompletion( + stream: boolean, + messages: RequestMessage[], onUpdate?: (message: string, chunk: string) => void, ) { - this.currentModel = config.model; - this.engine = await CreateWebServiceWorkerEngine(config.model, { - chatOpts: { - temperature: config.temperature, - top_p: config.top_p, - presence_penalty: config.presence_penalty, - frequency_penalty: config.frequency_penalty, - }, - appConfig: { - ...prebuiltAppConfig, - useIndexedDBCache: config.cache === "index_db", - }, - initProgressCallback: (report: InitProgressReport) => { - onUpdate?.(report.text, report.text); - }, + let reply: string | null = ""; + + const completion = await this.engine!.chatCompletion({ + stream: stream, + messages: messages as ChatCompletionMessageParam[], }); + + if (stream) { + const asyncGenerator = completion as AsyncIterable; + for await (const chunk of asyncGenerator) { + if (chunk.choices[0].delta.content) { + reply += chunk.choices[0].delta.content; + onUpdate?.(reply, chunk.choices[0].delta.content); + } + } + return reply; + } + return (completion as ChatCompletion).choices[0].message.content; } async chat(options: ChatOptions): Promise { - if (options.config.model !== this.currentModel) { + // in case the service worker is dead, revive it by firing a fetch event + fetch("/ping.txt"); + + if (this.isConfigChanged(options.config)) { + this.llmConfig = options.config; try { - await this.initModel(options.config, options.onUpdate); + await this.initModel(options.onUpdate); } catch (e) { console.error("Error in initModel", e); } } let reply: string | null = ""; - if (options.config.stream) { - try { - const asyncChunkGenerator = await this.engine!.chatCompletion({ - stream: options.config.stream, - messages: options.messages as ChatCompletionMessageParam[], - }); - - for await (const chunk of asyncChunkGenerator) { - if (chunk.choices[0].delta.content) { - reply += chunk.choices[0].delta.content; - options.onUpdate?.(reply, chunk.choices[0].delta.content); - } - } - } catch (err) { - console.error("Error in streaming chatCompletion", err); - options.onError?.(err as Error); - return; - } - } else { - try { - const completion = await this.engine!.chatCompletion({ - stream: options.config.stream, - messages: options.messages as ChatCompletionMessageParam[], - }); - reply = completion.choices[0].message.content; - } catch (err) { - console.error("Error in non-streaming chatCompletion", err); + try { + reply = await this.chatCompletion( + !!options.config.stream, + options.messages, + options.onUpdate, + ); + } catch (err: any) { + if (err.toString().includes("Please call `Engine.reload(model)` first")) { + console.error("Error in chatCompletion", err); options.onError?.(err as Error); return; } + // Service worker has been stopped. Restart it + await this.initModel(options.onUpdate); + reply = await this.chatCompletion( + !!options.config.stream, + options.messages, + options.onUpdate, + ); } if (reply) { @@ -106,18 +157,6 @@ export class WebLLMApi implements LLMApi { total: 0, }; } - - async models() { - return prebuiltAppConfig.model_list.map((record) => ({ - name: record.model_id, - available: true, - provider: { - id: "huggingface", - providerName: "huggingface", - providerType: "huggingface", - }, - })); - } } export const WebLLMContext = createContext(null); diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 4004999a..b831ceae 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -97,7 +97,7 @@ import { ExportMessageModal } from "./exporter"; import { getClientConfig } from "../config/client"; import { useAllModels } from "../utils/hooks"; import { MultimodalContent } from "../client/api"; -import { WebLLMApi, WebLLMContext } from "../client/webllm"; +import { WebLLMContext } from "../client/webllm"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -682,8 +682,7 @@ function _Chat() { const navigate = useNavigate(); const [attachImages, setAttachImages] = useState([]); const [uploading, setUploading] = useState(false); - - const webllm = useContext(WebLLMContext); + const webllm = useContext(WebLLMContext)!; // prompt hints const promptStore = usePromptStore(); @@ -764,7 +763,7 @@ function _Chat() { if (isStreaming) return; setIsLoading(true); chatStore - .onUserInput(userInput, webllm!, attachImages) + .onUserInput(userInput, webllm, attachImages) .then(() => setIsLoading(false)); setAttachImages([]); localStorage.setItem(LAST_INPUT_KEY, userInput); @@ -922,7 +921,7 @@ function _Chat() { const textContent = getMessageTextContent(userMessage); const images = getMessageImages(userMessage); chatStore - .onUserInput(textContent, webllm!, images) + .onUserInput(textContent, webllm, images) .then(() => setIsLoading(false)); inputRef.current?.focus(); }; diff --git a/app/components/home.tsx b/app/components/home.tsx index cc35d106..41e9e662 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -28,6 +28,7 @@ import { useAppConfig } from "../store/config"; import { getClientConfig } from "../config/client"; import { WebLLMApi, WebLLMContext } from "../client/webllm"; import Locale from "../locales"; +import { prebuiltAppConfig } from "@neet-nestor/web-llm"; export function Loading(props: { noLogo?: boolean }) { return ( @@ -177,32 +178,48 @@ function Screen() { ); } -export function useLoadData(webllm: WebLLMApi) { +export function useLoadData() { const config = useAppConfig(); useEffect(() => { - (async () => { - if (webllm) { - const models = await webllm.models(); - config.mergeModels(models); - } - })(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [webllm]); + config.mergeModels( + prebuiltAppConfig.model_list.map((record) => ({ + name: record.model_id, + available: true, + provider: { + id: "huggingface", + providerName: "huggingface", + providerType: "huggingface", + }, + })), + ); + }, []); } +const useWebLLM = () => { + const [webllm, setWebLLM] = useState(null); + const [isSWAlive, setSWAlive] = useState(true); + + useEffect(() => { + setWebLLM(new WebLLMApi()); + }, []); + + setInterval(() => { + if (webllm) { + // 10s per heartbeat, dead after 1 min of inactivity + setSWAlive(!!webllm.engine && webllm.engine.missedHeatbeat < 6); + } + }); + + return { webllm, isWebllmAlive: isSWAlive }; +}; + export function Home() { const hasHydrated = useHasHydrated(); const isServiceWorkerReady = useServiceWorkerReady(); - const [isEngineCrash, setEngineCrash] = useState(false); - - const webllm = useMemo(() => { - return new WebLLMApi(() => { - setEngineCrash(true); - }); - }, []); + const { webllm, isWebllmAlive } = useWebLLM(); - useLoadData(webllm); + useLoadData(); useSwitchTheme(); useHtmlLang(); @@ -210,7 +227,7 @@ export function Home() { return ; } - if (isEngineCrash) { + if (!isWebllmAlive) { return ; } diff --git a/app/components/settings.tsx b/app/components/settings.tsx index f44a9256..fff92caf 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -471,7 +471,6 @@ export function Settings() {