Skip to content

Commit

Permalink
store ok packets as raw and add parsing of raw packets for VTGate to …
Browse files Browse the repository at this point in the history
…retreive query result

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Nov 8, 2024
1 parent 6313ead commit ab80841
Show file tree
Hide file tree
Showing 14 changed files with 883 additions and 673 deletions.
30 changes: 17 additions & 13 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,13 @@ func (c *Conn) readPacketAsMemBuffer() (mem.Buffer, error) {
return mem.SliceBuffer(data), nil
}

const RawPacketsPos = 20

func updateProtoHeader(b []byte, v int) {
b[0] = byte(protowire.EncodeTag(1, protowire.BytesType))
b[0] = byte(protowire.EncodeTag(RawPacketsPos, protowire.BytesType))
switch {
case v < 1<<28:
// Proto packet data size is 4 bytes.
b[1] = byte((v>>0)&0x7f | 0x80)
b[2] = byte((v>>7)&0x7f | 0x80)
b[3] = byte((v>>14)&0x7f | 0x80)
Expand Down Expand Up @@ -1564,50 +1567,47 @@ type PacketOK struct {
sessionStateData string
}

func (c *Conn) parseOKPacket(packetOK *PacketOK, in []byte) error {
func parseOKPacket(packetOK *PacketOK, in []byte, queryInfoEnabled, sessionTrackingEnabled bool) error {
data := &coder{
data: in,
pos: 1, // We already read the type.
}
var ok bool

// Affected rows.
affectedRows, ok := data.readLenEncInt()
packetOK.affectedRows, ok = data.readLenEncInt()
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid OK packet affectedRows: %v", data.data)
}
packetOK.affectedRows = affectedRows

// Last Insert ID.
lastInsertID, ok := data.readLenEncInt()
packetOK.lastInsertID, ok = data.readLenEncInt()
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid OK packet lastInsertID: %v", data.data)
}
packetOK.lastInsertID = lastInsertID

// Status flags.
statusFlags, ok := data.readUint16()
packetOK.statusFlags, ok = data.readUint16()
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid OK packet statusFlags: %v", data.data)
}
packetOK.statusFlags = statusFlags

// assuming CapabilityClientProtocol41
// Warnings.
warnings, ok := data.readUint16()
packetOK.warnings, ok = data.readUint16()
if !ok {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid OK packet warnings: %v", data.data)
}
packetOK.warnings = warnings

// info
info, _ := data.readLenEncInfo()
if c.enableQueryInfo {
if queryInfoEnabled {
packetOK.info = info
}

if c.Capabilities&uint32(CapabilityClientSessionTrack) == CapabilityClientSessionTrack {
if sessionTrackingEnabled {
// session tracking
if statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
if packetOK.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
length, ok := data.readLenEncInt()
if !ok || length == 0 {
// In case we have no more data or a zero length string, there's no additional information so
Expand Down Expand Up @@ -1747,3 +1747,7 @@ func (c *Conn) IsMarkedForClose() bool {
func (c *Conn) IsShuttingDown() bool {
return c.listener.shutdown.Load()
}

func (c *Conn) isSessionTrack() bool {
return c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack
}
8 changes: 4 additions & 4 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func TestBasicPackets(t *testing.T) {
assert.EqualValues(data[0], OKPacket, "OKPacket")

var packetOk PacketOK
err = cConn.parseOKPacket(&packetOk, data)
err = parseOKPacket(&packetOk, data, cConn.enableQueryInfo, cConn.isSessionTrack())
require.NoError(err)
assert.EqualValues(12, packetOk.affectedRows)
assert.EqualValues(34, packetOk.lastInsertID)
Expand All @@ -274,7 +274,7 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.EqualValues(data[0], OKPacket, "OKPacket")

err = cConn.parseOKPacket(&packetOk, data)
err = parseOKPacket(&packetOk, data, cConn.enableQueryInfo, cConn.isSessionTrack())
require.NoError(err)
assert.EqualValues(23, packetOk.affectedRows)
assert.EqualValues(45, packetOk.lastInsertID)
Expand All @@ -297,7 +297,7 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.True(cConn.isEOFPacket(data), "expected EOF")

err = cConn.parseOKPacket(&packetOk, data)
err = parseOKPacket(&packetOk, data, cConn.enableQueryInfo, cConn.isSessionTrack())
require.NoError(err)
assert.EqualValues(12, packetOk.affectedRows)
assert.EqualValues(34, packetOk.lastInsertID)
Expand Down Expand Up @@ -693,7 +693,7 @@ func TestOkPackets(t *testing.T) {
sConn.Capabilities = testCase.cc
// parse the packet
var packetOk PacketOK
err := cConn.parseOKPacket(&packetOk, data)
err := parseOKPacket(&packetOk, data, cConn.enableQueryInfo, cConn.isSessionTrack())
if testCase.expectedErr != "" {
require.Error(t, err)
require.Equal(t, testCase.expectedErr, err.Error())
Expand Down
116 changes: 73 additions & 43 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,18 +428,19 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield
return res, warnings, err
}

func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool, uint16, error) {
func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (result *sqltypes.Result, more bool, warnings uint16, err error) {
var packetOk PacketOK

// Get the result.
colNumber, err := c.readComQueryResponse(&packetOk)
first, colNumber, err := c.readComQueryResponseAsMemBuf(&packetOk)
if err != nil {
return nil, false, 0, err
}
more := packetOk.statusFlags&ServerMoreResultsExists != 0
warnings := packetOk.warnings
more = packetOk.statusFlags&ServerMoreResultsExists != 0
warnings = packetOk.warnings
if colNumber == 0 {
// OK packet, means no results. Just use the numbers.
first.Free()
return &sqltypes.Result{
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
Expand All @@ -449,8 +450,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
}, more, warnings, nil
}

var rawPackets []mem.Buffer
var data mem.Buffer
rawPackets := []mem.Buffer{first}

defer func() {
if err != nil {
Expand All @@ -463,41 +463,37 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
// Read column headers. One packet per column.
// Build the fields.
for i := 0; i < colNumber; i++ {
data, err = c.readPacketAsMemBuffer()
data, err := c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRMalformedPacket, "", "")
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, "", "")
}
rawPackets = append(rawPackets, data)
}

if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// EOF is only present here if it's not deprecated.
data, err = c.readPacketAsMemBuffer()
data, err := c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
}
rawPackets = append(rawPackets, data)
defer data.Free()

if c.isEOFPacket(data.ReadOnlyData()) {
rawPackets = rawPackets[:len(rawPackets)-1]
// empty by design
} else if isErrorPacket(data.ReadOnlyData()) {
err = ParseErrorPacket(data.ReadOnlyData())
return nil, false, 0, err
return nil, false, 0, ParseErrorPacket(data.ReadOnlyData())
} else {
err = vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
return nil, false, 0, err
return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
}
}

var rowcount int

// Read each row until EOF or OK packet.
for {
data, err = c.readPacketAsMemBuffer()
data, err := c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
}
rawPackets = append(rawPackets, data)

Expand All @@ -514,20 +510,19 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
}
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags

rawPackets = rawPackets[:len(rawPackets)-1]
// rawPackets = rawPackets[:len(rawPackets)-1]
} else {
var packetEof PacketOK
if err = c.parseOKPacket(&packetEof, data.ReadOnlyData()); err != nil {
var packetOK PacketOK
if err = parseOKPacket(&packetOK, data.ReadOnlyData(), c.enableQueryInfo, c.isSessionTrack()); err != nil {
return nil, false, 0, err
}
warnings = packetEof.warnings
more = (packetEof.statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = packetEof.statusFlags
warnings = packetOK.warnings
more = (packetOK.statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = packetOK.statusFlags

rawPackets = rawPackets[:len(rawPackets)-1]
result.SessionStateChanges = packetEof.sessionStateData
result.Info = packetEof.info
// rawPackets = rawPackets[:len(rawPackets)-1]
result.SessionStateChanges = packetOK.sessionStateData
result.Info = packetOK.info
}

// log.Errorf("DEBUG: setting result cached proto to %v", rawPackets)
Expand All @@ -536,8 +531,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool

} else if isErrorPacket(data.ReadOnlyData()) {
// Error packet.
err = ParseErrorPacket(data.ReadOnlyData())
return nil, false, 0, err
return nil, false, 0, ParseErrorPacket(data.ReadOnlyData())
}

if maxrows == FETCH_NO_ROWS {
Expand All @@ -549,8 +543,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
if err = c.drainResults(); err != nil {
return nil, false, 0, err
}
err = vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
return nil, false, 0, err
return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
}

rowcount++
Expand Down Expand Up @@ -647,15 +640,15 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags
} else {
var packetEof PacketOK
if err := c.parseOKPacket(&packetEof, data); err != nil {
var packetOK PacketOK
if err = parseOKPacket(&packetOK, data, c.enableQueryInfo, c.isSessionTrack()); err != nil {
return nil, false, 0, err
}
warnings = packetEof.warnings
more = (packetEof.statusFlags & ServerMoreResultsExists) != 0
result.SessionStateChanges = packetEof.sessionStateData
result.StatusFlags = packetEof.statusFlags
result.Info = packetEof.info
warnings = packetOK.warnings
more = (packetOK.statusFlags & ServerMoreResultsExists) != 0
result.SessionStateChanges = packetOK.sessionStateData
result.StatusFlags = packetOK.statusFlags
result.Info = packetOK.info
}
return result, more, warnings, nil

Expand Down Expand Up @@ -720,7 +713,7 @@ func (c *Conn) readComQueryResponse(packetOk *PacketOK) (int, error) {

switch data[0] {
case OKPacket:
return 0, c.parseOKPacket(packetOk, data)
return 0, parseOKPacket(packetOk, data, c.enableQueryInfo, c.isSessionTrack())
case ErrPacket:
// Error
return 0, ParseErrorPacket(data)
Expand All @@ -738,6 +731,43 @@ func (c *Conn) readComQueryResponse(packetOk *PacketOK) (int, error) {
return int(n), nil
}

func (c *Conn) readComQueryResponseAsMemBuf(packetOk *PacketOK) (buf mem.Buffer, res int, err error) {
defer func() {
if buf != nil && err != nil {
buf.Free()
buf = nil
}
}()
buf, err = c.readPacketAsMemBuffer()
if err != nil {
return buf, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
data := buf.ReadOnlyData()[5:]
if len(data) == 0 {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet")
}

switch data[0] {
case OKPacket:
return buf, 0, parseOKPacket(packetOk, data, c.enableQueryInfo, c.isSessionTrack())
case ErrPacket:
// Error
return buf, 0, ParseErrorPacket(data)
case 0xfb:
// Local infile
return buf, 0, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
}
n, pos, ok := readLenEncInt(data, 0)
if !ok {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")
}
if pos != len(data) {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response")
}
return buf, int(n), nil
}

//
// Server side methods.
//
Expand Down
Loading

0 comments on commit ab80841

Please sign in to comment.