Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCALRCORE-30060 Pygohcl > Add variable validation #24

Merged
merged 5 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 239 additions & 14 deletions pygohcl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,267 @@ import "C"
import (
"encoding/json"
"fmt"
"strings"

"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/ext/tryfunc"
"github.com/hashicorp/hcl/v2/hclparse"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/convert"
"github.com/zclconf/go-cty/cty/function"
"github.com/zclconf/go-cty/cty/function/stdlib"
"strings"
)

//export Parse
func Parse(a *C.char) (resp C.parseResponse) {
defer func() {
if err := recover(); err != nil {
defer func() {
if err := recover(); err != nil {
retValue := fmt.Sprintf("panic HCL: %v", err)
resp = C.parseResponse{nil, C.CString(retValue)}
}
}()
}
}()

input := C.GoString(a)
hclFile, diags := hclparse.NewParser().ParseHCL([]byte(input), "tmp.hcl")
if diags.HasErrors() {
errors := make([]string, 0, len(diags))
for _, diag := range diags {
errors = append(errors, diag.Error())
}

return C.parseResponse{nil, C.CString(fmt.Sprintf("invalid HCL: %s", strings.Join(errors, ", ")))}
return C.parseResponse{nil, C.CString(diagErrorsToString(diags, "invalid HCL: %s"))}
}
hclMap, err := convertFile(hclFile)
if err != nil {
return C.parseResponse{nil, C.CString(fmt.Sprintf("cannot convert HCL to Go map representation: %s", err))}
}
hclInJson, err := json.Marshal(hclMap)
if err != nil {
return C.parseResponse{nil, C.CString(fmt.Sprintf("cannot Go map representation to JSON: %s", err))}
return C.parseResponse{nil, C.CString(fmt.Sprintf("cannot convert Go map representation to JSON: %s", err))}
}
resp = C.parseResponse{C.CString(string(hclInJson)), nil}

return
}

func main() {
//export ParseAttributes
func ParseAttributes(a *C.char) (resp C.parseResponse) {
defer func() {
if err := recover(); err != nil {
retValue := fmt.Sprintf("panic HCL: %v", err)
resp = C.parseResponse{nil, C.CString(retValue)}
}
}()

input := C.GoString(a)
hclFile, parseDiags := hclsyntax.ParseConfig([]byte(input), "tmp.hcl", hcl.InitialPos)
if parseDiags.HasErrors() {
return C.parseResponse{nil, C.CString(diagErrorsToString(parseDiags, "invalid HCL: %s"))}
}

var diags hcl.Diagnostics
hclMap := make(jsonObj)
c := converter{}

attrs, attrsDiags := hclFile.Body.JustAttributes()
diags = diags.Extend(attrsDiags)

for _, attr := range attrs {
_, valueDiags := attr.Expr.Value(nil)
diags = diags.Extend(valueDiags)
if valueDiags.HasErrors() {
continue
}

value, err := c.convertExpression(attr.Expr.(hclsyntax.Expression))
if err != nil {
diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Error processing variable value",
Detail: fmt.Sprintf("Cannot convert HCL to Go map representation: %s.", err),
Subject: attr.NameRange.Ptr(),
})
continue
}

hclMap[attr.Name] = value
}

hclInJson, err := json.Marshal(hclMap)
if err != nil {
diags.Append(&hcl.Diagnostic{
Severity: hcl.DiagError,
Summary: "Error preparing JSON result",
Detail: fmt.Sprintf("Cannot convert Go map representation to JSON: %s.", err),
})
return C.parseResponse{nil, C.CString(diagErrorsToString(diags, ""))}
}
if diags.HasErrors() {
resp = C.parseResponse{C.CString(string(hclInJson)), C.CString(diagErrorsToString(diags, ""))}
} else {
resp = C.parseResponse{C.CString(string(hclInJson)), nil}
}

return
}

