Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rf: centralize base64 image handling and secscan cleanup #2595

Merged
merged 11 commits into from
Jun 24, 2024
2 changes: 1 addition & 1 deletion backend/go/llm/rwkv/rwkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))

if model == nil {
return fmt.Errorf("could not load model")
return fmt.Errorf("rwkv could not load model")
}
llm.rwkv = model
return nil
Expand Down
2 changes: 1 addition & 1 deletion core/http/endpoints/openai/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
}
}

return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find "))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID))
}
}

Expand Down
12 changes: 7 additions & 5 deletions core/http/endpoints/openai/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -183,7 +182,7 @@ func TestAssistantEndpoints(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tt.expectedStatus, response.StatusCode)
if tt.expectedStatus != fiber.StatusOK {
all, _ := ioutil.ReadAll(response.Body)
all, _ := io.ReadAll(response.Body)
assert.Equal(t, tt.expectedStringResult, string(all))
} else {
var result []Assistant
Expand Down Expand Up @@ -279,6 +278,7 @@ func TestAssistantEndpoints(t *testing.T) {
assert.NoError(t, err)
var getAssistant Assistant
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
assert.NoError(t, err)

t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))

Expand Down Expand Up @@ -391,7 +391,10 @@ func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId s
}

var assistantFile AssistantFile
all, err := ioutil.ReadAll(resp.Body)
all, err := io.ReadAll(resp.Body)
if err != nil {
return AssistantFile{}, resp, err
}
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
if err != nil {
return AssistantFile{}, resp, err
Expand Down Expand Up @@ -422,8 +425,7 @@ func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Resp

var resultAssistant Assistant
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)

return resultAssistant, resp, nil
return resultAssistant, resp, err
}

func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
Expand Down
44 changes: 3 additions & 41 deletions core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@ package openai

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
Expand All @@ -39,41 +36,6 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
return modelFile, input, err
}

// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()

// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)

// return the base64 string
return encoded, nil
}

// if the string instead is prefixed with "data:image/...;base64,", drop it
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
for _, prefix := range dropPrefix {
if strings.HasPrefix(s, prefix) {
return strings.ReplaceAll(s, prefix, ""), nil
}
}

return "", fmt.Errorf("not valid string")
}

func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
Expand Down Expand Up @@ -187,7 +149,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
Expand Down
5 changes: 2 additions & 3 deletions core/http/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ func notFoundHandler(c *fiber.Ctx) error {
// Check if the request accepts JSON
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
})
} else {
// The client expects an HTML response
c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{})
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{})
}
return nil
}

