Skip to content

Commit

Permalink
Merge pull request #840 from Abirdcfly/agent-history
Browse files Browse the repository at this point in the history
feat: agent can get history
  • Loading branch information
nkwangleiGIT authored Mar 13, 2024
2 parents 6323c69 + 3a5a1c0 commit 1beca6d
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 35 deletions.
3 changes: 2 additions & 1 deletion api/app-node/agent/v1alpha1/agent_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ type Options struct {
// +kubebuilder:validation:Maximum=10
// +kubebuilder:default=5
MaxIterations int `json:"maxIterations,omitempty"`

// Memory for chain memory
Memory node.Memory `json:"memory,omitempty"`
// The options below might be used later
// prompt prompts.PromptTemplate
// outputKey string
Expand Down
1 change: 1 addition & 0 deletions api/app-node/agent/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 1 addition & 11 deletions api/app-node/chain/v1alpha1/llmchain_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type CommonChainConfig struct {
MaxNumberOfConccurent int `json:"maxNumberOfConccurent,omitempty"`

// for memory
Memory Memory `json:"memory,omitempty"`
Memory node.Memory `json:"memory,omitempty"`

// Model is the model to use in an llm call.like `gpt-3.5-turbo` or `chatglm_turbo`
// Usually this value is just empty
Expand Down Expand Up @@ -68,16 +68,6 @@ type CommonChainConfig struct {
RepetitionPenalty float64 `json:"repetitionPenalty,omitempty"`
}

type Memory struct {
// MaxTokenLimit is the maximum number of tokens to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize.
MaxTokenLimit int `json:"maxTokenLimit,omitempty"`
// ConversionWindowSize is the maximum number of conversation rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize.
// +kubebuilder:validation:Minimum=0
// +kubebuilder:validation:Maximum=30
// +kubebuilder:default=5
ConversionWindowSize int `json:"conversionWindowSize,omitempty"`
}

// LLMChainStatus defines the observed state of LLMChain
type LLMChainStatus struct {
// ObservedGeneration is the last observed generation.
Expand Down
15 changes: 0 additions & 15 deletions api/app-node/chain/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions api/app-node/common_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ const (
OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules`
)

type Memory struct {
// MaxTokenLimit is the maximum number of tokens to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize.
MaxTokenLimit int `json:"maxTokenLimit,omitempty"`
// ConversionWindowSize is the maximum number of conversation rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize.
// +kubebuilder:validation:Minimum=0
// +kubebuilder:validation:Maximum=30
// +kubebuilder:default=5
ConversionWindowSize int `json:"conversionWindowSize,omitempty"`
}

type Ref struct {
Kind string `json:"kind,omitempty"`
Group string `json:"group,omitempty"`
Expand Down
5 changes: 3 additions & 2 deletions apiserver/pkg/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"

apinode "github.com/kubeagi/arcadia/api/app-node"
apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1"
apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1"
apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1"
Expand Down Expand Up @@ -440,7 +441,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat
Description: "qachain",
},
CommonChainConfig: apichain.CommonChainConfig{
Memory: apichain.Memory{
Memory: apinode.Memory{
ConversionWindowSize: pointer.IntDeref(input.ConversionWindowSize, 0),
},
Model: pointer.StringDeref(input.Model, ""),
Expand Down Expand Up @@ -473,7 +474,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat
Description: "qachain",
},
CommonChainConfig: apichain.CommonChainConfig{
Memory: apichain.Memory{
Memory: apinode.Memory{
ConversionWindowSize: pointer.IntDeref(input.ConversionWindowSize, 0),
},
Model: pointer.StringDeref(input.Model, ""),
Expand Down
16 changes: 16 additions & 0 deletions config/crd/bases/arcadia.kubeagi.k8s.com.cn_agents.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ spec:
maximum: 10
minimum: 1
type: integer
memory:
description: Memory for chain memory
properties:
conversionWindowSize:
default: 5
description: ConversionWindowSize is the maximum number of
conversation rounds in memory.Can only use MaxTokenLimit
or ConversionWindowSize.
maximum: 30
minimum: 0
type: integer
maxTokenLimit:
description: MaxTokenLimit is the maximum number of tokens
to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize.
type: integer
type: object
showToolAction:
default: false
description: Whether to show tool action in the streaming output
Expand Down
16 changes: 16 additions & 0 deletions deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_agents.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ spec:
maximum: 10
minimum: 1
type: integer
memory:
description: Memory for chain memory
properties:
conversionWindowSize:
default: 5
description: ConversionWindowSize is the maximum number of
conversation rounds in memory.Can only use MaxTokenLimit
or ConversionWindowSize.
maximum: 30
minimum: 0
type: integer
maxTokenLimit:
description: MaxTokenLimit is the maximum number of tokens
to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize.
type: integer
type: object
showToolAction:
default: false
description: Whether to show tool action in the streaming output
Expand Down
10 changes: 10 additions & 0 deletions pkg/appruntime/agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import (
"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
langchaingoschema "github.com/tmc/langchaingo/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1"
"github.com/kubeagi/arcadia/pkg/appruntime/base"
"github.com/kubeagi/arcadia/pkg/appruntime/chain"
"github.com/kubeagi/arcadia/pkg/appruntime/log"
"github.com/kubeagi/arcadia/pkg/appruntime/tools"
)
Expand Down Expand Up @@ -59,6 +61,13 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a
}
allowedTools := tools.InitTools(ctx, instance.Spec.AllowedTools)

var history langchaingoschema.ChatMessageHistory
if v3, ok := args[base.LangchaingoChatMessageHistoryKeyInArg]; ok && v3 != nil {
history, ok = v3.(langchaingoschema.ChatMessageHistory)
if !ok {
return args, errors.New("history not memory.ChatMessageHistory")
}
}
// Initialize executor using langchaingo
executorOptions := func(o *agents.CreationOptions) {
agents.WithCallbacksHandler(log.KLogHandler{LogLevel: 3})(o)
Expand All @@ -70,6 +79,7 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a
agents.WithCallbacksHandler(streamHandler)(o)
}
}
agents.WithMemory(chain.GetMemory(llm, instance.Spec.AgentConfig.Options.Memory, history, "", ""))(o)
}
executor, err := agents.Initialize(llm, allowedTools, agents.ZeroShotReactDescription, executorOptions)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/appruntime/chain/apichain.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any
options := GetChainOptions(instance.Spec.CommonChainConfig)

chain := chains.NewAPIChain(llm, http.DefaultClient)
chain.RequestChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
chain.AnswerChain.Memory = getMemory(llm, instance.Spec.Memory, history, "input", "")
chain.RequestChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "")
chain.AnswerChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "input", "")
l.APIChain = chain
apiDoc := instance.Spec.APIDoc
if apiDoc == "" {
Expand Down
3 changes: 2 additions & 1 deletion pkg/appruntime/chain/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
langchaingoschema "github.com/tmc/langchaingo/schema"
"k8s.io/klog/v2"

appnode "github.com/kubeagi/arcadia/api/app-node"
"github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1"
"github.com/kubeagi/arcadia/pkg/appruntime/base"
)
Expand Down Expand Up @@ -85,7 +86,7 @@ func GetChainOptions(config v1alpha1.CommonChainConfig) []chains.ChainCallOption
return options
}

func getMemory(llm llms.Model, config v1alpha1.Memory, history langchaingoschema.ChatMessageHistory, inputKey, outputKey string) langchaingoschema.Memory {
func GetMemory(llm llms.Model, config appnode.Memory, history langchaingoschema.ChatMessageHistory, inputKey, outputKey string) langchaingoschema.Memory {
if inputKey == "" {
inputKey = "question"
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/appruntime/chain/llmchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any
}
chain := chains.NewLLMChain(llm, prompt)
if history != nil {
chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
chain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "")
}
l.LLMChain = *chain

Expand Down
4 changes: 2 additions & 2 deletions pkg/appruntime/chain/retrievalqachain.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st

llmChain := chains.NewLLMChain(llm, prompt)
if history != nil {
llmChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
llmChain.Memory = GetMemory(llm, instance.Spec.Memory, history, "", "")
}
chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), chains.LoadCondenseQuestionGenerator(llm), retriever, getMemory(llm, instance.Spec.Memory, history, "", ""))
chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), chains.LoadCondenseQuestionGenerator(llm), retriever, GetMemory(llm, instance.Spec.Memory, history, "", ""))
l.ConversationalRetrievalQA = chain
args["query"] = args["question"]
var out string
Expand Down
5 changes: 5 additions & 0 deletions tests/example-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ info "8.6 tool test"
kubectl apply -f config/samples/app_llmchain_chat_with_bot_tool.yaml
waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot-tool"
sleep 3
info "8.6.1 conversation test"
info "23*34 结果应该是 782, 结果再乘2是 1564, 再减去564是 1000"
getRespInAppChat "base-chat-with-bot-tool" "arcadia" "计算 23*34 的结果" "" "false"
getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再乘2" ${resp_conversation_id} "false"
getRespInAppChat "base-chat-with-bot-tool" "arcadia" "结果再减去564" ${resp_conversation_id} "false"
# info "8.6.1 bingsearch test"
# getRespInAppChat "base-chat-with-bot-tool" "arcadia" "用30字介绍一下时速云" "" "true"
# if [ -z "$references" ] || [ "$references" = "null" ]; then
Expand Down

0 comments on commit 1beca6d

Please sign in to comment.