diff --git a/database/migrations/000072_profile_selector_type.down.sql b/database/migrations/000072_profile_selector_type.down.sql new file mode 100644 index 0000000000..1fdb2de352 --- /dev/null +++ b/database/migrations/000072_profile_selector_type.down.sql @@ -0,0 +1,19 @@ +-- Copyright 2024 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +BEGIN; + +DROP TYPE IF EXISTS profile_selector; + +COMMIT; diff --git a/database/migrations/000072_profile_selectors_type.up.sql b/database/migrations/000072_profile_selectors_type.up.sql new file mode 100644 index 0000000000..8173933733 --- /dev/null +++ b/database/migrations/000072_profile_selectors_type.up.sql @@ -0,0 +1,25 @@ +-- Copyright 2024 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +BEGIN; + +CREATE TYPE profile_selector AS ( + id UUID, + profile_id UUID, + entity entities, + selector TEXT, + comment TEXT +); + +COMMIT; \ No newline at end of file diff --git a/database/query/profiles.sql b/database/query/profiles.sql index 92702007b0..5f528b9ebd 100644 --- a/database/query/profiles.sql +++ b/database/query/profiles.sql @@ -63,11 +63,40 @@ SELECT * FROM profiles WHERE id = $1 AND project_id = $2 FOR UPDATE; SELECT * FROM profiles WHERE lower(name) = lower(sqlc.arg(name)) AND project_id = $1 FOR UPDATE; -- name: ListProfilesByProjectID :many -SELECT sqlc.embed(profiles), sqlc.embed(profiles_with_entity_profiles) FROM profiles JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +WITH helper AS( + SELECT pr.id as profid, + ARRAY_AGG(ROW(ps.id, ps.profile_id, ps.entity, ps.selector, ps.comment)::profile_selector) FILTER (WHERE ps.id IS NOT NULL) AS selectors + FROM profiles pr + LEFT JOIN profile_selectors ps + ON pr.id = ps.profile_id + WHERE pr.project_id = $1 + GROUP BY pr.id +) +SELECT + sqlc.embed(profiles), + sqlc.embed(profiles_with_entity_profiles), + helper.selectors::profile_selector[] AS profiles_with_selectors +FROM profiles +JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +JOIN helper ON profiles.id = helper.profid WHERE profiles.project_id = $1; -- name: ListProfilesByProjectIDAndLabel :many -SELECT sqlc.embed(profiles), sqlc.embed(profiles_with_entity_profiles) FROM profiles JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +WITH helper AS( + SELECT pr.id as profid, + ARRAY_AGG(ROW(ps.id, ps.profile_id, ps.entity, ps.selector, ps.comment)::profile_selector) FILTER (WHERE ps.id IS NOT NULL) AS selectors + FROM profiles pr + LEFT JOIN profile_selectors ps + ON pr.id = ps.profile_id + WHERE pr.project_id = $1 + GROUP BY pr.id +) +SELECT sqlc.embed(profiles), + sqlc.embed(profiles_with_entity_profiles), + helper.selectors::profile_selector[] AS profiles_with_selectors +FROM profiles +JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +JOIN helper ON profiles.id = helper.profid WHERE profiles.project_id = $1 AND ( -- the most common case first, if the include_labels is empty, we list profiles with no labels diff --git a/internal/db/domain.go b/internal/db/domain.go index 74159151de..7215ab683e 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -15,12 +15,62 @@ package db import ( + "fmt" "slices" "strings" "github.com/sqlc-dev/pqtype" ) +func (s *ProfileSelector) Scan(value interface{}) error { + if value == nil { + return nil + } + + // Convert the value to a string + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("failed to scan SelectorInfo: %v", value) + } + str := string(bytes) + fmt.Println(str) + + // Remove the parentheses + str = strings.TrimPrefix(str, "(") + str = strings.TrimSuffix(str, ")") + + // Split the string by commas to get the individual field values + parts := strings.Split(str, ",") + + // Assign the values to the struct fields + if len(parts) != 5 { + return fmt.Errorf("failed to scan SelectorInfo: unexpected number of fields") + } + + if err := s.ID.Scan(parts[0]); err != nil { + return fmt.Errorf("failed to scan id: %v", err) + } + + if err := s.ProfileID.Scan(parts[1]); err != nil { + return fmt.Errorf("failed to scan profile_id: %v", err) + } + + s.Entity = NullEntities{} + if parts[2] != "" { + if err := s.Entity.Scan(parts[2]); err != nil { + return fmt.Errorf("failed to scan entity: %v", err) + } + } + + selector := strings.TrimPrefix(parts[3], "\"") + selector = strings.TrimSuffix(selector, "\"") + s.Selector = selector + + s.Comment = parts[4] + + return nil +} + // This file contains domain-level methods for db structs // CanImplement returns true if the provider implements the given type. diff --git a/internal/db/profiles.sql.go b/internal/db/profiles.sql.go index 8a984351e4..07b530bbcc 100644 --- a/internal/db/profiles.sql.go +++ b/internal/db/profiles.sql.go @@ -358,13 +358,29 @@ func (q *Queries) GetProfileForEntity(ctx context.Context, arg GetProfileForEnti } const listProfilesByProjectID = `-- name: ListProfilesByProjectID :many -SELECT profiles.id, profiles.name, profiles.provider, profiles.project_id, profiles.remediate, profiles.alert, profiles.created_at, profiles.updated_at, profiles.provider_id, profiles.subscription_id, profiles.display_name, profiles.labels, profiles_with_entity_profiles.id, profiles_with_entity_profiles.entity, profiles_with_entity_profiles.profile_id, profiles_with_entity_profiles.contextual_rules, profiles_with_entity_profiles.created_at, profiles_with_entity_profiles.updated_at, profiles_with_entity_profiles.profid FROM profiles JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +WITH helper AS( + SELECT pr.id as profid, + ARRAY_AGG(ROW(ps.id, ps.profile_id, ps.entity, ps.selector, ps.comment)::profile_selector) FILTER (WHERE ps.id IS NOT NULL) AS selectors + FROM profiles pr + LEFT JOIN profile_selectors ps + ON pr.id = ps.profile_id + WHERE pr.project_id = $1 + GROUP BY pr.id +) +SELECT + profiles.id, profiles.name, profiles.provider, profiles.project_id, profiles.remediate, profiles.alert, profiles.created_at, profiles.updated_at, profiles.provider_id, profiles.subscription_id, profiles.display_name, profiles.labels, + profiles_with_entity_profiles.id, profiles_with_entity_profiles.entity, profiles_with_entity_profiles.profile_id, profiles_with_entity_profiles.contextual_rules, profiles_with_entity_profiles.created_at, profiles_with_entity_profiles.updated_at, profiles_with_entity_profiles.profid, + helper.selectors::profile_selector[] AS profiles_with_selectors +FROM profiles +JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +JOIN helper ON profiles.id = helper.profid WHERE profiles.project_id = $1 ` type ListProfilesByProjectIDRow struct { Profile Profile `json:"profile"` ProfilesWithEntityProfile ProfilesWithEntityProfile `json:"profiles_with_entity_profile"` + ProfilesWithSelectors []ProfileSelector `json:"profiles_with_selectors"` } func (q *Queries) ListProfilesByProjectID(ctx context.Context, projectID uuid.UUID) ([]ListProfilesByProjectIDRow, error) { @@ -396,6 +412,7 @@ func (q *Queries) ListProfilesByProjectID(ctx context.Context, projectID uuid.UU &i.ProfilesWithEntityProfile.CreatedAt, &i.ProfilesWithEntityProfile.UpdatedAt, &i.ProfilesWithEntityProfile.Profid, + pq.Array(&i.ProfilesWithSelectors), ); err != nil { return nil, err } @@ -411,7 +428,21 @@ func (q *Queries) ListProfilesByProjectID(ctx context.Context, projectID uuid.UU } const listProfilesByProjectIDAndLabel = `-- name: ListProfilesByProjectIDAndLabel :many -SELECT profiles.id, profiles.name, profiles.provider, profiles.project_id, profiles.remediate, profiles.alert, profiles.created_at, profiles.updated_at, profiles.provider_id, profiles.subscription_id, profiles.display_name, profiles.labels, profiles_with_entity_profiles.id, profiles_with_entity_profiles.entity, profiles_with_entity_profiles.profile_id, profiles_with_entity_profiles.contextual_rules, profiles_with_entity_profiles.created_at, profiles_with_entity_profiles.updated_at, profiles_with_entity_profiles.profid FROM profiles JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +WITH helper AS( + SELECT pr.id as profid, + ARRAY_AGG(ROW(ps.id, ps.profile_id, ps.entity, ps.selector, ps.comment)::profile_selector) FILTER (WHERE ps.id IS NOT NULL) AS selectors + FROM profiles pr + LEFT JOIN profile_selectors ps + ON pr.id = ps.profile_id + WHERE pr.project_id = $1 + GROUP BY pr.id +) +SELECT profiles.id, profiles.name, profiles.provider, profiles.project_id, profiles.remediate, profiles.alert, profiles.created_at, profiles.updated_at, profiles.provider_id, profiles.subscription_id, profiles.display_name, profiles.labels, + profiles_with_entity_profiles.id, profiles_with_entity_profiles.entity, profiles_with_entity_profiles.profile_id, profiles_with_entity_profiles.contextual_rules, profiles_with_entity_profiles.created_at, profiles_with_entity_profiles.updated_at, profiles_with_entity_profiles.profid, + helper.selectors::profile_selector[] AS profiles_with_selectors +FROM profiles +JOIN profiles_with_entity_profiles ON profiles.id = profiles_with_entity_profiles.profid +JOIN helper ON profiles.id = helper.profid WHERE profiles.project_id = $1 AND ( -- the most common case first, if the include_labels is empty, we list profiles with no labels @@ -438,6 +469,7 @@ type ListProfilesByProjectIDAndLabelParams struct { type ListProfilesByProjectIDAndLabelRow struct { Profile Profile `json:"profile"` ProfilesWithEntityProfile ProfilesWithEntityProfile `json:"profiles_with_entity_profile"` + ProfilesWithSelectors []ProfileSelector `json:"profiles_with_selectors"` } func (q *Queries) ListProfilesByProjectIDAndLabel(ctx context.Context, arg ListProfilesByProjectIDAndLabelParams) ([]ListProfilesByProjectIDAndLabelRow, error) { @@ -469,6 +501,7 @@ func (q *Queries) ListProfilesByProjectIDAndLabel(ctx context.Context, arg ListP &i.ProfilesWithEntityProfile.CreatedAt, &i.ProfilesWithEntityProfile.UpdatedAt, &i.ProfilesWithEntityProfile.Profid, + pq.Array(&i.ProfilesWithSelectors), ); err != nil { return nil, err } diff --git a/internal/db/profiles_test.go b/internal/db/profiles_test.go index 9fef277aa1..dfdb3b20fe 100644 --- a/internal/db/profiles_test.go +++ b/internal/db/profiles_test.go @@ -50,6 +50,23 @@ func createRandomProfile(t *testing.T, projectID uuid.UUID, labels []string) Pro return prof } +func createRepoSelector(t *testing.T, profileId uuid.UUID, sel string, comment string) ProfileSelector { + return createEntitySelector(t, profileId, NullEntities{Entities: EntitiesRepository, Valid: true}, sel, comment) +} + +func createEntitySelector(t *testing.T, profileId uuid.UUID, ent NullEntities, sel string, comment string) ProfileSelector { + dbSel, err := testQueries.CreateSelector(context.Background(), CreateSelectorParams{ + ProfileID: profileId, + Entity: ent, + Selector: sel, + Comment: comment, + }) + require.NoError(t, err) + require.NotEmpty(t, dbSel) + + return dbSel +} + func createRandomRuleType(t *testing.T, projectID uuid.UUID) RuleType { t.Helper() @@ -192,6 +209,111 @@ func createTestRandomEntities(t *testing.T) *testRandomEntities { } } +func matchIdWithListLabelRow(t *testing.T, id uuid.UUID) func(r ListProfilesByProjectIDAndLabelRow) bool { + t.Helper() + + return func(r ListProfilesByProjectIDAndLabelRow) bool { + return r.Profile.ID == id + } +} + +func matchIdWithListRow(t *testing.T, id uuid.UUID) func(r ListProfilesByProjectIDRow) bool { + t.Helper() + + return func(r ListProfilesByProjectIDRow) bool { + return r.Profile.ID == id + } +} + +func findRowWithLabels(t *testing.T, rows []ListProfilesByProjectIDAndLabelRow, id uuid.UUID) int { + t.Helper() + + return slices.IndexFunc(rows, matchIdWithListLabelRow(t, id)) +} + +func findRow(t *testing.T, rows []ListProfilesByProjectIDRow, id uuid.UUID) int { + t.Helper() + + return slices.IndexFunc(rows, matchIdWithListRow(t, id)) +} + +func TestProfileListWithSelectors(t *testing.T) { + t.Parallel() + + randomEntities := createTestRandomEntities(t) + + noSelectors := createRandomProfile(t, randomEntities.proj.ID, []string{}) + oneSelectorProfile := createRandomProfile(t, randomEntities.proj.ID, []string{}) + oneSel := createRepoSelector(t, oneSelectorProfile.ID, "one_selector1", "one_comment1") + + multiSelectorProfile := createRandomProfile(t, randomEntities.proj.ID, []string{}) + mulitSel1 := createRepoSelector(t, multiSelectorProfile.ID, "multi_selector1", "multi_comment1") + mulitSel2 := createRepoSelector(t, multiSelectorProfile.ID, "multi_selector2", "multi_comment2") + mulitSel3 := createRepoSelector(t, multiSelectorProfile.ID, "multi_selector3", "multi_comment3") + + genericSelectorProfile := createRandomProfile(t, randomEntities.proj.ID, []string{}) + genericSel := createEntitySelector(t, genericSelectorProfile.ID, NullEntities{}, "gen_selector1", "gen_comment1") + + t.Run("list profiles with selectors using the label list", func(t *testing.T) { + t.Parallel() + + rows, err := testQueries.ListProfilesByProjectIDAndLabel( + context.Background(), ListProfilesByProjectIDAndLabelParams{ + ProjectID: randomEntities.proj.ID, + }) + require.NoError(t, err) + + require.Len(t, rows, 4) + + noSelIdx := findRowWithLabels(t, rows, noSelectors.ID) + require.True(t, noSelIdx >= 0, "noSelectors not found in rows") + require.Empty(t, rows[noSelIdx].ProfilesWithSelectors) + + oneSelIdx := findRowWithLabels(t, rows, oneSelectorProfile.ID) + require.True(t, oneSelIdx >= 0, "oneSelector not found in rows") + require.Len(t, rows[oneSelIdx].ProfilesWithSelectors, 1) + require.Contains(t, rows[oneSelIdx].ProfilesWithSelectors, oneSel) + + multiSelIdx := findRowWithLabels(t, rows, multiSelectorProfile.ID) + require.True(t, multiSelIdx >= 0, "multiSelectorProfile not found in rows") + require.Len(t, rows[multiSelIdx].ProfilesWithSelectors, 3) + require.Subset(t, rows[multiSelIdx].ProfilesWithSelectors, []ProfileSelector{mulitSel1, mulitSel2, mulitSel3}) + + genSelIdx := findRowWithLabels(t, rows, genericSelectorProfile.ID) + require.Len(t, rows[genSelIdx].ProfilesWithSelectors, 1) + require.Contains(t, rows[genSelIdx].ProfilesWithSelectors, genericSel) + }) + + t.Run("list profiles with selectors using the non-label list", func(t *testing.T) { + t.Parallel() + + rows, err := testQueries.ListProfilesByProjectID( + context.Background(), randomEntities.proj.ID) + require.NoError(t, err) + + require.Len(t, rows, 4) + + noSelIdx := findRow(t, rows, noSelectors.ID) + require.True(t, noSelIdx >= 0, "noSelectors not found in rows") + require.Empty(t, rows[noSelIdx].ProfilesWithSelectors) + + oneSelIdx := findRow(t, rows, oneSelectorProfile.ID) + require.True(t, oneSelIdx >= 0, "oneSelector not found in rows") + require.Len(t, rows[oneSelIdx].ProfilesWithSelectors, 1) + require.Contains(t, rows[oneSelIdx].ProfilesWithSelectors, oneSel) + + multiSelIdx := findRow(t, rows, multiSelectorProfile.ID) + require.True(t, multiSelIdx >= 0, "multiSelectorProfile not found in rows") + require.Len(t, rows[multiSelIdx].ProfilesWithSelectors, 3) + require.Subset(t, rows[multiSelIdx].ProfilesWithSelectors, []ProfileSelector{mulitSel1, mulitSel2, mulitSel3}) + + genSelIdx := findRow(t, rows, genericSelectorProfile.ID) + require.Len(t, rows[genSelIdx].ProfilesWithSelectors, 1) + require.Contains(t, rows[genSelIdx].ProfilesWithSelectors, genericSel) + }) + +} + func TestProfileLabels(t *testing.T) { t.Parallel() diff --git a/sqlc.yaml b/sqlc.yaml index c2abf32040..f1b3caf5b5 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -24,4 +24,12 @@ packages: emit_prepared_queries: false emit_interface: true emit_exact_table_names: false - emit_empty_slices: true \ No newline at end of file + emit_empty_slices: true + overrides: + - db_type: "uuid" + go_type: + import: "github.com/google/uuid" + type: "UUID" + - db_type: profile_selector + go_type: + type: "ProfileSelector"