Skip to content

Commit

Permalink
Merge pull request #90 from bjwswang/main
Browse files Browse the repository at this point in the history
chore: make sure embedding and chat use same LLM
  • Loading branch information
bjwswang authored Sep 13, 2023
2 parents 2954f8b + e21516c commit 3c14126
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 65 deletions.
113 changes: 76 additions & 37 deletions arctl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ arctl(arcadia command line tool)
go install github.com/kubeagi/arcadia/arctl@latest
```

If build succeeded, `arctl` will be built into `bin/arctl` under `arcadia`
Now have a try!❤️

```shell
❯ arctl -h
Expand All @@ -30,8 +30,47 @@ Flags:
Use "arctl [command] --help" for more information about a command.
```

### Build from local code

1. Clone `arcadia`

```shell
git clone https://github.com/kubeagi/arcadia.git
```

2. Build

```shell
make arctl
```

3. Have a try! ❤️

```shell
❯ ./bin/arctl chat -h
Do LLM chat with similarity search(optional)

Usage:
arctl chat [usage] [flags]

Flags:
--enable-embedding-search enable embedding similarity search
-h, --help help for chat
--llm-apikey string apiKey to access embedding/llm service.Must required when embedding similarity search is enabled
--llm-type string llm type to use for embedding & chat(Only zhipuai,openai supported now) (default "zhipuai")
--method string Invoke method used when access LLM service(invoke/sse-invoke) (default "sse-invoke")
--model string which model to use: chatglm_lite/chatglm_std/chatglm_pro (default "chatglm_lite")
--namespace string namespace/collection to query from (default "arcadia")
--num-docs int number of documents to be returned with SimilarSearch (default 5)
--question string question text to be asked
--score-threshold float score threshold for similarity search(Higher is better)
--temperature float32 temperature for chat (default 0.95)
--top-p float32 top-p for chat (default 0.7)
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")
```

## Usage
### Load documents into vector store
### Load documents into vector store with embedding service

```shell
❯ arctl load -h
Expand All @@ -41,24 +80,25 @@ Usage:
arctl load [usage] [flags]

Flags:
--chunk-overlap int chunk overlap for embedding (default 30)
--chunk-size int chunk size for embedding (default 300)
--document string path of the document to load
--document-language string language of the document(Only text,html,csv supported now) (default "text")
--embedding-llm-apikey string apiKey to access embedding service
--embedding-llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai")
-h, --help help for load
--namespace string namespace/collection of the document to load into (default "arcadia")
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")
--chunk-overlap int chunk overlap for embedding (default 30)
--chunk-size int chunk size for embedding (default 300)
--document string path of the document to load
--document-language string language of the document(Only text,html,csv supported now) (default "text")
-h, --help help for load
--llm-apikey string apiKey to access embedding service
--llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai")
--namespace string namespace/collection of the document to load into (default "arcadia")
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")
```

Required Arguments:
- `--embedding-llm-apikey`
- `--llm-apikey`
- `--document`

For example:
> This will load `./README.md` into vectorstore(chromadb http://localhost:8000) with help of embedding service `zhipuai` and its apikey `26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm`
```shell
arctl load --embedding-llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --document ./README.md
arctl load --llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --document ./README.md
```

### Chat with LLM
Expand All @@ -70,38 +110,37 @@ Usage:
arctl chat [usage] [flags]

Flags:
--chat-llm-apikey string apiKey to access embedding service
--chat-llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai")
--embedding-llm-apikey string apiKey to access embedding service.Must required when embedding similarity search is enabled
--embedding-llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai")
--enable-embedding-search enable embedding similarity search
-h, --help help for chat
--method string Invoke method used when access LLM service(invoke/sse-invoke) (default "sse-invoke")
--model string which model to use: chatglm_lite/chatglm_std/chatglm_pro (default "chatglm_lite")
--namespace string namespace/collection to query from (default "arcadia")
--num-docs int number of documents to be returned with SimilarSearch (default 3)
--question string question text to be asked
--score-threshold float score threshold for similarity search(Higher is better)
--temperature float32 temperature for chat (default 0.95)
--top-p float32 top-p for chat (default 0.7)
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")
--enable-embedding-search enable embedding similarity search(false by default)
-h, --help help for chat
--llm-apikey string apiKey to access embedding/llm service.Must required when embedding similarity search is enabled
--llm-type string llm type to use for embedding & chat(Only zhipuai,openai supported now) (default "zhipuai")
--method string Invoke method used when access LLM service(invoke/sse-invoke) (default "sse-invoke")
--model string which model to use: chatglm_lite/chatglm_std/chatglm_pro (default "chatglm_lite")
--namespace string namespace/collection to query from (default "arcadia")
--num-docs int number of documents to be returned with SimilarSearch (default 5)
--question string question text to be asked
--score-threshold float score threshold for similarity search(Higher is better)
--temperature float32 temperature for chat (default 0.95)
--top-p float32 top-p for chat (default 0.7)
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")
```

