Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set getTableNameCreateTableStmtMap to sessionContext func #2133

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 5 additions & 27 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2622,19 +2622,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
if stmt.From == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs)
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
default:
return nil, nil
Expand Down Expand Up @@ -2825,28 +2825,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context,
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte
// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
Expand Down Expand Up @@ -5388,7 +5366,7 @@ func checkWhereCondationUseIndex(ctx *session.Context, whereVisitor *util.WhereW
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
Expand Down Expand Up @@ -7585,7 +7563,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
// 如果SQL没有JOIN多表,则不需要审核
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
for tableName, createTableStmt := range tableNameCreateTableStmtMap {
tableIndexes[tableName] = createTableStmt.Constraints
Expand Down
22 changes: 22 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor {
func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) {
return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName))
}

func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}
Loading