Skip to content

Commit

Permalink
Merge pull request #2133 from actiontech/fix-issue1256
Browse files Browse the repository at this point in the history
set getTableNameCreateTableStmtMap to sessionContext func
  • Loading branch information
ColdWaterLW authored Dec 21, 2023
2 parents 156dd80 + 5c89b15 commit 673b87d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
32 changes: 5 additions & 27 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2622,19 +2622,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
if stmt.From == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs)
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
default:
return nil, nil
Expand Down Expand Up @@ -2825,28 +2825,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context,
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte
// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
Expand Down Expand Up @@ -5388,7 +5366,7 @@ func checkWhereConditionUseIndex(ctx *session.Context, whereVisitor *util.WhereW
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
Expand Down Expand Up @@ -7583,7 +7561,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
// 如果SQL没有JOIN多表,则不需要审核
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
for tableName, createTableStmt := range tableNameCreateTableStmtMap {
tableIndexes[tableName] = createTableStmt.Constraints
Expand Down
22 changes: 22 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor {
func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) {
return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName))
}

func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

0 comments on commit 673b87d

Please sign in to comment.