Skip to content

Commit

Permalink
Merge pull request #380 from Lumerin-protocol/test
Browse files Browse the repository at this point in the history
Add ClaudeAI support and fix test models
  • Loading branch information
abs2023 authored Jan 10, 2025
2 parents 7ac394d + d53f991 commit 783b00a
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 3 deletions.
207 changes: 207 additions & 0 deletions proxy-router/internal/aiengine/claudeai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package aiengine

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

c "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal"
gcs "github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/chatstorage/genericchatstorage"
"github.com/MorpheusAIs/Morpheus-Lumerin-Node/proxy-router/internal/lib"
"github.com/sashabaranov/go-openai"
)

// ClaudeAIResponse represents the top-level structure of the ClaudeAI JSON response.
type ClaudeAIResponse struct {
Content []ClaudeAIContent `json:"content"`
ID string `json:"id"`
Model string `json:"model"`
Role string `json:"role"`
StopReason string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Type string `json:"type"`
Usage ClaudeAIUsage `json:"usage"`
}

// ClaudeAIContent represents each item in the "content" array.
type ClaudeAIContent struct {
Text string `json:"text"`
Type string `json:"type"`
}

// ClaudeAIUsage represents the usage statistics of the request/response.
type ClaudeAIUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

type ClaudeAIStreamResponse struct {
Type string `json:"type"`
Delta ClaudeAIStreamDelta `json:"delta"`
ContentBlock ClaudeAIStreamContentBlock `json:"content_block"`
Message ClaudeAIStreamMessage `json:"message"`
}

type ClaudeAIStreamMessage struct {
ID string `json:"id"`
Role string `json:"role"`
Model string `json:"model"`
}

type ClaudeAIStreamDelta struct {
Type string `json:"type"`
Text string `json:"text"`
}

type ClaudeAIStreamContentBlock struct {
Type string `json:"type"`
Text string `json:"text"`
}

const API_TYPE_CLAUDEAI = "claudeai"

type ClaudeAI struct {
baseURL string
apiKey string
modelName string
client *http.Client
log lib.ILogger
}

func NewClaudeAIEngine(modelName, baseURL, apiKey string, log lib.ILogger) *ClaudeAI {
return &ClaudeAI{
baseURL: baseURL,
modelName: modelName,
apiKey: apiKey,
client: &http.Client{},
log: log,
}
}

