diff --git a/.github/workflows/openrouter-exp.yml b/.github/workflows/openrouter-exp.yml index 715cb15..729790a 100644 --- a/.github/workflows/openrouter-exp.yml +++ b/.github/workflows/openrouter-exp.yml @@ -157,6 +157,7 @@ jobs: env: LLMORPHEUS_LLM_API_ENDPOINT: '${{ secrets.OPENROUTER_LLM_API_ENDPOINT }}' LLMORPHEUS_LLM_AUTH_HEADERS: '${{ secrets.OPENROUTER_LLM_AUTH_HEADERS }}' + LLMORPHEUS_LLM_PROVIDER: '${{ secrets.LLMORPHEUS_LLM_PROVIDER }}' run: | cd ${{ matrix.package.name }} BENCHMARK_DIR=`pwd` diff --git a/src/model/IModel.ts b/src/model/IModel.ts index ff66124..f4b8ca8 100644 --- a/src/model/IModel.ts +++ b/src/model/IModel.ts @@ -20,14 +20,16 @@ export const defaultPostOptions = { top_p: 1, // no need to change this }; -export const defaultOpenAIPostoptions = { - ...defaultPostOptions, - n: 5, - stop: ["\n\n"], // list of tokens to stop at -}; +export interface PostOptionsType { + max_tokens: number; + temperature: number; + top_p: number; + provider: { + order: string[]; + }; +} -export type PostOptions = Partial; -export type OpenAIPostOptions = Partial; +export type PostOptions = Partial; export interface IModelFailureCounter { nrRetries: number; diff --git a/src/model/Model.ts b/src/model/Model.ts index 27d46fb..8030ed1 100644 --- a/src/model/Model.ts +++ b/src/model/Model.ts @@ -7,7 +7,7 @@ import { RateLimiter, } from "../util/promise-utils"; import { retry } from "../util/promise-utils"; -import { IModel, IModelFailureCounter } from "./IModel"; +import { IModel, IModelFailureCounter, PostOptionsType } from "./IModel"; import { PostOptions, defaultPostOptions } from "./IModel"; import { getEnv } from "../util/code-utils"; import { IQueryResult } from "./IQueryResult"; @@ -24,6 +24,10 @@ export class Model implements IModel { getEnv("LLMORPHEUS_LLM_AUTH_HEADERS") ); + protected static LLMORPHEUS_LLM_PROVIDER = JSON.parse( + getEnv("LLMORPHEUS_LLM_PROVIDER") + ); + protected instanceOptions: PostOptions; protected rateLimiter: RateLimiter; protected counter: IModelFailureCounter = { nrRetries: 0, nrFailures: 0 }; @@ -96,7 +100,7 @@ export class Model implements IModel { `templates/${this.metaInfo.systemPrompt}`, "utf8" ); - const body = { + let body = { model: this.getModelName(), messages: [ { role: "system", content: systemPrompt }, @@ -104,6 +108,13 @@ export class Model implements IModel { ], ...options, }; + if (Model.LLMORPHEUS_LLM_PROVIDER) { + const provider = Model.LLMORPHEUS_LLM_PROVIDER; + body = { + ...body, + provider: provider, + }; + } performance.mark("llm-query-start"); let res;