Skip to content

Commit

Permalink
Merge pull request #102 from Abirdcfly/main
Browse files Browse the repository at this point in the history
feat: add dashscope as llm
  • Loading branch information
bjwswang authored Oct 9, 2023
2 parents f9811e7 + 2a8dbcd commit 2484b8e
Show file tree
Hide file tree
Showing 14 changed files with 683 additions and 20 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ kubectl apply -f config/samples/arcadia_v1alpha1_llm.yaml
kubectl apply -f config/samples/arcadia_v1alpha1_prompt.yaml
```

After prompt got created, you can see the prompt in the following command:
After the prompt got created, you can see the prompt in the following command:

```shell
kubectl get prompt prompt-zhipuai-sample -oyaml
```

If no error found,you can use this command to get the prompt response data.
If no error is found, you can use this command to get the prompt response data.

```shell
kubectl get prompt prompt-zhipuai-sample --output="jsonpath={.status.data}" | base64 --decode
Expand All @@ -56,18 +56,18 @@ go install github.com/kubeagi/arcadia/arctl@latest

## Packages

To enhace the AI capability in Golang,we developed some packages.
To enhace the AI capability in Golang, we developed some packages.

### LLMs

-[ZhiPuAI(智谱AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)
-[ZhiPuAI(智谱 AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)
- [example](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go)

### Embeddings

> Fully compatible with [langchain embeddings](https://github.com/tmc/langchaingo/tree/main/embeddings)
-[ZhiPuAI(智谱AI) Embedding](https://github.com/kubeagi/arcadia/tree/main/pkg/embeddings/zhipuai)
-[ZhiPuAI(智谱 AI) Embedding](https://github.com/kubeagi/arcadia/tree/main/pkg/embeddings/zhipuai)

### VectorStores

Expand All @@ -77,10 +77,10 @@ To enhace the AI capability in Golang,we developed some packages.

## Examples

- [chat_with_document](https://github.com/kubeagi/arcadia/tree/main/examples/chat_with_document): a chat server which allows you chat with your document
- [chat_with_document](https://github.com/kubeagi/arcadia/tree/main/examples/chat_with_document): a chat server which allows you to chat with your document
- [embedding](https://github.com/kubeagi/arcadia/tree/main/examples/embedding) shows how to embedes your document to vector store with embedding service
- [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows to to inquiry the security risks in your RBAC with AI.
- [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) show how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)
- [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows how to inquiry the security risks in your RBAC with AI.
- [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) shows how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)

## Contribute to Arcadia

Expand Down
2 changes: 1 addition & 1 deletion controllers/prompt_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom
switch llm.Spec.Type {
case llms.ZhiPuAI:
llmClient = llmszhipuai.NewZhiPuAI(apiKey)
callData = prompt.Spec.ZhiPuAIParams.Marshall()
callData = prompt.Spec.ZhiPuAIParams.Marshal()
case llms.OpenAI:
llmClient = openai.NewOpenAI(apiKey)
default:
Expand Down
79 changes: 79 additions & 0 deletions examples/dashscope/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"context"
"os"

"github.com/kubeagi/arcadia/pkg/llms"
"github.com/kubeagi/arcadia/pkg/llms/dashscope"
"k8s.io/klog/v2"
)

const (
samplePrompt = "how to change a deployment's image?"
)

func main() {
if len(os.Args) == 1 {
panic("api key is empty")
}
apiKey := os.Args[1]
klog.Infof("sample chat start...\nwe use same prompt: %s to test\n", samplePrompt)
for _, model := range []dashscope.Model{dashscope.QWEN14BChat, dashscope.QWEN7BChat} {
klog.V(0).Infof("\nChat with %s\n", model)
resp, err := sampleChat(apiKey, model)
if err != nil {
panic(err)
}
klog.V(0).Infof("Response: \n %s\n", resp)
klog.V(0).Infoln("\nChat again with sse enable")
err = sampleSSEChat(apiKey, model)
if err != nil {
panic(err)
}
}
klog.Infoln("sample chat done")
}

func sampleChat(apiKey string, model dashscope.Model) (llms.Response, error) {
client := dashscope.NewDashScope(apiKey, false)
params := dashscope.DefaultModelParams()
params.Model = model
params.Input.Messages = []dashscope.Message{
{Role: dashscope.System, Content: "You are a kubernetes expert."},
{Role: dashscope.User, Content: samplePrompt},
}
return client.Call(params.Marshal())
}

func sampleSSEChat(apiKey string, model dashscope.Model) error {
client := dashscope.NewDashScope(apiKey, true)
params := dashscope.DefaultModelParams()
params.Model = model
params.Input.Messages = []dashscope.Message{
{Role: dashscope.System, Content: "You are a kubernetes expert."},
{Role: dashscope.User, Content: samplePrompt},
}
// you can define a customized `handler` on `Event`
err := client.StreamCall(context.TODO(), params.Marshal(), nil)
if err != nil {
return err
}
return nil
}
66 changes: 66 additions & 0 deletions pkg/llms/dashscope/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package dashscope

import (
"context"
"errors"

"github.com/kubeagi/arcadia/pkg/llms"
)

const (
DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
)

type Model string

const (
QWEN14BChat Model = "qwen-14b-chat"
QWEN7BChat Model = "qwen-7b-chat"
)

var _ llms.LLM = (*DashScope)(nil)

type DashScope struct {
apiKey string
sse bool
}

func NewDashScope(apiKey string, sse bool) *DashScope {
return &DashScope{
apiKey: apiKey,
sse: sse,
}
}

func (z DashScope) Type() llms.LLMType {
return llms.DashScope
}

// Call wraps a common AI api call
func (z *DashScope) Call(data []byte) (llms.Response, error) {
params := ModelParams{}
if err := params.Unmarshal(data); err != nil {
return nil, err
}
return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse)
}

func (z *DashScope) Validate() (llms.Response, error) {
return nil, errors.New("not implemented")
}
63 changes: 63 additions & 0 deletions pkg/llms/dashscope/http_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package dashscope

import (
"bytes"
"context"
"encoding/json"
"net/http"
)

func setHeaders(req *http.Request, token string, sse bool) {
if sse {
// req.Header.Set("Content-Type", "text/event-stream") // Although the documentation says we should do this, but will return a 400 error and the python sdk doesn't do this.
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("X-DashScope-SSE", "enable")
} else {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
}
req.Header.Set("Authorization", "Bearer "+token)
}

func parseHTTPResponse(resp *http.Response) (data *Response, err error) {
if err = json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return data, nil
}

func req(ctx context.Context, apiURL, token string, data []byte, sse bool) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data))
if err != nil {
return nil, err
}

setHeaders(req, token, sse)

return http.DefaultClient.Do(req)
}
func do(ctx context.Context, apiURL, token string, data []byte, sse bool) (*Response, error) {
resp, err := req(ctx, apiURL, token, data, sse)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return parseHTTPResponse(resp)
}
101 changes: 101 additions & 0 deletions pkg/llms/dashscope/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package dashscope

import (
"encoding/json"
"errors"

"github.com/kubeagi/arcadia/pkg/llms"
)

type Role string

const (
System Role = "system"
User Role = "user"
Assistant Role = "assistant"
)

var _ llms.ModelParams = (*ModelParams)(nil)

// +kubebuilder:object:generate=true

// ModelParams
// ref: https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-api-detailes#25745d61fbx49
// do not use 'input.history', according to the above document, this parameter will be deprecated soon.
// use 'message' in 'parameters.result_format' to keep better compatibility.
type ModelParams struct {
Model Model `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters"`
}

// +kubebuilder:object:generate=true

type Input struct {
Messages []Message `json:"messages"`
}

type Parameters struct {
TopP float32 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed int `json:"seed,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
}

// +kubebuilder:object:generate=true

type Message struct {
Role Role `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}

func DefaultModelParams() ModelParams {
return ModelParams{
Model: QWEN14BChat,
Input: Input{
Messages: []Message{},
},
Parameters: Parameters{
TopP: 0.5,
TopK: 0,
Seed: 1234,
ResultFormat: "message",
},
}
}

func (params *ModelParams) Marshal() []byte {
data, err := json.Marshal(params)
if err != nil {
return []byte{}
}
return data
}

func (params *ModelParams) Unmarshal(bytes []byte) error {
return json.Unmarshal(bytes, params)
}

func ValidateModelParams(params ModelParams) error {
if params.Parameters.TopP < 0 || params.Parameters.TopP > 1 {
return errors.New("top_p must be in (0, 1)")
}

return nil
}
Loading

0 comments on commit 2484b8e

Please sign in to comment.