diff --git a/pkg/engine/deferrer.go b/pkg/engine/deferrer.go new file mode 100644 index 0000000..8b64c5b --- /dev/null +++ b/pkg/engine/deferrer.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package engine + +import ( + "context" + "sync" + "time" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/util/workqueue" + log "k8s.io/klog/v2" + + "github.com/NVIDIA/knavigator/pkg/config" +) + +type executor interface { + RunTask(context.Context, *config.Task) error +} + +type Deferrer struct { + executor executor + queue workqueue.DelayingInterface + client kubernetes.Interface + wg sync.WaitGroup +} + +func NewDereffer(client kubernetes.Interface, executor executor) *Deferrer { + return &Deferrer{ + executor: executor, + queue: workqueue.NewDelayingQueue(), + client: client, + } +} + +func (d *Deferrer) ScheduleTermination(taskID string) { + d.wg.Add(1) + d.queue.Add(taskID) +} + +func (d *Deferrer) Start(ctx context.Context) { + go d.start(ctx) +} + +func (d *Deferrer) start(ctx context.Context) { + for { + // Get an item from the queue + obj, shutdown := d.queue.Get() + if shutdown { + break + } + + switch v := obj.(type) { + case string: + log.Info("Wait for running pods", "taskID", v) + err := d.executor.RunTask(ctx, &config.Task{ + ID: "status", + Type: TaskCheckPod, + Params: map[string]interface{}{ + "refTaskId": v, + "status": "Running", + "timeout": "24h", + }, + }) + if err != nil { + log.Error(err, "Failed to watch pods") + d.wg.Done() + } else { + log.Info("AddTask", "type", TaskDeleteObj) + d.queue.AddAfter(&config.Task{ + ID: "delete", + Type: TaskDeleteObj, + Params: map[string]interface{}{"refTaskId": v}, + }, 5*time.Second) + } + + case *config.Task: + log.Info("Deferrer initiates task", "type", v.Type, "ID", v.ID) + + err := d.executor.RunTask(ctx, v) + if err != nil { + log.Error(err, "failed to execute task", "type", v.Type, "ID", v.ID) + } + d.wg.Done() + } + + // Mark the item as done + d.queue.Done(obj) + } +} + +func (d *Deferrer) Wait(ctx context.Context, timeout time.Duration) error { + log.Info("Waiting for deferrer to complete task") + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + done := make(chan struct{}) + + go func() { + d.wg.Wait() + done <- struct{}{} + }() + + select { + case <-done: + d.queue.ShutDown() + log.Info("Deferrer stopped") + return nil + case <-ctx.Done(): + log.Info("Deferrer didn't stop in allocated time") + return ctx.Err() + } +} diff --git a/pkg/engine/deferrer_test.go b/pkg/engine/deferrer_test.go new file mode 100644 index 0000000..12180ea --- /dev/null +++ b/pkg/engine/deferrer_test.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package engine + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/knavigator/pkg/config" +) + +type testExecutor struct { + tasks []string +} + +func (exec *testExecutor) RunTask(_ context.Context, cfg *config.Task) error { + exec.tasks = append(exec.tasks, cfg.ID) + return nil +} + +func TestDeferrer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + exec := &testExecutor{tasks: []string{}} + deferrer := NewDereffer(testLogger, exec) + deferrer.Start(ctx) + + deferrer.Inc(6) + deferrer.AddTask(&config.Task{ID: "t3"}, 3*time.Second) + deferrer.AddTask(&config.Task{ID: "t1"}, 1*time.Second) + deferrer.AddTask(&config.Task{ID: "t5"}, 5*time.Second) + deferrer.AddTask(&config.Task{ID: "t4"}, 4*time.Second) + deferrer.AddTask(&config.Task{ID: "t2"}, 2*time.Second) + deferrer.AddTask(&config.Task{ID: "t6"}, 6*time.Second) + + err := deferrer.Wait(ctx, 8*time.Second) + require.NoError(t, err) + require.Equal(t, []string{"t1", "t2", "t3", "t4", "t5", "t6"}, exec.tasks) +} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 926531e..f79e17e 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -45,6 +45,7 @@ type Eng struct { discoveryClient *discovery.DiscoveryClient objTypeMap map[string]*RegisterObjParams objInfoMap map[string]*ObjInfo + deferrer *Deferrer cleanup *CleanupInfo } @@ -72,6 +73,9 @@ func New(config *rest.Config, cleanupInfo *CleanupInfo, sim ...bool) (*Eng, erro eng.discoveryClient = &discovery.DiscoveryClient{} } + eng.deferrer = NewDereffer(eng.k8sClient, eng) + eng.deferrer.Start(context.TODO()) + return eng, nil } @@ -115,7 +119,7 @@ func (eng *Eng) GetTask(cfg *config.Task) (Runnable, error) { return newConfigureTask(eng.k8sClient, cfg) case TaskSubmitObj: - task, err := newSubmitObjTask(eng.dynamicClient, eng, cfg) + task, err := newSubmitObjTask(eng.dynamicClient, eng, eng.deferrer, cfg) if err != nil { return nil, err } @@ -173,6 +177,9 @@ func (eng *Eng) GetTask(cfg *config.Task) (Runnable, error) { case TaskPause: return newPauseTask(cfg), nil + case TaskWait: + return newWaitTask(eng.deferrer, cfg), nil + default: return nil, fmt.Errorf("unsupported task type %q", cfg.Type) } diff --git a/pkg/engine/submit_object_task.go b/pkg/engine/submit_object_task.go index c5dfc81..49a012e 100644 --- a/pkg/engine/submit_object_task.go +++ b/pkg/engine/submit_object_task.go @@ -37,6 +37,7 @@ type SubmitObjTask struct { submitObjTaskParams client *dynamic.DynamicClient accessor ObjInfoAccessor + deferrer *Deferrer } type submitObjTaskParams struct { @@ -62,7 +63,7 @@ type GenericObject struct { } // newSubmitObjTask initializes and returns SubmitObjTask -func newSubmitObjTask(client *dynamic.DynamicClient, accessor ObjInfoAccessor, cfg *config.Task) (*SubmitObjTask, error) { +func newSubmitObjTask(client *dynamic.DynamicClient, accessor ObjInfoAccessor, deferrer *Deferrer, cfg *config.Task) (*SubmitObjTask, error) { if client == nil { return nil, fmt.Errorf("%s/%s: DynamicClient is not set", cfg.Type, cfg.ID) } @@ -74,6 +75,7 @@ func newSubmitObjTask(client *dynamic.DynamicClient, accessor ObjInfoAccessor, c }, client: client, accessor: accessor, + deferrer: deferrer, } if err := task.validate(cfg.Params); err != nil { @@ -137,8 +139,9 @@ func (task *SubmitObjTask) Exec(ctx context.Context) error { } } - return task.accessor.SetObjInfo(task.taskID, - NewObjInfo(names, objs[0].Metadata.Namespace, regObjParams.gvr, podCount, podRegexp...)) + info := NewObjInfo(names, objs[0].Metadata.Namespace, regObjParams.gvr, podCount, podRegexp...) + //task.deferrer.ScheduleTermination(task.taskID) + return task.accessor.SetObjInfo(task.taskID, info) } func (task *SubmitObjTask) getGenericObjects(regObjParams *RegisterObjParams) ([]GenericObject, []string, int, []string, error) { diff --git a/pkg/engine/types.go b/pkg/engine/types.go index bcce513..be57ba1 100644 --- a/pkg/engine/types.go +++ b/pkg/engine/types.go @@ -36,6 +36,7 @@ const ( TaskUpdateNodes = "UpdateNodes" TaskSleep = "Sleep" TaskPause = "Pause" + TaskWait = "Wait" OpCreate = "create" OpDelete = "delete" @@ -97,6 +98,7 @@ type RegisterObjParams struct { // ObjInfo contains object GVR and an optional list of derived pod names type ObjInfo struct { + TaskID string Names []string Namespace string GVR schema.GroupVersionResource diff --git a/pkg/engine/wait_task.go b/pkg/engine/wait_task.go new file mode 100644 index 0000000..c6c4f76 --- /dev/null +++ b/pkg/engine/wait_task.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package engine + +import ( + "context" + "time" + + "github.com/NVIDIA/knavigator/pkg/config" +) + +type WaitTask struct { + BaseTask + + deferrer *Deferrer +} + +func newWaitTask(deferrer *Deferrer, cfg *config.Task) *WaitTask { + return &WaitTask{ + BaseTask: BaseTask{ + taskType: TaskWait, + taskID: cfg.ID, + }, + deferrer: deferrer, + } +} + +// Exec implements Runnable interface +func (task *WaitTask) Exec(ctx context.Context) error { + return task.deferrer.Wait(ctx, time.Minute) +} diff --git a/pkg/utils/informers.go b/pkg/utils/informers.go new file mode 100644 index 0000000..51e66c9 --- /dev/null +++ b/pkg/utils/informers.go @@ -0,0 +1,36 @@ +package utils + +import ( + "sync" + + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/dynamic/dynamicinformer" + "k8s.io/client-go/tools/cache" +) + +type informerManager struct { + mutex sync.Mutex + factories map[string]dynamicinformer.DynamicSharedInformerFactory +} + +var informerMgr *informerManager + +func init() { + informerMgr = &informerManager{ + factories: make(map[string]dynamicinformer.DynamicSharedInformerFactory), + } +} + +func GetInformer(client dynamic.Interface, namespace string, gvr schema.GroupVersionResource) cache.SharedInformer { + informerMgr.mutex.Lock() + defer informerMgr.mutex.Unlock() + + factory, ok := informerMgr.factories[namespace] + if !ok { + factory = dynamicinformer.NewFilteredDynamicSharedInformerFactory(client, 0, namespace, nil) + informerMgr.factories[namespace] = factory + } + + return factory.ForResource(gvr).Informer() +}