Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve handling for YAML version directives #1038

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions parser/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,35 @@ func (yp *Parser) Unmarshal(p []byte, v interface{}) error {
}

func separateSubDocuments(data []byte) [][]byte {
// Determine line ending style
linebreak := "\n"
if bytes.Contains(data, []byte("\r\n---\r\n")) {
if bytes.Contains(data, []byte("\r\n")) {
linebreak = "\r\n"
}
separator := fmt.Sprintf("%s---%s", linebreak, linebreak)

return bytes.Split(data, []byte(linebreak+"---"+linebreak))
// Count actual document separators
parts := bytes.Split(data, []byte(separator))

// If we have a directive, first part is not a separate document
if bytes.HasPrefix(data, []byte("%")) {
if len(parts) <= 2 {
// Single document with directive
return [][]byte{data}
}
// Multiple documents - combine directive with first real document
firstDoc := append(parts[0], append([]byte(separator), parts[1]...)...)
result := [][]byte{firstDoc}
result = append(result, parts[2:]...)
return result
}

// No directive case
if len(parts) <= 1 {
// Single document
return [][]byte{data}
}
return parts
}

func unmarshalMultipleDocuments(subDocuments [][]byte, v interface{}) error {
Expand Down
82 changes: 77 additions & 5 deletions parser/yaml/yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml_test

import (
"reflect"
"strings"
"testing"

"github.com/open-policy-agent/conftest/parser/yaml"
Expand All @@ -15,6 +16,12 @@ func TestYAMLParser(t *testing.T) {
expectedResult interface{}
shouldError bool
}{
{
name: "empty config",
controlConfigs: []byte(``),
expectedResult: nil,
shouldError: false,
},
{
name: "a single config",
controlConfigs: []byte(`sample: true`),
Expand Down Expand Up @@ -44,21 +51,86 @@ nice: true`),
},
shouldError: false,
},
{
name: "a single config with multiple yaml subdocs with crlf line endings",
controlConfigs: []byte(strings.ReplaceAll(`---
sample: true
---
hello: true
---
nice: true`, "\n", "\r\n")),
expectedResult: []interface{}{
map[string]interface{}{
"sample": true,
},
map[string]interface{}{
"hello": true,
},
map[string]interface{}{
"nice": true,
},
},
shouldError: false,
},
{
name: "multiple documents with one invalid yaml",
controlConfigs: []byte(`---
valid: true
---
invalid:
- not closed
[
---
also_valid: true`),
expectedResult: nil,
shouldError: true,
},
{
name: "yaml with version directive",
controlConfigs: []byte(`%YAML 1.1
---
group_id: 1234`),
expectedResult: map[string]interface{}{
"group_id": float64(1234),
},
shouldError: false,
},
{
name: "yaml with version directive and multiple documents",
controlConfigs: []byte(`%YAML 1.1
---
group_id: 1234
---
other_id: 5678
---
third_id: 9012`),
expectedResult: []interface{}{
map[string]interface{}{
"group_id": float64(1234),
},
map[string]interface{}{
"other_id": float64(5678),
},
map[string]interface{}{
"third_id": float64(9012),
},
},
shouldError: false,
},
}

for _, test := range testTable {
t.Run(test.name, func(t *testing.T) {
var unmarshalledConfigs interface{}
yamlParser := new(yaml.Parser)

if err := yamlParser.Unmarshal(test.controlConfigs, &unmarshalledConfigs); err != nil {
err := yamlParser.Unmarshal(test.controlConfigs, &unmarshalledConfigs)
if test.shouldError && err == nil {
t.Error("expected error but got none")
} else if !test.shouldError && err != nil {
t.Errorf("errors unmarshalling: %v", err)
}

if unmarshalledConfigs == nil {
t.Error("error seeing actual value in object, received nil")
}

if !reflect.DeepEqual(test.expectedResult, unmarshalledConfigs) {
t.Errorf("Expected\n%T : %v\n to equal\n%T : %v\n", unmarshalledConfigs, unmarshalledConfigs, test.expectedResult, test.expectedResult)
}
Expand Down
Loading