diff --git a/go.mod b/go.mod index 1e4b1b8..c4b37df 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/chzyer/readline v1.5.1 github.com/kevinburke/ssh_config v1.2.0 github.com/rest-go/rest v0.1.3 - github.com/sashabaranov/go-openai v1.5.2 + github.com/sashabaranov/go-openai v1.10.1 github.com/stretchr/testify v1.8.2 golang.org/x/crypto v0.7.0 ) diff --git a/go.sum b/go.sum index 04d0d9d..544c37b 100644 --- a/go.sum +++ b/go.sum @@ -107,6 +107,8 @@ github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OK github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/sashabaranov/go-openai v1.5.2 h1:Gtn5HZEL25//rDDLEX+Anw5FI8TUC6gqIeM9BDBOO18= github.com/sashabaranov/go-openai v1.5.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.10.1 h1:6WyHJaNzF266VaEEuW6R4YW+Ei0wpMnqRYPGK7fhuhQ= +github.com/sashabaranov/go-openai v1.10.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= diff --git a/main.go b/main.go index 3e11a63..c1f9fc4 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,11 @@ package main import ( + "errors" "flag" "fmt" "io" + "log" "os" "path/filepath" "strings" @@ -12,6 +14,7 @@ import ( "github.com/atotto/clipboard" "github.com/briandowns/spinner" "github.com/chzyer/readline" + "github.com/sashabaranov/go-openai" "github.com/shellfly/aoi/pkg/chatgpt" "github.com/shellfly/aoi/pkg/color" @@ -24,21 +27,45 @@ the character laughing man who named Aoi, so you named yourself Aoi. Respond like we are good friend. ` -func main() { - startUp() - +func InitClient() (*openai.Client, string, error) { var model, openaiAPIKey, openaiAPIBaseUrl string + var azureDeployment string flag.StringVar(&openaiAPIBaseUrl, "openai_api_base_url", os.Getenv("OPENAI_API_BASE_URL"), "OpenAI API Base Url, default: https://api.openai.com") flag.StringVar(&openaiAPIKey, "openai_api_key", os.Getenv("OPENAI_API_KEY"), "OpenAI API key") flag.StringVar(&model, "model", "gpt-3.5-turbo", "model to use") + flag.StringVar(&azureDeployment, "azure.deployment", "", "azure deployment name of the model") flag.Parse() - // Create an AI - ai, err := chatgpt.NewAI(openaiAPIBaseUrl, openaiAPIKey, model) + if openaiAPIKey == "" { + return nil, "", errors.New("Please set the OPENAI_API_KEY environment variable") + } + + var config openai.ClientConfig + if azureDeployment != "" { + if openaiAPIBaseUrl == "" { + return nil, "", errors.New("Please set the OPENAI_API_BASE_URL to your azure endpoint") + } + config = openai.DefaultAzureConfig(openaiAPIKey, openaiAPIBaseUrl) + config.AzureModelMapperFunc = func(model string) string { + return azureDeployment + } + } else { + config = openai.DefaultConfig(openaiAPIKey) + if openaiAPIBaseUrl != "" { + config.BaseURL = openaiAPIBaseUrl + } + } + client := openai.NewClientWithConfig(config) + return client, model, nil +} + +func main() { + startUp() + client, model, err := InitClient() if err != nil { - fmt.Println("create ai error: ", err) - return + log.Fatal(err) } + ai := chatgpt.NewAI(client, model) ai.SetSystem(system) configDir := makeDir(".aoi") @@ -90,6 +117,7 @@ func main() { // If previous is finished try to create a new one, otherwise continue // to reuse it for prompts if cmd.IsFinished() { + ai.Reset() cmd, prompts = command.Parse(input) rl.SetPrompt(color.Yellow(cmd.Prompt(userPrompt))) } else { diff --git a/pkg/chatgpt/ai.go b/pkg/chatgpt/ai.go index 8bba2a7..e1ccc61 100644 --- a/pkg/chatgpt/ai.go +++ b/pkg/chatgpt/ai.go @@ -2,7 +2,6 @@ package chatgpt import ( "context" - "errors" "fmt" "strings" "time" @@ -22,17 +21,7 @@ type AI struct { debug bool } -func NewAI(apiBaseUrl, apiKey, model string) (*AI, error) { - if apiKey == "" { - return nil, errors.New("Please set the OPENAI_API_KEY environment variable") - } - - // Create a new OpenAI API client with the provided API key - config := openai.DefaultConfig(apiKey) - if apiBaseUrl != "" { - config.BaseURL = apiBaseUrl + "/v1" - } - client := openai.NewClientWithConfig(config) +func NewAI(client *openai.Client, model string) *AI { messages := make([]openai.ChatCompletionMessage, 0, 2*MessageLimit) ai := &AI{ client: client, @@ -40,8 +29,9 @@ func NewAI(apiBaseUrl, apiKey, model string) (*AI, error) { messages: messages, debug: false, } - return ai, nil + return ai } + func (ai *AI) SetSystem(system string) { ai.system = system ai.messages = []openai.ChatCompletionMessage{NewMessage(openai.ChatMessageRoleSystem, system)} @@ -70,7 +60,11 @@ func (ai *AI) Query(prompts []string) (string, error) { ai.limitTokens() if ai.debug { - fmt.Println(ai.messages) + fmt.Println("---debug---") + for _, msg := range ai.messages { + fmt.Println(msg) + } + fmt.Println("---debug---") } // Set the request parameters for the completion API req := openai.ChatCompletionRequest{ diff --git a/pkg/chatgpt/ai_test.go b/pkg/chatgpt/ai_test.go index c0e2949..2567a53 100644 --- a/pkg/chatgpt/ai_test.go +++ b/pkg/chatgpt/ai_test.go @@ -9,8 +9,8 @@ import ( ) func TestAI(t *testing.T) { - ai, err := NewAI("https:...", "api key", "model") - assert.Nil(t, err) + client := openai.NewClient("api key") + ai := NewAI(client, "model") t.Run("limit tokens", func(t *testing.T) { ai.messages = make([]openai.ChatCompletionMessage, MessageLimit+2) ai.messages[0] = NewMessage("system", "message")