diff --git a/apiserver/graph/schema/model.gql b/apiserver/graph/schema/model.gql index 9c2a7c1e7..63497516f 100644 --- a/apiserver/graph/schema/model.gql +++ b/apiserver/graph/schema/model.gql @@ -21,6 +21,10 @@ query listModels($input: ListModelInput!,$filesInput: FileFilter){ message types updateTimestamp + huggingFaceRepo + modelScopeRepo + revision + modelSource files(input: $filesInput) { totalCount hasNextPage @@ -59,6 +63,10 @@ query getModel($name: String!, $namespace: String!,$filesInput: FileFilter) { message types updateTimestamp + huggingFaceRepo + modelScopeRepo + revision + modelSource files(input: $filesInput) { totalCount hasNextPage @@ -95,6 +103,10 @@ mutation createModel($input: CreateModelInput!) { message types updateTimestamp + huggingFaceRepo + modelScopeRepo + revision + modelSource } } } @@ -117,6 +129,10 @@ mutation updateModel($input: UpdateModelInput) { message types updateTimestamp + huggingFaceRepo + modelScopeRepo + revision + modelSource } } } diff --git a/apiserver/pkg/common/common.go b/apiserver/pkg/common/common.go index bd446b5f2..eee760c36 100644 --- a/apiserver/pkg/common/common.go +++ b/apiserver/pkg/common/common.go @@ -52,6 +52,13 @@ var ( ModelTypeEmbedding = "embedding" ) +// ModelSource +var ( + ModelSourceLocal = "local" + ModelSourceModelscope = "modelscope" + ModelSourceHuggingface = "huggingface" +) + func SystemDatasourceOSS(ctx context.Context, mgrClient client.Client) (*datasource.OSS, error) { systemDatasource, err := config.GetSystemDatasource(ctx, mgrClient) if err != nil { diff --git a/apiserver/pkg/model/model.go b/apiserver/pkg/model/model.go index edba6fe34..71ca7b6ee 100644 --- a/apiserver/pkg/model/model.go +++ b/apiserver/pkg/model/model.go @@ -77,6 +77,10 @@ func obj2model(model *v1alpha1.Model) (*generated.Model, error) { UpdateTimestamp: &updateTime, Status: &status, Message: &message, + HuggingFaceRepo: &model.Spec.HuggingFaceRepo, + ModelScopeRepo: &model.Spec.ModelScopeRepo, + Revision: &model.Spec.Revision, + ModelSource: &model.Spec.ModelSource, } return &md, nil } @@ -91,6 +95,18 @@ func CreateModel(ctx context.Context, c client.Client, input generated.CreateMod Types: input.Types, }, } + if *input.ModelSource == common.ModelSourceModelscope { + if *input.Revision == "" { + return nil, errors.New("argument revision is required") + } + model.Spec.ModelScopeRepo = *input.ModelScopeRepo + model.Spec.Revision = *input.Revision + } + if *input.ModelSource == common.ModelSourceHuggingface { + model.Spec.HuggingFaceRepo = *input.HuggingFaceRepo + model.Spec.Revision = *input.Revision + } + model.Spec.ModelSource = *input.ModelSource model.Spec.DisplayName = pointer.StringDeref(input.DisplayName, model.Spec.DisplayName) model.Spec.Description = pointer.StringDeref(input.Description, model.Spec.Description) common.SetCreator(ctx, &model.Spec.CommonSpec) @@ -114,6 +130,17 @@ func UpdateModel(ctx context.Context, c client.Client, input *generated.UpdateMo model.Spec.Description = pointer.StringDeref(input.Description, model.Spec.Description) model.Spec.Types = pointer.StringDeref(input.Types, model.Spec.Types) + if model.Spec.ModelSource == common.ModelSourceModelscope { + if *input.Revision == "" { + return nil, errors.New("argument revision is required") + } + model.Spec.ModelScopeRepo = *input.ModelScopeRepo + model.Spec.Revision = *input.Revision + } + if model.Spec.ModelSource == common.ModelSourceHuggingface { + model.Spec.HuggingFaceRepo = *input.HuggingFaceRepo + model.Spec.Revision = *input.Revision + } err = c.Update(ctx, model) if err != nil { return nil, err