Now `arctl chat` has two modes which is controlled by flag `--enable-embedding-search`:
- normal chat without embedding search(default)
- enable similarity search with embedding

#### Normal chat(Without embedding)
#### Chat without embedding

> This will chat with LLM `zhipuai` with its apikey by using model `chatglm_pro` without embedding
```shell
arctl chat --chat-llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --question "介绍一下Arcadia"
arctl chat --llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --question "介绍一下Arcadia"
```

Required Arguments:
- `--chat-llm-apikey`
- `--llm-apikey`
- `--question`


**Output:**
```shell
Prompts: [{user 介绍一下Arcadia}]
Expand All @@ -116,16 +155,17 @@ Arcadia 开发的游戏中最知名的作品之一是《Second Life》(第二
总的来说,Arcadia 是一家在游戏行业具有较高声誉和影响力的公司,他们不断推出新的游戏作品,为玩家带来更多精彩的游戏体验。
```

#### Enable Similarity Search
#### Chat with embedding

> This will chat with LLM `zhipuai` with its apikey by using model `chatglm_pro` with embedding enabled
```shell
arctl chat --enable-embedding-search --chat-llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --embedding-llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia"
arctl chat --enable-embedding-search --llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia"
```

Required Arguments:
- `--enable-embedding-search`
- `--embedding-llm-apikey`
- `--chat-llm-apikey`
- `--llm-apikey`
- `--question`

