diff --git a/core/config/backend_config.go b/core/config/backend_config.go index a49792330b9f..eda663603055 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -27,9 +27,11 @@ type BackendConfig struct { Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` - PromptStrings, InputStrings []string `yaml:"-"` - InputToken [][]int `yaml:"-"` - functionCallString, functionCallNameString string `yaml:"-"` + PromptStrings, InputStrings []string `yaml:"-"` + InputToken [][]int `yaml:"-"` + functionCallString, functionCallNameString string `yaml:"-"` + ResponseFormat string `yaml:"-"` + ResponseFormatMap map[string]interface{} `yaml:"-"` FunctionsConfig functions.FunctionsConfig `yaml:"function"` diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b2e7aa755aff..6b4899a51669 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -183,8 +183,13 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup noActionDescription = config.FunctionsConfig.NoActionDescriptionName } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = functions.JSONBNF + if config.ResponseFormatMap != nil { + d := schema.ChatCompletionResponseFormat{} + dat, _ := json.Marshal(config.ResponseFormatMap) + _ = json.Unmarshal(dat, &d) + if d.Type == "json_object" { + input.Grammar = functions.JSONBNF + } } config.Grammar = input.Grammar diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index bcd46db55c8a..9554a2dc10e9 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -69,8 +69,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a return fmt.Errorf("failed reading parameters from request:%w", err) } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = functions.JSONBNF + if config.ResponseFormatMap != nil { + d := schema.ChatCompletionResponseFormat{} + dat, _ := json.Marshal(config.ResponseFormatMap) + _ = json.Unmarshal(dat, &d) + if d.Type == "json_object" { + input.Grammar = functions.JSONBNF + } } config.Grammar = input.Grammar diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 9e806b3e51a4..9de513a42a6c 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -149,10 +149,8 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon return fmt.Errorf("invalid value for 'size'") } - b64JSON := false - if input.ResponseFormat.Type == "b64_json" { - b64JSON = true - } + b64JSON := config.ResponseFormat == "b64_json" + // src and clip_skip var result []schema.Item for _, i := range config.PromptStrings { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index d25e05b56f3c..941a66e36f80 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -129,6 +129,15 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Maxtokens = input.Maxtokens } + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]interface{}: + config.ResponseFormatMap = responseFormat + } + } + switch stop := input.Stop.(type) { case string: if stop != "" { diff --git a/core/schema/openai.go b/core/schema/openai.go index 177dc7ecc580..ec8c2c3bdb23 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -99,6 +99,8 @@ type OpenAIModel struct { Object string `json:"object"` } +type ImageGenerationResponseFormat string + type ChatCompletionResponseFormatType string type ChatCompletionResponseFormat struct { @@ -114,7 +116,7 @@ type OpenAIRequest struct { // whisper File string `json:"file" validate:"required"` //whisper/image - ResponseFormat ChatCompletionResponseFormat `json:"response_format"` + ResponseFormat interface{} `json:"response_format,omitempty"` // image Size string `json:"size"` // Prompt is read only by completion/image API calls diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index 8fcd1280df15..670b34652826 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -123,13 +123,36 @@ var _ = Describe("E2E test", func() { openai.ImageRequest{ Prompt: "test", Size: openai.CreateImageSize512x512, - //ResponseFormat: openai.CreateImageResponseFormatURL, }, ) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) }) + It("correctly changes the response format to url", func() { + resp, err := client.CreateImage(context.TODO(), + openai.ImageRequest{ + Prompt: "test", + Size: openai.CreateImageSize512x512, + ResponseFormat: openai.CreateImageResponseFormatURL, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) + Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) + }) + It("correctly changes the response format to base64", func() { + resp, err := client.CreateImage(context.TODO(), + openai.ImageRequest{ + Prompt: "test", + Size: openai.CreateImageSize512x512, + ResponseFormat: openai.CreateImageResponseFormatB64JSON, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) + Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON)) + }) }) Context("embeddings", func() { It("correctly", func() {