Skip to content

Commit

Permalink
Merge pull request #817 from 0xff-dev/main
Browse files Browse the repository at this point in the history
chore: kubeagi-runner supports automatically pulling model files
  • Loading branch information
bjwswang authored Mar 8, 2024
2 parents 7a24518 + e47194d commit fd60d86
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
25 changes: 23 additions & 2 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,16 @@ var _ ModelRunner = (*KubeAGIRunner)(nil)
type KubeAGIRunner struct {
c client.Client
w *arcadiav1alpha1.Worker

modelFileFromRemote bool
}

func NewKubeAGIRunner(c client.Client, w *arcadiav1alpha1.Worker) (ModelRunner, error) {
func NewKubeAGIRunner(c client.Client, w *arcadiav1alpha1.Worker, modelFileFromRemote bool) (ModelRunner, error) {
return &KubeAGIRunner{
c: c,
w: w,

modelFileFromRemote: modelFileFromRemote,
}, nil
}

Expand All @@ -300,6 +304,23 @@ func (runner *KubeAGIRunner) Build(ctx context.Context, model *arcadiav1alpha1.T

// read worker address
mountPath := "/data/models"
rerankModelPath := fmt.Sprintf("%s/%s", mountPath, model.Name)

if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
return nil, err
}
if m.Spec.HuggingFaceRepo != "" {
rerankModelPath = m.Spec.HuggingFaceRepo
}
/*
TODO support modelscope
if m.Spec.ModelScopeRepo != "" {
rerankModelPath = m.Spec.ModelScopeRepo
}
*/
}
container := &corev1.Container{
Name: "runner",
Image: img,
Expand All @@ -309,7 +330,7 @@ func (runner *KubeAGIRunner) Build(ctx context.Context, model *arcadiav1alpha1.T
},
Env: []corev1.EnvVar{
// Only reranking supported for now
{Name: "RERANKING_MODEL_PATH", Value: fmt.Sprintf("%s/%s", mountPath, model.Name)},
{Name: "RERANKING_MODEL_PATH", Value: rerankModelPath},
},
Ports: []corev1.ContainerPort{
{Name: "http", ContainerPort: 21002},
Expand Down
2 changes: 1 addition & 1 deletion pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func (podWorker *PodWorker) Start(ctx context.Context) error {
}
podWorker.r = r
case arcadiav1alpha1.WorkerTypeKubeAGI:
r, err := NewKubeAGIRunner(podWorker.c, podWorker.w.DeepCopy())
r, err := NewKubeAGIRunner(podWorker.c, podWorker.w.DeepCopy(), loader == nil)
if err != nil {
return fmt.Errorf("failed to new a runner with %w", err)
}
Expand Down

0 comments on commit fd60d86

Please sign in to comment.