diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 40e9a953f57..9dcd239f9eb 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -20,11 +20,26 @@ import ( "vitess.io/vitess/go/vt/log" ) +// This is the number of OR expressions in a predicate that will disable the CNF +// rewrite because we don't want to send large queries to MySQL +const CNFOrLimit = 5 + // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN // Note: In order to re-plan, we need to empty the accumulated metadata in the AST, // so ColName.Metadata will be nil:ed out as part of this rewrite func RewritePredicate(ast SQLNode) SQLNode { + count := 0 + _ = Walk(func(node SQLNode) (bool, error) { + if _, isExpr := node.(*OrExpr); isExpr { + count++ + } + + return true, nil + }, ast) + + allowCNF := count < CNFOrLimit + for { printExpr(ast) exprChanged := false @@ -37,7 +52,7 @@ func RewritePredicate(ast SQLNode) SQLNode { return true } - rewritten, state := simplifyExpression(e) + rewritten, state := simplifyExpression(e, allowCNF) if ch, isChange := state.(changed); isChange { printRule(ch.rule, ch.exprMatched) exprChanged = true @@ -52,12 +67,12 @@ func RewritePredicate(ast SQLNode) SQLNode { } } -func simplifyExpression(expr Expr) (Expr, rewriteState) { +func simplifyExpression(expr Expr, allowCNF bool) (Expr, rewriteState) { switch expr := expr.(type) { case *NotExpr: return simplifyNot(expr) case *OrExpr: - return simplifyOr(expr) + return simplifyOr(expr, allowCNF) case *XorExpr: return simplifyXor(expr) case *AndExpr: @@ -113,14 +128,14 @@ func ExtractINFromOR(expr *OrExpr) []Expr { return uniquefy(ins) } -func simplifyOr(expr *OrExpr) (Expr, rewriteState) { +func simplifyOr(expr *OrExpr, allowCNF bool) (Expr, rewriteState) { or := expr // first we search for ANDs and see how they can be simplified land, lok := or.Left.(*AndExpr) rand, rok := or.Right.(*AndExpr) - switch { - case lok && rok: + + if lok && rok { // (<> AND <>) OR (<> AND <>) var a, b, c Expr var change changed @@ -128,40 +143,51 @@ func simplifyOr(expr *OrExpr) (Expr, rewriteState) { case Equals.Expr(land.Left, rand.Left): change = newChange("(A and B) or (A and C) => A AND (B OR C)", f(expr)) a, b, c = land.Left, land.Right, rand.Right + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change case Equals.Expr(land.Left, rand.Right): change = newChange("(A and B) or (C and A) => A AND (B OR C)", f(expr)) a, b, c = land.Left, land.Right, rand.Left + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change case Equals.Expr(land.Right, rand.Left): change = newChange("(B and A) or (A and C) => A AND (B OR C)", f(expr)) a, b, c = land.Right, land.Left, rand.Right + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change case Equals.Expr(land.Right, rand.Right): change = newChange("(B and A) or (C and A) => A AND (B OR C)", f(expr)) a, b, c = land.Right, land.Left, rand.Left - default: - return expr, noChange{} + return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change } - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, change - case lok: - // (<> AND <>) OR <> + } + + // (<> AND <>) OR <> + if lok { // Simplification if Equals.Expr(or.Right, land.Left) || Equals.Expr(or.Right, land.Right) { return or.Right, newChange("(A AND B) OR A => A", f(expr)) } - // Distribution Law - return &AndExpr{Left: &OrExpr{Left: land.Left, Right: or.Right}, Right: &OrExpr{Left: land.Right, Right: or.Right}}, - newChange("(A AND B) OR C => (A OR C) AND (B OR C)", f(expr)) - case rok: - // <> OR (<> AND <>) + + if allowCNF { + // Distribution Law + return &AndExpr{Left: &OrExpr{Left: land.Left, Right: or.Right}, Right: &OrExpr{Left: land.Right, Right: or.Right}}, + newChange("(A AND B) OR C => (A OR C) AND (B OR C)", f(expr)) + } + } + + // <> OR (<> AND <>) + if rok { // Simplification if Equals.Expr(or.Left, rand.Left) || Equals.Expr(or.Left, rand.Right) { return or.Left, newChange("A OR (A AND B) => A", f(expr)) } - // Distribution Law - return &AndExpr{ - Left: &OrExpr{Left: or.Left, Right: rand.Left}, - Right: &OrExpr{Left: or.Left, Right: rand.Right}, - }, - newChange("C OR (A AND B) => (C OR A) AND (C OR B)", f(expr)) + + if allowCNF { + // Distribution Law + return &AndExpr{ + Left: &OrExpr{Left: or.Left, Right: rand.Left}, + Right: &OrExpr{Left: or.Left, Right: rand.Right}, + }, + newChange("C OR (A AND B) => (C OR A) AND (C OR B)", f(expr)) + } } // next, we want to try to turn multiple ORs into an IN when possible @@ -257,7 +283,6 @@ func simplifyAnd(expr *AndExpr) (Expr, rewriteState) { and := expr if or, ok := and.Left.(*OrExpr); ok { // Simplification - if Equals.Expr(or.Left, and.Right) { return and.Right, newChange("(A OR B) AND A => A", f(expr)) } diff --git a/go/vt/sqlparser/predicate_rewriting_test.go b/go/vt/sqlparser/predicate_rewriting_test.go index 34e23597894..fba3d2f01dd 100644 --- a/go/vt/sqlparser/predicate_rewriting_test.go +++ b/go/vt/sqlparser/predicate_rewriting_test.go @@ -91,7 +91,7 @@ func TestSimplifyExpression(in *testing.T) { expr, err := ParseExpr(tc.in) require.NoError(t, err) - expr, didRewrite := simplifyExpression(expr) + expr, didRewrite := simplifyExpression(expr, true) assert.True(t, didRewrite.changed()) assert.Equal(t, tc.expected, String(expr)) }) @@ -129,6 +129,17 @@ func TestRewritePredicate(in *testing.T) { }, { in: "A and (B or A)", expected: "A", + }, { + in: "(a = 1 and b = 41) or (a = 2 and b = 42)", + // this might look weird, but it allows the planner to either a or b in a vindex operation + expected: "a in (1, 2) and (a = 1 or b = 42) and ((b = 41 or a = 2) and b in (41, 42))", + }, { + in: "(a = 1 and b = 41) or (a = 2 and b = 42) or (a = 3 and b = 43)", + expected: "a in (1, 2, 3) and (a in (1, 2) or b = 43) and ((a = 1 or b = 42 or a = 3) and (a = 1 or b = 42 or b = 43)) and ((b = 41 or a = 2 or a = 3) and (b = 41 or a = 2 or b = 43) and ((b in (41, 42) or a = 3) and b in (41, 42, 43)))", + }, { + // this has too many OR expressions in it, so we don't even try the CNF rewriting + in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", + expected: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", }} for _, tc := range tests { @@ -164,6 +175,9 @@ func TestExtractINFromOR(in *testing.T) { }, { in: "(a in (1, 5) and B or C and a in (5, 7))", expected: "a in (1, 5, 7)", + }, { + in: "(a = 5 and b = 1 or b = 2 and a = 6 or b = 3 and a = 4)", + expected: "", }} for _, tc := range tests { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 10440fc4af5..43df08b0f9c 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -1273,7 +1273,7 @@ func TestSelectINFromOR(t *testing.T) { _, err := executorExec(ctx, executor, session, "select 1 from user where id = 1 and name = 'apa' or id = 2 and name = 'toto'", nil) require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select 1 from `user` where id = 1 and `name` = 'apa' or id = 2 and `name` = 'toto'", + Sql: "select 1 from `user` where id in ::__vals and (id = 1 or `name` = 'toto') and (`name` = 'apa' or id = 2) and `name` in ('apa', 'toto')", BindVariables: map[string]*querypb.BindVariable{ "__vals": sqltypes.TestBindVariable([]any{int64(1), int64(2)}), }, diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index a3753375292..3af60651eea 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -4110,7 +4110,7 @@ "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where id = 5 and `name` = 'foo' or id = 12 and `name` = 'bar'", + "Query": "select id from `user` where id in ::__vals and (id = 5 or `name` = 'bar') and (`name` = 'foo' or id = 12) and `name` in ('foo', 'bar')", "Table": "`user`", "Values": [ "(INT64(5), INT64(12))"