Skip to content

Commit

Permalink
Merge pull request #311 from Lumerin-protocol/feat/prodia-video
Browse files Browse the repository at this point in the history
feat: prodia video generation
  • Loading branch information
alex-sandrk authored Dec 2, 2024
2 parents f470309 + 2f67891 commit 63dce8d
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 54 deletions.
8 changes: 7 additions & 1 deletion docs/models-config.json.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

- `modelId` (required) is the model id
- `modelName` (required) is the name of the model
- `apiType` (required) is the type of the model api. Currently supported values are "prodia" and "openai"
- `apiType` (required) is the type of the model api. Currently supported values are "prodia-sd", "prodia-sdxl", "prodia-v2" and "openai"
- `apiUrl` (required) is the url of the LLM server or model API
- `apiKey` (optional) is the api key for the model
- `concurrentSlots` (optional) are number of available distinct chats on the llm server and used for capacity policy
Expand Down Expand Up @@ -66,6 +66,12 @@ This file enables the morpheus-proxy-router to route requests to the correct mod
"apiUrl": "http://llmserver.domain.io:8080/v1",
"concurrentSlots": 8,
"capacityPolicy": "simple"
},
"0xe086adc275c99e32bb10b0aff5e8bfc391aad18cbb184727a75b2569149425c6": {
"apiUrl": "https://inference.prodia.com/v2",
"modelName": "inference.mochi1.txt2vid.v1",
"apiType": "prodia-v2",
"apiKey": "replace-with-your-api-key"
}
}
```
10 changes: 9 additions & 1 deletion proxy-router/internal/aiengine/ai_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ func (a *AiEngine) GetAdapter(ctx context.Context, chatID, modelID, sessionID co
}

if storeChatContext {
engine = NewHistory(engine, a.storage, chatID, modelID, forwardChatContext, a.log)
var actualModelID common.Hash
if modelID == (common.Hash{}) {
modelID, err := a.service.GetModelIdSession(ctx, sessionID)
if err != nil {
return nil, err
}
actualModelID = modelID
}
engine = NewHistory(engine, a.storage, chatID, actualModelID, forwardChatContext, a.log)
}

return engine, nil
Expand Down
2 changes: 2 additions & 0 deletions proxy-router/internal/aiengine/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ func ApiAdapterFactory(apiType string, modelName string, url string, apikey stri
return NewProdiaSDEngine(modelName, url, apikey, log), true
case API_TYPE_PRODIA_SDXL:
return NewProdiaSDXLEngine(modelName, url, apikey, log), true
case API_TYPE_PRODIA_V2:
return NewProdiaV2Engine(modelName, url, apikey, log), true
}
return nil, false
}
108 changes: 108 additions & 0 deletions proxy-router/internal/aiengine/prodia_v2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package aiengine

import (
"bytes"
"context"
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"

c "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal"
gcs "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/chatstorage/genericchatstorage"
"github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/lib"
"github.com/sashabaranov/go-openai"
)

const API_TYPE_PRODIA_V2 = "prodia-v2"
const PRODIA_V2_DEFAULT_BASE_URL = "https://inference.prodia.com/v2"

var (
ErrCapacity = errors.New("unable to schedule job with current token")
ErrBadResponse = errors.New("bad response")
ErrVideoGenerationRequest = errors.New("video generation error")
)

type ProdiaV2 struct {
modelName string
apiURL string
apiKey string

log lib.ILogger
}

func NewProdiaV2Engine(modelName, apiURL, apiKey string, log lib.ILogger) *ProdiaV2 {
if apiURL == "" {
apiURL = PRODIA_V2_DEFAULT_BASE_URL
}
return &ProdiaV2{
modelName: modelName,
apiURL: apiURL,
apiKey: apiKey,
log: log,
}
}

func (s *ProdiaV2) Prompt(ctx context.Context, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) error {
body := map[string]interface{}{
"type": s.modelName,
"config": map[string]string{
"prompt": prompt.Messages[len(prompt.Messages)-1].Content,
},
}

payload, err := json.Marshal(body)
if err != nil {
err = lib.WrapError(ErrImageGenerationInvalidRequest, err)
s.log.Error(err)
return err
}

req, err := http.NewRequest("POST", fmt.Sprintf("%s/job", s.apiURL), bytes.NewReader(payload))
if err != nil {
err = lib.WrapError(ErrImageGenerationRequest, err)
s.log.Error(err)
}

req.Header.Add(c.HEADER_ACCEPT, c.CONTENT_TYPE_VIDEO_MP4)
req.Header.Add(c.HEADER_CONTENT_TYPE, c.CONTENT_TYPE_JSON)
req.Header.Add(c.HEADER_AUTHORIZATION, fmt.Sprintf("Bearer %s", s.apiKey))

res, err := http.DefaultClient.Do(req)
if err != nil {
err = lib.WrapError(ErrImageGenerationRequest, err)
s.log.Error(err)
return err
}
defer res.Body.Close()

if res.StatusCode == http.StatusTooManyRequests {
return ErrCapacity
} else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
return lib.WrapError(ErrBadResponse, fmt.Errorf("status code: %d", res.StatusCode))
}

response, err := io.ReadAll(res.Body)
if err != nil {
err = lib.WrapError(ErrVideoGenerationRequest, err)
s.log.Error(err)
return err
}

sEnc := b64.StdEncoding.EncodeToString(response)

dataPrefix := "data:video/mp4;base64,"
chunk := gcs.NewChunkVideo(&gcs.VideoGenerationResult{
VideoRawContent: dataPrefix + sEnc,
})

return cb(ctx, chunk)
}

func (s *ProdiaV2) ApiType() string {
return API_TYPE_PRODIA_V2
}

var _ AIEngineStream = &ProdiaV2{}
1 change: 1 addition & 0 deletions proxy-router/internal/aiengine/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type RemoteModel struct {

type ProxyService interface {
SendPromptV2(ctx context.Context, sessionID common.Hash, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) (interface{}, error)
GetModelIdSession(ctx context.Context, sessionID common.Hash) (common.Hash, error)
}

func (p *RemoteModel) Prompt(ctx context.Context, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) error {
Expand Down
13 changes: 8 additions & 5 deletions proxy-router/internal/chatstorage/file_chat_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,19 @@ func (cs *ChatStorage) StorePromptResponseToFile(identifier string, isLocal bool
}

isImageContent := false
isVideoRawContent := false
if len(responses) > 0 {
isImageContent = responses[0].Type() == gcs.ChunkTypeImage
isVideoRawContent = responses[0].Type() == gcs.ChunkTypeVideo
}

newEntry := gcs.ChatMessage{
Prompt: p,
Response: strings.Join(resps, ""),
PromptAt: promptAt.Unix(),
ResponseAt: responseAt.Unix(),
IsImageContent: isImageContent,
Prompt: p,
Response: strings.Join(resps, ""),
PromptAt: promptAt.Unix(),
ResponseAt: responseAt.Unix(),
IsImageContent: isImageContent,
IsVideoRawContent: isVideoRawContent,
}

if chatHistory.Messages == nil && len(chatHistory.Messages) == 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,9 @@ type ImageGenerationResult struct {
}

type ImageGenerationCallback func(completion *ImageGenerationResult) error

type VideoGenerationResult struct {
VideoRawContent string `json:"videoRawContent"`
}

type VideoGenerationCallback func(completion *VideoGenerationResult) error
31 changes: 31 additions & 0 deletions proxy-router/internal/chatstorage/genericchatstorage/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,36 @@ func (c *ChunkImage) Data() interface{} {
return c.data
}

type ChunkVideo struct {
data *VideoGenerationResult
}

func NewChunkVideo(data *VideoGenerationResult) *ChunkVideo {
return &ChunkVideo{
data: data,
}
}

func (c *ChunkVideo) IsStreaming() bool {
return false
}

func (c *ChunkVideo) Tokens() int {
return 1
}

func (c *ChunkVideo) Type() ChunkType {
return ChunkTypeVideo
}

func (c *ChunkVideo) String() string {
return c.data.VideoRawContent
}

func (c *ChunkVideo) Data() interface{} {
return c.data
}

type Chunk interface {
IsStreaming() bool
Tokens() int
Expand All @@ -151,3 +181,4 @@ var _ Chunk = &ChunkText{}
var _ Chunk = &ChunkImage{}
var _ Chunk = &ChunkControl{}
var _ Chunk = &ChunkStreaming{}
var _ Chunk = &ChunkVideo{}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ func (h *ChatHistory) AppendChatHistory(req *openai.ChatCompletionRequest) *open
}

type ChatMessage struct {
Prompt OpenAiCompletionRequest `json:"prompt"`
Response string `json:"response"`
PromptAt int64 `json:"promptAt"`
ResponseAt int64 `json:"responseAt"`
IsImageContent bool `json:"isImageContent"`
Prompt OpenAiCompletionRequest `json:"prompt"`
Response string `json:"response"`
PromptAt int64 `json:"promptAt"`
ResponseAt int64 `json:"responseAt"`
IsImageContent bool `json:"isImageContent"`
IsVideoRawContent bool `json:"isVideoRawContent"`
}

type Chat struct {
ChatID string `json:"chatId"`
ModelID string `json:"modelId"`
Expand Down
2 changes: 1 addition & 1 deletion proxy-router/internal/config/models-config-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"title": "API Type",
"description": "Defines the type of API to be used with this model",
"type": "string",
"enum": ["openai", "prodia-sd", "prodia-sdxl"]
"enum": ["openai", "prodia-sd", "prodia-sdxl", "prodia-v2"]
},
"apiUrl": {
"title": "API URL",
Expand Down
1 change: 1 addition & 0 deletions proxy-router/internal/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const (

CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_EVENT_STREAM = "text/event-stream"
CONTENT_TYPE_VIDEO_MP4 = "video/mp4"

CONNECTION_KEEP_ALIVE = "keep-alive"

Expand Down
29 changes: 23 additions & 6 deletions proxy-router/internal/proxyapi/proxy_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ func (p *ProxyServiceSender) validateMsgSignature(result any, signature lib.HexS
return p.morRPC.VerifySignature(result, signature, providerPubicKey, p.log)
}

func (p *ProxyServiceSender) GetModelIdSession(ctx context.Context, sessionID common.Hash) (common.Hash, error) {
session, err := p.sessionRepo.GetSession(ctx, sessionID)
if err != nil {
return common.Hash{}, ErrSessionNotFound
}
return session.ModelID(), nil
}

func (p *ProxyServiceSender) SendPromptV2(ctx context.Context, sessionID common.Hash, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) (interface{}, error) {
session, err := p.sessionRepo.GetSession(ctx, sessionID)
if err != nil {
Expand Down Expand Up @@ -378,7 +386,7 @@ func (p *ProxyServiceSender) rpcRequestStreamV2(
) (interface{}, int, int, error) {
const (
TIMEOUT_TO_ESTABLISH_CONNECTION = time.Second * 3
TIMEOUT_TO_RECEIVE_FIRST_RESPONSE = time.Second * 5
TIMEOUT_TO_RECEIVE_FIRST_RESPONSE = time.Second * 30
MAX_RETRIES = 5
)

Expand Down Expand Up @@ -522,12 +530,21 @@ func (p *ProxyServiceSender) rpcRequestStreamV2(
} else {
var imageGenerationResult gcs.ImageGenerationResult
err = json.Unmarshal(aiResponse, &imageGenerationResult)
if err != nil {
return nil, ttftMs, totalTokens, lib.WrapError(ErrInvalidResponse, err)
if err == nil && imageGenerationResult.ImageUrl != "" {
totalTokens += 1
responses = append(responses, imageGenerationResult)
chunk = gcs.NewChunkImage(&imageGenerationResult)
} else {
var videoGenerationResult gcs.VideoGenerationResult
err = json.Unmarshal(aiResponse, &videoGenerationResult)
if err == nil && videoGenerationResult.VideoRawContent != "" {
totalTokens += 1
responses = append(responses, videoGenerationResult)
chunk = gcs.NewChunkVideo(&videoGenerationResult)
} else {
return nil, ttftMs, totalTokens, lib.WrapError(ErrInvalidResponse, err)
}
}
totalTokens += 1
responses = append(responses, imageGenerationResult)
chunk = gcs.NewChunkImage(&imageGenerationResult)
}

if ctx.Err() != nil {
Expand Down
15 changes: 13 additions & 2 deletions ui-desktop/src/renderer/src/components/chat/Chat.styles.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ export const Avatar = styled.div`
`

export const AvatarHeader = styled.div`
color: ${p => p.theme.colors.morMain}
color: ${p => p.theme.colors.morMain};
font-weight: 900;
padding: 0 8px;
font-size: 18px;
Expand All @@ -118,7 +118,7 @@ export const MessageBody = styled.div`
`

export const ChatTitleContainer = styled.div`
color: ${p => p.theme.colors.morMain}
color: ${p => p.theme.colors.morMain};
font-weight: 900;
padding: 0 8px;
font-size: 18px;
Expand Down Expand Up @@ -221,6 +221,17 @@ export const ImageContainer = styled.img`
}
`

export const VideoContainer = styled.div`
cursor: pointer;
padding: 0.25rem;
max-width: 100%;
height: 256px;
@media (min-height: 700px) {
height: 320px;
}
`

export const SubPriceLabel = styled.span`
color: ${p => p.theme.colors.morMain};
`
Loading

0 comments on commit 63dce8d

Please sign in to comment.