**Output:**
Expand Down Expand Up @@ -161,7 +201,6 @@ Arcadia 项目包含了一些用于增强 Golang 中 AI 功能的软件包。其
3. Embedding Service
- ✅ zhipuai
- ✅ openai
4. LLM Service
Expand Down
26 changes: 12 additions & 14 deletions arctl/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ import (

var (
question string

// chat with LLM
chatLLMType string
chatAPIKey string
model string
method string
temperature float32
Expand Down Expand Up @@ -102,25 +101,24 @@ func NewChatCmd() *cobra.Command {
}

// For similarity search
cmd.Flags().BoolVar(&enableSimilaritySearch, "enable-embedding-search", false, "enable embedding similarity search")
cmd.Flags().StringVar(&embeddingLLMType, "embedding-llm-type", string(llms.ZhiPuAI), "llm type to use(Only zhipuai,openai supported now)")
cmd.Flags().StringVar(&embeddingLLMApiKey, "embedding-llm-apikey", "", "apiKey to access embedding service.Must required when embedding similarity search is enabled")
cmd.Flags().BoolVar(&enableSimilaritySearch, "enable-embedding-search", false, "enable embedding similarity search(false by default)")
cmd.Flags().StringVar(&vectorStore, "vector-store", "http://localhost:8000", "vector stores to use(Only chroma supported now)")
// Similarity search params
cmd.Flags().StringVar(&nameSpace, "namespace", "arcadia", "namespace/collection to query from")
cmd.Flags().Float64Var(&scoreThreshold, "score-threshold", 0, "score threshold for similarity search(Higher is better)")
cmd.Flags().IntVar(&numDocs, "num-docs", 5, "number of documents to be returned with SimilarSearch")

cmd.Flags().StringVar(&question, "question", "", "question text to be asked")

// For LLM chat
cmd.Flags().StringVar(&chatLLMType, "chat-llm-type", string(llms.ZhiPuAI), "llm type to use(Only zhipuai,openai supported now)")
cmd.Flags().StringVar(&chatAPIKey, "chat-llm-apikey", "", "apiKey to access embedding service")
cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use for embedding & chat(Only zhipuai,openai supported now)")
cmd.Flags().StringVar(&apiKey, "llm-apikey", "", "apiKey to access embedding/llm service.Must required when embedding similarity search is enabled")
cmd.Flags().StringVar(&question, "question", "", "question text to be asked")
// LLM Chat params
cmd.PersistentFlags().StringVar(&model, "model", string(zhipuai.ZhiPuAILite), "which model to use: chatglm_lite/chatglm_std/chatglm_pro")
cmd.PersistentFlags().StringVar(&method, "method", "sse-invoke", "Invoke method used when access LLM service(invoke/sse-invoke)")
cmd.PersistentFlags().Float32Var(&temperature, "temperature", 0.95, "temperature for chat")
cmd.PersistentFlags().Float32Var(&topP, "top-p", 0.7, "top-p for chat")

if err = cmd.MarkFlagRequired("chat-llm-apikey"); err != nil {
if err = cmd.MarkFlagRequired("llm-apikey"); err != nil {
panic(err)
}
if err = cmd.MarkFlagRequired("question"); err != nil {
Expand All @@ -134,14 +132,14 @@ func SimilaritySearch(ctx context.Context) ([]schema.Document, error) {
var embedder embeddings.Embedder
var err error

if embeddingLLMApiKey == "" {
if apiKey == "" {
return nil, errors.New("embedding-llm-apikey is required when embedding similarity search is enabled")
}

switch embeddingLLMType {
switch llmType {
case "zhipuai":
embedder, err = zhipuaiembeddings.NewZhiPuAI(
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(embeddingLLMApiKey)),
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -169,7 +167,7 @@ func SimilaritySearch(ctx context.Context) ([]schema.Document, error) {

func Chat(ctx context.Context, similarDocs []schema.Document) error {
// Only for zhipuai
client := zhipuai.NewZhiPuAI(chatAPIKey)
client := zhipuai.NewZhiPuAI(apiKey)

params := zhipuai.DefaultModelParams()
params.Model = zhipuai.Model(model)
Expand Down
14 changes: 7 additions & 7 deletions arctl/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (
)

var (
embeddingLLMType string
embeddingLLMApiKey string
llmType string
apiKey string

document string
language string
Expand All @@ -61,8 +61,8 @@ func NewLoadCmd() *cobra.Command {
},
}

cmd.Flags().StringVar(&embeddingLLMType, "embedding-llm-type", string(llms.ZhiPuAI), "llm type to use(Only zhipuai,openai supported now)")
cmd.Flags().StringVar(&embeddingLLMApiKey, "embedding-llm-apikey", "", "apiKey to access embedding service")
cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use(Only zhipuai,openai supported now)")
cmd.Flags().StringVar(&apiKey, "llm-apikey", "", "apiKey to access embedding service")

cmd.Flags().StringVar(&vectorStore, "vector-store", "http://localhost:8000", "vector stores to use(Only chroma supported now)")

Expand All @@ -73,7 +73,7 @@ func NewLoadCmd() *cobra.Command {
cmd.Flags().IntVar(&chunkSize, "chunk-size", 300, "chunk size for embedding")
cmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 30, "chunk overlap for embedding")

if err = cmd.MarkFlagRequired("embedding-llm-apikey"); err != nil {
if err = cmd.MarkFlagRequired("llm-apikey"); err != nil {
panic(err)
}
if err = cmd.MarkFlagRequired("document"); err != nil {
Expand Down Expand Up @@ -112,10 +112,10 @@ func EmbedDocuments(ctx context.Context, documents []schema.Document) error {
var embedder embeddings.Embedder
var err error

switch embeddingLLMType {
switch llmType {
case "zhipuai":
embedder, err = zhipuaiembeddings.NewZhiPuAI(
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(embeddingLLMApiKey)),
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
)
if err != nil {
return err
Expand Down
4 changes: 4 additions & 0 deletions pkg/llms/zhipuai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ func (z *ZhiPuAI) CreateEmbedding(ctx context.Context, inputTexts []string) ([][
if err != nil {
return nil, err
}
if postResponse.Code != 200 {
return nil, fmt.Errorf("embedding failed: %s", postResponse.String())
}

embeddings = append(embeddings, postResponse.Data.Embedding)
}

Expand Down
34 changes: 27 additions & 7 deletions pkg/llms/zhipuai/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,40 @@ import (
"github.com/kubeagi/arcadia/pkg/llms"
)

type Response struct {
Code int `json:"code"`
Data *Data `json:"data"`
Msg string `json:"msg"`
Success bool `json:"success"`
}

type EmbeddingResponse struct {
Code int `json:"code"`
Data *EmbeddingData `json:"data"`
Msg string `json:"msg"`
Success bool `json:"success"`
}

func (embeddingResp *EmbeddingResponse) Unmarshall(bytes []byte) error {
return json.Unmarshal(embeddingResp.Bytes(), embeddingResp)
}

func (embeddingResp *EmbeddingResponse) Type() llms.LLMType {
return llms.ZhiPuAI
}

func (embeddingResp *EmbeddingResponse) Bytes() []byte {
bytes, err := json.Marshal(embeddingResp)
if err != nil {
return []byte{}
}
return bytes
}

func (embeddingResp *EmbeddingResponse) String() string {
return string(embeddingResp.Bytes())
}

type Response struct {
Code int `json:"code"`
Data *Data `json:"data"`
Msg string `json:"msg"`
Success bool `json:"success"`
}

func (response *Response) Unmarshall(bytes []byte) error {
return json.Unmarshal(response.Bytes(), response)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/llms/zhipuai/sse_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func Stream(apiURL, token string, params ModelParams, timeout time.Duration, han
}
defer resp.Body.Close()

fmt.Printf("%v", resp)
// parse response body as stream events
eventChan, errorChan := NewSSEClient().Events(resp)

Expand Down

0 comments on commit 3c14126

Please sign in to comment.