diff --git a/ast/check.go b/ast/check.go index 23d1ed8fa1..b0388acb9e 100644 --- a/ast/check.go +++ b/ast/check.go @@ -33,15 +33,18 @@ type typeChecker struct { allowNet []string input types.Type allowUndefinedFuncs bool + schemaTypes map[string]types.Type } // newTypeChecker returns a new typeChecker object that has no errors. func newTypeChecker() *typeChecker { - tc := &typeChecker{} - tc.exprCheckers = map[string]exprChecker{ - "eq": tc.checkExprEq, + return &typeChecker{ + builtins: make(map[string]*Builtin), + schemaTypes: make(map[string]types.Type), + exprCheckers: map[string]exprChecker{ + "eq": checkExprEq, + }, } - return tc } func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv { @@ -196,20 +199,42 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { return result } +func (tc *typeChecker) getSchemaType(schemaAnnot *SchemaAnnotation, rule *Rule) (types.Type, *Error) { + if refType, exists := tc.schemaTypes[schemaAnnot.Schema.String()]; exists { + return refType, nil + } + + refType, err := processAnnotation(tc.ss, schemaAnnot, rule, tc.allowNet) + if err != nil { + return nil, err + } + + if refType == nil { + return nil, nil + } + + tc.schemaTypes[schemaAnnot.Schema.String()] = refType + return refType, nil + +} + func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { env = env.wrap() schemaAnnots := getRuleAnnotation(as, rule) for _, schemaAnnot := range schemaAnnots { - ref, refType, err := processAnnotation(tc.ss, schemaAnnot, rule, tc.allowNet) + refType, err := tc.getSchemaType(schemaAnnot, rule) if err != nil { tc.err([]*Error{err}) continue } + + ref := schemaAnnot.Path if ref == nil && refType == nil { continue } + prefixRef, t := getPrefix(env, ref) if t == nil || len(prefixRef) == len(ref) { env.tree.Put(ref, refType) @@ -404,7 +429,7 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { return nil } -func (tc *typeChecker) checkExprEq(env *TypeEnv, expr *Expr) *Error { +func checkExprEq(env *TypeEnv, expr *Expr) *Error { pre := getArgTypes(env, expr.Operands()) exp := Equality.Decl.FuncArgs() @@ -1266,17 +1291,17 @@ func getRuleAnnotation(as *AnnotationSet, rule *Rule) (result []*SchemaAnnotatio return result } -func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allowNet []string) (Ref, types.Type, *Error) { +func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allowNet []string) (types.Type, *Error) { var schema interface{} if annot.Schema != nil { if ss == nil { - return nil, nil, nil + return nil, nil } schema = ss.Get(annot.Schema) if schema == nil { - return nil, nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema) + return nil, NewError(TypeErr, rule.Location, "undefined schema: %v", annot.Schema) } } else if annot.Definition != nil { schema = *annot.Definition @@ -1284,10 +1309,10 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow tpe, err := loadSchema(schema, allowNet) if err != nil { - return nil, nil, NewError(TypeErr, rule.Location, err.Error()) + return nil, NewError(TypeErr, rule.Location, err.Error()) } - return annot.Path, tpe, nil + return tpe, nil } func errAnnotationRedeclared(a *Annotations, other *Location) *Error {