From 5c89b15190537ad96733dff2a82df817d861424a Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Thu, 7 Dec 2023 06:48:35 +0000 Subject: [PATCH] set getTableNameCreateTableStmtMap to sessionContext func --- sqle/driver/mysql/rule/rule.go | 32 +++++----------------------- sqle/driver/mysql/session/context.go | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index 552ec90815..a7896578d9 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -2622,19 +2622,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea if stmt.From == nil { return nil, nil } - tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs) + tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs) onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs) case *ast.UpdateStmt: if stmt.TableRefs == nil { return nil, nil } - tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs) + tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs) onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs) case *ast.DeleteStmt: if stmt.TableRefs == nil { return nil, nil } - tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs) + tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs) onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs) default: return nil, nil @@ -2825,28 +2825,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context, return tableNameCreateTableStmtMap } -func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt { - tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt) - tableSources := util.GetTableSources(joinStmt) - for _, tableSource := range tableSources { - if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok { - tableName := tableNameStmt.Name.L - if tableSource.AsName.L != "" { - // 如果使用别名,则需要用别名引用 - tableName = tableSource.AsName.L - } - - createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt) - if err != nil || !exist { - continue - } - // TODO: 跨库的 JOIN 无法区分 - tableNameCreateTableStmtMap[tableName] = createTableStmt - } - } - return tableNameCreateTableStmtMap -} - func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) { var leftType, rightType byte // onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名 @@ -5388,7 +5366,7 @@ func checkWhereCondationUseIndex(ctx *session.Context, whereVisitor *util.WhereW continue } - tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef) + tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef) util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) { switch x := expr.(type) { case *ast.ColumnNameExpr: @@ -7585,7 +7563,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) { // 如果SQL没有JOIN多表,则不需要审核 return true, fmt.Errorf("sql have not join node") } - tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode) + tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode) tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap)) for tableName, createTableStmt := range tableNameCreateTableStmtMap { tableIndexes[tableName] = createTableStmt.Constraints diff --git a/sqle/driver/mysql/session/context.go b/sqle/driver/mysql/session/context.go index 40a9286b2d..bbea0418e1 100644 --- a/sqle/driver/mysql/session/context.go +++ b/sqle/driver/mysql/session/context.go @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor { func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) { return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName)) } + +func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt { + tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt) + tableSources := util.GetTableSources(joinStmt) + for _, tableSource := range tableSources { + if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok { + tableName := tableNameStmt.Name.L + if tableSource.AsName.L != "" { + // 如果使用别名,则需要用别名引用 + tableName = tableSource.AsName.L + } + + createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt) + if err != nil || !exist { + continue + } + // TODO: 跨库的 JOIN 无法区分 + tableNameCreateTableStmtMap[tableName] = createTableStmt + } + } + return tableNameCreateTableStmtMap +}