Skip to content

Commit

Permalink
support base64 image
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Nov 19, 2023
1 parent 6e670c0 commit 57d0fc3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 32 deletions.
64 changes: 64 additions & 0 deletions common/image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package common

import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"github.com/chai2010/webp"
"image"
"io"
"net/http"
"strings"
)

func DecodeBase64ImageData(base64String string) (image.Config, error) {
// 去除base64数据的URL前缀(如果有)
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
}

// 将base64字符串解码为字节切片
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, err
}

// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, err := getImageConfig(reader)
return config, err
}

func DecodeUrlImageData(imageUrl string) (image.Config, error) {
response, err := http.Get(imageUrl)
if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, err
}

// 限制读取的字节数,防止下载整个图片
limitReader := io.LimitReader(response.Body, 8192)
config, err := getImageConfig(limitReader)
response.Body.Close()
return config, err
}

func getImageConfig(reader io.Reader) (image.Config, error) {
// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error())
config, err = webp.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error())
}
}
if err != nil {
return image.Config{}, err
}
return config, nil
}
18 changes: 9 additions & 9 deletions controller/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ import (
func UpdateMidjourneyTask() {
//revocer
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
Expand Down Expand Up @@ -55,7 +55,6 @@ func UpdateMidjourneyTask() {
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
Expand All @@ -68,8 +67,8 @@ func UpdateMidjourneyTask() {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
Expand All @@ -83,12 +82,12 @@ func UpdateMidjourneyTask() {
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
Expand Down Expand Up @@ -138,6 +137,7 @@ func UpdateMidjourneyTask() {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}
Expand Down
29 changes: 10 additions & 19 deletions controller/relay-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/chai2010/webp"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"image"
Expand Down Expand Up @@ -75,29 +74,21 @@ func getImageToken(imageUrl MessageImageUrl) (int, error) {
if imageUrl.Detail == "low" {
return 85, nil
}

response, err := http.Get(imageUrl.Url)
var config image.Config
var err error
if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
fmt.Println("Error: Failed to get the URL")
return 0, err
}

// 限制读取的字节数,防止下载整个图片
limitReader := io.LimitReader(response.Body, 8192)

response.Body.Close()

// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
config, err = webp.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
}
}
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
}
if config.Width < 512 && config.Height < 512 {
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
Expand Down
11 changes: 7 additions & 4 deletions web/src/components/LogsTable.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ const LogsTable = () => {
return (
record.type === 0 || record.type === 2 ?
<div>
<Tag color='grey' size='large' onClick={()=>{
<Tag color='grey' size='large' onClick={() => {
copyText(text)
}}> {text} </Tag>
</div>
Expand All @@ -133,7 +133,7 @@ const LogsTable = () => {
return (
record.type === 0 || record.type === 2 ?
<div>
<Tag color={stringToColor(text)} size='large' onClick={()=>{
<Tag color={stringToColor(text)} size='large' onClick={() => {
copyText(text)
}}> {text} </Tag>
</div>
Expand Down Expand Up @@ -202,11 +202,12 @@ const LogsTable = () => {
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
// 初始化start_timestamp为前一天
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(0),
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
Expand Down Expand Up @@ -338,7 +339,7 @@ const LogsTable = () => {
showSuccess('已复制:' + text);
} else {
// setSearchKeyword(text);
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
}
}

Expand Down Expand Up @@ -412,10 +413,12 @@ const LogsTable = () => {
name='model_name'
onChange={value => handleInputChange(value, 'model_name')}/>
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
initValue={start_timestamp}
value={start_timestamp} type='dateTime'
name='start_timestamp'
onChange={value => handleInputChange(value, 'start_timestamp')}/>
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
initValue={end_timestamp}
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
Expand Down

0 comments on commit 57d0fc3

Please sign in to comment.