From 1bf12a9ca460dfcdad845c5226cdb995b7368e24 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Thu, 5 Sep 2024 16:08:17 +0200 Subject: [PATCH] planner: fix with replacement with local-var values Every function down the call stacks of a with-replacement that does this will now get extra variables added to its args to carry along the local variables that will eventually be needed to replace the with-mocked func call. TODOs: - [ ] think hard about variable name clashes (do we need a func for this?) - [ ] more test cases Fixes #5311. Signed-off-by: Stephan Renatus --- internal/planner/planner.go | 27 +++++++++----- internal/planner/rules.go | 13 +++++++ .../withkeyword/test-with-builtin-mock.yaml | 35 +++++++++++++++++++ 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/internal/planner/planner.go b/internal/planner/planner.go index b75d26ddab..6ebc49adc8 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -160,7 +160,7 @@ func (p *Planner) buildFunctrie() error { return nil } -func (p *Planner) planRules(rules []*ast.Rule) (string, error) { +func (p *Planner) planRules(rules []*ast.Rule, extras ...*ast.Term) (string, error) { // We know the rules with closer to the root (shorter static path) are ordered first. pathRef := rules[0].Ref() @@ -223,10 +223,9 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { } // Initialize parameters for functions. - for i := 0; i < len(rules[0].Head.Args); i++ { + for i := 0; i < len(rules[0].Head.Args)+len(extras); i++ { fn.Params = append(fn.Params, p.newLocal()) } - params := fn.Params[2:] // Initialize return value for partial set/object rules. Complete document @@ -322,7 +321,10 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { // Complete and partial rules are treated as special cases of // functions. If there are no args, the first step is a no-op. - err := p.planFuncParams(params, rule.Head.Args, 0, func() error { + + // If there are local variables to carry along for with-replacements, add them here + args := append(rule.Head.Args.Copy(), extras...) + err := p.planFuncParams(params, args, 0, func() error { // Run planner on the rule body. return p.planQuery(rule.Body, 0, func() error { @@ -690,6 +692,11 @@ func (p *Planner) planWith(e *ast.Expr, iter planiter) error { continue // not a mock } + // NOTE(sr): If w.Target is a function (built-in or user-defined) and w.Value is a variable, + // recording it here will make all the functions (rules or user-defined functions) use an + // extra argument to pass along the locally-defined variable. That way, if need in a nested + // call stack, w.Target needs to be replaced, we'll have the variable available in local + // scope, too. mocks[w.Target.String()] = w.Value } @@ -990,6 +997,8 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { operands := e.Operands() op := e.Operator() + extras := p.mocks.ExtraVars() + if replacement := p.mocks.Lookup(operator); replacement != nil { switch r := replacement.Value.(type) { case ast.Ref: @@ -1018,7 +1027,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { } } - // replacement is a value, or ref + // replacement is a value, or var if bi, ok := p.decls[operator]; ok { return p.planExprCallValue(replacement, len(bi.Decl.FuncArgs().Args), operands, iter) } @@ -1028,8 +1037,9 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return fmt.Errorf("illegal replacement of operator %q by %v", operator, replacement) // should be unreachable } + // plan g(...) with extra arguments (filled in by planExprCallFunc) if node := p.rules.Lookup(op); node != nil { - name, err = p.planRules(node.Rules()) + name, err = p.planRules(node.Rules(), extras...) if err != nil { return err } @@ -1046,14 +1056,15 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { } if len(operands) < arity || len(operands) > arity+1 { - return fmt.Errorf("illegal call: wrong number of operands: got %v, want %v)", len(operands), arity) + return fmt.Errorf("illegal call: wrong number of operands: got %v, want %v", len(operands), arity) } if relation { return p.planExprCallRelation(name, arity, operands, args, iter) } - return p.planExprCallFunc(name, arity, void, operands, args, iter) + // inject extras into call to g(...) which may be calling f(...) indirectly + return p.planExprCallFunc(name, arity+len(extras), void, append(operands, extras...), args, iter) } } diff --git a/internal/planner/rules.go b/internal/planner/rules.go index f5d6f3fc6c..401ef8e2b6 100644 --- a/internal/planner/rules.go +++ b/internal/planner/rules.go @@ -306,3 +306,16 @@ func (s *functionMocksStack) Lookup(f string) *ast.Term { } return nil } + +func (s *functionMocksStack) ExtraVars() []*ast.Term { + exs := make([]*ast.Term, 0) + current := *s.stack[len(s.stack)-1] + for i := len(current) - 1; i >= 0; i-- { + for _, t := range current[i] { + if _, ok := t.Value.(ast.Var); ok { + exs = append(exs, t) + } + } + } + return exs +} diff --git a/test/cases/testdata/v0/withkeyword/test-with-builtin-mock.yaml b/test/cases/testdata/v0/withkeyword/test-with-builtin-mock.yaml index 9e002905ca..b5fea4e1ba 100644 --- a/test/cases/testdata/v0/withkeyword/test-with-builtin-mock.yaml +++ b/test/cases/testdata/v0/withkeyword/test-with-builtin-mock.yaml @@ -450,3 +450,38 @@ cases: query: data.test.p = x want_result: - x: true + - modules: + - | + package test + p { + my_var := 1 + q with count as my_var + } + q { + r + } + r { + count([1,2,3]) == 1 + } + note: "withkeyword/builtin: indirect call, arity 1, replacement is local variable" + query: data.test.p = x + want_result: + - x: true + - modules: + - | + package test + f(1) = 2 + p { + my_var := 1 + q with f as my_var + } + q { + r + } + r { + f(1) == 1 + } + note: "withkeyword/function: indirect call, arity 1, replacement is local variable" + query: data.test.p = x + want_result: + - x: true