diff --git a/cmd/server/app/serve.go b/cmd/server/app/serve.go index 6cee7d9c48..72a57b98df 100644 --- a/cmd/server/app/serve.go +++ b/cmd/server/app/serve.go @@ -31,6 +31,7 @@ import ( "github.com/stacklok/minder/internal/config" "github.com/stacklok/minder/internal/controlplane" "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/eea" "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/events" "github.com/stacklok/minder/internal/logger" @@ -109,7 +110,13 @@ var serveCmd = &cobra.Command{ return fmt.Errorf("unable to create server: %w", err) } - exec, err := engine.NewExecutor(store, &cfg.Auth, engine.WithProviderMetrics(providerMetrics)) + aggr := eea.NewEEA(store, evt, &cfg.Events.Aggregator) + + s.ConsumeEvents(aggr) + + exec, err := engine.NewExecutor(ctx, store, &cfg.Auth, evt, + engine.WithProviderMetrics(providerMetrics), + engine.WithAggregatorMiddleware(aggr)) if err != nil { return fmt.Errorf("unable to create executor: %w", err) } @@ -134,6 +141,13 @@ var serveCmd = &cobra.Command{ errg.Go(s.HandleEvents(ctx)) + // Wait for event handlers to start running + <-evt.Running() + + if err := aggr.FlushAll(ctx); err != nil { + return fmt.Errorf("error flushing cache: %w", err) + } + return errg.Wait() }, } diff --git a/database/migrations/000010_entity_execution_lock.down.sql b/database/migrations/000010_entity_execution_lock.down.sql new file mode 100644 index 0000000000..91e6e0d790 --- /dev/null +++ b/database/migrations/000010_entity_execution_lock.down.sql @@ -0,0 +1,19 @@ +-- Copyright 2023 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +DROP INDEX IF EXISTS flush_cache_idx; +DROP INDEX IF EXISTS entity_execution_lock_idx; + +DROP TABLE IF EXISTS flush_cache; +DROP TABLE IF EXISTS entity_execution_lock; \ No newline at end of file diff --git a/database/migrations/000010_entity_execution_lock.up.sql b/database/migrations/000010_entity_execution_lock.up.sql new file mode 100644 index 0000000000..37c5da5fa3 --- /dev/null +++ b/database/migrations/000010_entity_execution_lock.up.sql @@ -0,0 +1,51 @@ +-- Copyright 2023 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + + +--- This implements two tables: +--- * The entity execution lock table, which is used to prevent multiple +--- instances of the same entity from running at the same time. +--- * The flush cache table, which is used to cache entities to be executed +--- once the lock is released. + +CREATE TABLE IF NOT EXISTS entity_execution_lock ( + id UUID NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, + entity entities NOT NULL, + locked_by UUID NOT NULL, + last_lock_time TIMESTAMP NOT NULL, + repository_id UUID NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, + artifact_id UUID REFERENCES artifacts(id) ON DELETE CASCADE, + pull_request_id UUID REFERENCES pull_requests(id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS entity_execution_lock_idx ON entity_execution_lock( + entity, + repository_id, + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)); + +CREATE TABLE IF NOT EXISTS flush_cache ( + id UUID NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, + entity entities NOT NULL, + repository_id UUID NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, + artifact_id UUID REFERENCES artifacts(id) ON DELETE CASCADE, + pull_request_id UUID REFERENCES pull_requests(id) ON DELETE CASCADE, + queued_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS flush_cache_idx ON flush_cache( + entity, + repository_id, + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)); \ No newline at end of file diff --git a/database/mock/store.go b/database/mock/store.go index e3d30ed115..10116153b0 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -662,6 +662,36 @@ func (mr *MockStoreMockRecorder) DeleteUser(arg0, arg1 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockStore)(nil).DeleteUser), arg0, arg1) } +// EnqueueFlush mocks base method. +func (m *MockStore) EnqueueFlush(arg0 context.Context, arg1 db.EnqueueFlushParams) (db.FlushCache, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnqueueFlush", arg0, arg1) + ret0, _ := ret[0].(db.FlushCache) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EnqueueFlush indicates an expected call of EnqueueFlush. +func (mr *MockStoreMockRecorder) EnqueueFlush(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnqueueFlush", reflect.TypeOf((*MockStore)(nil).EnqueueFlush), arg0, arg1) +} + +// FlushCache mocks base method. +func (m *MockStore) FlushCache(arg0 context.Context, arg1 db.FlushCacheParams) (db.FlushCache, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FlushCache", arg0, arg1) + ret0, _ := ret[0].(db.FlushCache) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FlushCache indicates an expected call of FlushCache. +func (mr *MockStoreMockRecorder) FlushCache(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushCache", reflect.TypeOf((*MockStore)(nil).FlushCache), arg0, arg1) +} + // GetAccessTokenByProjectID mocks base method. func (m *MockStore) GetAccessTokenByProjectID(arg0 context.Context, arg1 db.GetAccessTokenByProjectIDParams) (db.ProviderAccessToken, error) { m.ctrl.T.Helper() @@ -1082,6 +1112,21 @@ func (mr *MockStoreMockRecorder) GetPullRequest(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPullRequest", reflect.TypeOf((*MockStore)(nil).GetPullRequest), arg0, arg1) } +// GetPullRequestByID mocks base method. +func (m *MockStore) GetPullRequestByID(arg0 context.Context, arg1 uuid.UUID) (db.PullRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPullRequestByID", arg0, arg1) + ret0, _ := ret[0].(db.PullRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPullRequestByID indicates an expected call of GetPullRequestByID. +func (mr *MockStoreMockRecorder) GetPullRequestByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPullRequestByID", reflect.TypeOf((*MockStore)(nil).GetPullRequestByID), arg0, arg1) +} + // GetQuerierWithTransaction mocks base method. func (m *MockStore) GetQuerierWithTransaction(arg0 *sql.Tx) db.ExtendQuerier { m.ctrl.T.Helper() @@ -1441,6 +1486,21 @@ func (mr *MockStoreMockRecorder) ListArtifactsByRepoID(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListArtifactsByRepoID", reflect.TypeOf((*MockStore)(nil).ListArtifactsByRepoID), arg0, arg1) } +// ListFlushCache mocks base method. +func (m *MockStore) ListFlushCache(arg0 context.Context) ([]db.FlushCache, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListFlushCache", arg0) + ret0, _ := ret[0].([]db.FlushCache) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListFlushCache indicates an expected call of ListFlushCache. +func (mr *MockStoreMockRecorder) ListFlushCache(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListFlushCache", reflect.TypeOf((*MockStore)(nil).ListFlushCache), arg0) +} + // ListOrganizations mocks base method. func (m *MockStore) ListOrganizations(arg0 context.Context, arg1 db.ListOrganizationsParams) ([]db.Project, error) { m.ctrl.T.Helper() @@ -1666,6 +1726,35 @@ func (mr *MockStoreMockRecorder) ListUsersByRoleId(arg0, arg1 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUsersByRoleId", reflect.TypeOf((*MockStore)(nil).ListUsersByRoleId), arg0, arg1) } +// LockIfThresholdNotExceeded mocks base method. +func (m *MockStore) LockIfThresholdNotExceeded(arg0 context.Context, arg1 db.LockIfThresholdNotExceededParams) (db.EntityExecutionLock, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LockIfThresholdNotExceeded", arg0, arg1) + ret0, _ := ret[0].(db.EntityExecutionLock) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockIfThresholdNotExceeded indicates an expected call of LockIfThresholdNotExceeded. +func (mr *MockStoreMockRecorder) LockIfThresholdNotExceeded(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockIfThresholdNotExceeded", reflect.TypeOf((*MockStore)(nil).LockIfThresholdNotExceeded), arg0, arg1) +} + +// ReleaseLock mocks base method. +func (m *MockStore) ReleaseLock(arg0 context.Context, arg1 db.ReleaseLockParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseLock", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseLock indicates an expected call of ReleaseLock. +func (mr *MockStoreMockRecorder) ReleaseLock(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseLock", reflect.TypeOf((*MockStore)(nil).ReleaseLock), arg0, arg1) +} + // Rollback mocks base method. func (m *MockStore) Rollback(arg0 *sql.Tx) error { m.ctrl.T.Helper() @@ -1695,6 +1784,20 @@ func (mr *MockStoreMockRecorder) UpdateAccessToken(arg0, arg1 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccessToken", reflect.TypeOf((*MockStore)(nil).UpdateAccessToken), arg0, arg1) } +// UpdateLease mocks base method. +func (m *MockStore) UpdateLease(arg0 context.Context, arg1 db.UpdateLeaseParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLease", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateLease indicates an expected call of UpdateLease. +func (mr *MockStoreMockRecorder) UpdateLease(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLease", reflect.TypeOf((*MockStore)(nil).UpdateLease), arg0, arg1) +} + // UpdateOrganization mocks base method. func (m *MockStore) UpdateOrganization(arg0 context.Context, arg1 db.UpdateOrganizationParams) (db.Project, error) { m.ctrl.T.Helper() diff --git a/database/query/entity_execution_lock.sql b/database/query/entity_execution_lock.sql new file mode 100644 index 0000000000..b333d4dc67 --- /dev/null +++ b/database/query/entity_execution_lock.sql @@ -0,0 +1,70 @@ +-- LockIfThresholdNotExceeded is used to lock an entity for execution. It will +-- attempt to insert or update the entity_execution_lock table only if the +-- last_lock_time is older than the threshold. If the lock is successful, it +-- will return the lock record. If the lock is unsuccessful, it will return +-- NULL. + +-- name: LockIfThresholdNotExceeded :one +INSERT INTO entity_execution_lock( + entity, + locked_by, + last_lock_time, + repository_id, + artifact_id, + pull_request_id +) VALUES( + sqlc.arg(entity)::entities, + gen_random_uuid(), + NOW(), + sqlc.arg(repository_id)::UUID, + sqlc.narg(artifact_id)::UUID, + sqlc.narg(pull_request_id)::UUID +) ON CONFLICT(entity, repository_id, COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)) +DO UPDATE SET + locked_by = gen_random_uuid(), + last_lock_time = NOW() +WHERE entity_execution_lock.last_lock_time < (NOW() - (@interval::TEXT || ' seconds')::interval) +RETURNING *; + +-- ReleaseLock is used to release a lock on an entity. It will delete the +-- entity_execution_lock record if the lock is held by the given locked_by +-- value. + +-- name: ReleaseLock :exec +DELETE FROM entity_execution_lock +WHERE entity = sqlc.arg(entity)::entities AND repository_id = sqlc.arg(repository_id)::UUID AND + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(artifact_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(pull_request_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + locked_by = sqlc.arg(locked_by)::UUID; + +-- name: UpdateLease :exec +UPDATE entity_execution_lock SET last_lock_time = NOW() +WHERE entity = $1 AND repository_id = $2 AND +COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(artifact_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND +COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(pull_request_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND +locked_by = sqlc.arg(locked_by)::UUID; + +-- name: EnqueueFlush :one +INSERT INTO flush_cache( + entity, + repository_id, + artifact_id, + pull_request_id +) VALUES( + sqlc.arg(entity)::entities, + sqlc.arg(repository_id)::UUID, + sqlc.narg(artifact_id)::UUID, + sqlc.narg(pull_request_id)::UUID +) ON CONFLICT(entity, repository_id, COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)) +DO NOTHING +RETURNING *; + +-- name: FlushCache :one +DELETE FROM flush_cache +WHERE entity = $1 AND repository_id = $2 AND + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(artifact_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE(sqlc.narg(pull_request_id)::UUID, '00000000-0000-0000-0000-000000000000'::UUID) +RETURNING *; + +-- name: ListFlushCache :many +SELECT * FROM flush_cache; \ No newline at end of file diff --git a/database/query/pull_requests.sql b/database/query/pull_requests.sql index b9f700f6eb..89a871608d 100644 --- a/database/query/pull_requests.sql +++ b/database/query/pull_requests.sql @@ -19,6 +19,10 @@ RETURNING *; SELECT * FROM pull_requests WHERE repository_id = $1 AND pr_number = $2; +-- name: GetPullRequestByID :one +SELECT * FROM pull_requests +WHERE id = $1; + -- name: DeletePullRequest :exec DELETE FROM pull_requests WHERE repository_id = $1 AND pr_number = $2; \ No newline at end of file diff --git a/internal/config/events.go b/internal/config/events.go index 1998d4eb9c..46bdf8fadc 100644 --- a/internal/config/events.go +++ b/internal/config/events.go @@ -23,6 +23,8 @@ type EventConfig struct { RouterCloseTimeout int64 `mapstructure:"router_close_timeout" default:"10"` // GoChannel is the configuration for the go channel event driver GoChannel GoChannelEventConfig `mapstructure:"go-channel" default:"{}"` + // Aggregator is the configuration for the event aggregator middleware + Aggregator AggregatorConfig `mapstructure:"aggregator" default:"{}"` } // GoChannelEventConfig is the configuration for the go channel event driver @@ -32,4 +34,14 @@ type GoChannelEventConfig struct { BufferSize int64 `mapstructure:"buffer_size" default:"0"` // PersistEvents is whether or not to persist events to the channel PersistEvents bool `mapstructure:"persist_events" default:"false"` + // BlockPublishUntilSubscriberAck is whether or not to block publishing until + // the subscriber acks the message. This is useful for testing. + BlockPublishUntilSubscriberAck bool `mapstructure:"block_publish_until_subscriber_ack" default:"false"` +} + +// AggregatorConfig is the configuration for the event aggregator middleware +type AggregatorConfig struct { + // LockInterval is the interval for locking events in seconds. + // This is the threshold between rule evaluations + actions. + LockInterval int64 `mapstructure:"lock_interval" default:"30"` } diff --git a/internal/controlplane/handlers_githubwebhooks.go b/internal/controlplane/handlers_githubwebhooks.go index db82afc783..fecc9ad5b5 100644 --- a/internal/controlplane/handlers_githubwebhooks.go +++ b/internal/controlplane/handlers_githubwebhooks.go @@ -190,7 +190,7 @@ func (s *Server) HandleGitHubWebHook() http.HandlerFunc { wes.accepted = true - if err := s.evt.Publish(engine.InternalEntityEventTopic, m); err != nil { + if err := s.evt.Publish(engine.ExecuteEntityEventTopic, m); err != nil { wes.error = true log.Printf("Error publishing message: %v", err) w.WriteHeader(http.StatusInternalServerError) diff --git a/internal/controlplane/handlers_githubwebhooks_test.go b/internal/controlplane/handlers_githubwebhooks_test.go index f55b2d3aed..2ffdad5bbc 100644 --- a/internal/controlplane/handlers_githubwebhooks_test.go +++ b/internal/controlplane/handlers_githubwebhooks_test.go @@ -31,7 +31,6 @@ import ( "testing" "time" - "github.com/ThreeDotsLabs/watermill/message" "github.com/golang/mock/gomock" "github.com/google/go-github/v53/github" "github.com/google/uuid" @@ -46,6 +45,7 @@ import ( "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/util/rand" + "github.com/stacklok/minder/internal/util/testqueue" ) // MockClient is a mock implementation of the GitHub client. @@ -95,10 +95,10 @@ func (s *UnitTestSuite) TestHandleWebHookPing() { srv := newDefaultServer(t, mockStore) defer srv.evt.Close() - pq := newPassthroughQueue() - queued := pq.getQueue() + pq := testqueue.NewPassthroughQueue() + queued := pq.GetQueue() - srv.evt.Register(engine.InternalEntityEventTopic, pq.pass) + srv.evt.Register(engine.ExecuteEntityEventTopic, pq.Pass) go func() { err := srv.evt.Run(context.Background()) @@ -148,10 +148,10 @@ func (s *UnitTestSuite) TestHandleWebHookUnexistentRepository() { srv := newDefaultServer(t, mockStore) defer srv.evt.Close() - pq := newPassthroughQueue() - queued := pq.getQueue() + pq := testqueue.NewPassthroughQueue() + queued := pq.GetQueue() - srv.evt.Register(engine.InternalEntityEventTopic, pq.pass) + srv.evt.Register(engine.ExecuteEntityEventTopic, pq.Pass) go func() { err := srv.evt.Run(context.Background()) @@ -214,10 +214,10 @@ func (s *UnitTestSuite) TestHandleWebHookRepository() { srv := newDefaultServer(t, mockStore) defer srv.evt.Close() - pq := newPassthroughQueue() - queued := pq.getQueue() + pq := testqueue.NewPassthroughQueue() + queued := pq.GetQueue() - srv.evt.Register(engine.InternalEntityEventTopic, pq.pass) + srv.evt.Register(engine.ExecuteEntityEventTopic, pq.Pass) go func() { err := srv.evt.Run(context.Background()) @@ -331,10 +331,10 @@ func (s *UnitTestSuite) TestHandleWebHookUnexistentRepoPackage() { srv := newDefaultServer(t, mockStore) defer srv.evt.Close() - pq := newPassthroughQueue() - queued := pq.getQueue() + pq := testqueue.NewPassthroughQueue() + queued := pq.GetQueue() - srv.evt.Register(engine.InternalEntityEventTopic, pq.pass) + srv.evt.Register(engine.ExecuteEntityEventTopic, pq.Pass) go func() { err := srv.evt.Run(context.Background()) @@ -393,22 +393,3 @@ func TestAll(t *testing.T) { RunUnitTestSuite(t) // Call other test runner functions for additional test suites } - -type passthroughQueue struct { - ch chan *message.Message -} - -func newPassthroughQueue() *passthroughQueue { - return &passthroughQueue{ - ch: make(chan *message.Message), - } -} - -func (q *passthroughQueue) getQueue() <-chan *message.Message { - return q.ch -} - -func (q *passthroughQueue) pass(msg *message.Message) error { - q.ch <- msg - return nil -} diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 7c1915c5d0..158e0c9083 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -23,6 +23,7 @@ import ( "net/http" "time" + "github.com/ThreeDotsLabs/watermill/message" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/prometheus/client_golang/prometheus/promhttp" _ "github.com/signalfx/splunk-otel-go/instrumentation/github.com/lib/pq/splunkpq" // Auto-instrumented version of lib/pq @@ -145,8 +146,8 @@ func (s *Server) initTracer() (*sdktrace.TracerProvider, error) { } // Register implements events.Registrar -func (s *Server) Register(topic string, handler events.Handler) { - s.evt.Register(topic, handler) +func (s *Server) Register(topic string, handler events.Handler, mdw ...message.HandlerMiddleware) { + s.evt.Register(topic, handler, mdw...) } // ConsumeEvents implements events.Registrar diff --git a/internal/db/entity_execution_lock.sql.go b/internal/db/entity_execution_lock.sql.go new file mode 100644 index 0000000000..710b1c4915 --- /dev/null +++ b/internal/db/entity_execution_lock.sql.go @@ -0,0 +1,238 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.23.0 +// source: entity_execution_lock.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const enqueueFlush = `-- name: EnqueueFlush :one +INSERT INTO flush_cache( + entity, + repository_id, + artifact_id, + pull_request_id +) VALUES( + $1::entities, + $2::UUID, + $3::UUID, + $4::UUID +) ON CONFLICT(entity, repository_id, COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)) +DO NOTHING +RETURNING id, entity, repository_id, artifact_id, pull_request_id, queued_at +` + +type EnqueueFlushParams struct { + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` +} + +func (q *Queries) EnqueueFlush(ctx context.Context, arg EnqueueFlushParams) (FlushCache, error) { + row := q.db.QueryRowContext(ctx, enqueueFlush, + arg.Entity, + arg.RepositoryID, + arg.ArtifactID, + arg.PullRequestID, + ) + var i FlushCache + err := row.Scan( + &i.ID, + &i.Entity, + &i.RepositoryID, + &i.ArtifactID, + &i.PullRequestID, + &i.QueuedAt, + ) + return i, err +} + +const flushCache = `-- name: FlushCache :one +DELETE FROM flush_cache +WHERE entity = $1 AND repository_id = $2 AND + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($3::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($4::UUID, '00000000-0000-0000-0000-000000000000'::UUID) +RETURNING id, entity, repository_id, artifact_id, pull_request_id, queued_at +` + +type FlushCacheParams struct { + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` +} + +func (q *Queries) FlushCache(ctx context.Context, arg FlushCacheParams) (FlushCache, error) { + row := q.db.QueryRowContext(ctx, flushCache, + arg.Entity, + arg.RepositoryID, + arg.ArtifactID, + arg.PullRequestID, + ) + var i FlushCache + err := row.Scan( + &i.ID, + &i.Entity, + &i.RepositoryID, + &i.ArtifactID, + &i.PullRequestID, + &i.QueuedAt, + ) + return i, err +} + +const listFlushCache = `-- name: ListFlushCache :many +SELECT id, entity, repository_id, artifact_id, pull_request_id, queued_at FROM flush_cache +` + +func (q *Queries) ListFlushCache(ctx context.Context) ([]FlushCache, error) { + rows, err := q.db.QueryContext(ctx, listFlushCache) + if err != nil { + return nil, err + } + defer rows.Close() + items := []FlushCache{} + for rows.Next() { + var i FlushCache + if err := rows.Scan( + &i.ID, + &i.Entity, + &i.RepositoryID, + &i.ArtifactID, + &i.PullRequestID, + &i.QueuedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const lockIfThresholdNotExceeded = `-- name: LockIfThresholdNotExceeded :one + +INSERT INTO entity_execution_lock( + entity, + locked_by, + last_lock_time, + repository_id, + artifact_id, + pull_request_id +) VALUES( + $1::entities, + gen_random_uuid(), + NOW(), + $2::UUID, + $3::UUID, + $4::UUID +) ON CONFLICT(entity, repository_id, COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID), COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID)) +DO UPDATE SET + locked_by = gen_random_uuid(), + last_lock_time = NOW() +WHERE entity_execution_lock.last_lock_time < (NOW() - ($5::TEXT || ' seconds')::interval) +RETURNING id, entity, locked_by, last_lock_time, repository_id, artifact_id, pull_request_id +` + +type LockIfThresholdNotExceededParams struct { + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` + Interval string `json:"interval"` +} + +// LockIfThresholdNotExceeded is used to lock an entity for execution. It will +// attempt to insert or update the entity_execution_lock table only if the +// last_lock_time is older than the threshold. If the lock is successful, it +// will return the lock record. If the lock is unsuccessful, it will return +// NULL. +func (q *Queries) LockIfThresholdNotExceeded(ctx context.Context, arg LockIfThresholdNotExceededParams) (EntityExecutionLock, error) { + row := q.db.QueryRowContext(ctx, lockIfThresholdNotExceeded, + arg.Entity, + arg.RepositoryID, + arg.ArtifactID, + arg.PullRequestID, + arg.Interval, + ) + var i EntityExecutionLock + err := row.Scan( + &i.ID, + &i.Entity, + &i.LockedBy, + &i.LastLockTime, + &i.RepositoryID, + &i.ArtifactID, + &i.PullRequestID, + ) + return i, err +} + +const releaseLock = `-- name: ReleaseLock :exec + +DELETE FROM entity_execution_lock +WHERE entity = $1::entities AND repository_id = $2::UUID AND + COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($3::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($4::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND + locked_by = $5::UUID +` + +type ReleaseLockParams struct { + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` + LockedBy uuid.UUID `json:"locked_by"` +} + +// ReleaseLock is used to release a lock on an entity. It will delete the +// entity_execution_lock record if the lock is held by the given locked_by +// value. +func (q *Queries) ReleaseLock(ctx context.Context, arg ReleaseLockParams) error { + _, err := q.db.ExecContext(ctx, releaseLock, + arg.Entity, + arg.RepositoryID, + arg.ArtifactID, + arg.PullRequestID, + arg.LockedBy, + ) + return err +} + +const updateLease = `-- name: UpdateLease :exec +UPDATE entity_execution_lock SET last_lock_time = NOW() +WHERE entity = $1 AND repository_id = $2 AND +COALESCE(artifact_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($3::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND +COALESCE(pull_request_id, '00000000-0000-0000-0000-000000000000'::UUID) = COALESCE($4::UUID, '00000000-0000-0000-0000-000000000000'::UUID) AND +locked_by = $5::UUID +` + +type UpdateLeaseParams struct { + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` + LockedBy uuid.UUID `json:"locked_by"` +} + +func (q *Queries) UpdateLease(ctx context.Context, arg UpdateLeaseParams) error { + _, err := q.db.ExecContext(ctx, updateLease, + arg.Entity, + arg.RepositoryID, + arg.ArtifactID, + arg.PullRequestID, + arg.LockedBy, + ) + return err +} diff --git a/internal/db/entity_execution_lock.sql_test.go b/internal/db/entity_execution_lock.sql_test.go new file mode 100644 index 0000000000..397370ea75 --- /dev/null +++ b/internal/db/entity_execution_lock.sql_test.go @@ -0,0 +1,99 @@ +// +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestQueries_LockIfThresholdNotExceeded(t *testing.T) { + t.Parallel() + + org := createRandomOrganization(t) + group := createRandomProject(t, org.ID) + prov := createRandomProvider(t, group.ID) + repo := createRandomRepository(t, group.ID, prov.Name) + + threshold := 1 + concurrentCalls := 10 + + // waitgroup + var wg sync.WaitGroup + var queueCount atomic.Int32 + var effectiveFlush atomic.Int32 + + for i := 0; i < concurrentCalls; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := testQueries.LockIfThresholdNotExceeded(context.Background(), LockIfThresholdNotExceededParams{ + Entity: EntitiesRepository, + RepositoryID: repo.ID, + ArtifactID: uuid.NullUUID{}, + PullRequestID: uuid.NullUUID{}, + Interval: fmt.Sprintf("%d", threshold), + }) + + if err != nil && errors.Is(err, sql.ErrNoRows) { + t.Log("lock had been acquired. adding to queue") + // count the number of times we've been queued + queueCount.Add(1) + + _, err := testQueries.EnqueueFlush(context.Background(), EnqueueFlushParams{ + Entity: EntitiesRepository, + RepositoryID: repo.ID, + ArtifactID: uuid.NullUUID{}, + PullRequestID: uuid.NullUUID{}, + }) + if err == nil { + effectiveFlush.Add(1) + } + } else if err != nil { + assert.NoError(t, err, "expected no error") + } + }() + } + + wg.Wait() + + assert.Equal(t, int32(concurrentCalls-1), queueCount.Load(), "expected all but one to be queued") + assert.Equal(t, int32(1), effectiveFlush.Load(), "expected only one flush to be queued") + + t.Log("sleeping for threshold") + time.Sleep(time.Duration(threshold) * time.Second) + + t.Log("Attempting to acquire lock now that threshold has passed") + + _, err := testQueries.LockIfThresholdNotExceeded(context.Background(), LockIfThresholdNotExceededParams{ + Entity: EntitiesRepository, + RepositoryID: repo.ID, + ArtifactID: uuid.NullUUID{}, + PullRequestID: uuid.NullUUID{}, + Interval: fmt.Sprintf("%d", threshold), + }) + + assert.NoError(t, err, "expected no error") +} diff --git a/internal/db/models.go b/internal/db/models.go index 8197fc1244..8599a92690 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -310,6 +310,16 @@ type Entitlement struct { CreatedAt time.Time `json:"created_at"` } +type EntityExecutionLock struct { + ID uuid.UUID `json:"id"` + Entity Entities `json:"entity"` + LockedBy uuid.UUID `json:"locked_by"` + LastLockTime time.Time `json:"last_lock_time"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` +} + type EntityProfile struct { ID uuid.UUID `json:"id"` Entity Entities `json:"entity"` @@ -333,6 +343,15 @@ type Feature struct { UpdatedAt time.Time `json:"updated_at"` } +type FlushCache struct { + ID uuid.UUID `json:"id"` + Entity Entities `json:"entity"` + RepositoryID uuid.UUID `json:"repository_id"` + ArtifactID uuid.NullUUID `json:"artifact_id"` + PullRequestID uuid.NullUUID `json:"pull_request_id"` + QueuedAt time.Time `json:"queued_at"` +} + type Profile struct { ID uuid.UUID `json:"id"` Name string `json:"name"` diff --git a/internal/db/pull_requests.sql.go b/internal/db/pull_requests.sql.go index 4cac8c8815..42e8d69543 100644 --- a/internal/db/pull_requests.sql.go +++ b/internal/db/pull_requests.sql.go @@ -74,6 +74,24 @@ func (q *Queries) GetPullRequest(ctx context.Context, arg GetPullRequestParams) return i, err } +const getPullRequestByID = `-- name: GetPullRequestByID :one +SELECT id, repository_id, pr_number, created_at, updated_at FROM pull_requests +WHERE id = $1 +` + +func (q *Queries) GetPullRequestByID(ctx context.Context, id uuid.UUID) (PullRequest, error) { + row := q.db.QueryRowContext(ctx, getPullRequestByID, id) + var i PullRequest + err := row.Scan( + &i.ID, + &i.RepositoryID, + &i.PrNumber, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const upsertPullRequest = `-- name: UpsertPullRequest :one INSERT INTO pull_requests ( repository_id, diff --git a/internal/db/querier.go b/internal/db/querier.go index 5e88d918de..36510bf3ec 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -54,6 +54,8 @@ type Querier interface { DeleteSessionStateByProjectID(ctx context.Context, arg DeleteSessionStateByProjectIDParams) error DeleteSigningKey(ctx context.Context, arg DeleteSigningKeyParams) error DeleteUser(ctx context.Context, id int32) error + EnqueueFlush(ctx context.Context, arg EnqueueFlushParams) (FlushCache, error) + FlushCache(ctx context.Context, arg FlushCacheParams) (FlushCache, error) GetAccessTokenByProjectID(ctx context.Context, arg GetAccessTokenByProjectIDParams) (ProviderAccessToken, error) GetAccessTokenByProvider(ctx context.Context, provider string) ([]ProviderAccessToken, error) GetAccessTokenSinceDate(ctx context.Context, arg GetAccessTokenSinceDateParams) (ProviderAccessToken, error) @@ -84,6 +86,7 @@ type Querier interface { GetProviderByID(ctx context.Context, arg GetProviderByIDParams) (Provider, error) GetProviderByName(ctx context.Context, arg GetProviderByNameParams) (Provider, error) GetPullRequest(ctx context.Context, arg GetPullRequestParams) (PullRequest, error) + GetPullRequestByID(ctx context.Context, id uuid.UUID) (PullRequest, error) GetRepositoryByID(ctx context.Context, id uuid.UUID) (Repository, error) GetRepositoryByIDAndProject(ctx context.Context, arg GetRepositoryByIDAndProjectParams) (Repository, error) GetRepositoryByRepoID(ctx context.Context, repoID int32) (Repository, error) @@ -106,6 +109,7 @@ type Querier interface { ListArtifactVersionsByArtifactID(ctx context.Context, arg ListArtifactVersionsByArtifactIDParams) ([]ArtifactVersion, error) ListArtifactVersionsByArtifactIDAndTag(ctx context.Context, arg ListArtifactVersionsByArtifactIDAndTagParams) ([]ArtifactVersion, error) ListArtifactsByRepoID(ctx context.Context, repositoryID uuid.UUID) ([]Artifact, error) + ListFlushCache(ctx context.Context) ([]FlushCache, error) ListOrganizations(ctx context.Context, arg ListOrganizationsParams) ([]Project, error) ListProfilesByProjectID(ctx context.Context, projectID uuid.UUID) ([]ListProfilesByProjectIDRow, error) // get profile information that instantiate a rule. This is done by joining the profiles with entity_profiles, then correlating those @@ -124,7 +128,18 @@ type Querier interface { ListUsersByOrganization(ctx context.Context, arg ListUsersByOrganizationParams) ([]User, error) ListUsersByProject(ctx context.Context, arg ListUsersByProjectParams) ([]User, error) ListUsersByRoleId(ctx context.Context, roleID int32) ([]int32, error) + // LockIfThresholdNotExceeded is used to lock an entity for execution. It will + // attempt to insert or update the entity_execution_lock table only if the + // last_lock_time is older than the threshold. If the lock is successful, it + // will return the lock record. If the lock is unsuccessful, it will return + // NULL. + LockIfThresholdNotExceeded(ctx context.Context, arg LockIfThresholdNotExceededParams) (EntityExecutionLock, error) + // ReleaseLock is used to release a lock on an entity. It will delete the + // entity_execution_lock record if the lock is held by the given locked_by + // value. + ReleaseLock(ctx context.Context, arg ReleaseLockParams) error UpdateAccessToken(ctx context.Context, arg UpdateAccessTokenParams) (ProviderAccessToken, error) + UpdateLease(ctx context.Context, arg UpdateLeaseParams) error UpdateOrganization(ctx context.Context, arg UpdateOrganizationParams) (Project, error) UpdateProfile(ctx context.Context, arg UpdateProfileParams) (Profile, error) // set clone_url if the value is not an empty string diff --git a/internal/eea/eea.go b/internal/eea/eea.go new file mode 100644 index 0000000000..e2bff40bdb --- /dev/null +++ b/internal/eea/eea.go @@ -0,0 +1,290 @@ +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Package rule provides the CLI subcommand for managing rules + +// Package eea provides objects and event handlers for the EEA. EEA stands for +// Event Execution Aggregator. The EEA is responsible for aggregating events +// from the webhook and making sure we don't send too many events to the +// executor engine. +package eea + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/stacklok/minder/internal/config" + "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/engine" + "github.com/stacklok/minder/internal/entities" + "github.com/stacklok/minder/internal/events" + "github.com/stacklok/minder/internal/util" +) + +// EEA is the Event Execution Aggregator +type EEA struct { + querier db.Store + evt *events.Eventer + cfg *config.AggregatorConfig +} + +// NewEEA creates a new EEA +func NewEEA(querier db.Store, evt *events.Eventer, cfg *config.AggregatorConfig) *EEA { + return &EEA{ + querier: querier, + evt: evt, + cfg: cfg, + } +} + +// Register implements the Consumer interface. +func (e *EEA) Register(r events.Registrar) { + r.Register(engine.FlushEntityEventTopic, e.FlushMessageHandler) +} + +// AggregateMiddleware will pass on the event to the executor engine +// if the event is ready to be executed. Else it'll cache +// the event until it's ready to be executed. +func (e *EEA) AggregateMiddleware(h message.HandlerFunc) message.HandlerFunc { + return func(msg *message.Message) ([]*message.Message, error) { + msg, err := e.aggregate(msg) + if err != nil { + return nil, fmt.Errorf("error aggregating event: %w", err) + } + + if msg == nil { + return nil, nil + } + + return h(msg) + } +} + +func (e *EEA) aggregate(msg *message.Message) (*message.Message, error) { + ctx := msg.Context() + inf, err := engine.ParseEntityEvent(msg) + if err != nil { + return nil, fmt.Errorf("error unmarshalling payload: %w", err) + } + + repoID, artifactID, pullRequestID := inf.GetEntityDBIDs() + + res, err := e.querier.LockIfThresholdNotExceeded(ctx, db.LockIfThresholdNotExceededParams{ + Entity: entities.EntityTypeToDB(inf.Type), + RepositoryID: repoID, + ArtifactID: artifactID, + PullRequestID: pullRequestID, + Interval: fmt.Sprintf("%d", e.cfg.LockInterval), + }) + + logger := zerolog.Ctx(ctx).Info() + logger = logger.Str("event", msg.UUID). + Str("entity", inf.Type.ToString()). + Str("repository_id", repoID.String()) + + if artifactID.Valid { + logger = logger.Str("artifact_id", artifactID.UUID.String()) + } + + if pullRequestID.Valid { + logger = logger.Str("pull_request_id", pullRequestID.UUID.String()) + } + + // if nothing was retrieved from the database, then we can assume + // that the event is not ready to be executed. + if err != nil && errors.Is(err, sql.ErrNoRows) { + logger.Msg("event not ready to be executed") + + _, err := e.querier.EnqueueFlush(ctx, db.EnqueueFlushParams{ + Entity: entities.EntityTypeToDB(inf.Type), + RepositoryID: repoID, + ArtifactID: artifactID, + PullRequestID: pullRequestID, + }) + if err != nil { + // We already have this item in the queue. + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("error enqueuing flush: %w", err) + } + + return nil, nil + } else if err != nil { + logger.Err(err).Msg("error locking event") + return nil, fmt.Errorf("error locking: %w", err) + } + + logger.Str("execution_id", res.LockedBy.String()).Msg("event ready to be executed") + msg.Metadata.Set(engine.ExecutionIDKey, res.LockedBy.String()) + + return msg, nil +} + +// FlushMessageHandler will flush the cache of events to the executor engine +// if the event is ready to be executed. +func (e *EEA) FlushMessageHandler(msg *message.Message) error { + ctx := msg.Context() + + inf, err := engine.ParseEntityEvent(msg) + if err != nil { + return fmt.Errorf("error unmarshalling payload: %w", err) + } + + repoID, artifactID, pullRequestID := inf.GetEntityDBIDs() + + zerolog.Ctx(ctx).Info(). + Str("event", msg.UUID). + Str("entity", inf.Type.ToString()). + Str("repository_id", repoID.String()).Msg("flushing event") + + _, err = e.querier.FlushCache(ctx, db.FlushCacheParams{ + Entity: entities.EntityTypeToDB(inf.Type), + RepositoryID: repoID, + ArtifactID: artifactID, + PullRequestID: pullRequestID, + }) + // Nothing to do here. If we can't flush the cache, it means + // that the event has already been executed. + if err != nil && errors.Is(err, sql.ErrNoRows) { + zerolog.Ctx(ctx).Info(). + Str("event", msg.UUID). + Str("entity", inf.Type.ToString()). + Str("repository_id", repoID.String()).Msg("no flushing needed") + return nil + } else if err != nil { + return fmt.Errorf("error flushing cache: %w", err) + } + + zerolog.Ctx(ctx).Info(). + Str("event", msg.UUID). + Str("entity", inf.Type.ToString()). + Str("repository_id", repoID.String()).Msg("re-publishing event because of flush") + + // Now that we've flushed the event, let's try to publish it again + // which means, go through the locking process again. + if err := inf.Publish(e.evt); err != nil { + return fmt.Errorf("error publishing execute event: %w", err) + } + + return nil +} + +// FlushAll will flush all events in the cache to the executor engine +func (e *EEA) FlushAll(ctx context.Context) error { + caches, err := e.querier.ListFlushCache(ctx) + if err != nil { + // No rows to flush, this is fine. + if errors.Is(err, sql.ErrNoRows) { + return nil + } + return fmt.Errorf("error listing flush cache: %w", err) + } + + for _, cache := range caches { + cache := cache + + eiw, err := e.buildEntityWrapper(ctx, cache.Entity, + cache.RepositoryID, cache.ArtifactID, cache.PullRequestID) + if err != nil && errors.Is(err, sql.ErrNoRows) { + continue + } else if err != nil { + return fmt.Errorf("error flushing cache: %w", err) + } + + msg, err := eiw.BuildMessage() + if err != nil { + return fmt.Errorf("error flushing cache: %w", err) + } + + msg.SetContext(ctx) + + if err := e.FlushMessageHandler(msg); err != nil { + return fmt.Errorf("error flushing cache: %w", err) + } + } + + return nil +} + +func (e *EEA) buildEntityWrapper( + ctx context.Context, + entity db.Entities, + repoID uuid.UUID, + artID, prID uuid.NullUUID, +) (*engine.EntityInfoWrapper, error) { + switch entity { + case db.EntitiesRepository: + return e.buildRepositoryInfoWrapper(ctx, repoID) + case db.EntitiesArtifact: + return e.buildArtifactInfoWrapper(ctx, repoID, artID) + case db.EntitiesPullRequest: + return e.buildPullRequestInfoWrapper(ctx, repoID, prID) + case db.EntitiesBuildEnvironment: + return nil, fmt.Errorf("build environment entity not supported") + default: + return nil, fmt.Errorf("unknown entity type: %s", entity) + } +} + +func (e *EEA) buildRepositoryInfoWrapper( + ctx context.Context, + repoID uuid.UUID, +) (*engine.EntityInfoWrapper, error) { + r, err := util.GetRepository(ctx, e.querier, repoID) + if err != nil { + return nil, fmt.Errorf("error getting repository: %w", err) + } + + return engine.NewEntityInfoWrapper(). + WithRepository(r). + WithRepositoryID(repoID), nil +} + +func (e *EEA) buildArtifactInfoWrapper( + ctx context.Context, + repoID uuid.UUID, + artID uuid.NullUUID, +) (*engine.EntityInfoWrapper, error) { + a, err := util.GetArtifactWithVersions(ctx, e.querier, repoID, artID.UUID) + if err != nil { + return nil, fmt.Errorf("error getting artifact with versions: %w", err) + } + + return engine.NewEntityInfoWrapper(). + WithRepositoryID(repoID). + WithArtifact(a). + WithArtifactID(artID.UUID), nil +} + +func (e *EEA) buildPullRequestInfoWrapper( + ctx context.Context, + repoID uuid.UUID, + prID uuid.NullUUID, +) (*engine.EntityInfoWrapper, error) { + pr, err := util.GetPullRequest(ctx, e.querier, repoID, prID.UUID) + if err != nil { + return nil, fmt.Errorf("error getting pull request: %w", err) + } + + return engine.NewEntityInfoWrapper(). + WithRepositoryID(repoID). + WithPullRequest(pr). + WithPullRequestID(prID.UUID), nil +} diff --git a/internal/eea/eea_test.go b/internal/eea/eea_test.go new file mode 100644 index 0000000000..6534c7a27b --- /dev/null +++ b/internal/eea/eea_test.go @@ -0,0 +1,199 @@ +// +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eea_test + +import ( + "context" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/minder/internal/config" + "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/eea" + "github.com/stacklok/minder/internal/engine" + "github.com/stacklok/minder/internal/events" + minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" +) + +const ( + providerName = "test-provider" +) + +func TestAggregator(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var concurrentEvents int64 = 100 + + projectID, repoID := createNeededEntities(ctx, t) + + evt, err := events.Setup(ctx, &config.EventConfig{ + Driver: "go-channel", + GoChannel: config.GoChannelEventConfig{ + BufferSize: concurrentEvents, + BlockPublishUntilSubscriberAck: true, + }, + }) + require.NoError(t, err) + + // we'll wait 2 seconds for the lock to be available + var eventThreshold int64 = 2 + + aggr := eea.NewEEA(testQueries, evt, &config.AggregatorConfig{ + LockInterval: eventThreshold, + }) + + rateLimitedMessages := newTestPubSub() + flushedMessages := newTestPubSub() + + rateLimitedMessageTopic := t.Name() + + // This tests that the middleware works as expected + evt.Register(rateLimitedMessageTopic, rateLimitedMessages.Add, aggr.AggregateMiddleware) + + // This tests that flushing works as expected + evt.Register(engine.FlushEntityEventTopic, aggr.FlushMessageHandler) + + // This tests that flushing sends messages to the executor engine + evt.Register(engine.ExecuteEntityEventTopic, flushedMessages.Add, aggr.AggregateMiddleware) + + go func() { + t.Log("Running eventer") + err := evt.Run(ctx) + assert.NoError(t, err, "expected no error when running eventer") + }() + defer evt.Close() + + inf := engine.NewEntityInfoWrapper(). + WithRepository(&minderv1.Repository{}). + WithRepositoryID(repoID). + WithProjectID(projectID). + WithProvider(providerName) + + <-evt.Running() + + t.Log("Publishing events") + var wg sync.WaitGroup + for i := 0; i < int(concurrentEvents); i++ { + wg.Add(1) + go func() { + defer wg.Done() + msg, err := inf.BuildMessage() + require.NoError(t, err, "expected no error when building message") + err = evt.Publish(rateLimitedMessageTopic, msg.Copy()) + require.NoError(t, err, "expected no error when publishing message") + }() + } + + wg.Wait() + rateLimitedMessages.Wait() + + assert.Equal(t, int32(1), rateLimitedMessages.count.Load(), "expected only one message to be published") + + t.Log("Waiting for lock to be available") + time.Sleep(time.Duration(eventThreshold) * time.Second) + + t.Log("Publishing flush events") + var flushWg sync.WaitGroup + for i := 0; i < int(concurrentEvents); i++ { + flushWg.Add(1) + go func() { + defer flushWg.Done() + msg, err := inf.BuildMessage() + require.NoError(t, err, "expected no error when building message") + + err = evt.Publish(engine.FlushEntityEventTopic, msg.Copy()) + require.NoError(t, err, "expected no error when publishing message") + }() + } + + flushWg.Wait() + flushedMessages.Wait() + + // flushing should only happen once + assert.Equal(t, int32(1), flushedMessages.count.Load(), "expected only one message to be published") +} + +func createNeededEntities(ctx context.Context, t *testing.T) (projID uuid.UUID, repoID uuid.UUID) { + t.Helper() + + // setup project + proj, err := testQueries.CreateProject(ctx, db.CreateProjectParams{ + Name: "test-project", + Metadata: json.RawMessage("{}"), + }) + require.NoError(t, err, "expected no error when creating project") + + // setup provider + _, err = testQueries.CreateProvider(ctx, db.CreateProviderParams{ + Name: providerName, + ProjectID: proj.ID, + Implements: []db.ProviderType{db.ProviderTypeRest}, + Definition: json.RawMessage(`{}`), + }) + require.NoError(t, err, "expected no error when creating provider") + + // setup repo + repo, err := testQueries.CreateRepository(ctx, db.CreateRepositoryParams{ + ProjectID: proj.ID, + Provider: providerName, + RepoName: "test-repo", + RepoOwner: "test-owner", + RepoID: 123, + }) + require.NoError(t, err, "expected no error when creating repo") + + return proj.ID, repo.ID +} + +type testPubSub struct { + // counts the number of messages added + count *atomic.Int32 + firstMessageOnce *sync.Once + // allows us to wait for the first message to be added + firstMessage chan struct{} +} + +func newTestPubSub() *testPubSub { + var count atomic.Int32 + return &testPubSub{ + count: &count, + firstMessage: make(chan struct{}), + firstMessageOnce: &sync.Once{}, + } +} + +func (t *testPubSub) Wait() { + <-t.firstMessage +} + +func (t *testPubSub) Add(_ *message.Message) error { + t.count.Add(1) + t.firstMessageOnce.Do(func() { + t.firstMessage <- struct{}{} + }) + return nil +} diff --git a/internal/eea/main_test.go b/internal/eea/main_test.go new file mode 100644 index 0000000000..381b7cb38e --- /dev/null +++ b/internal/eea/main_test.go @@ -0,0 +1,97 @@ +// +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eea_test + +import ( + "database/sql" + "log" + "os" + "testing" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" // nolint + _ "github.com/golang-migrate/migrate/v4/source/file" // nolint + _ "github.com/lib/pq" + + "github.com/stacklok/minder/internal/db" +) + +var testQueries db.Store +var testDB *sql.DB + +func TestMain(m *testing.M) { + os.Exit(runTestWithInProcessPostgres(m)) +} + +func runTestWithInProcessPostgres(m *testing.M) int { + tmpName, err := os.MkdirTemp("", "mediator-db-test") + if err != nil { + log.Println("cannot create tmpdir:", err) + return -1 + } + + defer func() { + if err := os.RemoveAll(tmpName); err != nil { + log.Println("cannot remove tmpdir:", err) + } + }() + + dbCfg := embeddedpostgres.DefaultConfig(). + Database("mediator"). + RuntimePath(tmpName). + Port(5434) + postgres := embeddedpostgres.NewDatabase(dbCfg) + + if err := postgres.Start(); err != nil { + log.Println("cannot start postgres:", err) + return -1 + } + defer func() { + if err := postgres.Stop(); err != nil { + log.Println("cannot stop postgres:", err) + } + }() + + testDB, err = sql.Open("postgres", "user=postgres dbname=mediator password=postgres host=localhost port=5434 sslmode=disable") + if err != nil { + log.Println("cannot connect to db test instance:", err) + return -1 + } + + configPath := "file://../../database/migrations" + mig, err := migrate.New(configPath, dbCfg.GetConnectionURL()+"?sslmode=disable") + if err != nil { + log.Printf("Error while creating migration instance (%s): %v\n", configPath, err) + return -1 + } + + if err := mig.Up(); err != nil { + log.Println("cannot run db migrations:", err) + return -1 + } + + defer func() { + if err := testDB.Close(); err != nil { + log.Println("cannot close test db:", err) + } + }() + + testQueries = db.NewStore(testDB) + + // Run tests + return m.Run() +} diff --git a/internal/engine/entity_event.go b/internal/engine/entity_event.go index e118782c5a..7265e50605 100644 --- a/internal/engine/entity_event.go +++ b/internal/engine/entity_event.go @@ -52,6 +52,7 @@ type EntityInfoWrapper struct { Entity protoreflect.ProtoMessage Type minderv1.Entity OwnershipData map[string]string + ExecutionID *uuid.UUID } const ( @@ -76,6 +77,8 @@ const ( ArtifactIDEventKey = "artifact_id" // PullRequestIDEventKey is the key for the pull request ID PullRequestIDEventKey = "pull_request_id" + // ExecutionIDKey is the key for the execution ID. This is set when acquiring a lock. + ExecutionIDKey = "execution_id" ) // NewEntityInfoWrapper creates a new EntityInfoWrapper @@ -144,6 +147,13 @@ func (eiw *EntityInfoWrapper) WithPullRequestID(id uuid.UUID) *EntityInfoWrapper return eiw } +// WithExecutionID sets the execution ID +func (eiw *EntityInfoWrapper) WithExecutionID(id uuid.UUID) *EntityInfoWrapper { + eiw.ExecutionID = &id + + return eiw +} + // AsRepository sets the entity type to a repository func (eiw *EntityInfoWrapper) AsRepository() *EntityInfoWrapper { eiw.Type = minderv1.Entity_ENTITY_REPOSITORIES @@ -188,7 +198,7 @@ func (eiw *EntityInfoWrapper) Publish(evt *events.Eventer) error { return err } - if err := evt.Publish(InternalEntityEventTopic, msg); err != nil { + if err := evt.Publish(ExecuteEntityEventTopic, msg); err != nil { return fmt.Errorf("error publishing entity event: %w", err) } @@ -210,6 +220,10 @@ func (eiw *EntityInfoWrapper) ToMessage(msg *message.Message) error { return fmt.Errorf("provider is required") } + if eiw.ExecutionID != nil { + msg.Metadata.Set(ExecutionIDKey, eiw.ExecutionID.String()) + } + msg.Metadata.Set(ProviderEventKey, eiw.Provider) msg.Metadata.Set(EntityTypeEventKey, typ) msg.Metadata.Set(ProjectIDEventKey, eiw.ProjectID.String()) @@ -224,6 +238,30 @@ func (eiw *EntityInfoWrapper) ToMessage(msg *message.Message) error { return nil } +// GetEntityDBIDs returns the repository, artifact and pull request IDs +// from the ownership data +func (eiw *EntityInfoWrapper) GetEntityDBIDs() (repoID uuid.UUID, artifactID uuid.NullUUID, pullRequestID uuid.NullUUID) { + repoID = uuid.MustParse(eiw.OwnershipData[RepositoryIDEventKey]) + + strArtifactID, ok := eiw.OwnershipData[ArtifactIDEventKey] + if ok { + artifactID = uuid.NullUUID{ + UUID: uuid.MustParse(strArtifactID), + Valid: true, + } + } + + strPullRequestID, ok := eiw.OwnershipData[PullRequestIDEventKey] + if ok { + pullRequestID = uuid.NullUUID{ + UUID: uuid.MustParse(strPullRequestID), + Valid: true, + } + } + + return repoID, artifactID, pullRequestID +} + func (eiw *EntityInfoWrapper) withProjectIDFromMessage(msg *message.Message) error { rawID := msg.Metadata.Get(ProjectIDEventKey) if rawID == "" { @@ -261,6 +299,21 @@ func (eiw *EntityInfoWrapper) withPullRequestIDFromMessage(msg *message.Message) return eiw.withIDFromMessage(msg, PullRequestIDEventKey) } +func (eiw *EntityInfoWrapper) withExecutionIDFromMessage(msg *message.Message) error { + executionID := msg.Metadata.Get(ExecutionIDKey) + if executionID == "" { + return fmt.Errorf("%s not found in metadata", ExecutionIDKey) + } + + id, err := uuid.Parse(executionID) + if err != nil { + return fmt.Errorf("error parsing execution ID: %w", err) + } + + eiw.ExecutionID = &id + return nil +} + func (eiw *EntityInfoWrapper) withIDFromMessage(msg *message.Message, key string) error { id, err := getIDFromMessage(msg, key) if err != nil { @@ -305,7 +358,8 @@ func getIDFromMessage(msg *message.Message, key string) (string, error) { return rawID, nil } -func parseEntityEvent(msg *message.Message) (*EntityInfoWrapper, error) { +// ParseEntityEvent parses a message.Message and returns an EntityInfoWrapper +func ParseEntityEvent(msg *message.Message) (*EntityInfoWrapper, error) { out := &EntityInfoWrapper{ OwnershipData: make(map[string]string), } diff --git a/internal/engine/entity_event_test.go b/internal/engine/entity_event_test.go index 83b381b83b..3835071447 100644 --- a/internal/engine/entity_event_test.go +++ b/internal/engine/entity_event_test.go @@ -163,7 +163,7 @@ func Test_parseEntityEvent(t *testing.T) { msg.Metadata.Set(PullRequestIDEventKey, tt.args.ownership["pull_request_id"]) } - got, err := parseEntityEvent(msg) + got, err := ParseEntityEvent(msg) if tt.wantErr { require.Error(t, err, "expected error") require.Nil(t, got, "expected nil entity info") diff --git a/internal/engine/eval_status.go b/internal/engine/eval_status.go index 0139143599..149e0eabf9 100644 --- a/internal/engine/eval_status.go +++ b/internal/engine/eval_status.go @@ -42,28 +42,16 @@ func (e *Executor) createEvalStatusParams( return nil, fmt.Errorf("error parsing profile ID: %w", err) } - params := &engif.EvalStatusParams{ - Rule: rule, - Profile: profile, - ProfileID: profileID, - RepoID: uuid.MustParse(inf.OwnershipData[RepositoryIDEventKey]), - EntityType: entities.EntityTypeToDB(inf.Type), - } + repoID, artID, prID := inf.GetEntityDBIDs() - artifactID, ok := inf.OwnershipData[ArtifactIDEventKey] - if ok { - params.ArtifactID = uuid.NullUUID{ - UUID: uuid.MustParse(artifactID), - Valid: true, - } - } - - pullRequestID, ok := inf.OwnershipData[PullRequestIDEventKey] - if ok { - params.PullRequestID = uuid.NullUUID{ - UUID: uuid.MustParse(pullRequestID), - Valid: true, - } + params := &engif.EvalStatusParams{ + Rule: rule, + Profile: profile, + ProfileID: profileID, + EntityType: entities.EntityTypeToDB(inf.Type), + RepoID: repoID, + ArtifactID: artID, + PullRequestID: prID, } // Prepare params for fetching the current rule evaluation from the database diff --git a/internal/engine/executor.go b/internal/engine/executor.go index 3c3e2e68e3..81ccb73d46 100644 --- a/internal/engine/executor.go +++ b/internal/engine/executor.go @@ -17,6 +17,8 @@ package engine import ( "context" "fmt" + "sync" + "time" "github.com/ThreeDotsLabs/watermill/message" "github.com/google/uuid" @@ -28,6 +30,7 @@ import ( evalerrors "github.com/stacklok/minder/internal/engine/errors" "github.com/stacklok/minder/internal/engine/ingestcache" engif "github.com/stacklok/minder/internal/engine/interfaces" + "github.com/stacklok/minder/internal/entities" "github.com/stacklok/minder/internal/events" "github.com/stacklok/minder/internal/providers" providertelemetry "github.com/stacklok/minder/internal/providers/telemetry" @@ -35,15 +38,29 @@ import ( ) const ( - // InternalEntityEventTopic is the topic for internal webhook events - InternalEntityEventTopic = "internal.entity.event" + // ExecuteEntityEventTopic is the topic for internal webhook events + ExecuteEntityEventTopic = "execute.entity.event" + // FlushEntityEventTopic is the topic for flushing internal webhook events + FlushEntityEventTopic = "flush.entity.event" +) + +const ( + // DefaultExecutionTimeout is the timeout for execution of a set + // of profiles on an entity. + DefaultExecutionTimeout = 5 * time.Minute ) // Executor is the engine that executes the rules for a given event type Executor struct { - querier db.Store - crypteng *crypto.Engine - provMt providertelemetry.ProviderMetrics + querier db.Store + evt *events.Eventer + crypteng *crypto.Engine + provMt providertelemetry.ProviderMetrics + aggrMdw events.AggregatorMiddleware + executions *sync.WaitGroup + // terminationcontext is used to terminate the executor + // when the server is shutting down. + terminationcontext context.Context } // ExecutorOption is a function that modifies an executor @@ -56,10 +73,19 @@ func WithProviderMetrics(mt providertelemetry.ProviderMetrics) ExecutorOption { } } +// WithAggregatorMiddleware sets the aggregator middleware for the executor +func WithAggregatorMiddleware(mdw events.AggregatorMiddleware) ExecutorOption { + return func(e *Executor) { + e.aggrMdw = mdw + } +} + // NewExecutor creates a new executor func NewExecutor( + ctx context.Context, querier db.Store, authCfg *config.AuthConfig, + evt *events.Eventer, opts ...ExecutorOption, ) (*Executor, error) { crypteng, err := crypto.EngineFromAuthConfig(authCfg) @@ -68,9 +94,12 @@ func NewExecutor( } e := &Executor{ - querier: querier, - crypteng: crypteng, - provMt: providertelemetry.NewNoopMetrics(), + querier: querier, + crypteng: crypteng, + provMt: providertelemetry.NewNoopMetrics(), + evt: evt, + executions: &sync.WaitGroup{}, + terminationcontext: ctx, } for _, opt := range opts { @@ -82,18 +111,57 @@ func NewExecutor( // Register implements the Consumer interface. func (e *Executor) Register(r events.Registrar) { - r.Register(InternalEntityEventTopic, e.HandleEntityEvent) + if e.aggrMdw == nil { + r.Register(ExecuteEntityEventTopic, e.HandleEntityEvent) + } else { + r.Register(ExecuteEntityEventTopic, e.HandleEntityEvent, + e.aggrMdw.AggregateMiddleware) + } +} + +// Wait waits for all the executions to finish. +func (e *Executor) Wait() { + e.executions.Wait() } // HandleEntityEvent handles events coming from webhooks/signals // as well as the init event. func (e *Executor) HandleEntityEvent(msg *message.Message) error { - inf, err := parseEntityEvent(msg) + // Let's not share memory with the caller + msg = msg.Copy() + + inf, err := ParseEntityEvent(msg) if err != nil { return fmt.Errorf("error unmarshalling payload: %w", err) } - ctx := msg.Context() + e.executions.Add(1) + go func() { + defer e.executions.Done() + // TODO: Make this timeout configurable + ctx, cancel := context.WithTimeout(e.terminationcontext, DefaultExecutionTimeout) + defer cancel() + + if err := inf.withExecutionIDFromMessage(msg); err != nil { + logger := zerolog.Ctx(ctx) + logger.Info(). + Str("message_id", msg.UUID). + Msg("message does not contain execution ID, skipping") + return + } + + if err := e.prepAndEvalEntityEvent(ctx, inf); err != nil { + zerolog.Ctx(ctx).Info(). + Str("project", inf.ProjectID.String()). + Str("provider", inf.Provider). + Str("entity", inf.Type.String()). + Err(err).Msg("got error while evaluating entity event") + } + }() + + return nil +} +func (e *Executor) prepAndEvalEntityEvent(ctx context.Context, inf *EntityInfoWrapper) error { projectID := inf.ProjectID @@ -145,6 +213,8 @@ func (e *Executor) evalEntityEvent( // access. ingestCache := ingestcache.NewCache() + defer e.releaseLockAndFlush(ctx, inf) + // Get profiles relevant to group dbpols, err := e.querier.ListProfilesByProjectID(ctx, *inf.ProjectID) if err != nil { @@ -166,6 +236,9 @@ func (e *Executor) evalEntityEvent( return err } + // Update the lock lease at the end of the evaluation + defer e.updateLockLease(ctx, *inf.ExecutionID, evalParams) + // Evaluate the rule evalParams.SetEvalErr(rte.Eval(ctx, inf, evalParams)) @@ -187,6 +260,7 @@ func (e *Executor) evalEntityEvent( return fmt.Errorf("error traversing rules for profile %s: %w", p, err) } } + return nil } @@ -243,6 +317,77 @@ func (e *Executor) getEvaluator( return params, rte, nil } +func (e *Executor) updateLockLease( + ctx context.Context, + executionID uuid.UUID, + params *engif.EvalStatusParams, +) { + logger := zerolog.Ctx(ctx).Info(). + Str("entity_type", string(params.EntityType)). + Str("execution_id", executionID.String()). + Str("repo_id", params.RepoID.String()) + if params.ArtifactID.Valid { + logger = logger.Str("artifact_id", params.ArtifactID.UUID.String()) + } + if params.PullRequestID.Valid { + logger = logger.Str("pull_request_id", params.PullRequestID.UUID.String()) + } + + if err := e.querier.UpdateLease(ctx, db.UpdateLeaseParams{ + Entity: params.EntityType, + RepositoryID: params.RepoID, + ArtifactID: params.ArtifactID, + PullRequestID: params.PullRequestID, + LockedBy: executionID, + }); err != nil { + logger.Err(err).Msg("error updating lock lease") + return + } + + logger.Msg("lock lease updated") +} + +func (e *Executor) releaseLockAndFlush( + ctx context.Context, + inf *EntityInfoWrapper, +) { + repoID, artID, prID := inf.GetEntityDBIDs() + + logger := zerolog.Ctx(ctx).Info(). + Str("entity_type", inf.Type.ToString()). + Str("execution_id", inf.ExecutionID.String()). + Str("repo_id", repoID.String()) + + if artID.Valid { + logger = logger.Str("artifact_id", artID.UUID.String()) + } + if prID.Valid { + logger = logger.Str("pull_request_id", prID.UUID.String()) + } + + if err := e.querier.ReleaseLock(ctx, db.ReleaseLockParams{ + Entity: entities.EntityTypeToDB(inf.Type), + RepositoryID: repoID, + ArtifactID: artID, + PullRequestID: prID, + LockedBy: *inf.ExecutionID, + }); err != nil { + logger.Err(err).Msg("error updating lock lease") + } + + // We don't need to unset the execution ID because the event is going to be + // deleted from the database anyway. The aggregator will take care of that. + msg, err := inf.BuildMessage() + if err != nil { + logger.Err(err).Msg("error building message") + return + } + + if err := e.evt.Publish(FlushEntityEventTopic, msg); err != nil { + logger.Err(err).Msg("error publishing flush event") + } +} + func logEval( ctx context.Context, inf *EntityInfoWrapper, diff --git a/internal/engine/executor_test.go b/internal/engine/executor_test.go index bfce3a9781..8d3f7c272f 100644 --- a/internal/engine/executor_test.go +++ b/internal/engine/executor_test.go @@ -15,6 +15,7 @@ package engine_test import ( + "context" "encoding/base64" "encoding/json" "os" @@ -32,6 +33,8 @@ import ( "github.com/stacklok/minder/internal/crypto" "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" + "github.com/stacklok/minder/internal/events" + "github.com/stacklok/minder/internal/util/testqueue" minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) @@ -77,6 +80,7 @@ func TestExecutor_handleEntityEvent(t *testing.T) { profileID := uuid.New() ruleTypeID := uuid.New() repositoryID := uuid.New() + executionID := uuid.New() authtoken := generateFakeAccessToken(t) @@ -225,6 +229,27 @@ default allow = true`, Metadata: meta, Details: "", }).Return(ruleEvalAlertId, nil) + + // Mock update lease for lock + mockStore.EXPECT(). + UpdateLease(gomock.Any(), db.UpdateLeaseParams{ + Entity: db.EntitiesRepository, + RepositoryID: repositoryID, + ArtifactID: uuid.NullUUID{}, + PullRequestID: uuid.NullUUID{}, + LockedBy: executionID, + }).Return(nil) + + // Mock release lock + mockStore.EXPECT(). + ReleaseLock(gomock.Any(), db.ReleaseLockParams{ + Entity: db.EntitiesRepository, + RepositoryID: repositoryID, + ArtifactID: uuid.NullUUID{}, + PullRequestID: uuid.NullUUID{}, + LockedBy: executionID, + }).Return(nil) + // -- end expectations tmpdir := t.TempDir() @@ -235,11 +260,34 @@ default allow = true`, err = os.WriteFile(tokenKeyPath, []byte(fakeTokenKey), 0600) require.NoError(t, err, "expected no error") - e, err := engine.NewExecutor(mockStore, &config.AuthConfig{ - TokenKey: tokenKeyPath, + evt, err := events.Setup(context.Background(), &config.EventConfig{ + Driver: "go-channel", + GoChannel: config.GoChannelEventConfig{ + BlockPublishUntilSubscriberAck: true, + }, }) + require.NoError(t, err, "failed to setup eventer") + + go func() { + t.Log("Running eventer") + err := evt.Run(context.Background()) + require.NoError(t, err, "failed to run eventer") + }() + + pq := testqueue.NewPassthroughQueue() + queued := pq.GetQueue() + + testTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + e, err := engine.NewExecutor(ctx, mockStore, &config.AuthConfig{ + TokenKey: tokenKeyPath, + }, evt) require.NoError(t, err, "expected no error") + evt.Register(engine.FlushEntityEventTopic, pq.Pass) + eiw := engine.NewEntityInfoWrapper(). WithProvider(providerName). WithProjectID(projectID). @@ -247,10 +295,26 @@ default allow = true`, Name: "test", RepoId: 123, CloneUrl: "github.com/foo/bar.git", - }).WithRepositoryID(repositoryID) + }).WithRepositoryID(repositoryID). + WithExecutionID(executionID) msg, err := eiw.BuildMessage() require.NoError(t, err, "expected no error") - require.NoError(t, e.HandleEntityEvent(msg), "expected no error") + // Run in the background + go func() { + t.Log("Running entity event handler") + require.NoError(t, e.HandleEntityEvent(msg), "expected no error") + }() + + t.Log("waiting for eventer to start") + <-evt.Running() + + // expect flush + t.Log("waiting for flush") + require.NotNil(t, <-queued, "expected message") + + require.NoError(t, evt.Close(), "expected no error") + + e.Wait() } diff --git a/internal/events/eventer.go b/internal/events/eventer.go index ed12903289..009d5b8015 100644 --- a/internal/events/eventer.go +++ b/internal/events/eventer.go @@ -62,7 +62,7 @@ type Registrar interface { // functions, or to call Register multiple times with different topics and the same // handler function. It's allowed to call Register with both argument the same, but // then events will be delivered twice to the handler, which is probably not what you want. - Register(topic string, handler Handler) + Register(topic string, handler Handler, mdw ...message.HandlerMiddleware) // HandleAll registers all the consumers with the registrar // TODO: should this be a different interface? @@ -76,6 +76,12 @@ type Consumer interface { Register(Registrar) } +// AggregatorMiddleware is an interface that allows the eventer to +// add middleware to the router +type AggregatorMiddleware interface { + AggregateMiddleware(h message.HandlerFunc) message.HandlerFunc +} + // Eventer is a wrapper over the relevant eventing objects in such // a way that they can be easily accessible and configurable. type Eventer struct { @@ -215,10 +221,11 @@ func (e *Eventer) Publish(topic string, messages ...*message.Message) error { func (e *Eventer) Register( topic string, handler message.NoPublishHandlerFunc, + mdw ...message.HandlerMiddleware, ) { // From https://stackoverflow.com/questions/7052693/how-to-get-the-name-of-a-function-in-go funcName := fmt.Sprintf("%s-%s", runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(), topic) - e.router.AddNoPublisherHandler( + hand := e.router.AddNoPublisherHandler( funcName, topic, e.webhookSubscriber, @@ -243,6 +250,10 @@ func (e *Eventer) Register( return nil }, ) + + for _, m := range mdw { + hand.AddMiddleware(m) + } } // ConsumeEvents allows registration of multiple consumers easily diff --git a/internal/util/helpers.go b/internal/util/helpers.go index ce88437e63..abc7765c96 100644 --- a/internal/util/helpers.go +++ b/internal/util/helpers.go @@ -525,8 +525,31 @@ func Int32FromString(v string) (int32, error) { return int32(asInt32), nil } +// GetRepository retrieves a repository from the database +// and converts it to a protobuf +func GetRepository(ctx context.Context, store db.ExtendQuerier, repoID uuid.UUID) (*minderv1.Repository, error) { + dbrepo, err := store.GetRepositoryByID(ctx, repoID) + if err != nil { + return nil, fmt.Errorf("error getting repository: %w", err) + } + + strRepoID := repoID.String() + return &minderv1.Repository{ + Id: &strRepoID, + Owner: dbrepo.RepoOwner, + Name: dbrepo.RepoName, + RepoId: dbrepo.RepoID, + HookUrl: dbrepo.WebhookUrl, + DeployUrl: dbrepo.DeployUrl, + CloneUrl: dbrepo.CloneUrl, + CreatedAt: timestamppb.New(dbrepo.CreatedAt), + UpdatedAt: timestamppb.New(dbrepo.UpdatedAt), + }, nil +} + // GetArtifactWithVersions retrieves an artifact and its versions from the database -func GetArtifactWithVersions(ctx context.Context, store db.Store, repoID, artifactID uuid.UUID) (*minderv1.Artifact, error) { +func GetArtifactWithVersions( + ctx context.Context, store db.ExtendQuerier, repoID, artifactID uuid.UUID) (*minderv1.Artifact, error) { // Get repository data - we need the owner and name dbrepo, err := store.GetRepositoryByID(ctx, repoID) if errors.Is(err, sql.ErrNoRows) { @@ -595,3 +618,33 @@ func GetArtifactWithVersions(ctx context.Context, store db.Store, repoID, artifa CreatedAt: timestamppb.New(artifact.CreatedAt), }, nil } + +// GetPullRequest retrieves a pull request from the database +// and converts it to a protobuf +func GetPullRequest( + ctx context.Context, + store db.ExtendQuerier, + repoID, pullRequestID uuid.UUID, +) (*minderv1.PullRequest, error) { + // Get repository data - we need the owner and name + dbrepo, err := store.GetRepositoryByID(ctx, repoID) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("repository not found") + } else if err != nil { + return nil, fmt.Errorf("cannot read repository: %v", err) + } + + dbpr, err := store.GetPullRequestByID(ctx, pullRequestID) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("pull request not found") + } else if err != nil { + return nil, fmt.Errorf("cannot read pull request: %v", err) + } + + // TODO: Do we need extra columns in the pull request table? + return &minderv1.PullRequest{ + Number: int32(dbpr.PrNumber), // TODO: this should be int64 + RepoOwner: dbrepo.RepoOwner, + RepoName: dbrepo.RepoName, + }, nil +} diff --git a/internal/util/testqueue/passthroughqueue.go b/internal/util/testqueue/passthroughqueue.go new file mode 100644 index 0000000000..bfbdcdcea5 --- /dev/null +++ b/internal/util/testqueue/passthroughqueue.go @@ -0,0 +1,43 @@ +// +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testqueue contains queue utilities for testing +package testqueue + +import "github.com/ThreeDotsLabs/watermill/message" + +// PassthroughQueue is a queue that passes messages through. +// It's only useful for testing. +type PassthroughQueue struct { + ch chan *message.Message +} + +// NewPassthroughQueue creates a new PassthroughQueue +func NewPassthroughQueue() *PassthroughQueue { + return &PassthroughQueue{ + ch: make(chan *message.Message), + } +} + +// GetQueue returns the queue +func (q *PassthroughQueue) GetQueue() <-chan *message.Message { + return q.ch +} + +// Pass passes a message through the queue +func (q *PassthroughQueue) Pass(msg *message.Message) error { + q.ch <- msg + return nil +}