Skip to content

Commit

Permalink
fix: Do not reload model for summarization tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Neet-Nestor committed May 17, 2024
1 parent 0c28da9 commit 8dc7536
Showing 1 changed file with 59 additions and 27 deletions.
86 changes: 59 additions & 27 deletions app/client/webllm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,12 @@ export class WebLLMApi implements LLMApi {
);
}

async chatCompletion(
stream: boolean,
messages: RequestMessage[],
onUpdate?: (message: string, chunk: string) => void,
) {
let reply: string | null = "";

const completion = await this.engine!.chatCompletion({
stream: stream,
messages: messages as ChatCompletionMessageParam[],
});

if (stream) {
const asyncGenerator = completion as AsyncIterable<ChatCompletionChunk>;
for await (const chunk of asyncGenerator) {
if (chunk.choices[0].delta.content) {
reply += chunk.choices[0].delta.content;
onUpdate?.(reply, chunk.choices[0].delta.content);
}
}
return reply;
}
return (completion as ChatCompletion).choices[0].message.content;
}

async chat(options: ChatOptions): Promise<void> {
// in case the service worker is dead, revive it by firing a fetch event
fetch("/ping.txt");

if (this.isConfigChanged(options.config)) {
this.llmConfig = options.config;
if (this.isDifferentConfig(options.config)) {
this.llmConfig = { ...(this.llmConfig || {}), ...options.config };
try {
await this.initModel(options.onUpdate);
} catch (e) {
Expand Down Expand Up @@ -157,6 +132,63 @@ export class WebLLMApi implements LLMApi {
total: 0,
};
}

isDifferentConfig(config: LLMConfig): boolean {
if (!this.llmConfig) {
return true;
}

// Compare required fields
if (this.llmConfig.model !== config.model) {
return true;
}

// Compare optional fields
const optionalFields: (keyof LLMConfig)[] = [
"temperature",
"top_p",
"stream",
"presence_penalty",
"frequency_penalty",
];

for (const field of optionalFields) {
if (
this.llmConfig[field] !== undefined &&
config[field] !== undefined &&
config[field] !== config[field]
) {
return true;
}
}

return false;
}

async chatCompletion(
stream: boolean,
messages: RequestMessage[],
onUpdate?: (message: string, chunk: string) => void,
) {
let reply: string | null = "";

const completion = await this.engine!.chatCompletion({
stream: stream,
messages: messages as ChatCompletionMessageParam[],
});

if (stream) {
const asyncGenerator = completion as AsyncIterable<ChatCompletionChunk>;
for await (const chunk of asyncGenerator) {
if (chunk.choices[0].delta.content) {
reply += chunk.choices[0].delta.content;
onUpdate?.(reply, chunk.choices[0].delta.content);
}
}
return reply;
}
return (completion as ChatCompletion).choices[0].message.content;
}
}

export const WebLLMContext = createContext<WebLLMApi | null>(null);

0 comments on commit 8dc7536

Please sign in to comment.