diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 6891044322..63a1b2b4b8 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -2,6 +2,7 @@ package mysql import ( "database/sql" + "encoding/hex" "fmt" "strconv" "strings" @@ -11,6 +12,7 @@ import ( "github.com/pingcap/parser/ast" _model "github.com/pingcap/parser/model" + parserMysql "github.com/pingcap/parser/mysql" ) func (i *MysqlDriverImpl) GenerateRollbackSql(node ast.Node) (string, string, error) { @@ -449,6 +451,12 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin return rollbackSql, "", nil } +// 将二进制字段转化为十六进制字段 +func getHexStrFromBytesStr(byteStr string) string { + encode := []byte(byteStr) + return hex.EncodeToString(encode) +} + // generateDeleteRollbackSql generate insert SQL for delete. func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) { // not support multi-table syntax @@ -497,8 +505,10 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin values := []string{} columnsName := []string{} + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { if len(record) != len(columnsName) { @@ -508,7 +518,13 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin for _, name := range columnsName { v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } vs = append(vs, v) } @@ -581,13 +597,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin if err != nil { return "", "", err } - columnsName := []string{} rollbackSql := "" + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { - columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { - if len(record) != len(columnsName) { + if len(record) != len(colNameDefMap) { return "", "", nil } where := []string{} @@ -610,7 +626,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin name := col.Name.Name.O v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } if colChanged {