diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 6891044322..63a1b2b4b8 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -2,6 +2,7 @@ package mysql import ( "database/sql" + "encoding/hex" "fmt" "strconv" "strings" @@ -11,6 +12,7 @@ import ( "github.com/pingcap/parser/ast" _model "github.com/pingcap/parser/model" + parserMysql "github.com/pingcap/parser/mysql" ) func (i *MysqlDriverImpl) GenerateRollbackSql(node ast.Node) (string, string, error) { @@ -449,6 +451,12 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin return rollbackSql, "", nil } +// 将二进制字段转化为十六进制字段 +func getHexStrFromBytesStr(byteStr string) string { + encode := []byte(byteStr) + return hex.EncodeToString(encode) +} + // generateDeleteRollbackSql generate insert SQL for delete. func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) { // not support multi-table syntax @@ -497,8 +505,10 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin values := []string{} columnsName := []string{} + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { if len(record) != len(columnsName) { @@ -508,7 +518,13 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin for _, name := range columnsName { v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } vs = append(vs, v) } @@ -581,13 +597,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin if err != nil { return "", "", err } - columnsName := []string{} rollbackSql := "" + colNameDefMap := make(map[string]*ast.ColumnDef) for _, col := range createTableStmt.Cols { - columnsName = append(columnsName, col.Name.Name.String()) + colNameDefMap[col.Name.Name.String()] = col } for _, record := range records { - if len(record) != len(columnsName) { + if len(record) != len(colNameDefMap) { return "", "", nil } where := []string{} @@ -610,7 +626,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin name := col.Name.Name.O v := "NULL" if record[name].Valid { - v = fmt.Sprintf("'%s'", record[name].String) + colDef := colNameDefMap[name] + if parserMysql.HasBinaryFlag(colDef.Tp.Flag) { + hexStr := getHexStrFromBytesStr(record[name].String) + v = fmt.Sprintf("X'%s'", hexStr) + } else { + v = fmt.Sprintf("'%s'", record[name].String) + } } if colChanged { diff --git a/sqle/model/utils.go b/sqle/model/utils.go index 61210aef1e..e982cfb043 100644 --- a/sqle/model/utils.go +++ b/sqle/model/utils.go @@ -87,7 +87,7 @@ type Model struct { func NewStorage(user, password, host, port, schema string, debug bool) (*Storage, error) { log.Logger().Infof("connecting to storage, host: %s, port: %s, user: %s, schema: %s", host, port, user, schema) - db, err := gorm.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local", + db, err := gorm.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, schema)) if err != nil { log.Logger().Errorf("connect to storage failed, error: %v", err)