func renderEngine() *fiberhtml.Engine {
Expand Down
5 changes: 4 additions & 1 deletion core/startup/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode

if options.LibPath != "" {
// If there is a lib directory, set LD_LIBRARY_PATH to include it
library.LoadExternal(options.LibPath)
err := library.LoadExternal(options.LibPath)
if err != nil {
log.Error().Err(err).Str("LibPath", options.LibPath).Msg("Error while loading external libraries")
}
}

// turn off any process that was started by GRPC if the context is canceled
Expand Down
19 changes: 13 additions & 6 deletions pkg/library/dynaload.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package library

import (
"errors"
"fmt"
"os"
"path/filepath"
Expand All @@ -17,14 +18,17 @@ import (
var skipLibraryPath = os.Getenv("LOCALAI_SKIP_LIBRARY_PATH") != ""

// LoadExtractedLibs loads the extracted libraries from the asset dir
func LoadExtractedLibs(dir string) {
func LoadExtractedLibs(dir string) error {
// Skip this if LOCALAI_SKIP_LIBRARY_PATH is set
if skipLibraryPath {
return
return nil
}

var err error = nil
for _, libDir := range []string{filepath.Join(dir, "backend-assets", "lib"), filepath.Join(dir, "lib")} {
LoadExternal(libDir)
err = errors.Join(err, LoadExternal(libDir))
}
return err
}

// LoadLDSO checks if there is a ld.so in the asset dir and if so, prefixes the grpc process with it.
Expand Down Expand Up @@ -57,23 +61,26 @@ func LoadLDSO(assetDir string, args []string, grpcProcess string) ([]string, str
}

// LoadExternal sets the LD_LIBRARY_PATH to include the given directory
func LoadExternal(dir string) {
func LoadExternal(dir string) error {
// Skip this if LOCALAI_SKIP_LIBRARY_PATH is set
if skipLibraryPath {
return
return nil
}

lpathVar := "LD_LIBRARY_PATH"
if runtime.GOOS == "darwin" {
lpathVar = "DYLD_FALLBACK_LIBRARY_PATH" // should it be DYLD_LIBRARY_PATH ?
}

var setErr error = nil
if _, err := os.Stat(dir); err == nil {
ldLibraryPath := os.Getenv(lpathVar)
if ldLibraryPath == "" {
ldLibraryPath = dir
} else {
ldLibraryPath = fmt.Sprintf("%s:%s", ldLibraryPath, dir)
}
os.Setenv(lpathVar, ldLibraryPath)
setErr = errors.Join(setErr, os.Setenv(lpathVar, ldLibraryPath))
}
return setErr
}
21 changes: 10 additions & 11 deletions pkg/model/loader_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package model_test

import (
"github.com/go-skynet/LocalAI/pkg/model"
. "github.com/go-skynet/LocalAI/pkg/model"

. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -44,7 +43,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
"user": {
"template": llama3,
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Expand All @@ -59,7 +58,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
"assistant": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Expand All @@ -74,7 +73,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
"function_call": {
"template": llama3,
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Expand All @@ -89,7 +88,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
"function_response": {
"template": llama3,
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Expand All @@ -107,7 +106,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
"user": {
"template": chatML,
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "user",
RoleName: "user",
Expand All @@ -122,7 +121,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
"assistant": {
"template": chatML,
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Expand All @@ -137,7 +136,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
"function_call": {
"template": chatML,
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "assistant",
RoleName: "assistant",
Expand All @@ -152,7 +151,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
"function_response": {
"template": chatML,
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
"data": model.ChatMessageTemplateData{
"data": ChatMessageTemplateData{
SystemPrompt: "",
Role: "tool",
RoleName: "tool",
Expand All @@ -175,7 +174,7 @@ var _ = Describe("Templates", func() {
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData))
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
Expand All @@ -189,7 +188,7 @@ var _ = Describe("Templates", func() {
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData))
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
Expect(err).ToNot(HaveOccurred())
Expect(templated).To(Equal(foo["expected"]), templated)
})
Expand Down
5 changes: 4 additions & 1 deletion pkg/model/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
grpcControlProcess.Stop()
err := grpcControlProcess.Stop()
if err != nil {
log.Error().Err(err).Msg("error while shutting down grpc process")
}
}()

go func() {
Expand Down
11 changes: 7 additions & 4 deletions pkg/utils/base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ func GetImageURLAsBase64(s string) (string, error) {
return encoded, nil
}

// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
// if the string instead is prefixed with "data:image/...;base64,", drop it
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
for _, prefix := range dropPrefix {
if strings.HasPrefix(s, prefix) {
return strings.ReplaceAll(s, prefix, ""), nil
}
}
return "", fmt.Errorf("not valid string")
}
}
9 changes: 8 additions & 1 deletion pkg/utils/base64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ import (
)

var _ = Describe("utils/base64 tests", func() {
It("GetImageURLAsBase64 can strip data url prefixes", func() {
It("GetImageURLAsBase64 can strip jpeg data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := ""
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("FOO"))
})
It("GetImageURLAsBase64 can strip png data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := ""
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("BAR"))
})
It("GetImageURLAsBase64 returns an error for bogus data", func() {
input := "FOO"
b64, err := GetImageURLAsBase64(input)
Expand Down
23 changes: 0 additions & 23 deletions tests/integration/reflect_test.go

This file was deleted.

Loading