Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Oct 25, 2024
1 parent c58abd9 commit f54cc03
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 36 deletions.
13 changes: 4 additions & 9 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (qb *queryBuilder) addTableExpr(
}

func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {
if _, toBeSkipped := qb.ctx.SkipPredicates[expr]; toBeSkipped {
if qb.ctx.ShouldSkip(expr) {
// This is a predicate that was added to the RHS of an ApplyJoin.
// The original predicate will be added, so we don't have to add this here
return
Expand Down Expand Up @@ -566,21 +566,16 @@ func buildProjection(op *Projection, qb *queryBuilder) error {
func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error {
predicates := slice.Map(op.JoinPredicates, func(jc JoinColumn) sqlparser.Expr {
// since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done
qb.ctx.SkipPredicates[jc.RHSExpr] = nil

qb.ctx.SkipJoinPredicates(jc.Original.Expr)
return jc.Original.Expr
})

pred := sqlparser.AndExpressions(predicates...)
err := buildQuery(op.LHS, qb)
if err != nil {
return err
}
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, pred := range op.JoinPredicates {
qb.ctx.SkipPredicates[pred.RHSExpr] = nil
}

qbR := &queryBuilder{ctx: qb.ctx}
err = buildQuery(op.RHS, qbR)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,14 @@ func createOpFromStmt(ctx *plancontext.PlanningContext, stmt sqlparser.Statement
newCtx.VerifyAllFKs = verifyAllFKs
newCtx.ParentFKToIgnore = fkToIgnore

return PlanQuery(newCtx, stmt)
query, err := PlanQuery(newCtx, stmt)
if err != nil {
return nil, err
}

ctx.KeepPredicateInfo(newCtx)

return query, err
}

func getOperatorFromTableExpr(ctx *plancontext.PlanningContext, tableExpr sqlparser.TableExpr, onlyTable bool) (ops.Operator, error) {
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ func BreakExpressionInLHSandRHS(
cursor.Replace(arg)
}, nil).(sqlparser.Expr)

if err != nil {
return JoinColumn{}, err
}
ctx.JoinPredicates[expr] = append(ctx.JoinPredicates[expr], rewrittenExpr)
ctx.AddJoinPredicates(expr, rewrittenExpr)
col.RHSExpr = rewrittenExpr
return
}
14 changes: 14 additions & 0 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ func (s *planTestSuite) TestForeignKeyPlanning() {
s.testFile("foreignkey_cases.json", vschemaWrapper, false)
}

// TestForeignKeyPlanning tests the planning of foreign keys in a managed mode by Vitess.
func (s *planTestSuite) TestForeignKeyPlanningOne() {
closer := oprewriters.EnableDebugPrinting()
defer closer()
vschema := loadSchema(s.T(), "vschemas/schema.json", true)
s.setFks(vschema)
vschemaWrapper := &vschemawrapper.VSchemaWrapper{
V: vschema,
TestBuilder: TestBuilder,
}

s.testFile("onecase.json", vschemaWrapper, false)
}

func (s *planTestSuite) setFks(vschema *vindexes.VSchema) {
if vschema.Keyspaces["sharded_fk_allow"] != nil {
// FK from multicol_tbl2 referencing multicol_tbl1 that is shard scoped.
Expand Down
94 changes: 86 additions & 8 deletions go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package plancontext
import (
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

Expand All @@ -27,12 +28,16 @@ type PlanningContext struct {
SemTable *semantics.SemTable
VSchema VSchema

// here we add all predicates that were created because of a join condition
// e.g. [FROM tblA JOIN tblB ON a.colA = b.colB] will be rewritten to [FROM tblB WHERE :a_colA = b.colB],
// if we assume that tblB is on the RHS of the join. This last predicate in the WHERE clause is added to the
// map below
JoinPredicates map[sqlparser.Expr][]sqlparser.Expr
SkipPredicates map[sqlparser.Expr]any
// joinPredicates maps each original join predicate (key) to a slice of
// variations of the RHS predicates (value). This map is used to handle
// different scenarios in join planning, where the RHS predicates are
// modified to accommodate dependencies
joinPredicates map[sqlparser.Expr][]sqlparser.Expr

// skipPredicates tracks predicates that should be skipped, typically when
// a join predicate is reverted to its original form during planning.
skipPredicates map[sqlparser.Expr]any

PlannerVersion querypb.ExecuteOptions_PlannerVersion

// If we during planning have turned this expression into an argument name,
Expand Down Expand Up @@ -79,8 +84,8 @@ func CreatePlanningContext(stmt sqlparser.Statement,
ReservedVars: reservedVars,
SemTable: semTable,
VSchema: vschema,
JoinPredicates: map[sqlparser.Expr][]sqlparser.Expr{},
SkipPredicates: map[sqlparser.Expr]any{},
joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{},
skipPredicates: map[sqlparser.Expr]any{},
PlannerVersion: version,
ReservedArguments: map[sqlparser.Expr]string{},
}, nil
Expand Down Expand Up @@ -116,3 +121,76 @@ func (ctx *PlanningContext) GetArgumentFor(expr sqlparser.Expr, f func() string)
ctx.ReservedArguments[expr] = bvName
return bvName
}

// ShouldSkip determines if a given expression should be ignored in the SQL output building.
// It checks against expressions that have been marked to be excluded from further processing.
func (ctx *PlanningContext) ShouldSkip(expr sqlparser.Expr) bool {
for k := range ctx.skipPredicates {
if ctx.SemTable.EqualsExpr(expr, k) {
return true
}
}
return false
}

// AddJoinPredicates associates additional RHS predicates with an existing join predicate.
// This is used to dynamically adjust the RHS predicates based on evolving join conditions.
func (ctx *PlanningContext) AddJoinPredicates(joinPred sqlparser.Expr, predicates ...sqlparser.Expr) {
fn := func(original sqlparser.Expr, rhsExprs []sqlparser.Expr) {
ctx.joinPredicates[original] = append(rhsExprs, predicates...)
}
if ctx.execOnJoinPredicateEqual(joinPred, fn) {
return
}

// we didn't find an existing entry
ctx.joinPredicates[joinPred] = predicates
}

// SkipJoinPredicates marks the predicates related to a specific join predicate as irrelevant
// for the current planning stage. This is used when a join has been pushed under a route and
// the original predicate will be used.
func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error {
fn := func(_ sqlparser.Expr, rhsExprs []sqlparser.Expr) {
ctx.skipThesePredicates(rhsExprs...)
}
if ctx.execOnJoinPredicateEqual(joinPred, fn) {
return nil
}
return vterrors.VT13001("predicate does not exist: " + sqlparser.String(joinPred))
}

// KeepPredicateInfo transfers join predicate information from another context.
// This is useful when nesting queries, ensuring consistent predicate handling across contexts.
func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) {
for k, v := range other.joinPredicates {
ctx.AddJoinPredicates(k, v...)
}
for expr := range other.skipPredicates {
ctx.skipThesePredicates(expr)
}
}

// skipThesePredicates is a utility function to exclude certain predicates from SQL building
func (ctx *PlanningContext) skipThesePredicates(preds ...sqlparser.Expr) {
outer:
for _, expr := range preds {
for k := range ctx.skipPredicates {
if ctx.SemTable.EqualsExpr(expr, k) {
// already skipped
continue outer
}
}
ctx.skipPredicates[expr] = nil
}
}

func (ctx *PlanningContext) execOnJoinPredicateEqual(joinPred sqlparser.Expr, fn func(original sqlparser.Expr, rhsExprs []sqlparser.Expr)) bool {
for key, values := range ctx.joinPredicates {
if ctx.SemTable.EqualsExpr(joinPred, key) {
fn(key, values)
return true
}
}
return false
}
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@
"Sharded": false
},
"FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where 1 != 1",
"Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where u_tbl9.col9 is null and (u_tbl8.col8) in ::fkc_vals limit 1 lock in share mode",
"Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where u_tbl9.col9 is null and (u_tbl8.col8) in ::fkc_vals and :u_tbl9_col9 = 'foo' limit 1 lock in share mode",
"Table": "u_tbl8, u_tbl9"
},
{
Expand Down Expand Up @@ -1208,7 +1208,7 @@
"Sharded": false
},
"FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where 1 != 1",
"Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode",
"Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals and :u_tbl3_col3 = 'foo' limit 1 lock in share mode",
"Table": "u_tbl3, u_tbl4"
},
{
Expand Down Expand Up @@ -1297,7 +1297,7 @@
"Sharded": false
},
"FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where 1 != 1",
"Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode",
"Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals and :u_tbl3_col3 = :v1 limit 1 lock in share mode",
"Table": "u_tbl3, u_tbl4"
},
{
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/from_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@
"Sharded": true
},
"FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t, user_extra where 1 != 1",
"Query": "select t.id from (select id from `user` where id = 5 and id = :user_extra_user_id) as t, user_extra where t.id = user_extra.user_id",
"Query": "select t.id from (select id from `user` where id = 5) as t, user_extra where t.id = user_extra.user_id",
"Table": "`user`, user_extra",
"Values": [
"INT64(5)"
Expand Down Expand Up @@ -1736,7 +1736,7 @@
"Sharded": true
},
"FieldQuery": "select t.id from (select id, textcol1 as baz from `user` as route1 where 1 != 1) as t, (select id, textcol1 + textcol1 as baz from `user` where 1 != 1) as s where 1 != 1",
"Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3' and id = :t_id) as s where t.id = s.id",
"Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3') as s where t.id = s.id",
"Table": "`user`"
},
"TablesUsed": [
Expand Down
75 changes: 72 additions & 3 deletions go/vt/vtgate/planbuilder/testdata/onecase.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,78 @@
[
{
"comment": "Add your test case here for debugging and run go test -run=One.",
"query": "",
"comment": "Update in a table with shard-scoped foreign keys with cascade that requires a validation of a different parent foreign key",
"query": "update u_tbl6 set col6 = 'foo'",
"plan": {

"QueryType": "UPDATE",
"Original": "update u_tbl6 set col6 = 'foo'",
"Instructions": {
"OperatorType": "FkCascade",
"Inputs": [
{
"InputName": "Selection",
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "unsharded_fk_allow",
"Sharded": false
},
"FieldQuery": "select u_tbl6.col6 from u_tbl6 where 1 != 1",
"Query": "select u_tbl6.col6 from u_tbl6 for update",
"Table": "u_tbl6"
},
{
"InputName": "CascadeChild-1",
"OperatorType": "FKVerify",
"BvName": "fkc_vals",
"Cols": [
0
],
"Inputs": [
{
"InputName": "VerifyParent-1",
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "unsharded_fk_allow",
"Sharded": false
},
"FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1",
"Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait",
"Table": "u_tbl8, u_tbl9"
},
{
"InputName": "PostVerify",
"OperatorType": "Update",
"Variant": "Unsharded",
"Keyspace": {
"Name": "unsharded_fk_allow",
"Sharded": false
},
"TargetTabletType": "PRIMARY",
"Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl8 set col8 = 'foo' where (col8) in ::fkc_vals",
"Table": "u_tbl8"
}
]
},
{
"InputName": "Parent",
"OperatorType": "Update",
"Variant": "Unsharded",
"Keyspace": {
"Name": "unsharded_fk_allow",
"Sharded": false
},
"TargetTabletType": "PRIMARY",
"Query": "update u_tbl6 set col6 = 'foo'",
"Table": "u_tbl6"
}
]
},
"TablesUsed": [
"unsharded_fk_allow.u_tbl6",
"unsharded_fk_allow.u_tbl8",
"unsharded_fk_allow.u_tbl9"
]
}
}
]
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/postprocess_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@
"Sharded": true
},
"FieldQuery": "select * from (select user_id from user_extra where 1 != 1) as eu, `user` as u where 1 != 1",
"Query": "select * from (select user_id from user_extra where user_id = 5 and user_id = :u_id) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc",
"Query": "select * from (select user_id from user_extra where user_id = 5) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc",
"Table": "`user`, user_extra",
"Values": [
"INT64(5)"
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/reference_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@
"Sharded": true
},
"FieldQuery": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where 1 != 1",
"Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where sr.foo = :ue_foo and rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo",
"Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo",
"Table": "`user`, ref, ref_with_source, user_extra"
},
"TablesUsed": [
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/select_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -4110,7 +4110,7 @@
"Sharded": true
},
"FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1",
"Query": "select music.id from (select id from music where music.user_id = 5 and id = :music_id) as other, music where other.id = music.id",
"Query": "select music.id from (select id from music where music.user_id = 5) as other, music where other.id = music.id",
"Table": "music",
"Values": [
"INT64(5)"
Expand All @@ -4136,7 +4136,7 @@
"Sharded": true
},
"FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1",
"Query": "select music.id from (select id from music where music.user_id in ::__vals and id = :music_id) as other, music where other.id = music.id",
"Query": "select music.id from (select id from music where music.user_id in ::__vals) as other, music where other.id = music.id",
"Table": "music",
"Values": [
"(INT64(5), INT64(6), INT64(7))"
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/tpcc_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"Sharded": true
},
"FieldQuery": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where 1 != 1",
"Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_w_id = :w_id and c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id",
"Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id",
"Table": "customer1, warehouse1",
"Values": [
"INT64(1)"
Expand Down Expand Up @@ -947,7 +947,7 @@
"Sharded": true
},
"FieldQuery": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where 1 != 1 group by o_c_id, o_d_id, o_w_id) as t, orders1 as o where 1 != 1",
"Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 and o_w_id = :o_o_w_id and o_d_id = :o_o_d_id and o_c_id = :o_o_c_id group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1",
"Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1",
"Table": "orders1",
"Values": [
"INT64(1)"
Expand Down

0 comments on commit f54cc03

Please sign in to comment.