func (a *ClaudeAI) Prompt(ctx context.Context, compl *openai.ChatCompletionRequest, cb gcs.CompletionCallback) error {
compl.Model = a.modelName
compl.MaxTokens = 1024
requestBody, err := json.Marshal(compl)
if err != nil {
return fmt.Errorf("failed to encode request: %v", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", a.baseURL+"/messages", bytes.NewReader(requestBody))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}

if a.apiKey != "" {
req.Header.Set("x-api-key", a.apiKey)
}
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set(c.HEADER_CONTENT_TYPE, c.CONTENT_TYPE_JSON)
req.Header.Set(c.HEADER_CONNECTION, c.CONNECTION_KEEP_ALIVE)
if compl.Stream {
req.Header.Set(c.HEADER_ACCEPT, c.CONTENT_TYPE_EVENT_STREAM)
}
resp, err := a.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
if isContentTypeStream(resp.Header) {
return a.readStream(ctx, resp.Body, cb)
}

return a.readResponse(ctx, resp.Body, cb)
}

func (a *ClaudeAI) readResponse(ctx context.Context, body io.Reader, cb gcs.CompletionCallback) error {
var compl ClaudeAIResponse
if err := json.NewDecoder(body).Decode(&compl); err != nil {
return fmt.Errorf("failed to decode response: %v", err)
}

var openaiCompl openai.ChatCompletionResponse
openaiCompl.ID = compl.ID
openaiCompl.Model = compl.Model
openaiCompl.Choices = make([]openai.ChatCompletionChoice, len(compl.Content))
for i, content := range compl.Content {
openaiCompl.Choices[i].Message.Content = content.Text
openaiCompl.Choices[i].Message.Role = compl.Role
}
openaiCompl.Usage.PromptTokens = compl.Usage.InputTokens
openaiCompl.Usage.CompletionTokens = compl.Usage.OutputTokens
openaiCompl.Usage.TotalTokens = compl.Usage.InputTokens + compl.Usage.OutputTokens

chunk := gcs.NewChunkText(&openaiCompl)
err := cb(ctx, chunk)
if err != nil {
return fmt.Errorf("callback failed: %v", err)
}

return nil
}

func (a *ClaudeAI) readStream(ctx context.Context, body io.Reader, cb gcs.CompletionCallback) error {
var model string
var role string
var messageID string

scanner := bufio.NewScanner(body)
for scanner.Scan() {
line := scanner.Text()

if strings.HasPrefix(line, StreamDataPrefix) {
data := line[len(StreamDataPrefix):] // Skip the "data: " prefix

var compl ClaudeAIStreamResponse
if err := json.Unmarshal([]byte(data), &compl); err != nil {
return fmt.Errorf("error decoding response: %s\n%s", err, line)
}
if compl.Type == "message_stop" {
return nil
}

if compl.Message.ID != "" {
messageID = compl.Message.ID
}
if compl.Message.Role != "" {
role = compl.Message.Role
}
if compl.Message.Model != "" {
model = compl.Message.Model
}
if compl.Delta.Text != "" || compl.ContentBlock.Text != "" {
openaiCompl := openai.ChatCompletionStreamResponse{}
openaiCompl.Choices = make([]openai.ChatCompletionStreamChoice, 1)
openaiCompl.Choices[0].Delta.Content = compl.Delta.Text
openaiCompl.Choices[0].Delta.Role = role
openaiCompl.ID = messageID
openaiCompl.Model = model
openaiCompl.Created = time.Now().Unix()

// Call the callback function with the unmarshalled completion
chunk := gcs.NewChunkStreaming(&openaiCompl)
err := cb(ctx, chunk)
if err != nil {
return fmt.Errorf("callback failed: %v", err)
}
}
}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading stream: %v", err)
}

return nil
}

func (a *ClaudeAI) ApiType() string {
return API_TYPE_CLAUDEAI
}

var _ AIEngineStream = &OpenAI{}
2 changes: 2 additions & 0 deletions proxy-router/internal/aiengine/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ func ApiAdapterFactory(apiType string, modelName string, url string, apikey stri
return NewProdiaV2Engine(modelName, url, apikey, log), true
case API_TYPE_HYPERBOLIC_SD:
return NewHyperbolicSDEngine(modelName, url, apikey, parameters, log), true
case API_TYPE_CLAUDEAI:
return NewClaudeAIEngine(modelName, url, apikey, log), true
}
return nil, false
}
1 change: 1 addition & 0 deletions proxy-router/internal/aiengine/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func (a *OpenAI) Prompt(ctx context.Context, compl *openai.ChatCompletionRequest
}
defer resp.Body.Close()

a.log.Debugf("AI Model responded with status code: %d", resp.StatusCode)
if isContentTypeStream(resp.Header) {
return a.readStream(ctx, resp.Body, cb)
}
Expand Down
2 changes: 1 addition & 1 deletion proxy-router/internal/config/models-config-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"title": "API Type",
"description": "Defines the type of API to be used with this model",
"type": "string",
"enum": ["openai", "prodia-v2", "hyperbolic-sd"]
"enum": ["openai", "prodia-v2", "hyperbolic-sd", "claudeai"]
},
"apiUrl": {
"title": "API URL",
Expand Down
16 changes: 15 additions & 1 deletion proxy-router/models-config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
"apiType": "hyperbolic-sd",
"apiUrl": "https://api.hyperbolic.xyz/v1",
"apiKey": "Authentication Token"
}
},
{
"modelId": "0x0000000000000000000000000000000000000000000000000000000000000003",
"modelName": "claude-3-5-sonnet-20241022",
"apiType": "claudeai",
"apiUrl": "https://api.anthropic.com/v1",
"apiKey": "FILL_ME_IN"
},
{
"modelId": "0x0000000000000000000000000000000000000000000000000000000000000004",
"modelName": "inference.sd15.txt2img.v1",
"apiType": "prodia-v2",
"apiUrl": "https://inference.prodia.com/v2",
"apiKey": "FILL_ME_IN"
},
]
}
2 changes: 1 addition & 1 deletion ui-desktop/src/renderer/src/components/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ const Chat = (props) => {

const { data: parts, isChunkIncomplete } = tryParseDataChunk(chunksBuffer);

if(isChunkIncomplete) {
if (isChunkIncomplete) {
continue;
}
else {
Expand Down

0 comments on commit 783b00a

Please sign in to comment.