//export EvalValidationRule
func EvalValidationRule(c *C.char, e *C.char, n *C.char, v *C.char) (resp *C.char) {
defer func() {
if err := recover(); err != nil {
retValue := fmt.Sprintf("panic HCL: %v", err)
resp = C.CString(retValue)
}
}()

condition := C.GoString(c)
errorMsg := C.GoString(e)
varName := C.GoString(n)
varValue := C.GoString(v)

// First evaluate variable value to get its cty representation

varValueCty, diags := expressionValue(varValue, nil)
if diags.HasErrors() {
if containsError(diags, "Variables not allowed") {
// Try again to handle the case when a string value was provided without enclosing quotes
varValueCty, diags = expressionValue(fmt.Sprintf("%q", varValue), nil)
}
}
if diags.HasErrors() {
return C.CString(diagErrorsToString(diags, "cannot process variable value: %s"))
}

// Now evaluate the condition

hclCtx := &hcl.EvalContext{
Variables: map[string]cty.Value{
"var": cty.ObjectVal(map[string]cty.Value{
varName: varValueCty,
}),
},
Functions: knownFunctions,
}
conditionCty, diags := expressionValue(condition, hclCtx)
if diags.HasErrors() {
return C.CString(diagErrorsToString(diags, "cannot process condition expression: %s"))
}

if conditionCty.IsNull() {
return C.CString("condition expression result is null")
}

conditionCty, err := convert.Convert(conditionCty, cty.Bool)
if err != nil {
return C.CString("condition expression result must be bool")
}

if conditionCty.True() {
return nil
}

// Finally evaluate the error message expression

var errorMsgValue = "cannot process error message expression"
errorMsgCty, diags := expressionValue(errorMsg, hclCtx)
if diags.HasErrors() {
errorMsgCty, diags = expressionValue(fmt.Sprintf("%q", errorMsg), hclCtx)
}
if !diags.HasErrors() && !errorMsgCty.IsNull() {
errorMsgCty, err = convert.Convert(errorMsgCty, cty.String)
if err == nil {
errorMsgValue = errorMsgCty.AsString()
}
}
return C.CString(errorMsgValue)
}

func diagErrorsToString(diags hcl.Diagnostics, format string) string {
diagErrs := diags.Errs()
errors := make([]string, 0, len(diagErrs))
for _, err := range diagErrs {
errors = append(errors, err.Error())
}
if format == "" {
return strings.Join(errors, ", ")
}
return fmt.Sprintf(format, strings.Join(errors, ", "))
}

func containsError(diags hcl.Diagnostics, e string) bool {
for _, err := range diags.Errs() {
if strings.Contains(err.Error(), e) {
return true
}
}
return false
}

func expressionValue(in string, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) {
var diags hcl.Diagnostics

expr, diags := hclsyntax.ParseExpression([]byte(in), "tmp.hcl", hcl.InitialPos)
if diags.HasErrors() {
return cty.NilVal, diags
}

val, diags := expr.Value(ctx)
if diags.HasErrors() {
return cty.NilVal, diags
}

return val, diags
}

var knownFunctions = map[string]function.Function{
"abs": stdlib.AbsoluteFunc,
"can": tryfunc.CanFunc,
"ceil": stdlib.CeilFunc,
"chomp": stdlib.ChompFunc,
"coalescelist": stdlib.CoalesceListFunc,
"compact": stdlib.CompactFunc,
"concat": stdlib.ConcatFunc,
"contains": stdlib.ContainsFunc,
"csvdecode": stdlib.CSVDecodeFunc,
"distinct": stdlib.DistinctFunc,
"element": stdlib.ElementFunc,
"chunklist": stdlib.ChunklistFunc,
"flatten": stdlib.FlattenFunc,
"floor": stdlib.FloorFunc,
"format": stdlib.FormatFunc,
"formatdate": stdlib.FormatDateFunc,
"formatlist": stdlib.FormatListFunc,
"indent": stdlib.IndentFunc,
"join": stdlib.JoinFunc,
"jsondecode": stdlib.JSONDecodeFunc,
"jsonencode": stdlib.JSONEncodeFunc,
"keys": stdlib.KeysFunc,
"log": stdlib.LogFunc,
"lower": stdlib.LowerFunc,
"max": stdlib.MaxFunc,
"merge": stdlib.MergeFunc,
"min": stdlib.MinFunc,
"parseint": stdlib.ParseIntFunc,
"pow": stdlib.PowFunc,
"range": stdlib.RangeFunc,
"regex": stdlib.RegexFunc,
"regexall": stdlib.RegexAllFunc,
"reverse": stdlib.ReverseListFunc,
"setintersection": stdlib.SetIntersectionFunc,
"setproduct": stdlib.SetProductFunc,
"setsubtract": stdlib.SetSubtractFunc,
"setunion": stdlib.SetUnionFunc,
"signum": stdlib.SignumFunc,
"slice": stdlib.SliceFunc,
"sort": stdlib.SortFunc,
"split": stdlib.SplitFunc,
"strrev": stdlib.ReverseFunc,
"substr": stdlib.SubstrFunc,
"timeadd": stdlib.TimeAddFunc,
"title": stdlib.TitleFunc,
"trim": stdlib.TrimFunc,
"trimprefix": stdlib.TrimPrefixFunc,
"trimspace": stdlib.TrimSpaceFunc,
"trimsuffix": stdlib.TrimSuffixFunc,
"try": tryfunc.TryFunc,
"upper": stdlib.UpperFunc,
"values": stdlib.ValuesFunc,
"zipmap": stdlib.ZipmapFunc,
}

