From 42231616473c35210a78ec395503aa12dfb18aaa Mon Sep 17 00:00:00 2001 From: dayuy <973860441@qq.com> Date: Fri, 23 Feb 2024 09:36:31 +0800 Subject: [PATCH] feat: Add api to support copying existing rag --- apiserver/graph/generated/generated.go | 222 +++++++++++++++++++++++- apiserver/graph/generated/models_gen.go | 13 +- apiserver/graph/impl/rag.resolvers.go | 9 + apiserver/graph/schema/rag.graphqls | 7 + apiserver/pkg/rag/rag.go | 26 +++ gqlgen.yaml | 2 + 6 files changed, 273 insertions(+), 6 deletions(-) diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 861ab3dfc..06661bb0a 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -610,9 +610,10 @@ type ComplexityRoot struct { } RAGMutation struct { - CreateRag func(childComplexity int, input CreateRAGInput) int - DeleteRag func(childComplexity int, input DeleteRAGInput) int - UpdateRag func(childComplexity int, input UpdateRAGInput) int + CreateRag func(childComplexity int, input CreateRAGInput) int + DeleteRag func(childComplexity int, input DeleteRAGInput) int + DuplicateRag func(childComplexity int, input DuplicateRAGInput) int + UpdateRag func(childComplexity int, input UpdateRAGInput) int } RAGQuery struct { @@ -894,6 +895,7 @@ type RAGMutationResolver interface { CreateRag(ctx context.Context, obj *RAGMutation, input CreateRAGInput) (*Rag, error) UpdateRag(ctx context.Context, obj *RAGMutation, input UpdateRAGInput) (*Rag, error) DeleteRag(ctx context.Context, obj *RAGMutation, input DeleteRAGInput) (*string, error) + DuplicateRag(ctx context.Context, obj *RAGMutation, input DuplicateRAGInput) (*Rag, error) } type RAGQueryResolver interface { GetRag(ctx context.Context, obj *RAGQuery, name string, namespace string) (*Rag, error) @@ -3783,6 +3785,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.RAGMutation.DeleteRag(childComplexity, args["input"].(DeleteRAGInput)), true + case "RAGMutation.duplicateRAG": + if e.complexity.RAGMutation.DuplicateRag == nil { + break + } + + args, err := ec.field_RAGMutation_duplicateRAG_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.RAGMutation.DuplicateRag(childComplexity, args["input"].(DuplicateRAGInput)), true + case "RAGMutation.updateRAG": if e.complexity.RAGMutation.UpdateRag == nil { break @@ -4498,6 +4512,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputDeleteDataProcessInput, ec.unmarshalInputDeleteRAGInput, ec.unmarshalInputDeleteVersionedDatasetInput, + ec.unmarshalInputDuplicateRAGInput, ec.unmarshalInputEndpointInput, ec.unmarshalInputFileFilter, ec.unmarshalInputFileGroup, @@ -6921,6 +6936,12 @@ input UpdateRAGInput { suspend: Boolean } +input DuplicateRAGInput { + name: String! + namespace: String! + displayName: String +} + input DeleteRAGInput { name: String! namespace: String! @@ -6954,6 +6975,7 @@ type RAGMutation { createRAG(input: CreateRAGInput!): RAG! updateRAG(input: UpdateRAGInput!): RAG! deleteRAG(input: DeleteRAGInput!): Void + duplicateRAG(input: DuplicateRAGInput!): RAG! } type RAGQuery { @@ -8451,6 +8473,21 @@ func (ec *executionContext) field_RAGMutation_deleteRAG_args(ctx context.Context return args, nil } +func (ec *executionContext) field_RAGMutation_duplicateRAG_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 DuplicateRAGInput + if tmp, ok := rawArgs["input"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) + arg0, err = ec.unmarshalNDuplicateRAGInput2githubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐDuplicateRAGInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg0 + return args, nil +} + func (ec *executionContext) field_RAGMutation_updateRAG_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -23295,6 +23332,8 @@ func (ec *executionContext) fieldContext_Mutation_RAG(ctx context.Context, field return ec.fieldContext_RAGMutation_updateRAG(ctx, field) case "deleteRAG": return ec.fieldContext_RAGMutation_deleteRAG(ctx, field) + case "duplicateRAG": + return ec.fieldContext_RAGMutation_duplicateRAG(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type RAGMutation", field.Name) }, @@ -26669,6 +26708,101 @@ func (ec *executionContext) fieldContext_RAGMutation_deleteRAG(ctx context.Conte return fc, nil } +func (ec *executionContext) _RAGMutation_duplicateRAG(ctx context.Context, field graphql.CollectedField, obj *RAGMutation) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_RAGMutation_duplicateRAG(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.RAGMutation().DuplicateRag(rctx, obj, fc.Args["input"].(DuplicateRAGInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*Rag) + fc.Result = res + return ec.marshalNRAG2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐRag(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_RAGMutation_duplicateRAG(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "RAGMutation", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "name": + return ec.fieldContext_RAG_name(ctx, field) + case "namespace": + return ec.fieldContext_RAG_namespace(ctx, field) + case "labels": + return ec.fieldContext_RAG_labels(ctx, field) + case "annotations": + return ec.fieldContext_RAG_annotations(ctx, field) + case "creator": + return ec.fieldContext_RAG_creator(ctx, field) + case "displayName": + return ec.fieldContext_RAG_displayName(ctx, field) + case "description": + return ec.fieldContext_RAG_description(ctx, field) + case "creationTimestamp": + return ec.fieldContext_RAG_creationTimestamp(ctx, field) + case "completeTimestamp": + return ec.fieldContext_RAG_completeTimestamp(ctx, field) + case "application": + return ec.fieldContext_RAG_application(ctx, field) + case "datasets": + return ec.fieldContext_RAG_datasets(ctx, field) + case "judgeLLM": + return ec.fieldContext_RAG_judgeLLM(ctx, field) + case "metrics": + return ec.fieldContext_RAG_metrics(ctx, field) + case "storage": + return ec.fieldContext_RAG_storage(ctx, field) + case "serviceAccountName": + return ec.fieldContext_RAG_serviceAccountName(ctx, field) + case "suspend": + return ec.fieldContext_RAG_suspend(ctx, field) + case "status": + return ec.fieldContext_RAG_status(ctx, field) + case "phase": + return ec.fieldContext_RAG_phase(ctx, field) + case "phaseMessage": + return ec.fieldContext_RAG_phaseMessage(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type RAG", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_RAGMutation_duplicateRAG_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _RAGQuery_getRAG(ctx context.Context, field graphql.CollectedField, obj *RAGQuery) (ret graphql.Marshaler) { fc, err := ec.fieldContext_RAGQuery_getRAG(ctx, field) if err != nil { @@ -34120,6 +34254,47 @@ func (ec *executionContext) unmarshalInputDeleteVersionedDatasetInput(ctx contex return it, nil } +func (ec *executionContext) unmarshalInputDuplicateRAGInput(ctx context.Context, obj interface{}) (DuplicateRAGInput, error) { + var it DuplicateRAGInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"name", "namespace", "displayName"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "name": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("name")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Name = data + case "namespace": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("namespace")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Namespace = data + case "displayName": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("displayName")) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + it.DisplayName = data + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputEndpointInput(ctx context.Context, obj interface{}) (EndpointInput, error) { var it EndpointInput asMap := map[string]interface{}{} @@ -41984,6 +42159,42 @@ func (ec *executionContext) _RAGMutation(ctx context.Context, sel ast.SelectionS continue } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "duplicateRAG": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._RAGMutation_duplicateRAG(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) @@ -43887,6 +44098,11 @@ func (ec *executionContext) unmarshalNDeleteVersionedDatasetInput2githubᚗcom return res, graphql.ErrorOnPath(ctx, err) } +func (ec *executionContext) unmarshalNDuplicateRAGInput2githubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐDuplicateRAGInput(ctx context.Context, v interface{}) (DuplicateRAGInput, error) { + res, err := ec.unmarshalInputDuplicateRAGInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalNEmbedder2githubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐEmbedder(ctx context.Context, sel ast.SelectionSet, v Embedder) graphql.Marshaler { return ec._Embedder(ctx, sel, &v) } diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 90a9d56bc..2456a6d37 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -684,6 +684,12 @@ type DeleteVersionedDatasetInput struct { FieldSelector *string `json:"fieldSelector,omitempty"` } +type DuplicateRAGInput struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + DisplayName *string `json:"displayName,omitempty"` +} + type Embedder struct { ID *string `json:"id,omitempty"` Name string `json:"name"` @@ -1367,9 +1373,10 @@ type RAGMetricInput struct { } type RAGMutation struct { - CreateRag Rag `json:"createRAG"` - UpdateRag Rag `json:"updateRAG"` - DeleteRag *string `json:"deleteRAG,omitempty"` + CreateRag Rag `json:"createRAG"` + UpdateRag Rag `json:"updateRAG"` + DeleteRag *string `json:"deleteRAG,omitempty"` + DuplicateRag Rag `json:"duplicateRAG"` } type RAGQuery struct { diff --git a/apiserver/graph/impl/rag.resolvers.go b/apiserver/graph/impl/rag.resolvers.go index eb9629bb2..34d900a43 100644 --- a/apiserver/graph/impl/rag.resolvers.go +++ b/apiserver/graph/impl/rag.resolvers.go @@ -102,6 +102,15 @@ func (r *rAGMutationResolver) DeleteRag(ctx context.Context, obj *generated.RAGM return nil, rag.DeleteRAG(ctx, c, &input) } +// DuplicateRag is the resolver for the duplicateRAG field. +func (r *rAGMutationResolver) DuplicateRag(ctx context.Context, obj *generated.RAGMutation, input generated.DuplicateRAGInput) (*generated.Rag, error) { + c, err := getClientFromCtx(ctx) + if err != nil { + return nil, err + } + return rag.DuplicateRAG(ctx, c, &input) +} + // GetRag is the resolver for the getRAG field. func (r *rAGQueryResolver) GetRag(ctx context.Context, obj *generated.RAGQuery, name string, namespace string) (*generated.Rag, error) { c, err := getClientFromCtx(ctx) diff --git a/apiserver/graph/schema/rag.graphqls b/apiserver/graph/schema/rag.graphqls index db13906de..de81e1bcf 100644 --- a/apiserver/graph/schema/rag.graphqls +++ b/apiserver/graph/schema/rag.graphqls @@ -152,6 +152,12 @@ input UpdateRAGInput { suspend: Boolean } +input DuplicateRAGInput { + name: String! + namespace: String! + displayName: String +} + input DeleteRAGInput { name: String! namespace: String! @@ -185,6 +191,7 @@ type RAGMutation { createRAG(input: CreateRAGInput!): RAG! updateRAG(input: UpdateRAGInput!): RAG! deleteRAG(input: DeleteRAGInput!): Void + duplicateRAG(input: DuplicateRAGInput!): RAG! } type RAGQuery { diff --git a/apiserver/pkg/rag/rag.go b/apiserver/pkg/rag/rag.go index 7b8f6c4de..c70fcb011 100644 --- a/apiserver/pkg/rag/rag.go +++ b/apiserver/pkg/rag/rag.go @@ -628,3 +628,29 @@ func DeleteRAG(ctx context.Context, kubeClient client.Client, input *generated.D } return kubeClient.DeleteAllOf(ctx, &evav1alpha1.RAG{}, opts...) } + +func DuplicateRAG(ctx context.Context, kubeClient client.Client, input *generated.DuplicateRAGInput) (*generated.Rag, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + rag := &evav1alpha1.RAG{} + err := kubeClient.Get(ctx, types.NamespacedName{Namespace: input.Namespace, Name: input.Name}, rag) + if err != nil { + return nil, err + } + rag.Name = generateKubernetesResourceName("rag", 10) + if input.DisplayName != nil { + rag.Spec.DisplayName = *input.DisplayName + } + labels := rag.Labels + if labels == nil { + labels = make(map[string]string) + } + labels["duplication"] = fmt.Sprintf("%s_%s", input.Namespace, input.Name) + rag.SetLabels(labels) + rag.ResourceVersion = "" + rag.Spec.Creator = currentUser + err = kubeClient.Create(ctx, rag) + if err != nil { + return nil, err + } + return rag2model(rag) +} diff --git a/gqlgen.yaml b/gqlgen.yaml index 5341e54d8..5a9c0a98d 100644 --- a/gqlgen.yaml +++ b/gqlgen.yaml @@ -284,6 +284,8 @@ models: resolver: true deleteRAG: resolver: true + duplicateRAG: + resolver: true RAGQuery: fields: getRAG: