Skip to content

Commit

Permalink
Cache RuleTypeEngine instances in Executor
Browse files Browse the repository at this point in the history
This simplifies an upcoming PR, and also addresses a comment in the code
about avoiding continuously re-querying for rule types. This PR changes
the executor to load all relevant rule types before evaluation and
construct RuleTypeEngine instances for each. These are then stored in a
cache and used during rule evaluation.

Certain aspects of this change (specifically, the need to query for the
rule type ID) are a bit clunky, but will be simplified in an upcoming
PR.
  • Loading branch information
dmjb committed Jul 15, 2024
1 parent 3be7c67 commit de45dcc
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 58 deletions.
30 changes: 30 additions & 0 deletions database/mock/store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion database/query/rule_instances.sql
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,12 @@ AND NOT id = ANY(sqlc.arg(updated_ids)::UUID[]);
SELECT id FROM rule_instances
WHERE profile_id = $1
AND entity_type = $2
AND name = $3;
AND name = $3;

-- intended as a temporary transition query
-- this will be removed once rule_instances is used consistently in the engine
-- name: GetRuleTypeIDByRuleNameEntityProfile :one
SELECT rule_type_id FROM rule_instances
WHERE name = $1
AND entity_type = $2
AND profile_id = $3;
6 changes: 6 additions & 0 deletions database/query/rule_types.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ UPDATE rule_type
SET description = $2, definition = sqlc.arg(definition)::jsonb, severity_value = sqlc.arg(severity_value), display_name = sqlc.arg(display_name)
WHERE id = $1
RETURNING *;

-- name: GetRuleTypesByEntityInHierarchy :many
SELECT rt.* FROM rule_type AS rt
JOIN rule_instances AS ri ON ri.rule_type_id = rt.id
WHERE ri.entity_type = $1
AND ri.project_id = ANY(sqlc.arg(projects)::uuid[]);
4 changes: 4 additions & 0 deletions internal/db/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions internal/db/rule_instances.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions internal/db/rule_types.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

71 changes: 38 additions & 33 deletions internal/engine/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
minderlogger "github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/profiles"
"github.com/stacklok/minder/internal/providers/manager"
"github.com/stacklok/minder/internal/ruletypes"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
provinfv1 "github.com/stacklok/minder/pkg/providers/v1"
)
Expand Down Expand Up @@ -107,8 +106,21 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo

defer e.releaseLockAndFlush(ctx, inf)

// Load all the relevant rule type engines for this entity
ruleEngineCache, err := rtengine.NewRuleEngineCache(
ctx,
e.querier,
entities.EntityTypeToDB(inf.Type),
inf.ProjectID,
provider,
ingestCache,
)
if err != nil {
return fmt.Errorf("unable to fetch rule type instances for project: %w", err)
}

