diff --git a/pygohcl.go b/pygohcl.go index d2ddd95..fde9d2b 100644 --- a/pygohcl.go +++ b/pygohcl.go @@ -8,29 +8,30 @@ import "C" import ( "encoding/json" "fmt" - "strings" - + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/ext/tryfunc" "github.com/hashicorp/hcl/v2/hclparse" + "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/convert" + "github.com/zclconf/go-cty/cty/function" + "github.com/zclconf/go-cty/cty/function/stdlib" + "strings" ) //export Parse func Parse(a *C.char) (resp C.parseResponse) { - defer func() { - if err := recover(); err != nil { + defer func() { + if err := recover(); err != nil { retValue := fmt.Sprintf("panic HCL: %v", err) resp = C.parseResponse{nil, C.CString(retValue)} - } - }() + } + }() input := C.GoString(a) hclFile, diags := hclparse.NewParser().ParseHCL([]byte(input), "tmp.hcl") if diags.HasErrors() { - errors := make([]string, 0, len(diags)) - for _, diag := range diags { - errors = append(errors, diag.Error()) - } - - return C.parseResponse{nil, C.CString(fmt.Sprintf("invalid HCL: %s", strings.Join(errors, ", ")))} + return C.parseResponse{nil, C.CString(diagErrorsToString(diags, "invalid HCL: %s"))} } hclMap, err := convertFile(hclFile) if err != nil { @@ -38,12 +39,236 @@ func Parse(a *C.char) (resp C.parseResponse) { } hclInJson, err := json.Marshal(hclMap) if err != nil { - return C.parseResponse{nil, C.CString(fmt.Sprintf("cannot Go map representation to JSON: %s", err))} + return C.parseResponse{nil, C.CString(fmt.Sprintf("cannot convert Go map representation to JSON: %s", err))} } resp = C.parseResponse{C.CString(string(hclInJson)), nil} return } -func main() { +//export ParseAttributes +func ParseAttributes(a *C.char) (resp C.parseResponse) { + defer func() { + if err := recover(); err != nil { + retValue := fmt.Sprintf("panic HCL: %v", err) + resp = C.parseResponse{nil, C.CString(retValue)} + } + }() + + input := C.GoString(a) + hclFile, parseDiags := hclsyntax.ParseConfig([]byte(input), "tmp.hcl", hcl.InitialPos) + if parseDiags.HasErrors() { + return C.parseResponse{nil, C.CString(diagErrorsToString(parseDiags, "invalid HCL: %s"))} + } + + var diags hcl.Diagnostics + hclMap := make(jsonObj) + c := converter{} + + attrs, attrsDiags := hclFile.Body.JustAttributes() + diags = diags.Extend(attrsDiags) + + for _, attr := range attrs { + _, valueDiags := attr.Expr.Value(nil) + diags = diags.Extend(valueDiags) + if valueDiags.HasErrors() { + continue + } + + value, err := c.convertExpression(attr.Expr.(hclsyntax.Expression)) + if err != nil { + diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Error processing variable value", + Detail: fmt.Sprintf("Cannot convert HCL to Go map representation: %s.", err), + Subject: attr.NameRange.Ptr(), + }) + continue + } + + hclMap[attr.Name] = value + } + + hclInJson, err := json.Marshal(hclMap) + if err != nil { + diags.Append(&hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Error preparing JSON result", + Detail: fmt.Sprintf("Cannot convert Go map representation to JSON: %s.", err), + }) + return C.parseResponse{nil, C.CString(diagErrorsToString(diags, ""))} + } + if diags.HasErrors() { + resp = C.parseResponse{C.CString(string(hclInJson)), C.CString(diagErrorsToString(diags, ""))} + } else { + resp = C.parseResponse{C.CString(string(hclInJson)), nil} + } + + return +} + +//export EvalValidationRule +func EvalValidationRule(c *C.char, e *C.char, n *C.char, v *C.char) (resp *C.char) { + defer func() { + if err := recover(); err != nil { + retValue := fmt.Sprintf("panic HCL: %v", err) + resp = C.CString(retValue) + } + }() + + condition := C.GoString(c) + errorMsg := C.GoString(e) + varName := C.GoString(n) + varValue := C.GoString(v) + + // First evaluate variable value to get its cty representation + + varValueCty, diags := expressionValue(varValue, nil) + if diags.HasErrors() { + if containsError(diags, "Variables not allowed") { + // Try again to handle the case when a string value was provided without enclosing quotes + varValueCty, diags = expressionValue(fmt.Sprintf("%q", varValue), nil) + } + } + if diags.HasErrors() { + return C.CString(diagErrorsToString(diags, "cannot process variable value: %s")) + } + + // Now evaluate the condition + + hclCtx := &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "var": cty.ObjectVal(map[string]cty.Value{ + varName: varValueCty, + }), + }, + Functions: knownFunctions, + } + conditionCty, diags := expressionValue(condition, hclCtx) + if diags.HasErrors() { + return C.CString(diagErrorsToString(diags, "cannot process condition expression: %s")) + } + + if conditionCty.IsNull() { + return C.CString("condition expression result is null") + } + + conditionCty, err := convert.Convert(conditionCty, cty.Bool) + if err != nil { + return C.CString("condition expression result must be bool") + } + + if conditionCty.True() { + return nil + } + + // Finally evaluate the error message expression + + var errorMsgValue = "cannot process error message expression" + errorMsgCty, diags := expressionValue(errorMsg, hclCtx) + if diags.HasErrors() { + errorMsgCty, diags = expressionValue(fmt.Sprintf("%q", errorMsg), hclCtx) + } + if !diags.HasErrors() && !errorMsgCty.IsNull() { + errorMsgCty, err = convert.Convert(errorMsgCty, cty.String) + if err == nil { + errorMsgValue = errorMsgCty.AsString() + } + } + return C.CString(errorMsgValue) +} + +func diagErrorsToString(diags hcl.Diagnostics, format string) string { + diagErrs := diags.Errs() + errors := make([]string, 0, len(diagErrs)) + for _, err := range diagErrs { + errors = append(errors, err.Error()) + } + if format == "" { + return strings.Join(errors, ", ") + } + return fmt.Sprintf(format, strings.Join(errors, ", ")) +} + +func containsError(diags hcl.Diagnostics, e string) bool { + for _, err := range diags.Errs() { + if strings.Contains(err.Error(), e) { + return true + } + } + return false } + +func expressionValue(in string, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { + var diags hcl.Diagnostics + + expr, diags := hclsyntax.ParseExpression([]byte(in), "tmp.hcl", hcl.InitialPos) + if diags.HasErrors() { + return cty.NilVal, diags + } + + val, diags := expr.Value(ctx) + if diags.HasErrors() { + return cty.NilVal, diags + } + + return val, diags +} + +var knownFunctions = map[string]function.Function{ + "abs": stdlib.AbsoluteFunc, + "can": tryfunc.CanFunc, + "ceil": stdlib.CeilFunc, + "chomp": stdlib.ChompFunc, + "coalescelist": stdlib.CoalesceListFunc, + "compact": stdlib.CompactFunc, + "concat": stdlib.ConcatFunc, + "contains": stdlib.ContainsFunc, + "csvdecode": stdlib.CSVDecodeFunc, + "distinct": stdlib.DistinctFunc, + "element": stdlib.ElementFunc, + "chunklist": stdlib.ChunklistFunc, + "flatten": stdlib.FlattenFunc, + "floor": stdlib.FloorFunc, + "format": stdlib.FormatFunc, + "formatdate": stdlib.FormatDateFunc, + "formatlist": stdlib.FormatListFunc, + "indent": stdlib.IndentFunc, + "join": stdlib.JoinFunc, + "jsondecode": stdlib.JSONDecodeFunc, + "jsonencode": stdlib.JSONEncodeFunc, + "keys": stdlib.KeysFunc, + "log": stdlib.LogFunc, + "lower": stdlib.LowerFunc, + "max": stdlib.MaxFunc, + "merge": stdlib.MergeFunc, + "min": stdlib.MinFunc, + "parseint": stdlib.ParseIntFunc, + "pow": stdlib.PowFunc, + "range": stdlib.RangeFunc, + "regex": stdlib.RegexFunc, + "regexall": stdlib.RegexAllFunc, + "reverse": stdlib.ReverseListFunc, + "setintersection": stdlib.SetIntersectionFunc, + "setproduct": stdlib.SetProductFunc, + "setsubtract": stdlib.SetSubtractFunc, + "setunion": stdlib.SetUnionFunc, + "signum": stdlib.SignumFunc, + "slice": stdlib.SliceFunc, + "sort": stdlib.SortFunc, + "split": stdlib.SplitFunc, + "strrev": stdlib.ReverseFunc, + "substr": stdlib.SubstrFunc, + "timeadd": stdlib.TimeAddFunc, + "title": stdlib.TitleFunc, + "trim": stdlib.TrimFunc, + "trimprefix": stdlib.TrimPrefixFunc, + "trimspace": stdlib.TrimSpaceFunc, + "trimsuffix": stdlib.TrimSuffixFunc, + "try": tryfunc.TryFunc, + "upper": stdlib.UpperFunc, + "values": stdlib.ValuesFunc, + "zipmap": stdlib.ZipmapFunc, +} + +func main() {} diff --git a/pygohcl/__init__.py b/pygohcl/__init__.py index d143c1e..cd4c59e 100644 --- a/pygohcl/__init__.py +++ b/pygohcl/__init__.py @@ -24,6 +24,14 @@ class HCLInternalError(Exception): pass +class ValidationError(Exception): + pass + + +class UnknownFunctionError(ValidationError): + pass + + def loadb(data: bytes) -> tp.Dict: s = ffi.new("char[]", data) ret = lib.Parse(s) @@ -46,3 +54,85 @@ def loads(data: str) -> tp.Dict: def load(stream: tp.IO) -> tp.Dict: data = stream.read() return loadb(data) + + +def attributes_loadb(data: bytes) -> tp.Dict: + """ + Like :func:`pygohcl.loadb`, + but expects from the input to contain only top-level attributes. + + Example: + >>> hcl = ''' + ... key1 = "value" + ... key2 = false + ... key3 = [1, 2, 3] + ... ''' + >>> import pygohcl + >>> print(pygohcl.attributes_loads(hcl)) + {'key1': 'value', 'key2': False, 'key3': [1, 2, 3]} + + :raises HCLParseError: when the provided input cannot be parsed as valid HCL, + or it contains other blocks, not only attributes. + """ + s = ffi.new("char[]", data) + ret = lib.ParseAttributes(s) + if ret.err != ffi.NULL: + err: bytes = ffi.string(ret.err) + ffi.gc(ret.err, lib.free) + err = err.decode("utf8") + raise HCLParseError(err) + ret_json = ffi.string(ret.json) + ffi.gc(ret.json, lib.free) + return json.loads(ret_json) + + +def attributes_loads(data: str) -> tp.Dict: + return attributes_loadb(data.encode("utf8")) + + +def attributes_load(stream: tp.IO) -> tp.Dict: + data = stream.read() + return attributes_loadb(data) + + +def eval_var_condition( + condition: str, error_message: str, variable_name: str, variable_value: str +) -> None: + """ + This is specific to Terraform/OpenTofu configuration language + and is meant to evaluate results of the `validation` block of a variable definition. + + This comes with a limited selection of supported functions. + Terraform/OpenTofu expand this list with their own set + of useful functions, which will not pass this validation. + For that reason a separate `UnknownFunctionError` is raised then, + so the consumer can decide how to treat this case. + + Example: + >>> import pygohcl + >>> pygohcl.eval_var_condition( + ... condition="var.count < 3", + ... error_message="count must be less than 3, but ${var.count} was given", + ... variable_name="count", + ... variable_value="5", + ... ) + Traceback (most recent call last): + ... + pygohcl.ValidationError: count must be less than 3, but 5 was given + + :raises ValidationError: when the condition expression has not evaluated to `True` + :raises UnknownFunctionError: when the condition expression refers to a function + that is not known to the library + """ + c = ffi.new("char[]", condition.encode("utf8")) + e = ffi.new("char[]", error_message.encode("utf8")) + n = ffi.new("char[]", variable_name.encode("utf8")) + v = ffi.new("char[]", variable_value.encode("utf8")) + ret = lib.EvalValidationRule(c, e, n, v) + if ret != ffi.NULL: + err: bytes = ffi.string(ret) + ffi.gc(ret, lib.free) + err = err.decode("utf8") + if "Call to unknown function" in err: + raise UnknownFunctionError(err) + raise ValidationError(err) diff --git a/pygohcl/build_cffi.py b/pygohcl/build_cffi.py index 9a44c37..de5098f 100644 --- a/pygohcl/build_cffi.py +++ b/pygohcl/build_cffi.py @@ -18,6 +18,8 @@ } parseResponse; parseResponse Parse(char* a); + parseResponse ParseAttributes(char* a); + char* EvalValidationRule(char* c, char* e, char* n, char* v); void free(void *ptr); """ ) diff --git a/tests/test_attributes.py b/tests/test_attributes.py new file mode 100644 index 0000000..69eef89 --- /dev/null +++ b/tests/test_attributes.py @@ -0,0 +1,87 @@ +import pytest +import pygohcl + + +def test_basic(): + s = """ + var1 = "value" + var2 = 2 + var3 = true + """ + assert pygohcl.attributes_loads(s) == {"var1": "value", "var2": 2, "var3": True} + + +def test_list(): + s = """ + var1 = ["value1", "value2", "value3"] + var2 = [1, 2, 3] + var3 = [true, false] + """ + assert pygohcl.attributes_loads(s) == { + "var1": ["value1", "value2", "value3"], + "var2": [1, 2, 3], + "var3": [True, False], + } + + +def test_empty_list(): + s = """ + var = [] + """ + assert pygohcl.attributes_loads(s) == {"var": []} + + +def test_non_hcl(): + s = """ + + """ + with pytest.raises(pygohcl.HCLParseError) as err: + pygohcl.attributes_loads(s) + assert "invalid HCL" in str(err.value) + + +def test_non_attributes(): + """ + When the content is mixed with not expected but valid HCL. + """ + s = """ + var = "value" + variable "test" {} + """ + with pytest.raises(pygohcl.HCLParseError) as err: + pygohcl.attributes_loads(s) + assert "Blocks are not allowed" in str(err.value) + + +def test_variable_in_value(): + s = """ + var1 = "value" + var2 = value + """ + with pytest.raises(pygohcl.HCLParseError) as err: + pygohcl.attributes_loads(s) + assert "Variables not allowed" in str(err.value) + + +def test_multiple_errors(): + """ + Make sure the processing doesn't stop at first error and all found issues are reported. + """ + s = """ + var = value + variable "test" {} + """ + with pytest.raises(pygohcl.HCLParseError) as err: + pygohcl.attributes_loads(s) + assert "Variables not allowed" in str(err.value) + assert "Blocks are not allowed" in str(err.value) + + +def test_heredoc(): + s = """ + var = <