diff --git a/internal/engine/handler.go b/internal/engine/handler.go index 27bb5a968d..3e699e388e 100644 --- a/internal/engine/handler.go +++ b/internal/engine/handler.go @@ -17,12 +17,14 @@ package engine import ( "context" "fmt" + "slices" "sync" "time" "github.com/ThreeDotsLabs/watermill/message" "github.com/rs/zerolog" + "github.com/stacklok/minder/internal/engine/engcontext" "github.com/stacklok/minder/internal/engine/entities" "github.com/stacklok/minder/internal/events" minderlogger "github.com/stacklok/minder/internal/logger" @@ -44,10 +46,12 @@ type ExecutorEventHandler struct { evt events.Publisher handlerMiddleware []message.HandlerMiddleware wgEntityEventExecution *sync.WaitGroup - // terminationcontext is used to terminate the executor - // when the server is shutting down. - terminationcontext context.Context - executor Executor + executor Executor + // cancels are a set of cancel functions for current entity events in flight. + // This allows us to cancel rule evaluation directly when terminationContext + // is cancelled. + cancels []*context.CancelFunc + lock sync.Mutex } // NewExecutorEventHandler creates the event handler for the executor @@ -57,13 +61,23 @@ func NewExecutorEventHandler( handlerMiddleware []message.HandlerMiddleware, executor Executor, ) *ExecutorEventHandler { - return &ExecutorEventHandler{ + eh := &ExecutorEventHandler{ evt: evt, wgEntityEventExecution: &sync.WaitGroup{}, - terminationcontext: ctx, handlerMiddleware: handlerMiddleware, executor: executor, } + go func() { + <-ctx.Done() + eh.lock.Lock() + defer eh.lock.Unlock() + + for _, cancel := range eh.cancels { + (*cancel)() + } + }() + + return eh } // Register implements the Consumer interface. @@ -79,9 +93,23 @@ func (e *ExecutorEventHandler) Wait() { // HandleEntityEvent handles events coming from webhooks/signals // as well as the init event. func (e *ExecutorEventHandler) HandleEntityEvent(msg *message.Message) error { - // Grab the context before making a copy of the message - msgCtx := msg.Context() - // Let's not share memory with the caller + + // NOTE: we're _deliberately_ "escaping" from the parent context's Cancel/Done + // completion, because the default watermill behavior for both Go channels and + // SQL is to process messages sequentially, but we need additional parallelism + // beyond that. When we switch to a different message processing system, we + // should aim to remove this goroutine altogether and have the messaging system + // provide the parallelism. + // We _do_ still want to cancel on shutdown, however. + // TODO: Make this timeout configurable + msgCtx := context.WithoutCancel(msg.Context()) + msgCtx, shutdownCancel := context.WithCancel(msgCtx) + + e.lock.Lock() + e.cancels = append(e.cancels, &shutdownCancel) + e.lock.Unlock() + + // Let's not share memory with the caller. Note that this does not copy Context msg = msg.Copy() inf, err := entities.ParseEntityEvent(msg) @@ -95,11 +123,23 @@ func (e *ExecutorEventHandler) HandleEntityEvent(msg *message.Message) error { if inf.Type == pb.Entity_ENTITY_ARTIFACTS { time.Sleep(ArtifactSignatureWaitPeriod) } - // TODO: Make this timeout configurable - ctx, cancel := context.WithTimeout(e.terminationcontext, DefaultExecutionTimeout) - defer cancel() - ts := minderlogger.BusinessRecord(msgCtx) + ctx, cancel := context.WithTimeout(msgCtx, DefaultExecutionTimeout) + defer cancel() + defer func() { + e.lock.Lock() + e.cancels = slices.DeleteFunc(e.cancels, func(cf *context.CancelFunc) bool { + return cf == &shutdownCancel + }) + e.lock.Unlock() + }() + + ctx = engcontext.WithEntityContext(ctx, &engcontext.EntityContext{ + Project: engcontext.Project{ID: inf.ProjectID}, + // TODO: extract Provider name from ProviderID? + }) + + ts := minderlogger.BusinessRecord(ctx) ctx = ts.WithTelemetry(ctx) logger := zerolog.Ctx(ctx) @@ -116,14 +156,14 @@ func (e *ExecutorEventHandler) HandleEntityEvent(msg *message.Message) error { // here even though we also record it in the middleware because the evaluation // is done in a separate goroutine which usually still runs after the middleware // had already recorded the telemetry. - logMsg := zerolog.Ctx(ctx).Info() + logMsg := logger.Info() if err != nil { - logMsg = zerolog.Ctx(ctx).Error() + logMsg = logger.Error() } ts.Record(logMsg).Send() if err != nil { - zerolog.Ctx(ctx).Info(). + logger.Info(). Str("project", inf.ProjectID.String()). Str("provider_id", inf.ProviderID.String()). Str("entity", inf.Type.String()). diff --git a/internal/engine/handler_test.go b/internal/engine/handler_test.go index 7ea02db5ad..cc267987a2 100644 --- a/internal/engine/handler_test.go +++ b/internal/engine/handler_test.go @@ -45,6 +45,8 @@ func TestExecutorEventHandler_handleEntityEvent(t *testing.T) { repositoryID := uuid.New() executionID := uuid.New() + parallelOps := 2 + // -- end expectations evt, err := events.Setup(context.Background(), &serverconfig.EventConfig{ @@ -80,9 +82,11 @@ func TestExecutorEventHandler_handleEntityEvent(t *testing.T) { WithExecutionID(executionID) executor := mockengine.NewMockExecutor(ctrl) - executor.EXPECT(). - EvalEntityEvent(gomock.Any(), gomock.Eq(eiw)). - Return(nil) + for i := 0; i < parallelOps; i++ { + executor.EXPECT(). + EvalEntityEvent(gomock.Any(), gomock.Eq(eiw)). + Return(nil) + } handler := engine.NewExecutorEventHandler( ctx, @@ -97,19 +101,23 @@ func TestExecutorEventHandler_handleEntityEvent(t *testing.T) { msg, err := eiw.BuildMessage() require.NoError(t, err, "expected no error") - // Run in the background - go func() { - t.Log("Running entity event handler") - require.NoError(t, handler.HandleEntityEvent(msg), "expected no error") - }() + // Run in the background, twice + for i := 0; i < parallelOps; i++ { + go func() { + t.Log("Running entity event handler") + require.NoError(t, handler.HandleEntityEvent(msg), "expected no error") + }() + } // expect flush - t.Log("waiting for flush") - result := <-queued - require.NotNil(t, result) - require.Equal(t, providerID.String(), msg.Metadata.Get(entities.ProviderIDEventKey)) - require.Equal(t, "repository", msg.Metadata.Get(entities.EntityTypeEventKey)) - require.Equal(t, projectID.String(), msg.Metadata.Get(entities.ProjectIDEventKey)) + for i := 0; i < parallelOps; i++ { + t.Log("waiting for flush") + result := <-queued + require.NotNil(t, result) + require.Equal(t, providerID.String(), msg.Metadata.Get(entities.ProviderIDEventKey)) + require.Equal(t, "repository", msg.Metadata.Get(entities.EntityTypeEventKey)) + require.Equal(t, projectID.String(), msg.Metadata.Get(entities.ProjectIDEventKey)) + } require.NoError(t, evt.Close(), "expected no error") diff --git a/internal/flags/flags.go b/internal/flags/flags.go index 38d609b08d..7067a0acfc 100644 --- a/internal/flags/flags.go +++ b/internal/flags/flags.go @@ -39,10 +39,12 @@ func fromContext(ctx context.Context) openfeature.EvaluationContext { // Note: engine.EntityFromContext is best-effort, so these values may be zero. ec := engcontext.EntityFromContext(ctx) return openfeature.NewEvaluationContext( - jwt.GetUserSubjectFromContext(ctx), + ec.Project.ID.String(), map[string]interface{}{ - "project": ec.Project.ID.String(), + "project": ec.Project.ID.String(), + // TODO: is this useful, given how provider names are used? "provider": ec.Provider.Name, + "user": jwt.GetUserSubjectFromContext(ctx), }, ) }