Skip to content

Commit

Permalink
Respect Straight Join in Vitess query planning (#15528)
Browse files Browse the repository at this point in the history
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
GuptaManan100 and systay authored Mar 21, 2024
1 parent bb049b1 commit 3d313b9
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 109 deletions.
20 changes: 0 additions & 20 deletions go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,6 @@ func TestSubQueriesOnOuterJoinOnCondition(t *testing.T) {
}
}

func TestPlannerWarning(t *testing.T) {
mcmp, closer := start(t)
defer closer()

// straight_join query
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)

// execute same query again.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)

// random query to reset the warning.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1`)

// execute same query again.
_ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`)
utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`)
}

func TestHashJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()
Expand Down
27 changes: 27 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,30 @@ func TestAlterTableWithView(t *testing.T) {

mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`)
}

// TestStraightJoin tests that Vitess respects the ordering of join in a STRAIGHT JOIN query.
func TestStraightJoin(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,0,10), (2,10,10), (3,4,20), (4,30,20), (5,40,10)")
mcmp.Exec(`insert into t1(id1, id2) values (10, 11), (20, 13)`)

mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col",
`[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`,
)
// Verify that in a normal join query, vitess joins tbl with t1.
res, err := mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", 100, false)
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "tbl_t1")

// Test the same query with a straight join
mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col",
`[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`,
)
// Verify that in a straight join query, vitess joins t1 with tbl.
res, err = mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", 100, false)
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl")
}
20 changes: 20 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,26 @@ func (node DatabaseOptionType) ToString() string {
}
}

// IsCommutative returns whether the join type supports rearranging or not.
func (joinType JoinType) IsCommutative() bool {
switch joinType {
case StraightJoinType, LeftJoinType, RightJoinType, NaturalLeftJoinType, NaturalRightJoinType:
return false
default:
return true
}
}

// IsInner returns whether the join type is an inner join or not.
func (joinType JoinType) IsInner() bool {
switch joinType {
case StraightJoinType, NaturalJoinType, NormalJoinType:
return true
default:
return false
}
}

// ToString returns the type as a string
func (ty LockType) ToString() string {
switch ty {
Expand Down
10 changes: 2 additions & 8 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3343,18 +3343,12 @@ func TestGen4SelectStraightJoin(t *testing.T) {
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{
{
Sql: "select u.id from `user` as u, user2 as u2 where u.id = u2.id",
Sql: "select u.id from `user` as u straight_join user2 as u2 on u.id = u2.id",
BindVariables: map[string]*querypb.BindVariable{},
},
}
wantWarnings := []*querypb.QueryWarning{
{
Code: 1235,
Message: "straight join is converted to normal join",
},
}
utils.MustMatch(t, wantQueries, sbc1.Queries)
utils.MustMatch(t, wantWarnings, session.Warnings)
require.Empty(t, session.Warnings)
}

func TestGen4MultiColumnVindexEqual(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func transformApplyJoinPlan(ctx *plancontext.PlanningContext, n *operators.Apply
return nil, err
}
opCode := engine.InnerJoin
if n.LeftJoin {
if !n.JoinType.IsInner() {
opCode = engine.LeftJoin
}

Expand Down
32 changes: 11 additions & 21 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ var _ FromStatement = (*sqlparser.Select)(nil)
var _ FromStatement = (*sqlparser.Update)(nil)
var _ FromStatement = (*sqlparser.Delete)(nil)

func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) {
func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

Expand All @@ -222,24 +222,18 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
}

newFromClause := append(stmt.GetFrom(), otherStmt.GetFrom()...)
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)
qb.addPredicate(onCondition)
}

func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) {
stmt := qb.stmt.(FromStatement)
otherStmt := other.stmt.(FromStatement)

if sel, isSel := stmt.(*sqlparser.Select); isSel {
otherSel := otherStmt.(*sqlparser.Select)
sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
var newFromClause []sqlparser.TableExpr
switch joinType {
case sqlparser.NormalJoinType:
newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...)
qb.addPredicate(onCondition)
default:
newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)}
}

newFromClause := []sqlparser.TableExpr{buildOuterJoin(stmt, otherStmt, onCondition)}
stmt.SetFrom(newFromClause)
qb.mergeWhereClauses(stmt, otherStmt)
}

func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
Expand All @@ -254,7 +248,7 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) {
}
}

func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr) *sqlparser.JoinTableExpr {
func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr {
var lhs sqlparser.TableExpr
fromClause := stmt.GetFrom()
if len(fromClause) == 1 {
Expand All @@ -273,7 +267,7 @@ func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sql
return &sqlparser.JoinTableExpr{
LeftExpr: lhs,
RightExpr: rhs,
Join: sqlparser.LeftJoinType,
Join: joinType,
Condition: &sqlparser.JoinCondition{
On: onCondition,
},
Expand Down Expand Up @@ -539,11 +533,7 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) {

qbR := &queryBuilder{ctx: qb.ctx}
buildQuery(op.RHS, qbR)
if op.LeftJoin {
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, pred)
}
qb.joinWith(qbR, pred, op.JoinType)
}

func buildUnion(op *Union, qb *queryBuilder) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
rhs := createJoinPusher(rootAggr, join.RHS)

columns := &applyJoinColumns{}
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, join.LeftJoin, columns, lhs, rhs)
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, !join.JoinType.IsInner(), columns, lhs, rhs)
join.JoinColumns = columns
if err != nil {
// if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query
Expand Down
14 changes: 10 additions & 4 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type (
ApplyJoin struct {
LHS, RHS Operator

// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType
// LeftJoin will be true in the case of an outer join
LeftJoin bool

Expand Down Expand Up @@ -82,12 +85,12 @@ type (
}
)

func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, joinType sqlparser.JoinType) *ApplyJoin {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
LeftJoin: leftOuterJoin,
JoinType: joinType,
JoinColumns: &applyJoinColumns{},
JoinPredicates: &applyJoinColumns{},
}
Expand Down Expand Up @@ -139,11 +142,14 @@ func (aj *ApplyJoin) SetRHS(operator Operator) {
}

func (aj *ApplyJoin) MakeInner() {
aj.LeftJoin = false
if aj.IsInner() {
return
}
aj.JoinType = sqlparser.NormalJoinType
}

func (aj *ApplyJoin) IsInner() bool {
return !aj.LeftJoin
return aj.JoinType.IsInner()
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s
case sqlparser.NormalJoinType:
return createInnerJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.LeftJoinType, sqlparser.RightJoinType:
return createOuterJoin(tableExpr, lhs, rhs)
return createLeftOuterJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.StraightJoinType:
return createStraightJoin(ctx, tableExpr, lhs, rhs)
default:
panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString()))
}
Expand Down
83 changes: 58 additions & 25 deletions go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
type Join struct {
LHS, RHS Operator
Predicate sqlparser.Expr
LeftJoin bool
// JoinType is permitted to store only 3 of the possible values
// NormalJoinType, StraightJoinType and LeftJoinType.
JoinType sqlparser.JoinType

noColumns
}
Expand All @@ -42,7 +44,7 @@ func (j *Join) Clone(inputs []Operator) Operator {
LHS: inputs[0],
RHS: inputs[1],
Predicate: j.Predicate,
LeftJoin: j.LeftJoin,
JoinType: j.JoinType,
}
}

Expand All @@ -61,8 +63,8 @@ func (j *Join) SetInputs(ops []Operator) {
}

func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) {
if j.LeftJoin {
// we can't merge outer joins into a single QG
if !j.JoinType.IsCommutative() {
// if we can't move tables around, we can't merge these inputs
return j, NoRewrite
}

Expand All @@ -83,38 +85,52 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult
return newOp, Rewrote("merge querygraphs into a single one")
}

func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
if tableExpr.Join == sqlparser.RightJoinType {
func createStraightJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
// for inner joins we can treat the predicates as filters on top of the join
joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join}

return addJoinPredicates(ctx, join.Condition.On, joinOp)
}

func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
// first we switch sides, so we always deal with left outer joins
switch join.Join {
case sqlparser.RightJoinType:
lhs, rhs = rhs, lhs
join.Join = sqlparser.LeftJoinType
case sqlparser.NaturalRightJoinType:
lhs, rhs = rhs, lhs
join.Join = sqlparser.NaturalLeftJoinType
}
subq, _ := getSubQuery(tableExpr.Condition.On)

joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join}

// for outer joins we have to be careful with the predicates we use
var op Operator
subq, _ := getSubQuery(join.Condition.On)
if subq != nil {
panic(vterrors.VT12001("subquery in outer join predicate"))
}
predicate := tableExpr.Condition.On
predicate := join.Condition.On
sqlparser.RemoveKeyspaceInCol(predicate)
return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate}
}
joinOp.Predicate = predicate
op = joinOp

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
lqg, lok := LHS.(*QueryGraph)
rqg, rok := RHS.(*QueryGraph)
if lok && rok {
op := &QueryGraph{
Tables: append(lqg.Tables, rqg.Tables...),
innerJoins: append(lqg.innerJoins, rqg.innerJoins...),
NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps),
}
return op
}
return &Join{LHS: LHS, RHS: RHS}
return op
}

func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
op := createJoin(ctx, lhs, rhs)
return addJoinPredicates(ctx, tableExpr.Condition.On, op)
}

func addJoinPredicates(
ctx *plancontext.PlanningContext,
joinPredicate sqlparser.Expr,
op Operator,
) Operator {
sqc := &SubQueryBuilder{}
outerID := TableID(op)
joinPredicate := tableExpr.Condition.On
sqlparser.RemoveKeyspaceInCol(joinPredicate)
exprs := sqlparser.SplitAndExpression(nil, joinPredicate)
for _, pred := range exprs {
Expand All @@ -127,6 +143,20 @@ func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.Join
return sqc.getRootOperator(op, nil)
}

func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator {
lqg, lok := LHS.(*QueryGraph)
rqg, rok := RHS.(*QueryGraph)
if lok && rok {
op := &QueryGraph{
Tables: append(lqg.Tables, rqg.Tables...),
innerJoins: append(lqg.innerJoins, rqg.innerJoins...),
NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps),
}
return op
}
return &Join{LHS: LHS, RHS: RHS}
}

func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator {
return AddPredicate(ctx, j, expr, false, newFilterSinglePredicate)
}
Expand All @@ -150,11 +180,14 @@ func (j *Join) SetRHS(operator Operator) {
}

func (j *Join) MakeInner() {
j.LeftJoin = false
if j.IsInner() {
return
}
j.JoinType = sqlparser.NormalJoinType
}

func (j *Join) IsInner() bool {
return !j.LeftJoin
return j.JoinType.IsInner()
}

func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
Expand Down
Loading

0 comments on commit 3d313b9

Please sign in to comment.