Skip to content

Commit

Permalink
feat: {{#lowpriority}} placeholder (#789)
Browse files Browse the repository at this point in the history
Co-authored-by: stevenksmith <[email protected]>
  • Loading branch information
sceuick and stevenksmith authored Jan 2, 2024
1 parent 9f2bff9 commit 66ae9f7
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 92 deletions.
25 changes: 18 additions & 7 deletions common/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Expression = content:Parent* {
return results
}
Parent "parent-node" = v:(BotIterator / ChatEmbedIterator / HistoryIterator / HistoryInsert / Condition / Placeholder / Text) { return v }
Parent "parent-node" = v:(BotIterator / ChatEmbedIterator / HistoryIterator / HistoryInsert / LowPriority / Condition / Placeholder / Text) { return v }
ManyPlaceholder "repeatable-placeholder" = OP i:(Character / User / Random / Roll) CL {
return { kind: 'placeholder', value: i }
Expand All @@ -40,6 +40,8 @@ HistoryInsert "history-insert" = OP "#insert"i WS "="? WS line:[0-9]|1..2| CL ch
ChatEmbedIterator "chat-embed-iterator" = OP "#each" WS loop:ChatEmbed CL children:(ChatEmbedChild / LoopText)* CloseLoop { return { kind: 'each', value: loop, children } }
ChatEmbedChild = i:(ChatEmbedRef / ManyPlaceholder) { return i }
LowPriority "lowpriority" = OP "#lowpriority"i CL children:(Placeholder / LowPriorityText)* CloseLowPriority { return { kind: 'lowpriority', children } }
Placeholder "placeholder"
= OP WS interp:Interp WS pipes:Pipe* CL {
Expand All @@ -58,24 +60,33 @@ HistoryCondition "history-condition" = OP "#if" WS prop:HistoryProperty CL sub:(
return { kind: 'history-if', prop, children: sub.flat() }
}
ConditionChild = Placeholder / Condition
ConditionChild = Placeholder / Condition / LowPriority
Condition "if" = OP "#if" WS value:Word CL sub:(ConditionChild / ConditionText)* CloseCondition {
return { kind: 'if', value, children: sub.flat() }
}
InsertText "insert-text" = !(BotChild / HistoryChild / CloseCondition / CloseInsert) ch:(.) { return ch }
LowPriorityText "lowpriority-text" = !(BotChild / HistoryChild / CloseCondition / CloseLowPriority) ch:(.) { return ch }
LoopText "loop-text" = !(BotChild / ChatEmbedChild / HistoryChild / CloseCondition / CloseLoop) ch:(.) { return ch }
ConditionText = !(ConditionChild / CloseCondition) ch:. { return ch }
Text "text" = !(Placeholder / Condition / BotIterator / HistoryIterator / ChatEmbedIterator) ch:. { return ch }
CSV "csv" = words:WordList* WS last:Word { return [...words, last] }
WordList = word:Word WS "," WS { return word }
DelimitedWords "csv" = head:Phrase tail:("," WS p:Phrase { return p })* { return [head, ...tail] }
Symbol = ch:("'" / "_" / "-" / "?" / "!" / "#" / "@" / "$" / "^" / "&" / "*" / "(" / ")" / "=" / "+" / "%" / "~" / ":" / ";" / "<" / ">" / "." / "/" / "|" / "\`" / "[" / "]") {
return ch
}
Phrase = text:(QuotedPhrase / CommalessPhrase) { return text }
QuotedPhrase = '"' text:(BasicChar / Symbol / "," / " ")+ '"' { return text.join('') }
CommalessPhrase = text:(BasicChar / Symbol / '"' / " ")+ { return text.join('') }
CloseCondition = OP "/if"i CL
CloseLoop = OP "/each"i CL
CloseInsert = OP "/insert"i CL
Word "word" = text:[a-zA-Z_ 0-9\!\?\.\'\#\@\%\"\&\*\=\+\-]+ { return text.join('') }
CloseLowPriority = OP "/lowpriority"i CL
BasicChar = [a-zA-Z0-9]
Word "word" = text:([a-zA-Z_ 0-9] / Symbol)+ { return text.join('') }
Pipe "pipe" = _ "|" _ fn:Handler { return fn }
Expand Down Expand Up @@ -112,7 +123,7 @@ Message "message" = ("msg"i / "message"i / "text"i) { return "message" }
ChatAge "chat-age" = "chat_age"i { return "chat_age" }
IdleDuration "idle-duration" = "idle_duration"i { return "idle_duration" }
UserEmbed "user-embed" = "user_embed"i { return "user_embed" }
Random "random" = "random"i ":"? WS words:CSV { return { kind: "random", values: words } }
Random "random" = ("random"i) ":"? WS words:DelimitedWords { return { kind: "random", values: words } }
Roll "roll" = ("roll"i / "dice"i) ":"? WS "d"|0..1| max:[0-9]|0..10| { return { kind: 'roll', values: +max.join('') || 20 } }
// Iterable entities
Expand All @@ -123,7 +134,7 @@ History "history" = ( "history"i / "messages"i / "msgs"i / "msg"i) { return "his
Interp "interp"
= Character
/ UserEmbed
/ User
/ User
/ Scenario
/ Persona
/ Impersonate
Expand Down
35 changes: 18 additions & 17 deletions common/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -430,23 +430,17 @@ export async function buildPromptParts(

if (opts.userEmbeds) {
const embeds = opts.userEmbeds.map((line) => line.text)
const fit = await fillPromptWithLines(
encoder,
opts.settings?.memoryUserEmbedLimit || 500,
'',
embeds
)
const fit = (
await fillPromptWithLines(encoder, opts.settings?.memoryUserEmbedLimit || 500, '', embeds)
).adding
parts.userEmbeds = fit
}

if (opts.chatEmbeds) {
const embeds = opts.chatEmbeds.map((line) => `${line.name}: ${line.text}`)
const fit = await fillPromptWithLines(
encoder,
opts.settings?.memoryChatEmbedLimit || 500,
'',
embeds
)
const fit = (
await fillPromptWithLines(encoder, opts.settings?.memoryChatEmbedLimit || 500, '', embeds)
).adding
parts.chatEmbeds = fit
}

Expand Down Expand Up @@ -560,7 +554,7 @@ export async function getLinesForPrompt(

const history = messages.slice().sort(sortMessagesDesc).map(formatMsg)

const lines = await fillPromptWithLines(encoder, maxContext, '', history)
const lines = (await fillPromptWithLines(encoder, maxContext, '', history)).adding

if (opts.trimSentences) {
return lines.map(trimSentence)
Expand Down Expand Up @@ -588,11 +582,16 @@ export async function fillPromptWithLines(
tokenLimit: number,
amble: string,
lines: string[],
inserts: Map<number, string> = new Map()
) {
inserts: Map<number, string> = new Map(),
lowpriority: { idToReplace: string; content: string }[] = []
): Promise<{ adding: string[]; unusedTokens: number }> {
const insertsCost = await encoder([...inserts.values()].join(' '))
const tokenLimitMinusInserts = tokenLimit - insertsCost
let count = await encoder(amble)
const ambleWithoutLowPriorityPlaceholders = lowpriority.reduce(
(amble, { idToReplace }) => amble.replace(idToReplace, ''),
amble
)
let count = await encoder(ambleWithoutLowPriorityPlaceholders)
const adding: string[] = []

let linesAddedCount = 0
Expand All @@ -615,7 +614,9 @@ export async function fillPromptWithLines(
adding.push(formatInsert(remainingInserts))
}

return adding
const unusedTokens = tokenLimitMinusInserts - count

return { adding, unusedTokens }
}

export function insertsDeeperThanConvoHistory(
Expand Down
83 changes: 68 additions & 15 deletions common/template-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,29 @@ import peggy from 'peggy'
import { elapsedSince } from './util'
import { v4 } from 'uuid'

const parser = peggy.generate(grammar.trim(), {
error: (stage, msg, loc) => {
console.error({ loc, stage }, msg)
},
})
const parser = loadParser()

type PNode = PlaceHolder | ConditionNode | IteratorNode | InsertNode | string
function loadParser() {
try {
const parser = peggy.generate(grammar.trim(), {
error: (stage, msg, loc) => {
console.error({ loc, stage }, msg)
},
})
return parser
} catch (ex) {
console.error(ex)
throw ex
}
}

type PNode = PlaceHolder | ConditionNode | IteratorNode | InsertNode | LowPriorityNode | string

type PlaceHolder = { kind: 'placeholder'; value: Holder; values?: any; pipes?: string[] }
type ConditionNode = { kind: 'if'; value: Holder; values?: any; children: PNode[] }
type IteratorNode = { kind: 'each'; value: IterableHolder; children: CNode[] }
type InsertNode = { kind: 'history-insert'; values: number; children: PNode[] }
type LowPriorityNode = { kind: 'lowpriority'; children: PNode[] }

type CNode =
| Exclude<PNode, { kind: 'each' }>
Expand Down Expand Up @@ -96,6 +107,7 @@ export type TemplateOpts = {
*/
repeatable?: boolean
inserts?: Map<number, string>
lowpriority?: { idToReplace: string; content: string }[]
}

/**
Expand Down Expand Up @@ -123,6 +135,7 @@ export async function parseTemplate(
const ast = parser.parse(template, {}) as PNode[]
readInserts(template, opts, ast)
let output = render(template, opts, ast)
let unusedTokens = 0

if (opts.limit && opts.limit.output) {
// const lastIndex = Object.keys(opts.limit.output).reduce((prev, curr) => {
Expand All @@ -135,17 +148,32 @@ export async function parseTemplate(
// }

for (const [id, lines] of Object.entries(opts.limit.output)) {
const trimmed = (
await fillPromptWithLines(
opts.limit.encoder,
opts.limit.context,
output,
lines,
opts.inserts
)
).reverse()
const filled = await fillPromptWithLines(
opts.limit.encoder,
opts.limit.context,
output,
lines,
opts.inserts,
opts.lowpriority
)
unusedTokens = filled.unusedTokens
const trimmed = filled.adding.reverse()
output = output.replace(id, trimmed.join('\n'))
}

// Adding the low priority blocks if we still have the budget for them,
// now that we inserted the conversation history.
// We start from the bottom (somewhat arbitrary design choice),
// hence the reverse().
for (const { idToReplace, content } of (opts.lowpriority ?? []).reverse()) {
const contentLength = await opts.limit!.encoder(content)
if (contentLength > unusedTokens) {
output = output.replace(idToReplace, '')
} else {
output = output.replace(idToReplace, content)
unusedTokens -= contentLength
}
}
}

const result = render(output, opts).replace(/\r\n/g, '\n').replace(/\n\n+/g, '\n\n').trimStart()
Expand Down Expand Up @@ -251,7 +279,32 @@ function renderNode(node: PNode, opts: TemplateOpts) {

case 'if':
return renderCondition(node, node.children, opts)

case 'lowpriority':
return renderLowPriority(node, opts)
}
}

/**
* This only returns an UUID, but adds the string meant to replace the UUID to the
* opts object. The UUID is only replaced with the actual content (or object) after
* the prompt is built once, because low priority content is NOT added if the
* rest of the prompt takes up the token budget already.
* It's up to the rest of the prompt-building to remove the UUIDs when
* calculating their token budget.
* This somewhat grungy string manipulation but unavoidable with the way prompt
* segments get turned into strings at the same time as their tokens are counted.
*/
function renderLowPriority(node: LowPriorityNode, opts: TemplateOpts) {
const output: string[] = []
for (const child of node.children) {
const result = renderNode(child, opts)
if (result) output.push(result)
}
opts.lowpriority = opts.lowpriority ?? []
const lowpriorityBlockId = '__' + v4() + '__'
opts.lowpriority.push({ idToReplace: lowpriorityBlockId, content: output.join('') })
return lowpriorityBlockId
}

function renderProp(node: CNode, opts: TemplateOpts, entity: unknown, i: number) {
Expand Down
2 changes: 1 addition & 1 deletion srv/adapter/ooba.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export function getThirdPartyPayload(opts: AdapterProps, stops: string[] = []) {
model: gen.thirdPartyModel || '',
stream: opts.kind === 'summary' ? false : gen.streamResponse ?? true,
temperature: gen.temp ?? 0.5,
max_tokens: opts.chat.mode === 'adventure' ? 400 : gen.maxTokens ?? 200,
max_tokens: gen.maxTokens ?? 200,
best_of: gen.swipesPerGeneration,
n: gen.swipesPerGeneration,
prompt,
Expand Down
3 changes: 1 addition & 2 deletions srv/adapter/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ export const handleOAI: ModelAdapter = async function* (opts) {

const oaiModel = gen.thirdPartyModel || gen.oaiModel || defaultPresets.openai.oaiModel

const maxResponseLength =
opts.chat.mode === 'adventure' ? 400 : gen.maxTokens ?? defaultPresets.openai.maxTokens
const maxResponseLength = gen.maxTokens ?? defaultPresets.openai.maxTokens

const body: any = {
model: oaiModel,
Expand Down
46 changes: 1 addition & 45 deletions srv/api/chat/message.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import { UnwrapBody, assertValid } from '/common/valid'
import { store } from '../../db'
import { createTextStreamV2, getResponseEntities, inferenceAsync } from '../../adapter/generate'
import { createTextStreamV2, getResponseEntities } from '../../adapter/generate'
import { AppRequest, StatusError, errors, handle } from '../wrap'
import { sendGuest, sendMany, sendOne } from '../ws'
import { obtainLock, releaseLock } from './lock'
import { AppSchema } from '../../../common/types/schema'
import { v4 } from 'uuid'
import { Response } from 'express'
import { publishMany } from '../ws/handle'
import { GuidanceParams, runGuidance } from '/common/guidance/guidance-parser'
import { cyoaTemplate } from '/common/mode-templates'
import { fillPromptWithLines } from '/common/prompt'
import { getTokenCounter } from '/srv/tokenize'

type GenRequest = UnwrapBody<typeof genValidator>

Expand Down Expand Up @@ -324,46 +320,6 @@ export const generateMessageV2 = handle(async (req, res) => {

const actions: AppSchema.ChatAction[] = []

if (chat.mode === 'adventure') {
const lines = await fillPromptWithLines(
getTokenCounter('main', undefined),
2048,
'',
body.lines.concat(`${body.replyAs.name}: ${responseText}`)
)

const prompt = cyoaTemplate(
body.settings.service,
body.settings.service === 'openai'
? body.settings.thirdPartyModel || body.settings.oaiModel
: ''
)

const infer = async (params: GuidanceParams) => {
const res = await inferenceAsync({
prompt: params.prompt,
maxTokens: params.tokens,
stop: params.stop,
log,
service: metadata.settings.service!,
settings: metadata.settings,
user: metadata.user,
})
return res.generated
}

const { values } = await runGuidance(prompt, {
infer,
placeholders: {
history: lines.join('\n'),
user: body.impersonate?.name || body.sender.handle,
},
})
actions.push({ emote: values.emote1, action: values.action1 })
actions.push({ emote: values.emote2, action: values.action2 })
actions.push({ emote: values.emote3, action: values.action3 })
}

await releaseLock(chatId)

switch (body.kind) {
Expand Down
1 change: 0 additions & 1 deletion web/pages/Chat/ChatSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ const ChatSettings: Component<{
onChange={(ev) => setMode(ev.value as any)}
items={[
{ label: 'Conversation', value: 'standard' },
{ label: 'Adventure (Experimental)', value: 'adventure' },
{ label: 'Companion', value: 'companion' },
]}
value={state.chat?.mode || 'standard'}
Expand Down
1 change: 0 additions & 1 deletion web/pages/Chat/CreateChatForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ const CreateChatForm: Component<{
}
items={[
{ label: 'Conversation', value: 'standard' },
{ label: 'Adventure (Experimental)', value: 'adventure' },
{ label: 'Companion', value: 'companion' },
]}
value={'standard'}
Expand Down
2 changes: 1 addition & 1 deletion web/shared/Card.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const Card: Component<{
ariaLabel?: string
}> = (props) => {
const cardBg = useBgStyle({
hex: props.bg ? getSettingColor(props.bg) : 'bg-700',
hex: props.bg ? props.bg : 'bg-700',
blur: false,
opacity: props.bgOpacity ?? 0.08,
})
Expand Down
Loading

0 comments on commit 66ae9f7

Please sign in to comment.