func main() {}
90 changes: 90 additions & 0 deletions pygohcl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class HCLInternalError(Exception):
pass


class ValidationError(Exception):
pass


class UnknownFunctionError(ValidationError):
pass


def loadb(data: bytes) -> tp.Dict:
s = ffi.new("char[]", data)
ret = lib.Parse(s)
Expand All @@ -46,3 +54,85 @@ def loads(data: str) -> tp.Dict:
def load(stream: tp.IO) -> tp.Dict:
data = stream.read()
return loadb(data)


def attributes_loadb(data: bytes) -> tp.Dict:
"""
Like :func:`pygohcl.loadb`,
but expects from the input to contain only top-level attributes.

Example:
>>> hcl = '''
... key1 = "value"
... key2 = false
... key3 = [1, 2, 3]
... '''
>>> import pygohcl
>>> print(pygohcl.attributes_loads(hcl))
{'key1': 'value', 'key2': False, 'key3': [1, 2, 3]}

:raises HCLParseError: when the provided input cannot be parsed as valid HCL,
or it contains other blocks, not only attributes.
"""
s = ffi.new("char[]", data)
ret = lib.ParseAttributes(s)
if ret.err != ffi.NULL:
err: bytes = ffi.string(ret.err)
ffi.gc(ret.err, lib.free)
err = err.decode("utf8")
raise HCLParseError(err)
ret_json = ffi.string(ret.json)
ffi.gc(ret.json, lib.free)
return json.loads(ret_json)


def attributes_loads(data: str) -> tp.Dict:
return attributes_loadb(data.encode("utf8"))


def attributes_load(stream: tp.IO) -> tp.Dict:
data = stream.read()
return attributes_loadb(data)


def eval_var_condition(
condition: str, error_message: str, variable_name: str, variable_value: str
) -> None:
"""
This is specific to Terraform/OpenTofu configuration language
and is meant to evaluate results of the `validation` block of a variable definition.

This comes with a limited selection of supported functions.
Terraform/OpenTofu expand this list with their own set
of useful functions, which will not pass this validation.
For that reason a separate `UnknownFunctionError` is raised then,
so the consumer can decide how to treat this case.

Example:
>>> import pygohcl
>>> pygohcl.eval_var_condition(
... condition="var.count < 3",
... error_message="count must be less than 3, but ${var.count} was given",
... variable_name="count",
... variable_value="5",
... )
Traceback (most recent call last):
...
pygohcl.ValidationError: count must be less than 3, but 5 was given

:raises ValidationError: when the condition expression has not evaluated to `True`
:raises UnknownFunctionError: when the condition expression refers to a function
that is not known to the library
"""
c = ffi.new("char[]", condition.encode("utf8"))
e = ffi.new("char[]", error_message.encode("utf8"))
n = ffi.new("char[]", variable_name.encode("utf8"))
v = ffi.new("char[]", variable_value.encode("utf8"))
ret = lib.EvalValidationRule(c, e, n, v)
if ret != ffi.NULL:
err: bytes = ffi.string(ret)
ffi.gc(ret, lib.free)
err = err.decode("utf8")
if "Call to unknown function" in err:
raise UnknownFunctionError(err)
raise ValidationError(err)
2 changes: 2 additions & 0 deletions pygohcl/build_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
} parseResponse;

parseResponse Parse(char* a);
parseResponse ParseAttributes(char* a);
char* EvalValidationRule(char* c, char* e, char* n, char* v);
void free(void *ptr);
"""
)
Expand Down
Loading
Loading