From c9b4400363871faed282a8f827254b21798d2e8a Mon Sep 17 00:00:00 2001 From: Michelangelo Mori Date: Tue, 9 Jul 2024 14:47:22 +0200 Subject: [PATCH] Added project id as mandatory filter. --- database/query/eval_history.sql | 14 +- internal/controlplane/handlers_evalstatus.go | 53 ++++- .../controlplane/handlers_evalstatus_test.go | 197 ++++++++++++++++++ internal/db/eval_history.sql.go | 30 ++- internal/history/models.go | 59 ++++++ internal/history/models_test.go | 115 +++++++++- internal/history/service.go | 11 + internal/history/service_test.go | 26 ++- 8 files changed, 479 insertions(+), 26 deletions(-) diff --git a/database/query/eval_history.sql b/database/query/eval_history.sql index d2f777cafc..90a260bf92 100644 --- a/database/query/eval_history.sql +++ b/database/query/eval_history.sql @@ -107,11 +107,12 @@ SELECT s.id::uuid AS evaluation_id, WHEN ere.pull_request_id IS NOT NULL THEN pr.id WHEN ere.artifact_id IS NOT NULL THEN a.id END AS entity_id, - -- entity name - CASE WHEN ere.repository_id IS NOT NULL THEN r.repo_name - WHEN ere.pull_request_id IS NOT NULL THEN pr.pr_number::text - WHEN ere.artifact_id IS NOT NULL THEN a.artifact_name - END AS entity_name, + -- raw fields for entity names + r.repo_owner, + r.repo_name, + pr.pr_number, + a.artifact_name, + j.id as project_id, -- rule type, name, and profile rt.name AS rule_type, ri.name AS rule_name, @@ -135,6 +136,7 @@ SELECT s.id::uuid AS evaluation_id, LEFT JOIN artifacts a ON ere.artifact_id = a.id LEFT JOIN remediation_events re ON re.evaluation_id = s.id LEFT JOIN alert_events ae ON ae.evaluation_id = s.id + LEFT JOIN projects j ON r.project_id = j.id WHERE (sqlc.narg(next)::timestamp without time zone IS NULL OR sqlc.narg(next) > s.most_recent_evaluation) AND (sqlc.narg(prev)::timestamp without time zone IS NULL OR sqlc.narg(prev) < s.most_recent_evaluation) -- inclusion filters @@ -159,5 +161,7 @@ SELECT s.id::uuid AS evaluation_id, AND (sqlc.narg(fromts)::timestamp without time zone IS NULL OR sqlc.narg(tots)::timestamp without time zone IS NULL OR s.most_recent_evaluation BETWEEN sqlc.narg(fromts) AND sqlc.narg(tots)) + -- implicit filter by project id + AND j.id = sqlc.arg(projectId) ORDER BY s.most_recent_evaluation DESC LIMIT sqlc.arg(size)::integer; diff --git a/internal/controlplane/handlers_evalstatus.go b/internal/controlplane/handlers_evalstatus.go index 18278f3d76..787cc03480 100644 --- a/internal/controlplane/handlers_evalstatus.go +++ b/internal/controlplane/handlers_evalstatus.go @@ -79,6 +79,9 @@ func (s *Server) ListEvaluationHistory( opts = append(opts, history.WithTo(in.GetTo().AsTime())) } + // we always filter by project id + opts = append(opts, history.WithProjectIDStr(in.GetContext().GetProject())) + filter, err := history.NewListEvaluationFilter(opts...) if err != nil { return nil, status.Error(codes.InvalidArgument, "invalid filter") @@ -138,10 +141,9 @@ func fromEvaluationHistoryRow( return nil, errors.New("internal error") } entityType := dbEntityToEntity(dbEntityType) - - entityName, ok := row.EntityName.(string) - if !ok { - return nil, errors.New("internal error") + entityName, err := getEntityName(dbEntityType, row) + if err != nil { + return nil, err } var alert *minderv1.EvaluationHistoryAlert @@ -589,3 +591,46 @@ func dbEntityToEntity(dbEnt db.Entities) minderv1.Entity { return minderv1.Entity_ENTITY_UNSPECIFIED } } + +func getEntityName( + dbEnt db.Entities, + row db.ListEvaluationHistoryRow, +) (string, error) { + switch dbEnt { + case db.EntitiesPullRequest: + if !row.RepoOwner.Valid { + return "", errors.New("repo_owner is missing") + } + if !row.RepoName.Valid { + return "", errors.New("repo_name is missing") + } + if !row.PrNumber.Valid { + return "", errors.New("pr_number is missing") + } + return fmt.Sprintf("%s/%s#%d", + row.RepoOwner.String, + row.RepoName.String, + row.PrNumber.Int64, + ), nil + case db.EntitiesArtifact: + if !row.ArtifactName.Valid { + return "", errors.New("artifact_name is missing") + } + return row.ArtifactName.String, nil + case db.EntitiesRepository: + if !row.RepoOwner.Valid { + return "", errors.New("repo_owner is missing") + } + if !row.RepoName.Valid { + return "", errors.New("repo_name is missing") + } + return fmt.Sprintf("%s/%s", + row.RepoOwner.String, + row.RepoName.String, + ), nil + case db.EntitiesBuildEnvironment: + return "", errors.New("invalid entity type") + default: + return "", errors.New("invalid entity type") + } +} diff --git a/internal/controlplane/handlers_evalstatus_test.go b/internal/controlplane/handlers_evalstatus_test.go index 266cdb3fbe..ed9a4dcfb9 100644 --- a/internal/controlplane/handlers_evalstatus_test.go +++ b/internal/controlplane/handlers_evalstatus_test.go @@ -120,3 +120,200 @@ func TestBuildEvalResultAlertFromLRERow(t *testing.T) { }) } } + +func TestDBEntityToEntity(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input db.Entities + output minderv1.Entity + }{ + { + name: "pull request", + input: db.EntitiesPullRequest, + output: minderv1.Entity_ENTITY_PULL_REQUESTS, + }, + { + name: "artifact", + input: db.EntitiesArtifact, + output: minderv1.Entity_ENTITY_ARTIFACTS, + }, + { + name: "repository", + input: db.EntitiesRepository, + output: minderv1.Entity_ENTITY_REPOSITORIES, + }, + { + name: "build environments", + input: db.EntitiesBuildEnvironment, + output: minderv1.Entity_ENTITY_BUILD_ENVIRONMENTS, + }, + { + name: "default", + input: db.Entities("whatever"), + output: minderv1.Entity_ENTITY_UNSPECIFIED, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + res := dbEntityToEntity(tt.input) + require.Equal(t, tt.output, res) + }) + } +} + +func TestGetEntityName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dbEnt db.Entities + row db.ListEvaluationHistoryRow + output string + err bool + }{ + { + name: "pull request", + dbEnt: db.EntitiesPullRequest, + row: db.ListEvaluationHistoryRow{ + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + PrNumber: sql.NullInt64{ + Valid: true, + Int64: 12345, + }, + }, + output: "stacklok/minder#12345", + }, + { + name: "pull request no repo owner", + dbEnt: db.EntitiesPullRequest, + row: db.ListEvaluationHistoryRow{ + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + PrNumber: sql.NullInt64{ + Valid: true, + Int64: 12345, + }, + }, + err: true, + }, + { + name: "pull request no repo name", + dbEnt: db.EntitiesPullRequest, + row: db.ListEvaluationHistoryRow{ + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + PrNumber: sql.NullInt64{ + Valid: true, + Int64: 12345, + }, + }, + err: true, + }, + { + name: "pull request no pr number", + dbEnt: db.EntitiesPullRequest, + row: db.ListEvaluationHistoryRow{ + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + }, + err: true, + }, + { + name: "artifact", + dbEnt: db.EntitiesArtifact, + row: db.ListEvaluationHistoryRow{ + ArtifactName: sql.NullString{ + Valid: true, + String: "artifact name", + }, + }, + output: "artifact name", + }, + { + name: "repository", + dbEnt: db.EntitiesRepository, + row: db.ListEvaluationHistoryRow{ + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + }, + output: "stacklok/minder", + }, + { + name: "repository no repo owner", + dbEnt: db.EntitiesRepository, + row: db.ListEvaluationHistoryRow{ + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + }, + err: true, + }, + { + name: "repository no repo name", + dbEnt: db.EntitiesRepository, + row: db.ListEvaluationHistoryRow{ + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + }, + err: true, + }, + { + name: "build environments", + dbEnt: db.EntitiesBuildEnvironment, + row: db.ListEvaluationHistoryRow{}, + err: true, + }, + { + name: "default", + dbEnt: db.Entities("whatever"), + row: db.ListEvaluationHistoryRow{}, + err: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + res, err := getEntityName(tt.dbEnt, tt.row) + + if tt.err { + require.Error(t, err) + require.Equal(t, "", res) + return + } + + require.NoError(t, err) + require.Equal(t, tt.output, res) + }) + } +} diff --git a/internal/db/eval_history.sql.go b/internal/db/eval_history.sql.go index 5e77b00911..431fd59d5b 100644 --- a/internal/db/eval_history.sql.go +++ b/internal/db/eval_history.sql.go @@ -203,11 +203,12 @@ SELECT s.id::uuid AS evaluation_id, WHEN ere.pull_request_id IS NOT NULL THEN pr.id WHEN ere.artifact_id IS NOT NULL THEN a.id END AS entity_id, - -- entity name - CASE WHEN ere.repository_id IS NOT NULL THEN r.repo_name - WHEN ere.pull_request_id IS NOT NULL THEN pr.pr_number::text - WHEN ere.artifact_id IS NOT NULL THEN a.artifact_name - END AS entity_name, + -- raw fields for entity names + r.repo_owner, + r.repo_name, + pr.pr_number, + a.artifact_name, + j.id as project_id, -- rule type, name, and profile rt.name AS rule_type, ri.name AS rule_name, @@ -231,6 +232,7 @@ SELECT s.id::uuid AS evaluation_id, LEFT JOIN artifacts a ON ere.artifact_id = a.id LEFT JOIN remediation_events re ON re.evaluation_id = s.id LEFT JOIN alert_events ae ON ae.evaluation_id = s.id + LEFT JOIN projects j ON r.project_id = j.id WHERE ($1::timestamp without time zone IS NULL OR $1 > s.most_recent_evaluation) AND ($2::timestamp without time zone IS NULL OR $2 < s.most_recent_evaluation) -- inclusion filters @@ -255,8 +257,10 @@ SELECT s.id::uuid AS evaluation_id, AND ($15::timestamp without time zone IS NULL OR $16::timestamp without time zone IS NULL OR s.most_recent_evaluation BETWEEN $15 AND $16) + -- implicit filter by project id + AND j.id = $17 ORDER BY s.most_recent_evaluation DESC - LIMIT $17::integer + LIMIT $18::integer ` type ListEvaluationHistoryParams struct { @@ -276,6 +280,7 @@ type ListEvaluationHistoryParams struct { Notstatuses []EvalStatusTypes `json:"notstatuses"` Fromts sql.NullTime `json:"fromts"` Tots sql.NullTime `json:"tots"` + Projectid uuid.UUID `json:"projectid"` Size int32 `json:"size"` } @@ -284,7 +289,11 @@ type ListEvaluationHistoryRow struct { EvaluatedAt time.Time `json:"evaluated_at"` EntityType interface{} `json:"entity_type"` EntityID interface{} `json:"entity_id"` - EntityName interface{} `json:"entity_name"` + RepoOwner sql.NullString `json:"repo_owner"` + RepoName sql.NullString `json:"repo_name"` + PrNumber sql.NullInt64 `json:"pr_number"` + ArtifactName sql.NullString `json:"artifact_name"` + ProjectID uuid.NullUUID `json:"project_id"` RuleType string `json:"rule_type"` RuleName string `json:"rule_name"` ProfileName string `json:"profile_name"` @@ -314,6 +323,7 @@ func (q *Queries) ListEvaluationHistory(ctx context.Context, arg ListEvaluationH pq.Array(arg.Notstatuses), arg.Fromts, arg.Tots, + arg.Projectid, arg.Size, ) if err != nil { @@ -328,7 +338,11 @@ func (q *Queries) ListEvaluationHistory(ctx context.Context, arg ListEvaluationH &i.EvaluatedAt, &i.EntityType, &i.EntityID, - &i.EntityName, + &i.RepoOwner, + &i.RepoName, + &i.PrNumber, + &i.ArtifactName, + &i.ProjectID, &i.RuleType, &i.RuleName, &i.ProfileName, diff --git a/internal/history/models.go b/internal/history/models.go index f7cc659f9a..052f4410c8 100644 --- a/internal/history/models.go +++ b/internal/history/models.go @@ -22,12 +22,17 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/stacklok/minder/internal/db" ) var ( // ErrMalformedCursor represents errors in the cursor payload. ErrMalformedCursor = errors.New("malformed cursor") + // ErrInvalidProjectID is returned when project id is missing + // or malformed form the filter. + ErrInvalidProjectID = errors.New("invalid project id") // ErrInvalidTimeRange is returned the time range from-to is // either missing one end or from is greater than to. ErrInvalidTimeRange = errors.New("invalid time range") @@ -123,6 +128,15 @@ type Filter interface{} // FilterOpt is the option type used to configure filters. type FilterOpt func(Filter) error +// ProjectFilter interface should be implemented by types implementing +// a filter on project id. +type ProjectFilter interface { + // AddProjectID adds a project id for inclusion in the filter. + AddProjectID(uuid.UUID) error + // GetProjectID returns the included project id. + GetProjectID() uuid.UUID +} + // EntityTypeFilter interface should be implemented by types // implementing a filter on entity types. type EntityTypeFilter interface { @@ -219,6 +233,7 @@ type TimeRangeFilter interface { // ListEvaluationFilter is a filter to be used when listing historical // evaluations. type ListEvaluationFilter interface { + ProjectFilter EntityTypeFilter EntityNameFilter ProfileNameFilter @@ -229,6 +244,8 @@ type ListEvaluationFilter interface { } type listEvaluationFilter struct { + // Project ID to include in the selection + projectID uuid.UUID // List of entity types to include in the selection includedEntityTypes []string // List of entity types to exclude from the selection @@ -259,6 +276,17 @@ type listEvaluationFilter struct { to *time.Time } +func (filter *listEvaluationFilter) AddProjectID(projectID uuid.UUID) error { + if projectID == uuid.Nil { + return fmt.Errorf("%w: project id", ErrInvalidIdentifier) + } + filter.projectID = projectID + return nil +} +func (filter *listEvaluationFilter) GetProjectID() uuid.UUID { + return filter.projectID +} + func (filter *listEvaluationFilter) AddEntityType(entityType string) error { if strings.HasPrefix(entityType, "!") { entityType = strings.Split(entityType, "!")[1] // guaranteed to exist @@ -415,6 +443,34 @@ func (filter *listEvaluationFilter) GetTo() *time.Time { var _ Filter = (*listEvaluationFilter)(nil) var _ ListEvaluationFilter = (*listEvaluationFilter)(nil) +// WithProjectIDStr adds a project id (string) to the filter. Whether +// a null uuid is valid or not is determined on a per-endpoint basis. +func WithProjectIDStr(projectID string) FilterOpt { + return func(filter Filter) error { + uuid, err := uuid.Parse(projectID) + if err != nil { + return fmt.Errorf("%w: project id", ErrInvalidIdentifier) + } + inner, ok := filter.(ProjectFilter) + if !ok { + return fmt.Errorf("%w: wrong filter type", ErrInvalidIdentifier) + } + return inner.AddProjectID(uuid) + } +} + +// WithProjectID adds a project id (uuid) to the filter. Whether a +// null uuid is valid or not is determined on a per-endpoint basis. +func WithProjectID(projectID uuid.UUID) FilterOpt { + return func(filter Filter) error { + inner, ok := filter.(ProjectFilter) + if !ok { + return fmt.Errorf("%w: wrong filter type", ErrInvalidIdentifier) + } + return inner.AddProjectID(projectID) + } +} + // WithEntityType adds an entity type string to the filter. The entity // type is added for inclusion unless it starts with a `!` characters, // in which case it is added for exclusion. @@ -552,6 +608,9 @@ func NewListEvaluationFilter(opts ...FilterOpt) (ListEvaluationFilter, error) { // Following we check that time range based filtering is // sound. + if filter.projectID == uuid.Nil { + return nil, fmt.Errorf("%w: missing", ErrInvalidProjectID) + } if filter.to != nil && filter.from == nil { return nil, fmt.Errorf("%w: from is missing", ErrInvalidTimeRange) } diff --git a/internal/history/models_test.go b/internal/history/models_test.go index d5a810058b..ee1d1df9fc 100644 --- a/internal/history/models_test.go +++ b/internal/history/models_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -161,6 +162,7 @@ func TestListEvaluationFilter(t *testing.T) { filter: func(t *testing.T) (ListEvaluationFilter, error) { t.Helper() return NewListEvaluationFilter( + WithProjectIDStr("deadbeef-0000-0000-0000-000000000000"), WithEntityType("repository"), ) }, @@ -169,6 +171,26 @@ func TestListEvaluationFilter(t *testing.T) { require.Equal(t, []string{"repository"}, filter.IncludedEntityTypes()) }, }, + { + name: "mandatory project id", + filter: func(t *testing.T) (ListEvaluationFilter, error) { + t.Helper() + return NewListEvaluationFilter( + WithEntityType("repository"), + ) + }, + err: true, + }, + { + name: "non-empty project id", + filter: func(t *testing.T) (ListEvaluationFilter, error) { + t.Helper() + return NewListEvaluationFilter( + WithProjectID(uuid.Nil), + ) + }, + err: true, + }, { name: "bogus", filter: func(t *testing.T) (ListEvaluationFilter, error) { @@ -184,14 +206,13 @@ func TestListEvaluationFilter(t *testing.T) { filter: func(t *testing.T) (ListEvaluationFilter, error) { t.Helper() return NewListEvaluationFilter( - WithEntityType("repository"), + WithProjectIDStr("deadbeef-0000-0000-0000-000000000000"), WithFrom(now), WithTo(now), ) }, check: func(t *testing.T, filter ListEvaluationFilter) { t.Helper() - require.Equal(t, []string{"repository"}, filter.IncludedEntityTypes()) require.Equal(t, now, *filter.GetFrom()) require.Equal(t, now, *filter.GetTo()) }, @@ -201,7 +222,7 @@ func TestListEvaluationFilter(t *testing.T) { filter: func(t *testing.T) (ListEvaluationFilter, error) { t.Helper() return NewListEvaluationFilter( - WithEntityType("repository"), + WithProjectIDStr("deadbeef-0000-0000-0000-000000000000"), WithTo(now), ) }, @@ -212,7 +233,7 @@ func TestListEvaluationFilter(t *testing.T) { filter: func(t *testing.T) (ListEvaluationFilter, error) { t.Helper() return NewListEvaluationFilter( - WithEntityType("repository"), + WithProjectIDStr("deadbeef-0000-0000-0000-000000000000"), WithFrom(now), ) }, @@ -223,6 +244,7 @@ func TestListEvaluationFilter(t *testing.T) { filter: func(t *testing.T) (ListEvaluationFilter, error) { t.Helper() return NewListEvaluationFilter( + WithProjectIDStr("deadbeef-0000-0000-0000-000000000000"), WithEntityType("repository"), WithFrom(now.Add(1*time.Millisecond)), WithTo(now), @@ -324,6 +346,9 @@ func TestFilterOptions(t *testing.T) { now := time.Now() + uuidstr := "deadbeef-0000-0000-0000-000000000000" + uuidval := uuid.MustParse(uuidstr) + tests := []struct { name string option func(*testing.T) FilterOpt @@ -331,6 +356,88 @@ func TestFilterOptions(t *testing.T) { check func(*testing.T, Filter) err bool }{ + // project id + { + name: "project id string", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectIDStr(uuidstr) + }, + filter: func(t *testing.T) Filter { + t.Helper() + return &listEvaluationFilter{} + }, + check: func(t *testing.T, filter Filter) { + t.Helper() + f := filter.(ProjectFilter) + require.Equal(t, uuidval, f.GetProjectID()) + }, + }, + { + name: "project id uuid", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectID(uuidval) + }, + filter: func(t *testing.T) Filter { + t.Helper() + return &listEvaluationFilter{} + }, + check: func(t *testing.T, filter Filter) { + t.Helper() + f := filter.(ProjectFilter) + require.Equal(t, uuidval, f.GetProjectID()) + }, + }, + { + name: "project id nil", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectID(uuid.Nil) + }, + filter: func(t *testing.T) Filter { + t.Helper() + return &listEvaluationFilter{} + }, + err: true, + }, + { + name: "project id malformed", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectIDStr("malformed") + }, + filter: func(t *testing.T) Filter { + t.Helper() + return &listEvaluationFilter{} + }, + err: true, + }, + { + name: "wrong project filter", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectIDStr(uuidstr) + }, + filter: func(t *testing.T) Filter { + t.Helper() + return foo + }, + err: true, + }, + { + name: "wrong project filter", + option: func(t *testing.T) FilterOpt { + t.Helper() + return WithProjectID(uuidval) + }, + filter: func(t *testing.T) Filter { + t.Helper() + return foo + }, + err: true, + }, + // entity type { name: "entity type in filter", diff --git a/internal/history/service.go b/internal/history/service.go index f9f44edfd5..8c3982bcc3 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -279,6 +279,9 @@ func toSQLFilter( return nil } + if err := paramsFromProjectFilter(filter, params); err != nil { + return err + } if err := paramsFromEntityTypeFilter(filter, params); err != nil { return err } @@ -300,6 +303,14 @@ func toSQLFilter( return paramsFromTimeRangeFilter(filter, params) } +func paramsFromProjectFilter( + filter ProjectFilter, + params *db.ListEvaluationHistoryParams, +) error { + params.Projectid = filter.GetProjectID() + return nil +} + func paramsFromEntityTypeFilter( filter EntityTypeFilter, params *db.ListEvaluationHistoryParams, diff --git a/internal/history/service_test.go b/internal/history/service_test.go index 9614097556..bd575159e2 100644 --- a/internal/history/service_test.go +++ b/internal/history/service_test.go @@ -696,11 +696,27 @@ func makeHistoryRow( alert db.NullAlertStatusTypes, ) db.ListEvaluationHistoryRow { return db.ListEvaluationHistoryRow{ - EvaluationID: id, - EvaluatedAt: evaluatedAt, - EntityType: entityType, - EntityID: id, - EntityName: "repo1", + EvaluationID: id, + EvaluatedAt: evaluatedAt, + EntityType: entityType, + EntityID: id, + RepoOwner: sql.NullString{ + Valid: true, + String: "stacklok", + }, + RepoName: sql.NullString{ + Valid: true, + String: "minder", + }, + PrNumber: sql.NullInt64{ + Valid: true, + Int64: 12345, + }, + ArtifactName: sql.NullString{ + Valid: true, + String: "artifact1", + }, + // EntityName: "repo1", RuleType: "rule_type", RuleName: "rule_name", ProfileName: "profile_name",