diff --git a/api/app-node/agent/v1alpha1/agent_types.go b/api/app-node/agent/v1alpha1/agent_types.go index 252c61f62..720081d4f 100644 --- a/api/app-node/agent/v1alpha1/agent_types.go +++ b/api/app-node/agent/v1alpha1/agent_types.go @@ -51,7 +51,8 @@ type Options struct { // +kubebuilder:validation:Maximum=10 // +kubebuilder:default=5 MaxIterations int `json:"maxIterations,omitempty"` - + // Memory for chain memory + Memory node.Memory `json:"memory,omitempty"` // The options below might be used later // prompt prompts.PromptTemplate // outputKey string diff --git a/api/app-node/agent/v1alpha1/zz_generated.deepcopy.go b/api/app-node/agent/v1alpha1/zz_generated.deepcopy.go index 58a9eabc9..892a830d9 100644 --- a/api/app-node/agent/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/agent/v1alpha1/zz_generated.deepcopy.go @@ -143,6 +143,7 @@ func (in *AgentStatus) DeepCopy() *AgentStatus { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Options) DeepCopyInto(out *Options) { *out = *in + out.Memory = in.Memory } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Options. diff --git a/api/app-node/chain/v1alpha1/llmchain_types.go b/api/app-node/chain/v1alpha1/llmchain_types.go index 11b4b6c1b..4d5d9de89 100644 --- a/api/app-node/chain/v1alpha1/llmchain_types.go +++ b/api/app-node/chain/v1alpha1/llmchain_types.go @@ -37,7 +37,7 @@ type CommonChainConfig struct { MaxNumberOfConccurent int `json:"maxNumberOfConccurent,omitempty"` // for memory - Memory Memory `json:"memory,omitempty"` + Memory node.Memory `json:"memory,omitempty"` // Model is the model to use in an llm call.like `gpt-3.5-turbo` or `chatglm_turbo` // Usually this value is just empty @@ -68,16 +68,6 @@ type CommonChainConfig struct { RepetitionPenalty float64 `json:"repetitionPenalty,omitempty"` } -type Memory struct { - // MaxTokenLimit is the maximum number of tokens to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. - MaxTokenLimit int `json:"maxTokenLimit,omitempty"` - // ConversionWindowSize is the maximum number of conversation rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. - // +kubebuilder:validation:Minimum=0 - // +kubebuilder:validation:Maximum=30 - // +kubebuilder:default=5 - ConversionWindowSize int `json:"conversionWindowSize,omitempty"` -} - // LLMChainStatus defines the observed state of LLMChain type LLMChainStatus struct { // ObservedGeneration is the last observed generation. diff --git a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go index efd6edd5f..7966645e3 100644 --- a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go @@ -230,21 +230,6 @@ func (in *LLMChainStatus) DeepCopy() *LLMChainStatus { return out } -// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. -func (in *Memory) DeepCopyInto(out *Memory) { - *out = *in -} - -// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Memory. -func (in *Memory) DeepCopy() *Memory { - if in == nil { - return nil - } - out := new(Memory) - in.DeepCopyInto(out) - return out -} - // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RetrievalQAChain) DeepCopyInto(out *RetrievalQAChain) { *out = *in diff --git a/api/app-node/common_type.go b/api/app-node/common_type.go index 18493f19b..f88486d29 100644 --- a/api/app-node/common_type.go +++ b/api/app-node/common_type.go @@ -27,6 +27,16 @@ const ( OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules` ) +type Memory struct { + // MaxTokenLimit is the maximum number of tokens to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. + MaxTokenLimit int `json:"maxTokenLimit,omitempty"` + // ConversionWindowSize is the maximum number of conversation rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=30 + // +kubebuilder:default=5 + ConversionWindowSize int `json:"conversionWindowSize,omitempty"` +} + type Ref struct { Kind string `json:"kind,omitempty"` Group string `json:"group,omitempty"` diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index 63bd7d49c..dfbb494d3 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + apinode "github.com/kubeagi/arcadia/api/app-node" apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1" @@ -440,7 +441,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat Description: "qachain", }, CommonChainConfig: apichain.CommonChainConfig{ - Memory: apichain.Memory{ + Memory: apinode.Memory{ ConversionWindowSize: pointer.IntDeref(input.ConversionWindowSize, 0), }, Model: pointer.StringDeref(input.Model, ""), @@ -473,7 +474,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat Description: "qachain", }, CommonChainConfig: apichain.CommonChainConfig{ - Memory: apichain.Memory{ + Memory: apinode.Memory{ ConversionWindowSize: pointer.IntDeref(input.ConversionWindowSize, 0), }, Model: pointer.StringDeref(input.Model, ""), diff --git a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_agents.yaml b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_agents.yaml index 269ef5dd1..766d0cce0 100644 --- a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_agents.yaml +++ b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_agents.yaml @@ -67,6 +67,22 @@ spec: maximum: 10 minimum: 1 type: integer + memory: + description: Memory for chain memory + properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of + conversation rounds in memory.Can only use MaxTokenLimit + or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer + maxTokenLimit: + description: MaxTokenLimit is the maximum number of tokens + to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. + type: integer + type: object showToolAction: default: false description: Whether to show tool action in the streaming output diff --git a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_agents.yaml b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_agents.yaml index 269ef5dd1..766d0cce0 100644 --- a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_agents.yaml +++ b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_agents.yaml @@ -67,6 +67,22 @@ spec: maximum: 10 minimum: 1 type: integer + memory: + description: Memory for chain memory + properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of + conversation rounds in memory.Can only use MaxTokenLimit + or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer + maxTokenLimit: + description: MaxTokenLimit is the maximum number of tokens + to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. + type: integer + type: object showToolAction: default: false description: Whether to show tool action in the streaming output diff --git a/pkg/appruntime/agent/executor.go b/pkg/appruntime/agent/executor.go index 0d8b9fa14..ac6ad6b9e 100644 --- a/pkg/appruntime/agent/executor.go +++ b/pkg/appruntime/agent/executor.go @@ -24,12 +24,14 @@ import ( "github.com/tmc/langchaingo/agents" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" + langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/chain" "github.com/kubeagi/arcadia/pkg/appruntime/log" "github.com/kubeagi/arcadia/pkg/appruntime/tools" ) @@ -59,6 +61,13 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a } allowedTools := tools.InitTools(ctx, instance.Spec.AllowedTools) + var history langchaingoschema.ChatMessageHistory + if v3, ok := args[base.LangchaingoChatMessageHistoryKeyInArg]; ok && v3 != nil { + history, ok = v3.(langchaingoschema.ChatMessageHistory) + if !ok { + return args, errors.New("history not memory.ChatMessageHistory") + } + } // Initialize executor using langchaingo executorOptions := func(o *agents.CreationOptions) { agents.WithCallbacksHandler(log.KLogHandler{LogLevel: 3})(o) @@ -70,6 +79,7 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a agents.WithCallbacksHandler(streamHandler)(o) } } + agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "", ""))(o) } executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions) if err != nil { diff --git a/pkg/appruntime/chain/apichain.go b/pkg/appruntime/chain/apichain.go index cb961bce9..abc7bf96b 100644 --- a/pkg/appruntime/chain/apichain.go +++ b/pkg/appruntime/chain/apichain.go @@ -88,8 +88,8 @@ func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any options := GetChainOptions(instance.Spec.CommonChainConfig) chain := chains.NewAPIChain(llm, http.DefaultClient) - chain.RequestChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") - chain.AnswerChain.Memory = getMemory(llm, instance.Spec.Memory, history, "input", "") + chain.RequestChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "") + chain.AnswerChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "input", "") l.APIChain = chain apiDoc := instance.Spec.APIDoc if apiDoc == "" { diff --git a/pkg/appruntime/chain/common.go b/pkg/appruntime/chain/common.go index 4b9d83f41..d3e5bec9d 100644 --- a/pkg/appruntime/chain/common.go +++ b/pkg/appruntime/chain/common.go @@ -27,6 +27,7 @@ import ( langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/klog/v2" + appnode "github.com/kubeagi/arcadia/api/app-node" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" ) @@ -85,7 +86,7 @@ func GetChainOptions(config v1alpha1.CommonChainConfig) []chains.ChainCallOption return options } -func getMemory(llm llms.Model, config v1alpha1.Memory, history langchaingoschema.ChatMessageHistory, inputKey, outputKey string) langchaingoschema.Memory { +func GetMemory(llm llms.Model, config appnode.Memory, history langchaingoschema.ChatMessageHistory, inputKey, outputKey string) langchaingoschema.Memory { if inputKey == "" { inputKey = "question" } diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index 515d4c064..8b410384d 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -112,7 +112,7 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any } chain := chains.NewLLMChain(llm, prompt) if history != nil { - chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") + chain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "") } l.LLMChain = *chain diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index 9bb5c8cc5..a59fabc32 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -113,9 +113,9 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st llmChain := chains.NewLLMChain(llm, prompt) if history != nil { - llmChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") + llmChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "") } - chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), chains.LoadCondenseQuestionGenerator(llm), retriever, getMemory(llm, instance.Spec.Memory, history, "", "")) + chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), chains.LoadCondenseQuestionGenerator(llm), retriever, GetMemory(llm, instance.Spec.Memory, history, "", "")) l.ConversationalRetrievalQA = chain args["query"] = args["question"] var out string diff --git a/tests/example-test.sh b/tests/example-test.sh index 3ce309dff..1f493fd74 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -587,6 +587,11 @@ info "8.6 tool test" kubectl apply -f config/samples/app_llmchain_chat_with_bot_tool.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot-tool" sleep 3 +info "8.6.1 conversation test" +info "23*34 结果应该是 782, 结果再乘2是 1564, 再减去564是 1000" +getRespInAppChat "base-chat-with-bot-tool" "arcadia" "计算 23*34 的结果" "" "false" +getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再乘2" ${resp_conversation_id} "false" +getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再减去564" ${resp_conversation_id} "false" # info "8.6.1 bingsearch test" # getRespInAppChat "base-chat-with-bot-tool" "arcadia" "用30字介绍一下时速云" "" "true" # if [ -z "$references" ] || [ "$references" = "null" ]; then