From 862f2ae66de4fb058b4ac75385092f84c799a6a5 Mon Sep 17 00:00:00 2001 From: Aleksandr Kukharenko Date: Mon, 2 Dec 2024 13:46:22 +0200 Subject: [PATCH 1/2] feat: prodia video generation --- proxy-router/internal/aiengine/ai_engine.go | 10 +- proxy-router/internal/aiengine/factory.go | 2 + proxy-router/internal/aiengine/prodia_v2.go | 108 ++++++++++++++++++ proxy-router/internal/aiengine/remote.go | 1 + .../internal/chatstorage/file_chat_storage.go | 13 ++- .../genericchatstorage/chat_responses.go | 6 + .../genericchatstorage/completion.go | 31 +++++ .../genericchatstorage/interface.go | 12 +- proxy-router/internal/constants.go | 1 + .../internal/proxyapi/proxy_sender.go | 29 ++++- .../src/components/chat/Chat.styles.tsx | 15 ++- .../src/renderer/src/components/chat/Chat.tsx | 81 +++++++------ .../src/components/chat/interfaces.tsx | 2 + 13 files changed, 259 insertions(+), 52 deletions(-) create mode 100644 proxy-router/internal/aiengine/prodia_v2.go diff --git a/proxy-router/internal/aiengine/ai_engine.go b/proxy-router/internal/aiengine/ai_engine.go index c66f595c..886e41c3 100644 --- a/proxy-router/internal/aiengine/ai_engine.go +++ b/proxy-router/internal/aiengine/ai_engine.go @@ -55,7 +55,15 @@ func (a *AiEngine) GetAdapter(ctx context.Context, chatID, modelID, sessionID co } if storeChatContext { - engine = NewHistory(engine, a.storage, chatID, modelID, forwardChatContext, a.log) + var actualModelID common.Hash + if modelID == (common.Hash{}) { + modelID, err := a.service.GetModelIdSession(ctx, sessionID) + if err != nil { + return nil, err + } + actualModelID = modelID + } + engine = NewHistory(engine, a.storage, chatID, actualModelID, forwardChatContext, a.log) } return engine, nil diff --git a/proxy-router/internal/aiengine/factory.go b/proxy-router/internal/aiengine/factory.go index 820f82e3..b9edd31e 100644 --- a/proxy-router/internal/aiengine/factory.go +++ b/proxy-router/internal/aiengine/factory.go @@ -10,6 +10,8 @@ func ApiAdapterFactory(apiType string, modelName string, url string, apikey stri return NewProdiaSDEngine(modelName, url, apikey, log), true case API_TYPE_PRODIA_SDXL: return NewProdiaSDXLEngine(modelName, url, apikey, log), true + case API_TYPE_PRODIA_V2: + return NewProdiaV2Engine(modelName, url, apikey, log), true } return nil, false } diff --git a/proxy-router/internal/aiengine/prodia_v2.go b/proxy-router/internal/aiengine/prodia_v2.go new file mode 100644 index 00000000..9606b63f --- /dev/null +++ b/proxy-router/internal/aiengine/prodia_v2.go @@ -0,0 +1,108 @@ +package aiengine + +import ( + "bytes" + "context" + b64 "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + c "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal" + gcs "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/chatstorage/genericchatstorage" + "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/lib" + "github.com/sashabaranov/go-openai" +) + +const API_TYPE_PRODIA_V2 = "prodia-v2" +const PRODIA_V2_DEFAULT_BASE_URL = "https://inference.prodia.com/v2" + +var ( + ErrCapacity = errors.New("unable to schedule job with current token") + ErrBadResponse = errors.New("bad response") + ErrVideoGenerationRequest = errors.New("video generation error") +) + +type ProdiaV2 struct { + modelName string + apiURL string + apiKey string + + log lib.ILogger +} + +func NewProdiaV2Engine(modelName, apiURL, apiKey string, log lib.ILogger) *ProdiaV2 { + if apiURL == "" { + apiURL = PRODIA_V2_DEFAULT_BASE_URL + } + return &ProdiaV2{ + modelName: modelName, + apiURL: apiURL, + apiKey: apiKey, + log: log, + } +} + +func (s *ProdiaV2) Prompt(ctx context.Context, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) error { + body := map[string]interface{}{ + "type": s.modelName, + "config": map[string]string{ + "prompt": prompt.Messages[len(prompt.Messages)-1].Content, + }, + } + + payload, err := json.Marshal(body) + if err != nil { + err = lib.WrapError(ErrImageGenerationInvalidRequest, err) + s.log.Error(err) + return err + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/job", s.apiURL), bytes.NewReader(payload)) + if err != nil { + err = lib.WrapError(ErrImageGenerationRequest, err) + s.log.Error(err) + } + + req.Header.Add(c.HEADER_ACCEPT, c.CONTENT_TYPE_VIDEO_MP4) + req.Header.Add(c.HEADER_CONTENT_TYPE, c.CONTENT_TYPE_JSON) + req.Header.Add(c.HEADER_AUTHORIZATION, fmt.Sprintf("Bearer %s", s.apiKey)) + + res, err := http.DefaultClient.Do(req) + if err != nil { + err = lib.WrapError(ErrImageGenerationRequest, err) + s.log.Error(err) + return err + } + defer res.Body.Close() + + if res.StatusCode == http.StatusTooManyRequests { + return ErrCapacity + } else if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + return lib.WrapError(ErrBadResponse, fmt.Errorf("status code: %d", res.StatusCode)) + } + + response, err := io.ReadAll(res.Body) + if err != nil { + err = lib.WrapError(ErrVideoGenerationRequest, err) + s.log.Error(err) + return err + } + + sEnc := b64.StdEncoding.EncodeToString(response) + + dataPrefix := "data:video/mp4;base64," + chunk := gcs.NewChunkVideo(&gcs.VideoGenerationResult{ + VideoRawContent: dataPrefix + sEnc, + }) + + return cb(ctx, chunk) +} + +func (s *ProdiaV2) ApiType() string { + return API_TYPE_PRODIA_V2 +} + +var _ AIEngineStream = &ProdiaV2{} diff --git a/proxy-router/internal/aiengine/remote.go b/proxy-router/internal/aiengine/remote.go index 86088600..ba958b2a 100644 --- a/proxy-router/internal/aiengine/remote.go +++ b/proxy-router/internal/aiengine/remote.go @@ -15,6 +15,7 @@ type RemoteModel struct { type ProxyService interface { SendPromptV2(ctx context.Context, sessionID common.Hash, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) (interface{}, error) + GetModelIdSession(ctx context.Context, sessionID common.Hash) (common.Hash, error) } func (p *RemoteModel) Prompt(ctx context.Context, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) error { diff --git a/proxy-router/internal/chatstorage/file_chat_storage.go b/proxy-router/internal/chatstorage/file_chat_storage.go index d10c7325..964baa4f 100644 --- a/proxy-router/internal/chatstorage/file_chat_storage.go +++ b/proxy-router/internal/chatstorage/file_chat_storage.go @@ -76,16 +76,19 @@ func (cs *ChatStorage) StorePromptResponseToFile(identifier string, isLocal bool } isImageContent := false + isVideoRawContent := false if len(responses) > 0 { isImageContent = responses[0].Type() == gcs.ChunkTypeImage + isVideoRawContent = responses[0].Type() == gcs.ChunkTypeVideo } newEntry := gcs.ChatMessage{ - Prompt: p, - Response: strings.Join(resps, ""), - PromptAt: promptAt.Unix(), - ResponseAt: responseAt.Unix(), - IsImageContent: isImageContent, + Prompt: p, + Response: strings.Join(resps, ""), + PromptAt: promptAt.Unix(), + ResponseAt: responseAt.Unix(), + IsImageContent: isImageContent, + IsVideoRawContent: isVideoRawContent, } if chatHistory.Messages == nil && len(chatHistory.Messages) == 0 { diff --git a/proxy-router/internal/chatstorage/genericchatstorage/chat_responses.go b/proxy-router/internal/chatstorage/genericchatstorage/chat_responses.go index 4eb6129d..f016d943 100644 --- a/proxy-router/internal/chatstorage/genericchatstorage/chat_responses.go +++ b/proxy-router/internal/chatstorage/genericchatstorage/chat_responses.go @@ -139,3 +139,9 @@ type ImageGenerationResult struct { } type ImageGenerationCallback func(completion *ImageGenerationResult) error + +type VideoGenerationResult struct { + VideoRawContent string `json:"videoRawContent"` +} + +type VideoGenerationCallback func(completion *VideoGenerationResult) error diff --git a/proxy-router/internal/chatstorage/genericchatstorage/completion.go b/proxy-router/internal/chatstorage/genericchatstorage/completion.go index eb9fd0fc..b75fe402 100644 --- a/proxy-router/internal/chatstorage/genericchatstorage/completion.go +++ b/proxy-router/internal/chatstorage/genericchatstorage/completion.go @@ -139,6 +139,36 @@ func (c *ChunkImage) Data() interface{} { return c.data } +type ChunkVideo struct { + data *VideoGenerationResult +} + +func NewChunkVideo(data *VideoGenerationResult) *ChunkVideo { + return &ChunkVideo{ + data: data, + } +} + +func (c *ChunkVideo) IsStreaming() bool { + return false +} + +func (c *ChunkVideo) Tokens() int { + return 1 +} + +func (c *ChunkVideo) Type() ChunkType { + return ChunkTypeVideo +} + +func (c *ChunkVideo) String() string { + return c.data.VideoRawContent +} + +func (c *ChunkVideo) Data() interface{} { + return c.data +} + type Chunk interface { IsStreaming() bool Tokens() int @@ -151,3 +181,4 @@ var _ Chunk = &ChunkText{} var _ Chunk = &ChunkImage{} var _ Chunk = &ChunkControl{} var _ Chunk = &ChunkStreaming{} +var _ Chunk = &ChunkVideo{} diff --git a/proxy-router/internal/chatstorage/genericchatstorage/interface.go b/proxy-router/internal/chatstorage/genericchatstorage/interface.go index 690c00f1..bc2cf9d9 100644 --- a/proxy-router/internal/chatstorage/genericchatstorage/interface.go +++ b/proxy-router/internal/chatstorage/genericchatstorage/interface.go @@ -47,12 +47,14 @@ func (h *ChatHistory) AppendChatHistory(req *openai.ChatCompletionRequest) *open } type ChatMessage struct { - Prompt OpenAiCompletionRequest `json:"prompt"` - Response string `json:"response"` - PromptAt int64 `json:"promptAt"` - ResponseAt int64 `json:"responseAt"` - IsImageContent bool `json:"isImageContent"` + Prompt OpenAiCompletionRequest `json:"prompt"` + Response string `json:"response"` + PromptAt int64 `json:"promptAt"` + ResponseAt int64 `json:"responseAt"` + IsImageContent bool `json:"isImageContent"` + IsVideoRawContent bool `json:"isVideoRawContent"` } + type Chat struct { ChatID string `json:"chatId"` ModelID string `json:"modelId"` diff --git a/proxy-router/internal/constants.go b/proxy-router/internal/constants.go index 8efcb1ff..1ee3288e 100644 --- a/proxy-router/internal/constants.go +++ b/proxy-router/internal/constants.go @@ -5,6 +5,7 @@ const ( CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_EVENT_STREAM = "text/event-stream" + CONTENT_TYPE_VIDEO_MP4 = "video/mp4" CONNECTION_KEEP_ALIVE = "keep-alive" diff --git a/proxy-router/internal/proxyapi/proxy_sender.go b/proxy-router/internal/proxyapi/proxy_sender.go index e02cf385..09071da4 100644 --- a/proxy-router/internal/proxyapi/proxy_sender.go +++ b/proxy-router/internal/proxyapi/proxy_sender.go @@ -285,6 +285,14 @@ func (p *ProxyServiceSender) validateMsgSignature(result any, signature lib.HexS return p.morRPC.VerifySignature(result, signature, providerPubicKey, p.log) } +func (p *ProxyServiceSender) GetModelIdSession(ctx context.Context, sessionID common.Hash) (common.Hash, error) { + session, err := p.sessionRepo.GetSession(ctx, sessionID) + if err != nil { + return common.Hash{}, ErrSessionNotFound + } + return session.ModelID(), nil +} + func (p *ProxyServiceSender) SendPromptV2(ctx context.Context, sessionID common.Hash, prompt *openai.ChatCompletionRequest, cb gcs.CompletionCallback) (interface{}, error) { session, err := p.sessionRepo.GetSession(ctx, sessionID) if err != nil { @@ -378,7 +386,7 @@ func (p *ProxyServiceSender) rpcRequestStreamV2( ) (interface{}, int, int, error) { const ( TIMEOUT_TO_ESTABLISH_CONNECTION = time.Second * 3 - TIMEOUT_TO_RECEIVE_FIRST_RESPONSE = time.Second * 5 + TIMEOUT_TO_RECEIVE_FIRST_RESPONSE = time.Second * 30 MAX_RETRIES = 5 ) @@ -522,12 +530,21 @@ func (p *ProxyServiceSender) rpcRequestStreamV2( } else { var imageGenerationResult gcs.ImageGenerationResult err = json.Unmarshal(aiResponse, &imageGenerationResult) - if err != nil { - return nil, ttftMs, totalTokens, lib.WrapError(ErrInvalidResponse, err) + if err == nil && imageGenerationResult.ImageUrl != "" { + totalTokens += 1 + responses = append(responses, imageGenerationResult) + chunk = gcs.NewChunkImage(&imageGenerationResult) + } else { + var videoGenerationResult gcs.VideoGenerationResult + err = json.Unmarshal(aiResponse, &videoGenerationResult) + if err == nil && videoGenerationResult.VideoRawContent != "" { + totalTokens += 1 + responses = append(responses, videoGenerationResult) + chunk = gcs.NewChunkVideo(&videoGenerationResult) + } else { + return nil, ttftMs, totalTokens, lib.WrapError(ErrInvalidResponse, err) + } } - totalTokens += 1 - responses = append(responses, imageGenerationResult) - chunk = gcs.NewChunkImage(&imageGenerationResult) } if ctx.Err() != nil { diff --git a/ui-desktop/src/renderer/src/components/chat/Chat.styles.tsx b/ui-desktop/src/renderer/src/components/chat/Chat.styles.tsx index f50d6216..a1043c3d 100644 --- a/ui-desktop/src/renderer/src/components/chat/Chat.styles.tsx +++ b/ui-desktop/src/renderer/src/components/chat/Chat.styles.tsx @@ -94,7 +94,7 @@ export const Avatar = styled.div` ` export const AvatarHeader = styled.div` - color: ${p => p.theme.colors.morMain} + color: ${p => p.theme.colors.morMain}; font-weight: 900; padding: 0 8px; font-size: 18px; @@ -118,7 +118,7 @@ export const MessageBody = styled.div` ` export const ChatTitleContainer = styled.div` - color: ${p => p.theme.colors.morMain} + color: ${p => p.theme.colors.morMain}; font-weight: 900; padding: 0 8px; font-size: 18px; @@ -221,6 +221,17 @@ export const ImageContainer = styled.img` } ` +export const VideoContainer = styled.div` + cursor: pointer; + padding: 0.25rem; + max-width: 100%; + height: 256px; + + @media (min-height: 700px) { + height: 320px; + } +` + export const SubPriceLabel = styled.span` color: ${p => p.theme.colors.morMain}; ` \ No newline at end of file diff --git a/ui-desktop/src/renderer/src/components/chat/Chat.tsx b/ui-desktop/src/renderer/src/components/chat/Chat.tsx index 7024c393..f0faf34a 100644 --- a/ui-desktop/src/renderer/src/components/chat/Chat.tsx +++ b/ui-desktop/src/renderer/src/components/chat/Chat.tsx @@ -18,7 +18,8 @@ import { SendBtn, LoadingCover, ImageContainer, - SubPriceLabel + SubPriceLabel, + VideoContainer } from './Chat.styles'; import { BtnAccent } from '../dashboard/BalanceBlock.styles'; import { withRouter } from 'react-router-dom'; @@ -223,7 +224,7 @@ const Chat = (props) => { const aiColor = getColor(aiIcon); messages.push({ id: makeId(16), text: m.prompt.messages[0].content, user: userMessage.user, role: userMessage.role, icon: userMessage.icon, color: userMessage.color }); - messages.push({ id: makeId(16), text: m.response, user: modelName, role: "assistant", icon: aiIcon, color: aiColor, isImageContent: m.isImageContent }); + messages.push({ id: makeId(16), text: m.response, user: modelName, role: "assistant", icon: aiIcon, color: aiColor, isImageContent: m.isImageContent, isVideoRawContent: m.isVideoRawContent }); }); setMessages(messages); } @@ -272,7 +273,7 @@ const Chat = (props) => { return; } - const selectedModel = chainData.models.find((m: any) => m.Id == modelId); + const selectedModel = chainData.isLocal ? chainData.models.find((m: any) => m.Id == modelId) : chainData.models.find((m: any) => m.Id == modelId && m.bids); setSelectedModel(selectedModel); setIsReadonly(false); @@ -421,8 +422,9 @@ const Chat = (props) => { } const imageContent = part.imageUrl; + const videoRawContent = part.videoRawContent; - if (!part?.id && !imageContent) { + if (!part?.id && !imageContent && !videoRawContent) { return; } @@ -432,6 +434,9 @@ const Chat = (props) => { if (imageContent) { result = [...otherMessages, { id: part.job, user: modelName, role: "assistant", text: imageContent, isImageContent: true, ...iconProps }]; } + if (videoRawContent) { + result = [...otherMessages, { id: part.job, user: modelName, role: "assistant", text: videoRawContent, isVideoRawContent: true, ...iconProps }]; + } else { const text = `${message?.text || ''}${part?.choices[0]?.delta?.content || ''}`.replace("<|im_start|>", "").replace("<|im_end|>", ""); result = [...otherMessages, { id: part.id, user: modelName, role: "assistant", text: text, ...iconProps }]; @@ -521,7 +526,7 @@ const Chat = (props) => { setIsReadonly(false); setChat({ id: generateHashId(), createdAt: new Date(), modelId, isLocal }); - const selectedModel = chainData.models.find((m: any) => m.Id == modelId); + const selectedModel = isLocal ? chainData.models.find((m: any) => m.Id == modelId) : chainData.models.find((m: any) => m.Id == modelId && m.bids); setSelectedModel(selectedModel); if (isLocal) { @@ -534,7 +539,7 @@ const Chat = (props) => { const openModelSession = openSessions.find(s => s.ModelAgentId == modelId); if (openModelSession) { - const selectedBid = selectedModel.bids.find(b => b.Id == openModelSession.BidID); + const selectedBid = selectedModel.bids.find(b => b.Id == openModelSession.BidID && b.bids); if (selectedBid) { setSelectedBid(selectedBid); } @@ -721,6 +726,42 @@ const Chat = (props) => { ) } +const renderMessage = (message, onOpenImage) => { + if (message.isImageContent) { + return ({ onOpenImage(message.text)} />}) + } + + if (message.isVideoRawContent) { + return () + } + + return ( + + + ) : ( + + {children} + + ) + } + }} + /> + ) +}; + const Message = ({ message, onOpenImage }) => { return (
@@ -730,33 +771,7 @@ const Message = ({ message, onOpenImage }) => {
{message.user} { - message.isImageContent - ? ({ onOpenImage(message.text)} />}) - : ( - - - ) : ( - - {children} - - ) - } - }} - /> - ) + renderMessage(message, onOpenImage) }
) diff --git a/ui-desktop/src/renderer/src/components/chat/interfaces.tsx b/ui-desktop/src/renderer/src/components/chat/interfaces.tsx index e8d31bb0..df40c786 100644 --- a/ui-desktop/src/renderer/src/components/chat/interfaces.tsx +++ b/ui-desktop/src/renderer/src/components/chat/interfaces.tsx @@ -22,6 +22,7 @@ export interface HistoryMessage { icon: string; color: string; isImageContent?: boolean; + isVideoRawContent?: boolean; } export interface ChatHistoryInterface { @@ -36,6 +37,7 @@ export interface ChatMessage { promptAt: number; responseAt: number; isImageContent?: boolean; + isVideoRawContent?: boolean; } export interface ChatPrompt { From ffc1cd225df24c52bb091a0493a61061a2f055c6 Mon Sep 17 00:00:00 2001 From: Aleksandr Kukharenko Date: Mon, 2 Dec 2024 18:20:27 +0200 Subject: [PATCH 2/2] update docs --- docs/models-config.json.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/models-config.json.md b/docs/models-config.json.md index 0452b7a8..a767c42f 100644 --- a/docs/models-config.json.md +++ b/docs/models-config.json.md @@ -1,7 +1,7 @@ # Example models config file. Local model configurations are stored in this file * `root_key` (required) is the model id * `modelName` (required) is the name of the model -* `apiType` (required) is the type of the model api. Currently supported values are "prodia" and "openai" +* `apiType` (required) is the type of the model api. Currently supported values are "prodia-sd", "prodia-sdxl", "prodia-v2" and "openai" * `apiUrl` (required) is the url of the LLM server or model API * `apiKey` (optional) is the api key for the model * `cononcurrentSlots` (optional) are number of available distinct chats on the llm server and used for capacity policy @@ -37,6 +37,12 @@ "apiUrl": "http://llmserver.domain.io:8080/v1", "concurrentSlots": 8, "capacityPolicy": "simple" + }, + "0xe086adc275c99e32bb10b0aff5e8bfc391aad18cbb184727a75b2569149425c6": { + "modelName": "inference.mochi1.txt2vid.v1", + "apiType": "prodia-v2", + "apiUrl": "https://inference.prodia.com/v2", + "apiKey": "replace-with-your-api-key" } } ``` \ No newline at end of file