Skip to content

Commit

Permalink
fix: add error based on expected regal dir (#1343)
Browse files Browse the repository at this point in the history
Fixes #1341

Signed-off-by: Charlie Egan <[email protected]>
  • Loading branch information
charlieegan3 authored Jan 16, 2025
1 parent df90dec commit 2c6ee8e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 32 deletions.
15 changes: 9 additions & 6 deletions cmd/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,18 @@ func fix(args []string, params *fixCommandParams) error {
return fmt.Errorf("could not find potential roots: %w", err)
}

versionsMap, err := config.AllRegoVersions(regalDir.Name(), &userConfig)
if err != nil {
return fmt.Errorf("failed to get all Rego versions: %w", err)
}

f := fixer.NewFixer()
f.RegisterRoots(roots...)
f.RegisterFixes(fixes.NewDefaultFixes()...)
f.SetRegoVersionsMap(versionsMap)

if userConfigFile != nil {
versionsMap, err := config.AllRegoVersions(filepath.Dir(userConfigFile.Name()), &userConfig)
if err != nil {
return fmt.Errorf("failed to get all Rego versions: %w", err)
}

f.SetRegoVersionsMap(versionsMap)
}

if !slices.Contains([]string{"error", "rename"}, params.conflictMode) {
return fmt.Errorf("invalid conflict mode: %s, expected 'error' or 'rename'", params.conflictMode)
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ func FromMap(confMap map[string]any) (Config, error) {
func AllRegoVersions(root string, conf *Config) (map[string]ast.RegoVersion, error) {
versionsMap := make(map[string]ast.RegoVersion)

if conf == nil {
return versionsMap, nil
}

manifestLocations, err := rio.FindManifestLocations(root)
if err != nil {
return nil, fmt.Errorf("failed to find manifest locations: %w", err)
Expand Down
70 changes: 47 additions & 23 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,37 +542,61 @@ func TestUnmarshalProjectRootsAsStringOrObject(t *testing.T) {
func TestAllRegoVersions(t *testing.T) {
t.Parallel()

bs := []byte(`project:
testCases := map[string]struct {
Config string
FS map[string]string
Expected map[string]ast.RegoVersion
}{
"values from config": {
Config: `project:
rego-version: 0
roots:
- path: foo
rego-version: 1
`)
`,
FS: map[string]string{
"bar/baz/.manifest": `{"rego_version": 1}`,
},
Expected: map[string]ast.RegoVersion{
"": ast.RegoV0,
"bar/baz": ast.RegoV1,
"foo": ast.RegoV1,
},
},
"no config": {
Config: "",
FS: map[string]string{
"bar/baz/.manifest": `{"rego_version": 1}`,
},
Expected: map[string]ast.RegoVersion{},
},
}

var conf Config
for testName, testData := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()

if err := yaml.Unmarshal(bs, &conf); err != nil {
t.Fatal(err)
}
var conf *Config

fs := map[string]string{
"bar/baz/.manifest": `{"rego_version": 1}`,
}
if testData.Config != "" {
var loadedConf Config
if err := yaml.Unmarshal([]byte(testData.Config), &loadedConf); err != nil {
t.Fatal(err)
}

test.WithTempFS(fs, func(root string) {
versions, err := AllRegoVersions(root, &conf)
if err != nil {
t.Fatal(err)
}
conf = &loadedConf
}

expected := map[string]ast.RegoVersion{
"": ast.RegoV0,
"foo": ast.RegoV1,
"bar/baz": ast.RegoV1,
}
test.WithTempFS(testData.FS, func(root string) {
versions, err := AllRegoVersions(root, conf)
if err != nil {
t.Fatal(err)
}

if !maps.Equal(versions, expected) {
t.Errorf("expected %v, got %v", expected, versions)
}
})
if !maps.Equal(versions, testData.Expected) {
t.Errorf("expected %v, got %v", testData.Expected, versions)
}
})
})
}
}
7 changes: 4 additions & 3 deletions pkg/fixer/fixer.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,15 @@ func (f *Fixer) applyLinterFixes(
return fmt.Errorf("failed to list files: %w", err)
}

if f.versionsMap == nil {
return errors.New("rego versions map not set")
var versionsMap map[string]ast.RegoVersion
if f.versionsMap != nil {
versionsMap = f.versionsMap
}

for {
fixMadeInIteration := false

in, err := fp.ToInput(f.versionsMap)
in, err := fp.ToInput(versionsMap)
if err != nil {
return fmt.Errorf("failed to generate linter input: %w", err)
}
Expand Down

0 comments on commit 2c6ee8e

Please sign in to comment.