diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index 129bb8ab6e..1767cfb4a1 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -31,12 +31,12 @@ import ( serverconfig "github.com/stacklok/minder/internal/config/server" "github.com/stacklok/minder/internal/db" - "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/engine/actions" "github.com/stacklok/minder/internal/engine/entities" "github.com/stacklok/minder/internal/engine/errors" "github.com/stacklok/minder/internal/engine/eval/rego" engif "github.com/stacklok/minder/internal/engine/interfaces" + "github.com/stacklok/minder/internal/engine/ruleengine" "github.com/stacklok/minder/internal/logger" "github.com/stacklok/minder/internal/profiles" "github.com/stacklok/minder/internal/providers/credentials" @@ -159,7 +159,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { off := "off" profile.Alert = &off - rules, err := engine.GetRulesFromProfileOfType(profile, ruletype) + ruleInstances, err := ruleengine.GetRulesFromProfileOfType(profile, ruletype) if err != nil { return fmt.Errorf("error getting relevant fragment: %w", err) } @@ -172,7 +172,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { // TODO: use cobra context here ctx := context.Background() - eng, err := engine.NewRuleTypeEngine(ctx, ruletype, prov) + eng, err := ruleengine.NewRuleTypeEngine(ctx, ruletype, prov) if err != nil { return fmt.Errorf("cannot create rule type engine: %w", err) } @@ -189,16 +189,16 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { return fmt.Errorf("error creating rule type engine: %w", err) } - if len(rules) == 0 { + if len(ruleInstances) == 0 { return fmt.Errorf("no rules found with type %s", ruletype.Name) } - return runEvaluationForRules(cmd, eng, inf, remediateStatus, remMetadata, rules, actionEngine) + return runEvaluationForRules(cmd, eng, inf, remediateStatus, remMetadata, ruleInstances, actionEngine) } func runEvaluationForRules( cmd *cobra.Command, - eng *engine.RuleTypeEngine, + eng *ruleengine.RuleTypeEngine, inf *entities.EntityInfoWrapper, remediateStatus db.NullRemediationStatusTypes, remMetadata pqtype.NullRawMessage, diff --git a/database/mock/store.go b/database/mock/store.go index 2a882d6a49..cff3125300 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -1331,6 +1331,36 @@ func (mr *MockStoreMockRecorder) GetRuleTypeByName(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuleTypeByName", reflect.TypeOf((*MockStore)(nil).GetRuleTypeByName), arg0, arg1) } +// GetRuleTypeIDByRuleNameEntityProfile mocks base method. +func (m *MockStore) GetRuleTypeIDByRuleNameEntityProfile(arg0 context.Context, arg1 db.GetRuleTypeIDByRuleNameEntityProfileParams) (uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRuleTypeIDByRuleNameEntityProfile", arg0, arg1) + ret0, _ := ret[0].(uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRuleTypeIDByRuleNameEntityProfile indicates an expected call of GetRuleTypeIDByRuleNameEntityProfile. +func (mr *MockStoreMockRecorder) GetRuleTypeIDByRuleNameEntityProfile(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuleTypeIDByRuleNameEntityProfile", reflect.TypeOf((*MockStore)(nil).GetRuleTypeIDByRuleNameEntityProfile), arg0, arg1) +} + +// GetRuleTypesByEntityInHierarchy mocks base method. +func (m *MockStore) GetRuleTypesByEntityInHierarchy(arg0 context.Context, arg1 db.GetRuleTypesByEntityInHierarchyParams) ([]db.RuleType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRuleTypesByEntityInHierarchy", arg0, arg1) + ret0, _ := ret[0].([]db.RuleType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRuleTypesByEntityInHierarchy indicates an expected call of GetRuleTypesByEntityInHierarchy. +func (mr *MockStoreMockRecorder) GetRuleTypesByEntityInHierarchy(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuleTypesByEntityInHierarchy", reflect.TypeOf((*MockStore)(nil).GetRuleTypesByEntityInHierarchy), arg0, arg1) +} + // GetSelectorByID mocks base method. func (m *MockStore) GetSelectorByID(arg0 context.Context, arg1 uuid.UUID) (db.ProfileSelector, error) { m.ctrl.T.Helper() diff --git a/database/query/rule_instances.sql b/database/query/rule_instances.sql index f27a8557e6..d0def68871 100644 --- a/database/query/rule_instances.sql +++ b/database/query/rule_instances.sql @@ -57,4 +57,12 @@ AND NOT id = ANY(sqlc.arg(updated_ids)::UUID[]); SELECT id FROM rule_instances WHERE profile_id = $1 AND entity_type = $2 -AND name = $3; \ No newline at end of file +AND name = $3; + +-- intended as a temporary transition query +-- this will be removed once rule_instances is used consistently in the engine +-- name: GetRuleTypeIDByRuleNameEntityProfile :one +SELECT rule_type_id FROM rule_instances +WHERE name = $1 +AND entity_type = $2 +AND profile_id = $3; \ No newline at end of file diff --git a/database/query/rule_types.sql b/database/query/rule_types.sql index 5b3227ad08..2743fd21d2 100644 --- a/database/query/rule_types.sql +++ b/database/query/rule_types.sql @@ -36,3 +36,9 @@ UPDATE rule_type SET description = $2, definition = sqlc.arg(definition)::jsonb, severity_value = sqlc.arg(severity_value), display_name = sqlc.arg(display_name) WHERE id = $1 RETURNING *; + +-- name: GetRuleTypesByEntityInHierarchy :many +SELECT rt.* FROM rule_type AS rt +JOIN rule_instances AS ri ON ri.rule_type_id = rt.id +WHERE ri.entity_type = $1 +AND ri.project_id = ANY(sqlc.arg(projects)::uuid[]); diff --git a/internal/db/querier.go b/internal/db/querier.go index b0363aa658..7e3908000c 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -144,6 +144,10 @@ type Querier interface { GetRuleInstancesForProfileEntity(ctx context.Context, arg GetRuleInstancesForProfileEntityParams) ([]RuleInstance, error) GetRuleTypeByID(ctx context.Context, id uuid.UUID) (RuleType, error) GetRuleTypeByName(ctx context.Context, arg GetRuleTypeByNameParams) (RuleType, error) + // intended as a temporary transition query + // this will be removed once rule_instances is used consistently in the engine + GetRuleTypeIDByRuleNameEntityProfile(ctx context.Context, arg GetRuleTypeIDByRuleNameEntityProfileParams) (uuid.UUID, error) + GetRuleTypesByEntityInHierarchy(ctx context.Context, arg GetRuleTypesByEntityInHierarchyParams) ([]RuleType, error) GetSelectorByID(ctx context.Context, id uuid.UUID) (ProfileSelector, error) GetSelectorsByProfileID(ctx context.Context, profileID uuid.UUID) ([]ProfileSelector, error) GetSubscriptionByProjectBundle(ctx context.Context, arg GetSubscriptionByProjectBundleParams) (Subscription, error) diff --git a/internal/db/rule_instances.sql.go b/internal/db/rule_instances.sql.go index 4e334399ed..b17daa03e0 100644 --- a/internal/db/rule_instances.sql.go +++ b/internal/db/rule_instances.sql.go @@ -132,6 +132,28 @@ func (q *Queries) GetRuleInstancesForProfileEntity(ctx context.Context, arg GetR return items, nil } +const getRuleTypeIDByRuleNameEntityProfile = `-- name: GetRuleTypeIDByRuleNameEntityProfile :one +SELECT rule_type_id FROM rule_instances +WHERE name = $1 +AND entity_type = $2 +AND profile_id = $3 +` + +type GetRuleTypeIDByRuleNameEntityProfileParams struct { + Name string `json:"name"` + EntityType Entities `json:"entity_type"` + ProfileID uuid.UUID `json:"profile_id"` +} + +// intended as a temporary transition query +// this will be removed once rule_instances is used consistently in the engine +func (q *Queries) GetRuleTypeIDByRuleNameEntityProfile(ctx context.Context, arg GetRuleTypeIDByRuleNameEntityProfileParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, getRuleTypeIDByRuleNameEntityProfile, arg.Name, arg.EntityType, arg.ProfileID) + var rule_type_id uuid.UUID + err := row.Scan(&rule_type_id) + return rule_type_id, err +} + const upsertRuleInstance = `-- name: UpsertRuleInstance :one INSERT INTO rule_instances ( diff --git a/internal/db/rule_types.sql.go b/internal/db/rule_types.sql.go index f5735727ff..f3f035e237 100644 --- a/internal/db/rule_types.sql.go +++ b/internal/db/rule_types.sql.go @@ -140,6 +140,55 @@ func (q *Queries) GetRuleTypeByName(ctx context.Context, arg GetRuleTypeByNamePa return i, err } +const getRuleTypesByEntityInHierarchy = `-- name: GetRuleTypesByEntityInHierarchy :many +SELECT rt.id, rt.name, rt.provider, rt.project_id, rt.description, rt.guidance, rt.definition, rt.created_at, rt.updated_at, rt.severity_value, rt.provider_id, rt.subscription_id, rt.display_name FROM rule_type AS rt +JOIN rule_instances AS ri ON ri.rule_type_id = rt.id +WHERE ri.entity_type = $1 +AND ri.project_id = ANY($2::uuid[]) +` + +type GetRuleTypesByEntityInHierarchyParams struct { + EntityType Entities `json:"entity_type"` + Projects []uuid.UUID `json:"projects"` +} + +func (q *Queries) GetRuleTypesByEntityInHierarchy(ctx context.Context, arg GetRuleTypesByEntityInHierarchyParams) ([]RuleType, error) { + rows, err := q.db.QueryContext(ctx, getRuleTypesByEntityInHierarchy, arg.EntityType, pq.Array(arg.Projects)) + if err != nil { + return nil, err + } + defer rows.Close() + items := []RuleType{} + for rows.Next() { + var i RuleType + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Provider, + &i.ProjectID, + &i.Description, + &i.Guidance, + &i.Definition, + &i.CreatedAt, + &i.UpdatedAt, + &i.SeverityValue, + &i.ProviderID, + &i.SubscriptionID, + &i.DisplayName, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listRuleTypesByProject = `-- name: ListRuleTypesByProject :many SELECT id, name, provider, project_id, description, guidance, definition, created_at, updated_at, severity_value, provider_id, subscription_id, display_name FROM rule_type WHERE project_id = $1 ` diff --git a/internal/engine/actions/alert/alert.go b/internal/engine/actions/alert/alert.go index 8c42372858..b33e9c6a3a 100644 --- a/internal/engine/actions/alert/alert.go +++ b/internal/engine/actions/alert/alert.go @@ -56,7 +56,12 @@ func NewRuleAlert( Msg("provider is not a GitHub provider. Silently skipping alerts.") return noop.NewNoopAlert(ActionType) } - return security_advisory.NewSecurityAdvisoryAlert(ActionType, ruletype.GetSeverity(), alertCfg.GetSecurityAdvisory(), client) + return security_advisory.NewSecurityAdvisoryAlert( + ActionType, + ruletype, + alertCfg.GetSecurityAdvisory(), + client, + ) } return nil, fmt.Errorf("unknown alert type: %s", alertCfg.GetType()) diff --git a/internal/engine/actions/alert/security_advisory/security_advisory.go b/internal/engine/actions/alert/security_advisory/security_advisory.go index 5007ebf4e3..d4eb718ccd 100644 --- a/internal/engine/actions/alert/security_advisory/security_advisory.go +++ b/internal/engine/actions/alert/security_advisory/security_advisory.go @@ -98,7 +98,7 @@ If you have any questions or believe that this evaluation is incorrect, please d type Alert struct { actionType interfaces.ActionType cli provifv1.GitHub - sev *pb.Severity + ruleType *pb.RuleType saCfg *pb.RuleType_Definition_Alert_AlertTypeSA summaryTmpl *htmltemplate.Template descriptionTmpl *htmltemplate.Template @@ -133,7 +133,7 @@ type alertMetadata struct { // NewSecurityAdvisoryAlert creates a new security-advisory alert action func NewSecurityAdvisoryAlert( actionType interfaces.ActionType, - sev *pb.Severity, + ruleType *pb.RuleType, saCfg *pb.RuleType_Definition_Alert_AlertTypeSA, cli provifv1.GitHub, ) (*Alert, error) { @@ -160,7 +160,7 @@ func NewSecurityAdvisoryAlert( return &Alert{ actionType: actionType, cli: cli, - sev: sev, + ruleType: ruleType, saCfg: saCfg, summaryTmpl: sumT, descriptionTmpl: descT, @@ -349,16 +349,16 @@ func (alert *Alert) getParamsForSecurityAdvisory( // Get the severity result.Template.Severity = alert.getSeverityString() // Get the guidance - result.Template.Guidance = params.GetRuleType().Guidance + result.Template.Guidance = alert.ruleType.Guidance // Get the rule type name - result.Template.Rule = params.GetRuleType().Name + result.Template.Rule = alert.ruleType.Name // Get the profile name result.Template.Profile = params.GetProfile().Name // Get the rule name result.Template.Name = params.GetRule().Name // Check if remediation is available for the rule type - if params.GetRuleType().Def.Remediate != nil { + if alert.ruleType.Def.Remediate != nil { result.Template.RuleRemediation = "already available" } else { result.Template.RuleRemediation = "not available yet" @@ -386,7 +386,7 @@ func (alert *Alert) getParamsForSecurityAdvisory( func (alert *Alert) getSeverityString() string { if alert.saCfg.Severity == "" { - ruleSev := alert.sev.GetValue().Enum().AsString() + ruleSev := alert.ruleType.Severity.GetValue().Enum().AsString() if ruleSev == "info" || ruleSev == "unknown" { return "low" } diff --git a/internal/engine/executor.go b/internal/engine/executor.go index 9e89448ff0..80ef02f033 100644 --- a/internal/engine/executor.go +++ b/internal/engine/executor.go @@ -30,11 +30,11 @@ import ( evalerrors "github.com/stacklok/minder/internal/engine/errors" "github.com/stacklok/minder/internal/engine/ingestcache" engif "github.com/stacklok/minder/internal/engine/interfaces" + "github.com/stacklok/minder/internal/engine/ruleengine" "github.com/stacklok/minder/internal/history" minderlogger "github.com/stacklok/minder/internal/logger" "github.com/stacklok/minder/internal/profiles" "github.com/stacklok/minder/internal/providers/manager" - "github.com/stacklok/minder/internal/ruletypes" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" provinfv1 "github.com/stacklok/minder/pkg/providers/v1" ) @@ -101,8 +101,20 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo defer e.releaseLockAndFlush(ctx, inf) + ruleEngineCache, err := ruleengine.NewRuleEngineCache( + ctx, + e.querier, + entities.EntityTypeToDB(inf.Type), + inf.ProjectID, + provider, + ingestCache, + ) + if err != nil { + return fmt.Errorf("unable to fetch rule type instances for project: %w", err) + } + err = e.forProjectsInHierarchy( - ctx, inf, func(ctx context.Context, profile *pb.Profile, hierarchy []uuid.UUID) error { + ctx, inf, func(ctx context.Context, profile *pb.Profile) error { // Get only these rules that are relevant for this entity type relevant, err := profiles.GetRulesForEntity(profile, inf.Type) if err != nil { @@ -113,7 +125,7 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo err = profiles.TraverseRules(relevant, func(rule *pb.Profile_Rule) error { // Get the engine evaluator for this rule type evalParams, ruleEngine, actionEngine, err := e.getEvaluator( - ctx, inf, provider, profile, rule, hierarchy, ingestCache) + ctx, inf, provider, profile, rule, ruleEngineCache) if err != nil { return err } @@ -130,7 +142,7 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo evalParams.SetActionsErr(ctx, actionsErr) // Log the evaluation - logEval(ctx, inf, evalParams) + logEval(ctx, inf, evalParams, ruleEngine.Meta.Name) // Create or update the evaluation status return e.createOrUpdateEvalStatus(ctx, evalParams) @@ -157,15 +169,14 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo func (e *executor) forProjectsInHierarchy( ctx context.Context, inf *entities.EntityInfoWrapper, - f func(context.Context, *pb.Profile, []uuid.UUID) error, + f func(context.Context, *pb.Profile) error, ) error { projList, err := e.querier.GetParentProjects(ctx, inf.ProjectID) if err != nil { return fmt.Errorf("error getting parent projects: %w", err) } - for idx, projID := range projList { - projectHierarchy := projList[idx:] + for _, projID := range projList { // Get profiles relevant to project dbpols, err := e.querier.ListProfilesByProjectID(ctx, projID) if err != nil { @@ -173,7 +184,7 @@ func (e *executor) forProjectsInHierarchy( } for _, profile := range profiles.MergeDatabaseListIntoProfiles(dbpols) { - if err := f(ctx, profile, projectHierarchy); err != nil { + if err := f(ctx, profile); err != nil { return err } } @@ -188,49 +199,43 @@ func (e *executor) getEvaluator( provider provinfv1.Provider, profile *pb.Profile, rule *pb.Profile_Rule, - hierarchy []uuid.UUID, - ingestCache ingestcache.Cache, -) (*engif.EvalStatusParams, *RuleTypeEngine, *actions.RuleActionsEngine, error) { + ruleEngineCache ruleengine.Cache, +) (*engif.EvalStatusParams, *ruleengine.RuleTypeEngine, *actions.RuleActionsEngine, error) { // Create eval status params params, err := e.createEvalStatusParams(ctx, inf, profile, rule) if err != nil { return nil, nil, nil, fmt.Errorf("error creating eval status params: %w", err) } - // Load Rule Class from database - // TODO(jaosorior): Rule types should be cached in memory so - // we don't have to query the database for each rule. - dbrt, err := e.querier.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - Projects: hierarchy, - Name: rule.Type, - }) - if err != nil { - return nil, nil, nil, fmt.Errorf("error getting rule type when traversing profile %s: %w", params.ProfileID, err) + // TODO: Once we use the rule instance table, this will no longer be necessary + if profile.Id == nil { + return nil, nil, nil, fmt.Errorf("profile %s missing ID", profile.Name) } - - // Parse the rule type - ruleType, err := ruletypes.RuleTypePBFromDB(&dbrt) + profileID, err := uuid.Parse(*profile.Id) if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing rule type when traversing profile %s: %w", params.ProfileID, err) + return nil, nil, nil, fmt.Errorf("unable to parse %s as profile ID", *profile.Id) } - // Save the rule type uuid - ruleTypeID, err := uuid.Parse(*ruleType.Id) + ruleTypeID, err := e.querier.GetRuleTypeIDByRuleNameEntityProfile(ctx, + db.GetRuleTypeIDByRuleNameEntityProfileParams{ + ProfileID: profileID, + EntityType: entities.EntityTypeToDB(inf.Type), + Name: rule.Name, + }, + ) if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing rule type ID: %w", err) + return nil, nil, nil, fmt.Errorf("unable to retrieve rule type ID: %w", err) } + params.RuleTypeID = ruleTypeID - params.RuleType = ruleType // Create the rule type engine - rte, err := NewRuleTypeEngine(ctx, ruleType, provider) + rte, err := ruleEngineCache.GetRuleEngine(ruleTypeID) if err != nil { return nil, nil, nil, fmt.Errorf("error creating rule type engine: %w", err) } - rte = rte.WithIngesterCache(ingestCache) - - actionEngine, err := actions.NewRuleActions(ctx, profile, ruleType, provider) + actionEngine, err := actions.NewRuleActions(ctx, profile, rte.GetRuleType(), provider) if err != nil { return nil, nil, nil, fmt.Errorf("cannot create rule actions engine: %w", err) } @@ -297,6 +302,7 @@ func logEval( ctx context.Context, inf *entities.EntityInfoWrapper, params *engif.EvalStatusParams, + ruleTypeName string, ) { evalLog := params.DecorateLogger( zerolog.Ctx(ctx).With(). @@ -313,5 +319,5 @@ func logEval( Msg("entity evaluation - completed") // log business logic - minderlogger.BusinessRecord(ctx).AddRuleEval(params) + minderlogger.BusinessRecord(ctx).AddRuleEval(params, ruleTypeName) } diff --git a/internal/engine/executor_test.go b/internal/engine/executor_test.go index 369ba2fab5..f29cfe7657 100644 --- a/internal/engine/executor_test.go +++ b/internal/engine/executor_test.go @@ -147,7 +147,8 @@ func TestExecutor_handleEntityEvent(t *testing.T) { mockStore.EXPECT(). GetParentProjects(gomock.Any(), projectID). - Return([]uuid.UUID{projectID}, nil) + Return([]uuid.UUID{projectID}, nil). + Times(2) mockStore.EXPECT(). ListProfilesByProjectID(gomock.Any(), projectID). @@ -197,15 +198,22 @@ default allow = true`, require.NoError(t, err, "expected no error") mockStore.EXPECT(). - GetRuleTypeByName(gomock.Any(), db.GetRuleTypeByNameParams{ - Projects: []uuid.UUID{projectID}, - Name: passthroughRuleType, - }).Return(db.RuleType{ - ID: ruleTypeID, - Name: passthroughRuleType, - ProjectID: projectID, - Definition: marshalledRTD, - }, nil) + GetRuleTypeIDByRuleNameEntityProfile(gomock.Any(), gomock.Any()). + Return(ruleTypeID, nil) + + mockStore.EXPECT(). + GetRuleTypesByEntityInHierarchy(gomock.Any(), db.GetRuleTypesByEntityInHierarchyParams{ + EntityType: db.EntitiesRepository, + Projects: []uuid.UUID{projectID}, + }). + Return([]db.RuleType{ + { + ID: ruleTypeID, + Name: passthroughRuleType, + ProjectID: projectID, + Definition: marshalledRTD, + }, + }, nil) ruleEvalId := uuid.New() diff --git a/internal/engine/interfaces/interface.go b/internal/engine/interfaces/interface.go index c1bfdc1edc..b3093ab49f 100644 --- a/internal/engine/interfaces/interface.go +++ b/internal/engine/interfaces/interface.go @@ -131,7 +131,6 @@ type EvalStatusParams struct { Result *Result Profile *pb.Profile Rule *pb.Profile_Rule - RuleType *pb.RuleType ProfileID uuid.UUID RepoID uuid.NullUUID ArtifactID uuid.NullUUID @@ -208,16 +207,16 @@ func (e *EvalStatusParams) GetRule() *pb.Profile_Rule { return e.Rule } +// GetRuleTypeID returns the rule type ID +func (e *EvalStatusParams) GetRuleTypeID() uuid.UUID { + return e.RuleTypeID +} + // GetEvalStatusFromDb returns the evaluation status from the database func (e *EvalStatusParams) GetEvalStatusFromDb() *db.ListRuleEvaluationsByProfileIdRow { return e.EvalStatusFromDb } -// GetRuleType returns the rule type -func (e *EvalStatusParams) GetRuleType() *pb.RuleType { - return e.RuleType -} - // GetProfile returns the profile func (e *EvalStatusParams) GetProfile() *pb.Profile { return e.Profile @@ -240,7 +239,7 @@ func (e *EvalStatusParams) DecorateLogger(l zerolog.Logger) zerolog.Logger { Str("profile_id", e.ProfileID.String()). Str("rule_type", e.GetRule().GetType()). Str("rule_name", e.GetRule().GetName()). - Str("rule_type_id", e.GetRuleType().GetId()). + Str("rule_type_id", e.RuleTypeID.String()). Logger() if e.RepoID.Valid { outl = outl.With().Str("repository_id", e.RepoID.UUID.String()).Logger() @@ -275,6 +274,6 @@ type ActionsParams interface { GetActionsErr() evalerrors.ActionsError GetEvalErr() error GetEvalStatusFromDb() *db.ListRuleEvaluationsByProfileIdRow - GetRuleType() *pb.RuleType GetProfile() *pb.Profile + GetRuleTypeID() uuid.UUID } diff --git a/internal/engine/ruleengine/cache.go b/internal/engine/ruleengine/cache.go new file mode 100644 index 0000000000..87d01f9744 --- /dev/null +++ b/internal/engine/ruleengine/cache.go @@ -0,0 +1,88 @@ +// Copyright 2024 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ruleengine + +import ( + "context" + "fmt" + + "github.com/google/uuid" + + "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/engine/ingestcache" + "github.com/stacklok/minder/internal/ruletypes" + provinfv1 "github.com/stacklok/minder/pkg/providers/v1" +) + +// Cache contains a set of RuleTypeEngine instances +type Cache interface { + GetRuleEngine(ruleTypeID uuid.UUID) (*RuleTypeEngine, error) +} + +type ruleEngineCache struct { + engines map[uuid.UUID]*RuleTypeEngine +} + +// NewRuleEngineCache creates the rule engine cache +func NewRuleEngineCache( + ctx context.Context, + store db.Querier, + entityType db.Entities, + projectID uuid.UUID, + provider provinfv1.Provider, + ingestCache ingestcache.Cache, +) (Cache, error) { + // Get the full project hierarchy + hierarchy, err := store.GetParentProjects(ctx, projectID) + if err != nil { + return nil, fmt.Errorf("error getting parent projects: %w", err) + } + + // For all projects in the hierarchy, get all the rule types used in each + // rule instance of the specified entity type. + ruleTypes, err := store.GetRuleTypesByEntityInHierarchy(ctx, db.GetRuleTypesByEntityInHierarchyParams{ + EntityType: entityType, + Projects: hierarchy, + }) + if err != nil { + return nil, err + } + + engines := make(map[uuid.UUID]*RuleTypeEngine, len(ruleTypes)) + for _, ruleType := range ruleTypes { + // Parse the rule type + pbRuleType, err := ruletypes.RuleTypePBFromDB(&ruleType) + if err != nil { + return nil, fmt.Errorf("error parsing rule type when parsing rule type %s: %w", ruleType.ID, err) + } + + // Create the rule type engine + ruleEngine, err := NewRuleTypeEngine(ctx, pbRuleType, provider) + if err != nil { + return nil, fmt.Errorf("error creating rule type engine: %w", err) + } + + engines[ruleType.ID] = ruleEngine.WithIngesterCache(ingestCache) + } + + return &ruleEngineCache{engines: engines}, nil +} + +func (r *ruleEngineCache) GetRuleEngine(ruleTypeID uuid.UUID) (*RuleTypeEngine, error) { + if ruleTypeEngine, ok := r.engines[ruleTypeID]; ok { + return ruleTypeEngine, nil + } + return nil, fmt.Errorf("unknown rule type with ID: %s", ruleTypeID) +} diff --git a/internal/engine/rule_type_engine.go b/internal/engine/ruleengine/engine.go similarity index 96% rename from internal/engine/rule_type_engine.go rename to internal/engine/ruleengine/engine.go index 60036aa3d7..879669243c 100644 --- a/internal/engine/rule_type_engine.go +++ b/internal/engine/ruleengine/engine.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -package engine +// Package ruleengine contains the RuleTypeEngine type. +package ruleengine import ( "context" @@ -131,6 +132,11 @@ func (r *RuleTypeEngine) GetRuleInstanceValidator() *profiles.RuleValidator { return r.ruleValidator } +// GetRuleType returns the rule type PB structure. +func (r *RuleTypeEngine) GetRuleType() *minderv1.RuleType { + return r.ruletype +} + // Eval runs the rule type engine against the given entity func (r *RuleTypeEngine) Eval( ctx context.Context, diff --git a/internal/logger/telemetry_store.go b/internal/logger/telemetry_store.go index e38eaf1f42..68d9db927d 100644 --- a/internal/logger/telemetry_store.go +++ b/internal/logger/telemetry_store.go @@ -105,16 +105,12 @@ type TelemetryStore struct { // AddRuleEval is a convenience method to add a rule evaluation result to the telemetry store. func (ts *TelemetryStore) AddRuleEval( evalInfo interfaces.ActionsParams, + ruleTypeName string, ) { if ts == nil { return } - // Get rule type ID - ruleTypeID, err := uuid.Parse(evalInfo.GetRuleType().GetId()) - if err != nil { - return - } // Get profile ID profileID, err := uuid.Parse(evalInfo.GetProfile().GetId()) if err != nil { @@ -122,7 +118,7 @@ func (ts *TelemetryStore) AddRuleEval( } red := RuleEvalData{ - RuleType: RuleType{Name: evalInfo.GetRuleType().GetName(), ID: ruleTypeID}, + RuleType: RuleType{Name: ruleTypeName, ID: evalInfo.GetRuleTypeID()}, Profile: Profile{Name: evalInfo.GetProfile().GetName(), ID: profileID}, EvalResult: errors.EvalErrorAsString(evalInfo.GetEvalErr()), Actions: map[interfaces.ActionType]ActionEvalData{ diff --git a/internal/logger/telemetry_store_test.go b/internal/logger/telemetry_store_test.go index 1af5e2e296..50d97be6c4 100644 --- a/internal/logger/telemetry_store_test.go +++ b/internal/logger/telemetry_store_test.go @@ -48,15 +48,10 @@ func TestTelemetryStore_Record(t *testing.T) { name: "nil telemetry", evalParamsFunc: func() *engif.EvalStatusParams { ep := &engif.EvalStatusParams{} - ep.Profile = &minderv1.Profile{ Name: "artifact_profile", Id: &testUUIDString, } - ep.RuleType = &minderv1.RuleType{ - Name: "artifact_signature", - Id: &testUUIDString, - } ep.SetEvalErr(enginerr.NewErrEvaluationFailed("evaluation failure reason")) ep.SetActionsOnOff(map[engif.ActionType]engif.ActionOpt{ alert.ActionType: engif.ActionOptOn, @@ -71,22 +66,18 @@ func TestTelemetryStore_Record(t *testing.T) { recordFunc: func(ctx context.Context, evalParams engif.ActionsParams) { logger.BusinessRecord(ctx).Project = testUUID logger.BusinessRecord(ctx).Repository = testUUID - logger.BusinessRecord(ctx).AddRuleEval(evalParams) + logger.BusinessRecord(ctx).AddRuleEval(evalParams, ruleTypeName) }, }, { name: "standard telemetry", telemetry: &logger.TelemetryStore{}, evalParamsFunc: func() *engif.EvalStatusParams { ep := &engif.EvalStatusParams{} - + ep.RuleTypeID = testUUID ep.Profile = &minderv1.Profile{ Name: "artifact_profile", Id: &testUUIDString, } - ep.RuleType = &minderv1.RuleType{ - Name: "artifact_signature", - Id: &testUUIDString, - } ep.SetEvalErr(enginerr.NewErrEvaluationFailed("evaluation failure reason")) ep.SetActionsOnOff(map[engif.ActionType]engif.ActionOpt{ alert.ActionType: engif.ActionOptOff, @@ -101,7 +92,7 @@ func TestTelemetryStore_Record(t *testing.T) { recordFunc: func(ctx context.Context, evalParams engif.ActionsParams) { logger.BusinessRecord(ctx).Project = testUUID logger.BusinessRecord(ctx).Repository = testUUID - logger.BusinessRecord(ctx).AddRuleEval(evalParams) + logger.BusinessRecord(ctx).AddRuleEval(evalParams, ruleTypeName) }, expected: `{ "project": "00000000-0000-0000-0000-000000000001", @@ -188,3 +179,5 @@ func TestTelemetryStore_Record(t *testing.T) { } } + +const ruleTypeName = "artifact_signature" diff --git a/internal/profiles/rule_validator_test.go b/internal/profiles/rule_validator_test.go index 099e919e43..a7bf3ef0e0 100644 --- a/internal/profiles/rule_validator_test.go +++ b/internal/profiles/rule_validator_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/stacklok/minder/internal/engine" + "github.com/stacklok/minder/internal/engine/ruleengine" "github.com/stacklok/minder/internal/profiles" minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) @@ -67,7 +67,7 @@ func TestExampleRulesAreValidatedCorrectly(t *testing.T) { rval, err := profiles.NewRuleValidator(rt) require.NoError(t, err, "failed to create rule validator for rule type %s", path) - rules, err := engine.GetRulesFromProfileOfType(pol, rt) + rules, err := ruleengine.GetRulesFromProfileOfType(pol, rt) require.NoError(t, err, "failed to get rules from profile for rule type %s", path) t.Log("validating rules")