Skip to content

Commit

Permalink
Simplify propagation of action settings (#5132)
Browse files Browse the repository at this point in the history
* Simplify propagation of action settings

Currently, action settings are passed around via a complex dance of
getter/setter functions that span multiple places in the executor.

However, this is not necessary, as we've already read these settings at
the beginning of execution, and can just set them in place where we need
them.

This simplifies this logic and leverages the aggregate profile structure
to set this. It also enforces the setting as part of the action driver
initiatlization.

Signed-off-by: Juan Antonio Osorio <[email protected]>

* Fix unit tests

Signed-off-by: Juan Antonio Osorio <[email protected]>

* Fix linter issues

Signed-off-by: Juan Antonio Osorio <[email protected]>

---------

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX authored Dec 4, 2024
1 parent fd38412 commit d26452d
Show file tree
Hide file tree
Showing 18 changed files with 86 additions and 90 deletions.
24 changes: 6 additions & 18 deletions internal/engine/actions/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (

// RuleActionsEngine is the engine responsible for processing all actions i.e., remediation and alerts
type RuleActionsEngine struct {
actions map[engif.ActionType]engif.Action
actionsOnOff map[engif.ActionType]models.ActionOpt
actions map[engif.ActionType]engif.Action
}

// NewRuleActions creates a new rule actions engine
Expand All @@ -40,13 +39,13 @@ func NewRuleActions(
actionConfig *models.ActionConfiguration,
) (*RuleActionsEngine, error) {
// Create the remediation engine
remEngine, err := remediate.NewRuleRemediator(ruletype, provider)
remEngine, err := remediate.NewRuleRemediator(ruletype, provider, actionConfig.Remediate)
if err != nil {
return nil, fmt.Errorf("cannot create rule remediator: %w", err)
}

// Create the alert engine
alertEngine, err := alert.NewRuleAlert(ctx, ruletype, provider)
alertEngine, err := alert.NewRuleAlert(ctx, ruletype, provider, actionConfig.Alert)
if err != nil {
return nil, fmt.Errorf("cannot create rule alerter: %w", err)
}
Expand All @@ -56,20 +55,9 @@ func NewRuleActions(
remEngine.Class(): remEngine,
alertEngine.Class(): alertEngine,
},
// The on/off state of the actions is an integral part of the action engine
// and should be set upon creation.
actionsOnOff: map[engif.ActionType]models.ActionOpt{
remEngine.Class(): remEngine.GetOnOffState(actionConfig.Remediate),
alertEngine.Class(): alertEngine.GetOnOffState(actionConfig.Alert),
},
}, nil
}

// GetOnOffState returns the on/off state of the actions
func (rae *RuleActionsEngine) GetOnOffState() map[engif.ActionType]models.ActionOpt {
return rae.actionsOnOff
}

// DoActions processes all actions i.e., remediation and alerts
func (rae *RuleActionsEngine) DoActions(
ctx context.Context,
Expand Down Expand Up @@ -143,7 +131,7 @@ func (rae *RuleActionsEngine) processAction(
// Get action engine
action := rae.actions[actionType]
// Return the result of the action
return action.Do(ctx, cmd, rae.actionsOnOff[actionType], ent, params, metadata)
return action.Do(ctx, cmd, ent, params, metadata)
}

// shouldRemediate returns the action command for remediation taking into account previous evaluations
Expand Down Expand Up @@ -257,14 +245,14 @@ func (rae *RuleActionsEngine) isSkippable(ctx context.Context, actionType engif.
Str("action", string(actionType))

// Get the profile option set for this action type
actionOnOff, ok := rae.actionsOnOff[actionType]
action, ok := rae.actions[actionType]
if !ok {
// If the action is not found, definitely skip it
logger.Msg("action type not found, skipping")
return true
}
// Check the action option
switch actionOnOff {
switch action.GetOnOffState() {
case models.ActionOptOff:
// Action is off, skip
logger.Msg("action is off, skipping")
Expand Down
5 changes: 4 additions & 1 deletion internal/engine/actions/alert/alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/mindersec/minder/internal/engine/actions/alert/security_advisory"
engif "github.com/mindersec/minder/internal/engine/interfaces"
pb "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
"github.com/mindersec/minder/pkg/profiles/models"
provinfv1 "github.com/mindersec/minder/pkg/providers/v1"
)

Expand All @@ -26,6 +27,7 @@ func NewRuleAlert(
ctx context.Context,
ruletype *pb.RuleType,
provider provinfv1.Provider,
setting models.ActionOpt,
) (engif.Action, error) {
alertCfg := ruletype.Def.GetAlert()
if alertCfg == nil {
Expand All @@ -44,7 +46,8 @@ func NewRuleAlert(
Msg("provider is not a GitHub provider. Silently skipping alerts.")
return noop.NewNoopAlert(ActionType)
}
return security_advisory.NewSecurityAdvisoryAlert(ActionType, ruletype, alertCfg.GetSecurityAdvisory(), client)
return security_advisory.NewSecurityAdvisoryAlert(
ActionType, ruletype, alertCfg.GetSecurityAdvisory(), client, setting)
}

return nil, fmt.Errorf("unknown alert type: %s", alertCfg.GetType())
Expand Down
3 changes: 1 addition & 2 deletions internal/engine/actions/alert/noop/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ func (_ *Alert) Type() string {
}

// GetOnOffState returns the off state of the noop engine
func (_ *Alert) GetOnOffState(_ models.ActionOpt) models.ActionOpt {
func (_ *Alert) GetOnOffState() models.ActionOpt {
return models.ActionOptOff
}

// Do perform the noop alert
func (a *Alert) Do(
_ context.Context,
_ interfaces.ActionCmd,
_ models.ActionOpt,
_ protoreflect.ProtoMessage,
_ interfaces.ActionsParams,
_ *json.RawMessage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type Alert struct {
summaryTmpl *htmltemplate.Template
descriptionTmpl *htmltemplate.Template
descriptionNoRemTmpl *htmltemplate.Template
setting models.ActionOpt
}

type paramsSA struct {
Expand Down Expand Up @@ -130,6 +131,7 @@ func NewSecurityAdvisoryAlert(
ruleType *pb.RuleType,
saCfg *pb.RuleType_Definition_Alert_AlertTypeSA,
cli provifv1.GitHub,
setting models.ActionOpt,
) (*Alert, error) {
if actionType == "" {
return nil, fmt.Errorf("action type cannot be empty")
Expand Down Expand Up @@ -161,6 +163,7 @@ func NewSecurityAdvisoryAlert(
summaryTmpl: sumT,
descriptionTmpl: descT,
descriptionNoRemTmpl: descNoRemT,
setting: setting,
}, nil
}

Expand All @@ -175,15 +178,14 @@ func (_ *Alert) Type() string {
}

// GetOnOffState returns the alert action state read from the profile
func (_ *Alert) GetOnOffState(actionOpt models.ActionOpt) models.ActionOpt {
return models.ActionOptOrDefault(actionOpt, models.ActionOptOff)
func (alert *Alert) GetOnOffState() models.ActionOpt {
return models.ActionOptOrDefault(alert.setting, models.ActionOptOff)
}

// Do alerts through security advisory
func (alert *Alert) Do(
ctx context.Context,
cmd interfaces.ActionCmd,
setting models.ActionOpt,
entity protoreflect.ProtoMessage,
params interfaces.ActionsParams,
metadata *json.RawMessage,
Expand All @@ -195,7 +197,7 @@ func (alert *Alert) Do(
}

// Process the command based on the action setting
switch setting {
switch alert.setting {
case models.ActionOptOn:
return alert.run(ctx, p, cmd)
case models.ActionOptDryRun:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ func TestSecurityAdvisoryAlert(t *testing.T) {
mockClient := mockghclient.NewMockGitHub(ctrl)
tt.mockSetup(mockClient)

saAlert, err := NewSecurityAdvisoryAlert(tt.actionType, &ruleType, &saCfg, mockClient)
saAlert, err := NewSecurityAdvisoryAlert(
tt.actionType, &ruleType, &saCfg, mockClient, models.ActionOptOn)
require.NoError(t, err)
require.NotNil(t, saAlert)

Expand All @@ -103,7 +104,6 @@ func TestSecurityAdvisoryAlert(t *testing.T) {
retMeta, err := saAlert.Do(
context.Background(),
interfaces.ActionCmdOn,
models.ActionOptOn,
&pbinternal.PullRequest{},
evalParams,
nil,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ type GhBranchProtectRemediator struct {
actionType interfaces.ActionType
cli provifv1.GitHub
patchTemplate *util.SafeTemplate
setting models.ActionOpt
}

// NewGhBranchProtectRemediator creates a new remediation engine that uses the GitHub API for branch protection
func NewGhBranchProtectRemediator(
actionType interfaces.ActionType,
ghp *pb.RuleType_Definition_Remediate_GhBranchProtectionType,
cli provifv1.GitHub,
setting models.ActionOpt,
) (*GhBranchProtectRemediator, error) {
if actionType == "" {
return nil, fmt.Errorf("action type cannot be empty")
Expand All @@ -63,6 +65,7 @@ func NewGhBranchProtectRemediator(
actionType: actionType,
cli: cli,
patchTemplate: patchTemplate,
setting: setting,
}, nil
}

Expand All @@ -87,15 +90,14 @@ func (_ *GhBranchProtectRemediator) Type() string {
}

// GetOnOffState returns the alert action state read from the profile
func (_ *GhBranchProtectRemediator) GetOnOffState(actionOpt models.ActionOpt) models.ActionOpt {
return models.ActionOptOrDefault(actionOpt, models.ActionOptOff)
func (r *GhBranchProtectRemediator) GetOnOffState() models.ActionOpt {
return models.ActionOptOrDefault(r.setting, models.ActionOptOff)
}

// Do perform the remediation
func (r *GhBranchProtectRemediator) Do(
ctx context.Context,
cmd interfaces.ActionCmd,
remAction models.ActionOpt,
ent protoreflect.ProtoMessage,
params interfaces.ActionsParams,
_ *json.RawMessage,
Expand Down Expand Up @@ -161,7 +163,7 @@ func (r *GhBranchProtectRemediator) Do(
return nil, fmt.Errorf("error patching request: %w", err)
}

switch remAction {
switch r.setting {
case models.ActionOptOn:
err = r.cli.UpdateBranchProtection(ctx, repo.Owner, repo.Name, branch, updatedRequest)
case models.ActionOptDryRun:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ func TestBranchProtectionRemediate(t *testing.T) {

prov, err := testGithubProvider(ghApiUrl)
require.NoError(t, err)
engine, err := NewGhBranchProtectRemediator(tt.newRemArgs.actionType, tt.newRemArgs.ghp, prov)
engine, err := NewGhBranchProtectRemediator(
tt.newRemArgs.actionType, tt.newRemArgs.ghp, prov, tt.remArgs.remAction)
if tt.wantInitErr {
require.Error(t, err, "expected error")
return
Expand All @@ -343,7 +344,7 @@ func TestBranchProtectionRemediate(t *testing.T) {
},
}

retMeta, err := engine.Do(context.Background(), interfaces.ActionCmdOn, tt.remArgs.remAction, tt.remArgs.ent, evalParams, nil)
retMeta, err := engine.Do(context.Background(), interfaces.ActionCmdOn, tt.remArgs.ent, evalParams, nil)
if tt.wantErr {
require.Error(t, err, "expected error")
require.Nil(t, retMeta, "expected nil metadata")
Expand Down
3 changes: 1 addition & 2 deletions internal/engine/actions/remediate/noop/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ func (_ *Remediator) Type() string {
}

// GetOnOffState returns the off state of the noop engine
func (_ *Remediator) GetOnOffState(_ models.ActionOpt) models.ActionOpt {
func (_ *Remediator) GetOnOffState() models.ActionOpt {
return models.ActionOptOff
}

// Do perform the remediation
func (r *Remediator) Do(
_ context.Context,
_ interfaces.ActionCmd,
_ models.ActionOpt,
_ protoreflect.ProtoMessage,
_ interfaces.ActionsParams,
_ *json.RawMessage,
Expand Down
10 changes: 6 additions & 4 deletions internal/engine/actions/remediate/pull_request/pull_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type pullRequestMetadata struct {
type Remediator struct {
ghCli provifv1.GitHub
actionType interfaces.ActionType
setting models.ActionOpt

prCfg *pb.RuleType_Definition_Remediate_PullRequestRemediation
modificationRegistry modificationRegistry
Expand All @@ -89,6 +90,7 @@ func NewPullRequestRemediate(
actionType interfaces.ActionType,
prCfg *pb.RuleType_Definition_Remediate_PullRequestRemediation,
ghCli provifv1.GitHub,
setting models.ActionOpt,
) (*Remediator, error) {
err := prCfg.Validate()
if err != nil {
Expand All @@ -113,6 +115,7 @@ func NewPullRequestRemediate(
prCfg: prCfg,
actionType: actionType,
modificationRegistry: modRegistry,
setting: setting,

titleTemplate: titleTmpl,
bodyTemplate: bodyTmpl,
Expand Down Expand Up @@ -140,15 +143,14 @@ func (_ *Remediator) Type() string {
}

// GetOnOffState returns the alert action state read from the profile
func (_ *Remediator) GetOnOffState(actionOpt models.ActionOpt) models.ActionOpt {
return models.ActionOptOrDefault(actionOpt, models.ActionOptOff)
func (r *Remediator) GetOnOffState() models.ActionOpt {
return models.ActionOptOrDefault(r.setting, models.ActionOptOff)
}

// Do perform the remediation
func (r *Remediator) Do(
ctx context.Context,
cmd interfaces.ActionCmd,
setting models.ActionOpt,
ent protoreflect.ProtoMessage,
params interfaces.ActionsParams,
metadata *json.RawMessage,
Expand All @@ -158,7 +160,7 @@ func (r *Remediator) Do(
return nil, fmt.Errorf("cannot get PR remediation params: %w", err)
}
var remErr error
switch setting {
switch r.setting {
case models.ActionOptOn:
return r.run(ctx, cmd, p)
case models.ActionOptDryRun:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ func TestPullRequestRemediate(t *testing.T) {

provider, err := testGithubProvider()
require.NoError(t, err)
engine, err := NewPullRequestRemediate(tt.newRemArgs.actionType, tt.newRemArgs.prRem, provider)
engine, err := NewPullRequestRemediate(
tt.newRemArgs.actionType, tt.newRemArgs.prRem, provider, tt.remArgs.remAction)
if tt.wantInitErr {
require.Error(t, err, "expected error")
return
Expand Down Expand Up @@ -655,7 +656,6 @@ func TestPullRequestRemediate(t *testing.T) {
})
retMeta, err := engine.Do(context.Background(),
interfaces.ActionCmdOn,
tt.remArgs.remAction,
tt.remArgs.ent,
evalParams,
nil)
Expand Down
10 changes: 7 additions & 3 deletions internal/engine/actions/remediate/remediate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/mindersec/minder/internal/engine/actions/remediate/rest"
engif "github.com/mindersec/minder/internal/engine/interfaces"
pb "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1"
"github.com/mindersec/minder/pkg/profiles/models"
provinfv1 "github.com/mindersec/minder/pkg/providers/v1"
)

Expand All @@ -25,6 +26,7 @@ const ActionType engif.ActionType = "remediate"
func NewRuleRemediator(
rt *pb.RuleType,
provider provinfv1.Provider,
setting models.ActionOpt,
) (engif.Action, error) {
remediate := rt.Def.GetRemediate()
if remediate == nil {
Expand All @@ -41,7 +43,7 @@ func NewRuleRemediator(
if remediate.GetRest() == nil {
return nil, fmt.Errorf("remediations engine missing rest configuration")
}
return rest.NewRestRemediate(ActionType, remediate.GetRest(), client)
return rest.NewRestRemediate(ActionType, remediate.GetRest(), client, setting)

case gh_branch_protect.RemediateType:
client, err := provinfv1.As[provinfv1.GitHub](provider)
Expand All @@ -51,7 +53,8 @@ func NewRuleRemediator(
if remediate.GetGhBranchProtection() == nil {
return nil, fmt.Errorf("remediations engine missing gh_branch_protection configuration")
}
return gh_branch_protect.NewGhBranchProtectRemediator(ActionType, remediate.GetGhBranchProtection(), client)
return gh_branch_protect.NewGhBranchProtectRemediator(
ActionType, remediate.GetGhBranchProtection(), client, setting)

case pull_request.RemediateType:
client, err := provinfv1.As[provinfv1.GitHub](provider)
Expand All @@ -62,7 +65,8 @@ func NewRuleRemediator(
return nil, fmt.Errorf("remediations engine missing pull request configuration")
}

return pull_request.NewPullRequestRemediate(ActionType, remediate.GetPullRequest(), client)
return pull_request.NewPullRequestRemediate(
ActionType, remediate.GetPullRequest(), client, setting)
}

return nil, fmt.Errorf("unknown remediation type: %s", remediate.GetType())
Expand Down
Loading

0 comments on commit d26452d

Please sign in to comment.