Skip to content

Commit

Permalink
fix: ai.prompt API now also allows the model parameter to be a string…
Browse files Browse the repository at this point in the history
… with simply the model's name

as originally intended
  • Loading branch information
chhoumann committed Jul 1, 2024
1 parent 08f2693 commit d946c49
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/quickAddApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export class QuickAddApi {
ai: {
prompt: async (
prompt: string,
model: Model,
model: Model | string,
settings?: Partial<{
variableName: string;
shouldAssignVariables: boolean;
Expand All @@ -131,17 +131,29 @@ export class QuickAddApi {
choiceExecutor
).format;

const modelProvider = getModelProvider(model.name);
let _model: Model;
if (typeof model === "string") {
const foundModel = getModelByName(model);
if (!foundModel) {
throw new Error(`Model '${model}' not found.`);
}

_model = foundModel;
} else {
_model = model;
}

const modelProvider = getModelProvider(_model.name);

if (!modelProvider) {
throw new Error(
`Model '${model.name}' not found in any provider`
`Model '${_model.name}' not found in any provider`
);
}

const assistantRes = await Prompt(
{
model,
model: _model,
prompt,
apiKey: modelProvider.apiKey,
modelOptions: settings?.modelOptions ?? {},
Expand Down Expand Up @@ -173,7 +185,7 @@ export class QuickAddApi {
chunkedPrompt: async (
text: string,
promptTemplate: string,
model: string,
model: Model | string,
settings?: Partial<{
variableName: string;
shouldAssignVariables: boolean;
Expand Down Expand Up @@ -201,13 +213,19 @@ export class QuickAddApi {
choiceExecutor
).format;

const _model = getModelByName(model);
let _model: Model;
if (typeof model === "string") {
const foundModel = getModelByName(model);
if (!foundModel) {
throw new Error(`Model ${model} not found.`);
}

if (!_model) {
throw new Error(`Model ${model} not found.`);
_model = foundModel;
} else {
_model = model;
}

const modelProvider = getModelProvider(model);
const modelProvider = getModelProvider(_model.name);

if (!modelProvider) {
throw new Error(
Expand Down

0 comments on commit d946c49

Please sign in to comment.