Skip to content

Commit

Permalink
Add ollama support (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
sceuick authored Jul 17, 2024
1 parent 14c4de4 commit 5346218
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 36 deletions.
2 changes: 2 additions & 0 deletions common/adapters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export const THIRDPARTY_HANDLERS: { [svc in ThirdPartyFormat]: AIAdapter } = {
ooba: 'ooba',
tabby: 'kobold',
mistral: 'kobold',
ollama: 'kobold',
}

export const THIRDPARTY_FORMATS = [
Expand All @@ -75,6 +76,7 @@ export const THIRDPARTY_FORMATS = [
'koboldcpp',
'tabby',
'mistral',
'ollama',
] as const

export const AI_ADAPTERS = [
Expand Down
2 changes: 1 addition & 1 deletion common/prompt-order.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export const formatHolders: Record<string, Record<string, string>> = {
{{#each msg}}<|im_start|>{{#if .isbot}}assistant{{/if}}{{#if .isuser}}user{{/if}}
{{.name}}: {{.msg}}<|im_end|>{{/each}}`,
post: neat`POST<|im_start|>assistant
post: neat`<|im_start|>assistant
{{post}}`,
system_prompt: neat`{{#if system_prompt}}<|im_start|>system
{{system_prompt}}<|im_end|>{{/if}}`,
Expand Down
4 changes: 2 additions & 2 deletions common/template-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,10 @@ function getPlaceholder(node: PlaceHolder | ConditionNode, opts: TemplateOpts) {

switch (node.value) {
case 'char':
return (opts.replyAs || opts.char).name || ''
return ((opts.replyAs || opts.char).name || '').trim()

case 'user':
return opts.impersonate?.name || opts.sender?.handle || 'You'
return (opts.impersonate?.name || opts.sender?.handle || 'You').trim()

case 'example_dialogue':
return opts.parts?.sampleChat?.join('\n') || ''
Expand Down
2 changes: 1 addition & 1 deletion common/types/ui.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export type UISettings = {
mobileSendOnEnter: boolean
msgOptsInline: { [key in MessageOption]: { outer: boolean; pos: number } }

viewMode?: 'split' | 'standard' | 'background'
viewMode?: 'split' | 'standard' | 'background' | 'background-contain' | 'background-cover'
viewHeight?: number

chatWidth?: ChatWidth
Expand Down
55 changes: 42 additions & 13 deletions srv/adapter/kobold.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const handleKobold: ModelAdapter = async function* (opts) {
const { members, characters, prompt, mappedSettings } = opts

const body =
opts.gen.thirdPartyFormat === 'ollama' ||
opts.gen.thirdPartyFormat === 'ooba' ||
opts.gen.thirdPartyFormat === 'mistral' ||
opts.gen.thirdPartyFormat === 'tabby' ||
Expand Down Expand Up @@ -124,9 +125,7 @@ async function dispatch(opts: AdapterProps, body: any) {
const baseURL = normalizeUrl(opts.gen.thirdPartyUrl || opts.user.koboldUrl)

const headers: any = await getHeaders(opts)
if (opts.gen.thirdPartyFormat === 'aphrodite') {
await validateModel(opts, baseURL, body, headers)
}
await validateModel(opts, baseURL, body, headers)

switch (opts.gen.thirdPartyFormat) {
case 'llamacpp':
Expand All @@ -153,6 +152,13 @@ async function dispatch(opts: AdapterProps, body: any) {
return stream
}

case 'ollama': {
const url = `${baseURL}/api/generate`
return opts.gen.streamResponse
? streamCompletion(url, body, headers, opts.gen.thirdPartyFormat, opts.log)
: fullCompletion(url, body, headers, opts.gen.thirdPartyFormat, opts.log)
}

default:
const isStreamSupported = await checkStreamSupported(`${baseURL}/api/extra/version`)
return opts.gen.streamResponse && isStreamSupported
Expand Down Expand Up @@ -341,7 +347,7 @@ const streamCompletion = async function* (
}

tokens.push(token)
yield { token: token }
yield { token }
}
} catch (err: any) {
yield { error: `${format} streaming request failed: ${err.message || err}` }
Expand All @@ -358,19 +364,42 @@ const streamCompletion = async function* (
}

async function validateModel(opts: AdapterProps, baseURL: string, payload: any, headers: any) {
if (opts.gen.thirdPartyFormat !== 'aphrodite') return
if (opts.gen.thirdPartyFormat === 'aphrodite') {
const res = await needle('get', `${baseURL}/v1/models`, { headers, json: true })

const res = await needle('get', `${baseURL}/v1/models`, { headers, json: true })
const code = res.statusCode ?? 400
if (code >= 400) {
return
}

const code = res.statusCode ?? 400
if (code >= 400) {
return
if (!Array.isArray(res.body.data)) return
const names = res.body.data.map((data: any) => data.id) as string[]

if (!payload.model || !names.includes(payload.model)) {
payload.model = names[0]
}
}

if (!Array.isArray(res.body.data)) return
const names = res.body.data.map((data: any) => data.id) as string[]
if (opts.gen.thirdPartyFormat === 'ollama') {
const res = await needle('get', `${baseURL}/api/tags`, { headers, json: true })
const code = res.statusCode ?? 400
if (code >= 400) {
return
}

if (!Array.isArray(res.body.models)) return
const models = res.body.models as Array<{ name: string; model: string }>
if (!models.length) return

if (!payload.model || !names.includes(payload.model)) {
payload.model = names[0]
if (!payload.model) {
payload.model = models[0].name
return
}

const match = models.find((m) => m.name === payload.model)
if (!match) {
payload.model = models[0].name
return
}
}
}
46 changes: 46 additions & 0 deletions srv/adapter/payloads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,52 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {
return body
}

if (format === 'ollama') {
const payload: any = {
prompt,
model: gen.thirdPartyModel,
stream: !!gen.streamResponse,
system: '',
format: opts.jsonSchema ? 'json' : undefined,

options: {
seed: Math.trunc(Math.random() * 1_000_000_000),
num_predict: gen.maxTokens,
top_k: gen.topK,
top_p: gen.topP,
tfs_z: gen.tailFreeSampling,
typical_p: gen.typicalP,
repeat_last_n: gen.repetitionPenaltyRange,
temperature: gen.temp,
repeat_penalty: gen.repetitionPenalty,
presence_penalty: gen.presencePenalty,
frequency_penalty: gen.frequencyPenalty,
mirostat: gen.mirostatToggle && gen.mirostatTau ? 2 : 0,
mirostat_tau: gen.mirostatTau,
mirostat_eta: gen.mirostatLR,
stop: getStoppingStrings(opts, stops),

// ignore_eos: false,
// min_p: gen.min_p,
dynatemp_range: gen.dynatemp_range,
dynatemp_exponent: gen.dynatemp_exponent,
},
}

if (opts.jsonSchema) {
const schema = JSON.stringify(opts.jsonSchema, null, 2)
payload.prompt += `\nRespond using the following JSON Schema:\n${schema}`
}

if (opts.imageData) {
const comma = opts.imageData.indexOf(',')
const base64 = opts.imageData.slice(comma + 1)
payload.images = [base64]
}

return payload
}

if (format === 'mistral') {
const body = {
messages: [{ role: 'user', content: prompt }],
Expand Down
38 changes: 37 additions & 1 deletion srv/adapter/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ export function requestStream(stream: NodeJS.ReadableStream, format?: ThirdParty
if (statusCode > 201) {
emitter.push({ error: `SSE request failed with status code ${statusCode}` })
emitter.done()
} else if (format === 'ollama') {
if (contentType.startsWith('application/x-ndjson')) return

emitter.push({
error: `SSE request received unexpected content-type ${headers['content-type']}`,
})
emitter.done()
} else if (!contentType.startsWith('text/event-stream')) {
emitter.push({
error: `SSE request received unexpected content-type ${headers['content-type']}`,
Expand All @@ -54,9 +61,20 @@ export function requestStream(stream: NodeJS.ReadableStream, format?: ThirdParty
stream.on('data', (chunk: Buffer) => {
const data = incomplete + chunk.toString()
incomplete = ''
const messages = data.split(/\r?\n\r?\n/)

const messages = data.split(/\r?\n\r?\n/).filter((l) => !!l)

for (const msg of messages) {
if (format === 'ollama') {
const event = parseOllama(incomplete + msg, emitter)
const token = event?.response
if (!token) continue

const data = JSON.stringify({ token })
emitter.push({ data })
continue
}

if (format === 'aphrodite') {
const event = parseAphrodite(incomplete + msg, emitter)
if (!event?.data) {
Expand Down Expand Up @@ -124,6 +142,24 @@ function parseAphrodite(msg: string, emitter: EventGenerator<ServerSentEvent>) {
return event
}

function parseOllama(msg: string, emitter: EventGenerator<ServerSentEvent>) {
const event: any = {}
const data = tryParse(msg)
if (!data) return event

Object.assign(event, data)
return event
}

function tryParse(value: any) {
try {
const obj = JSON.parse(value)
return obj
} catch (ex) {
return {}
}
}

function parseEvent(msg: string) {
const event: any = {}
for (const line of msg.split(/\r?\n/)) {
Expand Down
3 changes: 2 additions & 1 deletion web/pages/Chat/components/InputBar.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ImagePlus, Megaphone, MoreHorizontal, PlusCircle, Send, Zap } from 'lucide-solid'
import { ImagePlus, ImageUp, Megaphone, MoreHorizontal, PlusCircle, Send, Zap } from 'lucide-solid'
import {
Component,
createMemo,
Expand Down Expand Up @@ -386,6 +386,7 @@ const InputBar: Component<{
accept="image/jpg,image/png,image/jpeg"
/>
<LabelButton for="imageCaption" schema="secondary" class="w-full" alignLeft>
<ImageUp size={18} />
Attach Image
</LabelButton>
</Show>
Expand Down
4 changes: 3 additions & 1 deletion web/pages/Settings/UISettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ const UISettings: Component = () => {
items={[
{ label: 'Standard', value: 'standard' },
{ label: 'Split', value: 'split' },
{ label: 'Background', value: 'background' },
{ label: 'Background: Auto', value: 'background' },
{ label: 'Background: Cover', value: 'background-cover' },
{ label: 'Background: Contain', value: 'background-contain' },
]}
value={state.ui.viewMode || 'standard'}
onChange={(next) => userStore.saveUI({ viewMode: next.value as any })}
Expand Down
1 change: 1 addition & 0 deletions web/shared/GenerationSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ const GenerationSettings: Component<Props & { onSave: () => void }> = (props) =>
{ label: 'Claude', value: 'claude' },
{ label: 'Textgen (Ooba)', value: 'ooba' },
{ label: 'Llama.cpp', value: 'llamacpp' },
{ label: 'Ollama', value: 'ollama' },
{ label: 'Aphrodite', value: 'aphrodite' },
{ label: 'ExLlamaV2', value: 'exllamav2' },
{ label: 'KoboldCpp', value: 'koboldcpp' },
Expand Down
47 changes: 38 additions & 9 deletions web/shared/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type ImageCacheOpts = {

export function useCharacterBg(src: 'layout' | 'page') {
const location = useLocation()
const isMobile = useMobileDetect()

const isChat = createMemo(() => {
return location.pathname.startsWith('/chat/') || location.pathname.startsWith('/saga/')
Expand All @@ -63,27 +64,37 @@ export function useCharacterBg(src: 'layout' | 'page') {
if (cfg.anonymize) return {}
if (src === 'layout' && isChat()) return {}

const mobile = isMobile()

const base: JSX.CSSProperties = {
'background-repeat': 'no-repeat',
'background-position': 'center',
'background-color': isChat() ? undefined : '',
}

const isBg = state.ui.viewMode?.startsWith('background')
const char = chars.chars.map[chat.active?.char?._id!]
if (
!isChat() ||
state.ui.viewMode !== 'background' ||
!char ||
char.visualType === 'sprite' ||
!char.avatar
) {
return { ...base, 'background-image': `url(${state.background})`, 'background-size': 'cover' }
if (!isChat() || !isBg || !char || char.visualType === 'sprite' || !char.avatar) {
console.log('bg', mobile)
return {
...base,
'background-image': `url(${state.background})`,
'background-size': 'cover',
}
}

const size =
state.ui.viewMode === 'background-contain'
? 'contain'
: state.ui.viewMode === 'background-cover'
? 'cover'
: mobile
? 'contain'
: 'auto'
return {
...base,
'background-image': `url(${getAssetUrl(char.avatar)})`,
'background-size': 'auto',
'background-size': size,
}
})

Expand Down Expand Up @@ -365,6 +376,24 @@ export function useResizeObserver() {
return { size, load, loaded, platform }
}

export function useMobileDetect() {
const [mobile, setMobile] = createSignal(/iPhone|iPad|iPod|Android/i.test(navigator.userAgent))

useEffect(() => {
const timer = setInterval(() => {
const prev = mobile()
const next = /iPhone|iPad|iPod|Android/i.test(navigator.userAgent)

if (prev === next) return
setMobile(/iPhone|iPad|iPod|Android/i.test(navigator.userAgent))
}, 2000)

return () => clearInterval(timer)
})

return mobile
}

export function getWidthPlatform(width: number) {
return width > 1024 ? 'xl' : width > 720 ? 'lg' : 'sm'
}
Expand Down
Loading

0 comments on commit 5346218

Please sign in to comment.