diff --git a/internal/history/service.go b/internal/history/service.go index bc763b156e..4801185852 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -279,66 +279,115 @@ func toSQLFilter( return nil } - // filters on entity names - if len(filter.IncludedEntityNames()) != 0 { - params.Entitynames = filter.IncludedEntityNames() + if err := paramsFromEntityTypeFilter(filter, params); err != nil { + return err } - if len(filter.ExcludedEntityNames()) != 0 { - params.Notentitynames = filter.ExcludedEntityNames() + if err := paramsFromEntityNameFilter(filter, params); err != nil { + return err } - - // filters on profile names - if len(filter.IncludedProfileNames()) != 0 { - params.Profilenames = filter.IncludedProfileNames() + if err := paramsFromProfileNameFilter(filter, params); err != nil { + return err } - if len(filter.ExcludedProfileNames()) != 0 { - params.Notprofilenames = filter.ExcludedProfileNames() + if err := paramsFromRemediationFilter(filter, params); err != nil { + return err } + if err := paramsFromAlertFilter(filter, params); err != nil { + return err + } + if err := paramsFromStatusFilter(filter, params); err != nil { + return err + } + return paramsFromTimeRangeFilter(filter, params) +} - // filters on entity types +func paramsFromEntityTypeFilter( + filter EntityTypeFilter, + params *db.ListEvaluationHistoryParams, +) error { if len(filter.IncludedEntityTypes()) != 0 { - entityTypes, err := convertEntities( + entityTypes, err := convert( filter.IncludedEntityTypes(), + mapEntities, ) if err != nil { - return errors.New("internal error") + return err } params.Entitytypes = entityTypes } if len(filter.ExcludedEntityTypes()) != 0 { - entityTypes, err := convertEntities( + entityTypes, err := convert( filter.ExcludedEntityTypes(), + mapEntities, ) if err != nil { return errors.New("internal error") } params.Notentitytypes = entityTypes } + return nil +} - // filters on remediation status +func paramsFromEntityNameFilter( + filter EntityNameFilter, + params *db.ListEvaluationHistoryParams, +) error { + if len(filter.IncludedEntityNames()) != 0 { + params.Entitynames = filter.IncludedEntityNames() + } + if len(filter.ExcludedEntityNames()) != 0 { + params.Notentitynames = filter.ExcludedEntityNames() + } + return nil +} + +func paramsFromProfileNameFilter( + filter ProfileNameFilter, + params *db.ListEvaluationHistoryParams, +) error { + if len(filter.IncludedProfileNames()) != 0 { + params.Profilenames = filter.IncludedProfileNames() + } + if len(filter.ExcludedProfileNames()) != 0 { + params.Notprofilenames = filter.ExcludedProfileNames() + } + return nil +} + +func paramsFromRemediationFilter( + filter RemediationFilter, + params *db.ListEvaluationHistoryParams, +) error { if len(filter.IncludedRemediations()) != 0 { - remediations, err := convertRemediationStatusTypes( + remediations, err := convert( filter.IncludedRemediations(), + mapRemediationStatusTypes, ) if err != nil { - return errors.New("internal error") + return err } params.Remediations = remediations } if len(filter.ExcludedRemediations()) != 0 { - remediations, err := convertRemediationStatusTypes( + remediations, err := convert( filter.ExcludedRemediations(), + mapRemediationStatusTypes, ) if err != nil { - return errors.New("internal error") + return err } params.Notremediations = remediations } + return nil +} - // filters on alert status +func paramsFromAlertFilter( + filter AlertFilter, + params *db.ListEvaluationHistoryParams, +) error { if len(filter.IncludedAlerts()) != 0 { - alerts, err := convertAlertStatusTypes( + alerts, err := convert( filter.IncludedAlerts(), + mapAlertStatusTypes, ) if err != nil { return errors.New("internal error") @@ -346,36 +395,49 @@ func toSQLFilter( params.Alerts = alerts } if len(filter.ExcludedAlerts()) != 0 { - alerts, err := convertAlertStatusTypes( + alerts, err := convert( filter.ExcludedAlerts(), + mapAlertStatusTypes, ) if err != nil { - return errors.New("internal error") + return err } params.Notalerts = alerts } + return nil +} - // filters on evaluation status +func paramsFromStatusFilter( + filter StatusFilter, + params *db.ListEvaluationHistoryParams, +) error { if len(filter.IncludedStatuses()) != 0 { - statuses, err := convertEvalStatusTypes( + statuses, err := convert( filter.IncludedStatuses(), + mapEvalStatusTypes, ) if err != nil { - return errors.New("internal error") + return err } params.Statuses = statuses } if len(filter.ExcludedStatuses()) != 0 { - statuses, err := convertEvalStatusTypes( + statuses, err := convert( filter.ExcludedStatuses(), + mapEvalStatusTypes, ) if err != nil { - return errors.New("internal error") + return err } params.Notstatuses = statuses } + return nil +} - // filters on time range +func paramsFromTimeRangeFilter( + filter TimeRangeFilter, + params *db.ListEvaluationHistoryParams, +) error { if filter.GetFrom() != nil { params.Fromts = sql.NullTime{ Time: *filter.GetFrom(), @@ -388,14 +450,21 @@ func toSQLFilter( Valid: true, } } - return nil } -func convertEntities(values []string) ([]db.Entities, error) { - converted := []db.Entities{} +func convert[ + T db.Entities | + db.RemediationStatusTypes | + db.AlertStatusTypes | + db.EvalStatusTypes, +]( + values []string, + mapf func(string) (T, error), +) ([]T, error) { + converted := []T{} for _, v := range values { - dbObj, err := mapEntities(v) + dbObj, err := mapf(v) if err != nil { return nil, err } @@ -420,20 +489,6 @@ func mapEntities(value string) (db.Entities, error) { } } -func convertRemediationStatusTypes( - values []string, -) ([]db.RemediationStatusTypes, error) { - converted := []db.RemediationStatusTypes{} - for _, v := range values { - dbObj, err := mapRemediationStatusTypes(v) - if err != nil { - return nil, err - } - converted = append(converted, dbObj) - } - return converted, nil -} - //nolint:goconst func mapRemediationStatusTypes( value string, @@ -457,20 +512,6 @@ func mapRemediationStatusTypes( } } -func convertAlertStatusTypes( - values []string, -) ([]db.AlertStatusTypes, error) { - converted := []db.AlertStatusTypes{} - for _, v := range values { - dbObj, err := mapAlertStatusTypes(v) - if err != nil { - return nil, err - } - converted = append(converted, dbObj) - } - return converted, nil -} - //nolint:goconst func mapAlertStatusTypes( value string, @@ -492,20 +533,6 @@ func mapAlertStatusTypes( } } -func convertEvalStatusTypes( - values []string, -) ([]db.EvalStatusTypes, error) { - converted := []db.EvalStatusTypes{} - for _, v := range values { - dbObj, err := mapEvalStatusTypes(v) - if err != nil { - return nil, err - } - converted = append(converted, dbObj) - } - return converted, nil -} - //nolint:goconst func mapEvalStatusTypes( value string,