Skip to content

Commit

Permalink
feat: add azure openai support
Browse files Browse the repository at this point in the history
  • Loading branch information
shellfly committed Jun 9, 2023
1 parent ade11ac commit 4bfc79f
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 24 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/chzyer/readline v1.5.1
github.com/kevinburke/ssh_config v1.2.0
github.com/rest-go/rest v0.1.3
github.com/sashabaranov/go-openai v1.5.2
github.com/sashabaranov/go-openai v1.10.1
github.com/stretchr/testify v1.8.2
golang.org/x/crypto v0.7.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OK
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/sashabaranov/go-openai v1.5.2 h1:Gtn5HZEL25//rDDLEX+Anw5FI8TUC6gqIeM9BDBOO18=
github.com/sashabaranov/go-openai v1.5.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ=
github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
Expand Down
42 changes: 35 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package main

import (
"errors"
"flag"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
Expand All @@ -12,6 +14,7 @@ import (
"github.com/atotto/clipboard"
"github.com/briandowns/spinner"
"github.com/chzyer/readline"
"github.com/sashabaranov/go-openai"

"github.com/shellfly/aoi/pkg/chatgpt"
"github.com/shellfly/aoi/pkg/color"
Expand All @@ -24,21 +27,45 @@ the character laughing man who named Aoi, so you named yourself Aoi. Respond
like we are good friend.
`

func main() {
startUp()

func InitClient() (*openai.Client, string, error) {
var model, openaiAPIKey, openaiAPIBaseUrl string
var azureDeployment string
flag.StringVar(&openaiAPIBaseUrl, "openai_api_base_url", os.Getenv("OPENAI_API_BASE_URL"), "OpenAI API Base Url, default: https://api.openai.com")
flag.StringVar(&openaiAPIKey, "openai_api_key", os.Getenv("OPENAI_API_KEY"), "OpenAI API key")
flag.StringVar(&model, "model", "gpt-3.5-turbo", "model to use")
flag.StringVar(&azureDeployment, "azure.deployment", "", "azure deployment name of the model")
flag.Parse()

// Create an AI
ai, err := chatgpt.NewAI(openaiAPIBaseUrl, openaiAPIKey, model)
if openaiAPIKey == "" {
return nil, "", errors.New("Please set the OPENAI_API_KEY environment variable")
}

var config openai.ClientConfig
if azureDeployment != "" {
if openaiAPIBaseUrl == "" {
return nil, "", errors.New("Please set the OPENAI_API_BASE_URL to your azure endpoint")
}
config = openai.DefaultAzureConfig(openaiAPIKey, openaiAPIBaseUrl)
config.AzureModelMapperFunc = func(model string) string {
return azureDeployment
}
} else {
config = openai.DefaultConfig(openaiAPIKey)
if openaiAPIBaseUrl != "" {
config.BaseURL = openaiAPIBaseUrl
}
}
client := openai.NewClientWithConfig(config)
return client, model, nil
}

func main() {
startUp()
client, model, err := InitClient()
if err != nil {
fmt.Println("create ai error: ", err)
return
log.Fatal(err)
}
ai := chatgpt.NewAI(client, model)
ai.SetSystem(system)

configDir := makeDir(".aoi")
Expand Down Expand Up @@ -90,6 +117,7 @@ func main() {
// If previous is finished try to create a new one, otherwise continue
// to reuse it for prompts
if cmd.IsFinished() {
ai.Reset()
cmd, prompts = command.Parse(input)
rl.SetPrompt(color.Yellow(cmd.Prompt(userPrompt)))
} else {
Expand Down
22 changes: 8 additions & 14 deletions pkg/chatgpt/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package chatgpt

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand All @@ -22,26 +21,17 @@ type AI struct {
debug bool
}

func NewAI(apiBaseUrl, apiKey, model string) (*AI, error) {
if apiKey == "" {
return nil, errors.New("Please set the OPENAI_API_KEY environment variable")
}

// Create a new OpenAI API client with the provided API key
config := openai.DefaultConfig(apiKey)
if apiBaseUrl != "" {
config.BaseURL = apiBaseUrl + "/v1"
}
client := openai.NewClientWithConfig(config)
func NewAI(client *openai.Client, model string) *AI {
messages := make([]openai.ChatCompletionMessage, 0, 2*MessageLimit)
ai := &AI{
client: client,
model: model,
messages: messages,
debug: false,
}
return ai, nil
return ai
}

func (ai *AI) SetSystem(system string) {
ai.system = system
ai.messages = []openai.ChatCompletionMessage{NewMessage(openai.ChatMessageRoleSystem, system)}
Expand Down Expand Up @@ -70,7 +60,11 @@ func (ai *AI) Query(prompts []string) (string, error) {
ai.limitTokens()

if ai.debug {
fmt.Println(ai.messages)
fmt.Println("---debug---")
for _, msg := range ai.messages {
fmt.Println(msg)
}
fmt.Println("---debug---")
}
// Set the request parameters for the completion API
req := openai.ChatCompletionRequest{
Expand Down
4 changes: 2 additions & 2 deletions pkg/chatgpt/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

func TestAI(t *testing.T) {
ai, err := NewAI("https:...", "api key", "model")
assert.Nil(t, err)
client := openai.NewClient("api key")
ai := NewAI(client, "model")
t.Run("limit tokens", func(t *testing.T) {
ai.messages = make([]openai.ChatCompletionMessage, MessageLimit+2)
ai.messages[0] = NewMessage("system", "message")
Expand Down

0 comments on commit 4bfc79f

Please sign in to comment.