From ab324a250a53adfc4f3b1d4c947d46f65ab31cd6 Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Mon, 6 May 2024 18:32:55 +0800 Subject: [PATCH 1/2] delete && insert generate rollback sql func add binary data parse --- sqle/driver/mysql/rollback.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 6891044322..7be00dde31 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -2,6 +2,7 @@ package mysql import ( "database/sql" + "encoding/hex" "fmt" "strconv" "strings" @@ -11,6 +12,7 @@ import ( "github.com/pingcap/parser/ast" _model "github.com/pingcap/parser/model" + parserMysql "github.com/pingcap/parser/mysql" ) func (i *MysqlDriverImpl) GenerateRollbackSql(node ast.Node) (string, string, error) { @@ -449,6 +451,12 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin return rollbackSql, "", nil } +// 将二进制字段转化为十六进制字段 +func getHexStrFromBytesStr(byteStr string) string { + encode := []byte(byteStr) + return hex.EncodeToString(encode) +} + // generateDeleteRollbackSql generate insert SQL for delete. func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) { // not support multi-table syntax @@ -497,8 +505,10 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin values := []string{} columnsName := []string{} + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { if len(record) != len(columnsName) { @@ -508,7 +518,13 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin for _, name := range columnsName { v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } vs = append(vs, v) } @@ -583,8 +599,10 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin } columnsName := []string{} rollbackSql := "" + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { if len(record) != len(columnsName) { @@ -610,7 +628,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin name := col.Name.Name.O v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } if colChanged { From 874a9a051bd674626c1d061f2a97613514a24d6e Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Tue, 7 May 2024 13:53:43 +0800 Subject: [PATCH 2/2] remove extra var --- sqle/driver/mysql/rollback.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 7be00dde31..63a1b2b4b8 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -597,15 +597,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin if err != nil { return "", "", err } - columnsName := []string{} rollbackSql := "" colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { - columnsName = append(columnsName, col.Name.Name.String()) colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { - if len(record) != len(columnsName) { + if len(record) != len(colNameDefMap) { return "", "", nil } where := []string{}