Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix with replacement with local-var values #6996

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions internal/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (p *Planner) buildFunctrie() error {
return nil
}

func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
func (p *Planner) planRules(rules []*ast.Rule, extras ...*ast.Term) (string, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a shortcut -- I didn't want to adjust all call sites, but only the one where I'd like to have extra args injected...

// We know the rules with closer to the root (shorter static path) are ordered first.
pathRef := rules[0].Ref()

Expand Down Expand Up @@ -223,10 +223,9 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
}

// Initialize parameters for functions.
for i := 0; i < len(rules[0].Head.Args); i++ {
for i := 0; i < len(rules[0].Head.Args)+len(extras); i++ {
fn.Params = append(fn.Params, p.newLocal())
}

params := fn.Params[2:]

// Initialize return value for partial set/object rules. Complete document
Expand Down Expand Up @@ -322,7 +321,10 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {

// Complete and partial rules are treated as special cases of
// functions. If there are no args, the first step is a no-op.
err := p.planFuncParams(params, rule.Head.Args, 0, func() error {

// If there are local variables to carry along for with-replacements, add them here
args := append(rule.Head.Args.Copy(), extras...)
err := p.planFuncParams(params, args, 0, func() error {

// Run planner on the rule body.
return p.planQuery(rule.Body, 0, func() error {
Expand Down Expand Up @@ -690,6 +692,11 @@ func (p *Planner) planWith(e *ast.Expr, iter planiter) error {
continue // not a mock
}

// NOTE(sr): If w.Target is a function (built-in or user-defined) and w.Value is a variable,
// recording it here will make all the functions (rules or user-defined functions) use an
// extra argument to pass along the locally-defined variable. That way, if need in a nested
// call stack, w.Target needs to be replaced, we'll have the variable available in local
// scope, too.
mocks[w.Target.String()] = w.Value
}

Expand Down Expand Up @@ -990,6 +997,8 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
operands := e.Operands()
op := e.Operator()

extras := p.mocks.ExtraVars()

if replacement := p.mocks.Lookup(operator); replacement != nil {
switch r := replacement.Value.(type) {
case ast.Ref:
Expand Down Expand Up @@ -1018,7 +1027,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
}
}

// replacement is a value, or ref
// replacement is a value, or var
if bi, ok := p.decls[operator]; ok {
return p.planExprCallValue(replacement, len(bi.Decl.FuncArgs().Args), operands, iter)
}
Expand All @@ -1028,8 +1037,9 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
return fmt.Errorf("illegal replacement of operator %q by %v", operator, replacement) // should be unreachable
}

// plan g(...) with extra arguments (filled in by planExprCallFunc)
if node := p.rules.Lookup(op); node != nil {
name, err = p.planRules(node.Rules())
name, err = p.planRules(node.Rules(), extras...)
if err != nil {
return err
}
Expand All @@ -1046,14 +1056,15 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
}

if len(operands) < arity || len(operands) > arity+1 {
return fmt.Errorf("illegal call: wrong number of operands: got %v, want %v)", len(operands), arity)
return fmt.Errorf("illegal call: wrong number of operands: got %v, want %v", len(operands), arity)
}

if relation {
return p.planExprCallRelation(name, arity, operands, args, iter)
}

return p.planExprCallFunc(name, arity, void, operands, args, iter)
// inject extras into call to g(...) which may be calling f(...) indirectly
return p.planExprCallFunc(name, arity+len(extras), void, append(operands, extras...), args, iter)
}
}

Expand Down
13 changes: 13 additions & 0 deletions internal/planner/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,16 @@ func (s *functionMocksStack) Lookup(f string) *ast.Term {
}
return nil
}

func (s *functionMocksStack) ExtraVars() []*ast.Term {
exs := make([]*ast.Term, 0)
current := *s.stack[len(s.stack)-1]
for i := len(current) - 1; i >= 0; i-- {
for _, t := range current[i] {
if _, ok := t.Value.(ast.Var); ok {
exs = append(exs, t)
}
}
}
return exs
}
35 changes: 35 additions & 0 deletions test/cases/testdata/v0/withkeyword/test-with-builtin-mock.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,38 @@ cases:
query: data.test.p = x
want_result:
- x: true
- modules:
- |
package test
p {
my_var := 1
q with count as my_var
}
q {
r
}
r {
count([1,2,3]) == 1
}
note: "withkeyword/builtin: indirect call, arity 1, replacement is local variable"
query: data.test.p = x
want_result:
- x: true
- modules:
- |
package test
f(1) = 2
p {
my_var := 1
q with f as my_var
}
q {
r
}
r {
f(1) == 1
}
note: "withkeyword/function: indirect call, arity 1, replacement is local variable"
query: data.test.p = x
want_result:
- x: true
Loading