err = e.forProjectsInHierarchy(
ctx, inf, func(ctx context.Context, profile *pb.Profile, hierarchy []uuid.UUID) error {
ctx, inf, func(ctx context.Context, profile *pb.Profile) error {
profileStartTime := time.Now()
defer e.metrics.TimeProfileEvaluation(ctx, profileStartTime)
// Get only these rules that are relevant for this entity type
Expand All @@ -121,7 +133,7 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo
err = profiles.TraverseRules(relevant, func(rule *pb.Profile_Rule) error {
// Get the engine evaluator for this rule type
evalParams, ruleEngine, actionEngine, err := e.getEvaluator(
ctx, inf, provider, profile, rule, hierarchy, ingestCache)
ctx, inf, provider, profile, rule, ruleEngineCache)
if err != nil {
return err
}
Expand All @@ -138,7 +150,7 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo
evalParams.SetActionsErr(ctx, actionsErr)

// Log the evaluation
logEval(ctx, inf, evalParams)
logEval(ctx, inf, evalParams, ruleEngine.Meta.Name)

// Create or update the evaluation status
return e.createOrUpdateEvalStatus(ctx, evalParams)
Expand All @@ -165,23 +177,22 @@ func (e *executor) EvalEntityEvent(ctx context.Context, inf *entities.EntityInfo
func (e *executor) forProjectsInHierarchy(
ctx context.Context,
inf *entities.EntityInfoWrapper,
f func(context.Context, *pb.Profile, []uuid.UUID) error,
f func(context.Context, *pb.Profile) error,
) error {
projList, err := e.querier.GetParentProjects(ctx, inf.ProjectID)
if err != nil {
return fmt.Errorf("error getting parent projects: %w", err)
}

for idx, projID := range projList {
projectHierarchy := projList[idx:]
for _, projID := range projList {
// Get profiles relevant to project
dbpols, err := e.querier.ListProfilesByProjectID(ctx, projID)
if err != nil {
return fmt.Errorf("error getting profiles: %w", err)
}

for _, profile := range profiles.MergeDatabaseListIntoProfiles(dbpols) {
if err := f(ctx, profile, projectHierarchy); err != nil {
if err := f(ctx, profile); err != nil {
return err
}
}
Expand All @@ -196,49 +207,42 @@ func (e *executor) getEvaluator(
provider provinfv1.Provider,
profile *pb.Profile,
rule *pb.Profile_Rule,
hierarchy []uuid.UUID,
ingestCache ingestcache.Cache,
ruleEngineCache rtengine.Cache,
) (*engif.EvalStatusParams, *rtengine.RuleTypeEngine, *actions.RuleActionsEngine, error) {
// Create eval status params
params, err := e.createEvalStatusParams(ctx, inf, profile, rule)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating eval status params: %w", err)
}

// Load Rule Class from database
// TODO(jaosorior): Rule types should be cached in memory so
// we don't have to query the database for each rule.
dbrt, err := e.querier.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{
Projects: hierarchy,
Name: rule.Type,
})
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting rule type when traversing profile %s: %w", params.ProfileID, err)
if profile.Id == nil {
return nil, nil, nil, fmt.Errorf("profile %s missing ID", profile.Name)
}

// Parse the rule type
ruleType, err := ruletypes.RuleTypePBFromDB(&dbrt)
profileID, err := uuid.Parse(*profile.Id)
if err != nil {
return nil, nil, nil, fmt.Errorf("error parsing rule type when traversing profile %s: %w", params.ProfileID, err)
return nil, nil, nil, fmt.Errorf("unable to parse %s as profile ID", *profile.Id)
}

// Save the rule type uuid
ruleTypeID, err := uuid.Parse(*ruleType.Id)
// TODO: Once we use the rule instance table, this will no longer be necessary
ruleTypeID, err := e.querier.GetRuleTypeIDByRuleNameEntityProfile(ctx,
db.GetRuleTypeIDByRuleNameEntityProfileParams{
ProfileID: profileID,
EntityType: entities.EntityTypeToDB(inf.Type),
Name: rule.Name,
},
)
if err != nil {
return nil, nil, nil, fmt.Errorf("error parsing rule type ID: %w", err)
return nil, nil, nil, fmt.Errorf("unable to retrieve rule type ID: %w", err)
}

params.RuleTypeID = ruleTypeID
params.RuleTypeName = ruleType.Name

// Create the rule type engine
rte, err := rtengine.NewRuleTypeEngine(ctx, ruleType, provider)
rte, err := ruleEngineCache.GetRuleEngine(ruleTypeID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating rule type engine: %w", err)
}

rte = rte.WithIngesterCache(ingestCache)

actionEngine, err := actions.NewRuleActions(ctx, profile, ruleType, provider)
actionEngine, err := actions.NewRuleActions(ctx, profile, rte.GetRuleType(), provider)
if err != nil {
return nil, nil, nil, fmt.Errorf("cannot create rule actions engine: %w", err)
}
Expand Down Expand Up @@ -305,6 +309,7 @@ func logEval(
ctx context.Context,
inf *entities.EntityInfoWrapper,
params *engif.EvalStatusParams,
ruleTypeName string,
) {
evalLog := params.DecorateLogger(
zerolog.Ctx(ctx).With().
Expand All @@ -321,5 +326,5 @@ func logEval(
Msg("entity evaluation - completed")

// log business logic
minderlogger.BusinessRecord(ctx).AddRuleEval(params)
minderlogger.BusinessRecord(ctx).AddRuleEval(params, ruleTypeName)
}
28 changes: 18 additions & 10 deletions internal/engine/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ func TestExecutor_handleEntityEvent(t *testing.T) {

mockStore.EXPECT().
GetParentProjects(gomock.Any(), projectID).
Return([]uuid.UUID{projectID}, nil)
Return([]uuid.UUID{projectID}, nil).
Times(2)

mockStore.EXPECT().
ListProfilesByProjectID(gomock.Any(), projectID).
Expand Down Expand Up @@ -197,15 +198,22 @@ default allow = true`,
require.NoError(t, err, "expected no error")

mockStore.EXPECT().
GetRuleTypeByName(gomock.Any(), db.GetRuleTypeByNameParams{
Projects: []uuid.UUID{projectID},
Name: passthroughRuleType,
}).Return(db.RuleType{
ID: ruleTypeID,
Name: passthroughRuleType,
ProjectID: projectID,
Definition: marshalledRTD,
}, nil)
GetRuleTypeIDByRuleNameEntityProfile(gomock.Any(), gomock.Any()).
Return(ruleTypeID, nil)

mockStore.EXPECT().
GetRuleTypesByEntityInHierarchy(gomock.Any(), db.GetRuleTypesByEntityInHierarchyParams{
EntityType: db.EntitiesRepository,
Projects: []uuid.UUID{projectID},
}).
Return([]db.RuleType{
{
ID: ruleTypeID,
Name: passthroughRuleType,
ProjectID: projectID,
Definition: marshalledRTD,
},
}, nil)

ruleEvalId := uuid.New()

Expand Down
8 changes: 1 addition & 7 deletions internal/engine/interfaces/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ type EvalStatusParams struct {
Result *Result
Profile *pb.Profile
Rule *pb.Profile_Rule
RuleTypeName string
ProfileID uuid.UUID
RepoID uuid.NullUUID
ArtifactID uuid.NullUUID
Expand Down Expand Up @@ -223,11 +222,6 @@ func (e *EvalStatusParams) GetEvalStatusFromDb() *db.ListRuleEvaluationsByProfil
return e.EvalStatusFromDb
}

// GetRuleTypeName returns the rule type name
func (e *EvalStatusParams) GetRuleTypeName() string {
return e.RuleTypeName
}

// GetProfile returns the profile
func (e *EvalStatusParams) GetProfile() *pb.Profile {
return e.Profile
Expand All @@ -252,6 +246,7 @@ func (e *EvalStatusParams) DecorateLogger(l zerolog.Logger) zerolog.Logger {
Str("rule_name", e.GetRule().GetName()).
Str("rule_type_id", e.GetRuleTypeID().String()).
Str("execution_id", e.ExecutionID.String()).
Str("rule_type_id", e.RuleTypeID.String()).
Logger()
if e.RepoID.Valid {
outl = outl.With().Str("repository_id", e.RepoID.UUID.String()).Logger()
Expand Down Expand Up @@ -286,7 +281,6 @@ type ActionsParams interface {
GetActionsErr() evalerrors.ActionsError
GetEvalErr() error
GetEvalStatusFromDb() *db.ListRuleEvaluationsByProfileIdRow
GetRuleTypeName() string
GetProfile() *pb.Profile
GetRuleTypeID() uuid.UUID
}
Loading

0 comments on commit de45dcc

Please sign in to comment.