Skip to content

Commit

Permalink
feat: complete the fields of the CR Model
Browse files Browse the repository at this point in the history
  • Loading branch information
dayuy authored and bjwswang committed Mar 19, 2024
1 parent a8203d2 commit 21c524c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
16 changes: 16 additions & 0 deletions apiserver/graph/schema/model.gql
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ query listModels($input: ListModelInput!,$filesInput: FileFilter){
message
types
updateTimestamp
huggingFaceRepo
modelScopeRepo
revision
modelSource
files(input: $filesInput) {
totalCount
hasNextPage
Expand Down Expand Up @@ -59,6 +63,10 @@ query getModel($name: String!, $namespace: String!,$filesInput: FileFilter) {
message
types
updateTimestamp
huggingFaceRepo
modelScopeRepo
revision
modelSource
files(input: $filesInput) {
totalCount
hasNextPage
Expand Down Expand Up @@ -95,6 +103,10 @@ mutation createModel($input: CreateModelInput!) {
message
types
updateTimestamp
huggingFaceRepo
modelScopeRepo
revision
modelSource
}
}
}
Expand All @@ -117,6 +129,10 @@ mutation updateModel($input: UpdateModelInput) {
message
types
updateTimestamp
huggingFaceRepo
modelScopeRepo
revision
modelSource
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions apiserver/pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ var (
ModelTypeEmbedding = "embedding"
)

// ModelSource
var (
ModelSourceLocal = "local"
ModelSourceModelscope = "modelscope"
ModelSourceHuggingface = "huggingface"
)

func SystemDatasourceOSS(ctx context.Context, mgrClient client.Client) (*datasource.OSS, error) {
systemDatasource, err := config.GetSystemDatasource(ctx, mgrClient)
if err != nil {
Expand Down
27 changes: 27 additions & 0 deletions apiserver/pkg/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func obj2model(model *v1alpha1.Model) (*generated.Model, error) {
UpdateTimestamp: &updateTime,
Status: &status,
Message: &message,
HuggingFaceRepo: &model.Spec.HuggingFaceRepo,
ModelScopeRepo: &model.Spec.ModelScopeRepo,
Revision: &model.Spec.Revision,
ModelSource: &model.Spec.ModelSource,
}
return &md, nil
}
Expand All @@ -91,6 +95,18 @@ func CreateModel(ctx context.Context, c client.Client, input generated.CreateMod
Types: input.Types,
},
}
if *input.ModelSource == common.ModelSourceModelscope {
if *input.Revision == "" {
return nil, errors.New("argument revision is required")
}
model.Spec.ModelScopeRepo = *input.ModelScopeRepo
model.Spec.Revision = *input.Revision
}
if *input.ModelSource == common.ModelSourceHuggingface {
model.Spec.HuggingFaceRepo = *input.HuggingFaceRepo
model.Spec.Revision = *input.Revision
}
model.Spec.ModelSource = *input.ModelSource
model.Spec.DisplayName = pointer.StringDeref(input.DisplayName, model.Spec.DisplayName)
model.Spec.Description = pointer.StringDeref(input.Description, model.Spec.Description)
common.SetCreator(ctx, &model.Spec.CommonSpec)
Expand All @@ -114,6 +130,17 @@ func UpdateModel(ctx context.Context, c client.Client, input *generated.UpdateMo
model.Spec.Description = pointer.StringDeref(input.Description, model.Spec.Description)
model.Spec.Types = pointer.StringDeref(input.Types, model.Spec.Types)

if model.Spec.ModelSource == common.ModelSourceModelscope {
if *input.Revision == "" {
return nil, errors.New("argument revision is required")
}
model.Spec.ModelScopeRepo = *input.ModelScopeRepo
model.Spec.Revision = *input.Revision
}
if model.Spec.ModelSource == common.ModelSourceHuggingface {
model.Spec.HuggingFaceRepo = *input.HuggingFaceRepo
model.Spec.Revision = *input.Revision
}
err = c.Update(ctx, model)
if err != nil {
return nil, err
Expand Down

0 comments on commit 21c524c

Please sign in to comment.