Skip to content

Commit

Permalink
Merge pull request #2409 from actiontech/cherry-pick-v2-2.2311
Browse files Browse the repository at this point in the history
Cherry pick v2 2.2311
  • Loading branch information
ColdWaterLW authored May 7, 2024
2 parents 912d5cb + 874a9a0 commit 7c60cd1
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions sqle/driver/mysql/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysql

import (
"database/sql"
"encoding/hex"
"fmt"
"strconv"
"strings"
Expand All @@ -11,6 +12,7 @@ import (

"github.com/pingcap/parser/ast"
_model "github.com/pingcap/parser/model"
parserMysql "github.com/pingcap/parser/mysql"
)

func (i *MysqlDriverImpl) GenerateRollbackSql(node ast.Node) (string, string, error) {
Expand Down Expand Up @@ -449,6 +451,12 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
return rollbackSql, "", nil
}

// 将二进制字段转化为十六进制字段
func getHexStrFromBytesStr(byteStr string) string {
encode := []byte(byteStr)
return hex.EncodeToString(encode)
}

// generateDeleteRollbackSql generate insert SQL for delete.
func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) {
// not support multi-table syntax
Expand Down Expand Up @@ -497,8 +505,10 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin
values := []string{}

columnsName := []string{}
colNameDefMap := make(map[string]*ast.ColumnDef)
for _, col := range createTableStmt.Cols {
columnsName = append(columnsName, col.Name.Name.String())
colNameDefMap[col.Name.Name.String()] = col
}
for _, record := range records {
if len(record) != len(columnsName) {
Expand All @@ -508,7 +518,13 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin
for _, name := range columnsName {
v := "NULL"
if record[name].Valid {
v = fmt.Sprintf("'%s'", record[name].String)
colDef := colNameDefMap[name]
if parserMysql.HasBinaryFlag(colDef.Tp.Flag) {
hexStr := getHexStrFromBytesStr(record[name].String)
v = fmt.Sprintf("X'%s'", hexStr)
} else {
v = fmt.Sprintf("'%s'", record[name].String)
}
}
vs = append(vs, v)
}
Expand Down Expand Up @@ -581,13 +597,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
if err != nil {
return "", "", err
}
columnsName := []string{}
rollbackSql := ""
colNameDefMap := make(map[string]*ast.ColumnDef)
for _, col := range createTableStmt.Cols {
columnsName = append(columnsName, col.Name.Name.String())
colNameDefMap[col.Name.Name.String()] = col
}
for _, record := range records {
if len(record) != len(columnsName) {
if len(record) != len(colNameDefMap) {
return "", "", nil
}
where := []string{}
Expand All @@ -610,7 +626,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
name := col.Name.Name.O
v := "NULL"
if record[name].Valid {
v = fmt.Sprintf("'%s'", record[name].String)
colDef := colNameDefMap[name]
if parserMysql.HasBinaryFlag(colDef.Tp.Flag) {
hexStr := getHexStrFromBytesStr(record[name].String)
v = fmt.Sprintf("X'%s'", hexStr)
} else {
v = fmt.Sprintf("'%s'", record[name].String)
}
}

if colChanged {
Expand Down

0 comments on commit 7c60cd1

Please sign in to comment.