Skip to content

Commit

Permalink
Merge pull request #3 from keploy/newpg
Browse files Browse the repository at this point in the history
feat: make datarow values editable
  • Loading branch information
Sarthak160 authored Mar 2, 2024
2 parents 33b2755 + 72646b8 commit 51d7ac4
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 20 deletions.
2 changes: 1 addition & 1 deletion authentication_md5_password.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte `json:"salt" yaml:"salt"`
Salt [4]byte `json:"salt" yaml:"salt,omitempty,flow"`
}

// Backend identifies this message as sendable by the PostgreSQL backend.
Expand Down
6 changes: 4 additions & 2 deletions command_complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import (
)

type CommandComplete struct {
CommandTag []byte `json:"command_tag" yaml:"command_tag"`
CommandTag []byte `json:"-" yaml:"-"`
CommandTagType string `json:"command_tag_type" yaml:"command_tag_type"`
}

// Backend identifies this message as sendable by the PostgreSQL backend.
Expand All @@ -24,6 +25,7 @@ func (dst *CommandComplete) Decode(src []byte) error {
}

dst.CommandTag = src[:idx]
dst.CommandTagType = string(dst.CommandTag)

return nil
}
Expand All @@ -34,7 +36,7 @@ func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)

src.CommandTag = []byte(src.CommandTagType)
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)

Expand Down
86 changes: 70 additions & 16 deletions data_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"unicode"

"github.com/jackc/pgio"
Expand Down Expand Up @@ -52,7 +53,7 @@ func (dst *DataRow) Decode(src []byte) error {

// null
if msgSize == -1 {
dst.Values[i] = nil
dst.Values[i] = nil //[]byte{255, 255, 255, 255}
} else {
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "DataRow"}
Expand All @@ -62,17 +63,22 @@ func (dst *DataRow) Decode(src []byte) error {
rp += msgSize
}
}
// fmt.Println("DECODED VALUES", dst.Values)
dst.RowValues = []string{}
for _, v := range dst.Values {
// fmt.Println(string(v))
bufStr := ""
// if v == nil {
// bufStr = "NIL"
// dst.RowValues = append(dst.RowValues, bufStr)
// continue
// }
if !IsAsciiPrintable(string(v)) {
bufStr = base64.StdEncoding.EncodeToString(v)
bufStr = "base64:" + bufStr
dst.RowValues = append(dst.RowValues, bufStr)
// println("NON PRINTABLE STRING FOUND !")
continue
bufStr = "b64:" + base64.StdEncoding.EncodeToString(v)
} else {
bufStr = string(v)
}
dst.RowValues = append(dst.RowValues, string(v))
dst.RowValues = append(dst.RowValues, bufStr)
}

return nil
Expand All @@ -84,38 +90,86 @@ func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
src.Values = stringsToBytesArray(src.RowValues)
// src.Values = stringsToBytesArray(src.RowValues)

// epoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)

// // Given date
// givenDate := time.Date(2021, 7, 14, 0, 0, 0, 0, time.UTC)

// Calculate the difference in days
// difference := givenDate.Sub(epoch).Hours() / 24

// Prepare a byte slice to hold the binary representation
// buf := make([]byte, 4)
// binary.BigEndian.PutUint32(buf, uint32(difference))

// // Output the difference in days and the binary representation
// fmt.Printf("Days difference: %d\n", int(difference))
// fmt.Printf("Binary representation: %v\n", buf)
if src.RowValues != nil && len(src.RowValues) > 0 {
// fmt.Println("SRC ROW VALUES *** * ** * * ** ", src.RowValues)
src.Values = stringsToBytesArray(src.RowValues)
}
// fmt.Println("SRC VALUES", src.Values)
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
if v == nil || len(v) == 0{
dst = pgio.AppendInt32(dst, -1)
continue
}



dst = pgio.AppendInt32(dst, int32(len(v)))
dst = append(dst, v...)
}

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

// src.RowValues = []string{}
// src.Values = [][]byte{}
return dst
}

func stringsToBytesArray(strArray []string) [][]byte {
byteArray := make([][]byte, len(strArray))

for i, str := range strArray {
// if str[0] == 'b' && str[1] == 'a' && str[2] == 's' && str[3] == 'e' {
// // slic and decode
// byteArray[i], _ = base64.StdEncoding.DecodeString(str[7:])
// continue
// }
byteArray[i] = []byte(str)
if str == "NIL" {
fmt.Println("NIL AHHAHAHAHAHHAHAHAHAH")
byteArray[i] = []byte{255, 255, 255, 255}
continue
}
byt, isValidBase64 := isValidBase64(str)
if isValidBase64 && byt != nil {
byteArray[i] = byt
} else if IsAsciiPrintable(str) {
byteArray[i] = []byte(str)
}
}

return byteArray
}

func isValidBase64(s string) ([]byte, bool) {
// check if it contains b64:
// then slice the string and decode
if len(s) < 5 {
return nil, false
}
if s[:4] != "b64:" {
return nil, false
}
s = s[4:]

val, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, false
}
fmt.Println("VALUEEEEE", val, "HURRAY", s)
return val, true
}

// checks if s is ascii and printable, aka doesn't include tab, backspace, etc.
func IsAsciiPrintable(s string) bool {
for _, r := range s {
Expand Down
6 changes: 5 additions & 1 deletion row_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ const (
)

type FieldDescription struct {
Name []byte `json:"name" yaml:"name"`
Name []byte `json:"-" yaml:"-"`
FieldName string `json:"name" yaml:"name"`
TableOID uint32 `json:"table_oid" yaml:"table_oid"`
TableAttributeNumber uint16 `json:"table_attribute_number" yaml:"table_attribute_number"`
DataTypeOID uint32 `json:"data_type_oid" yaml:"data_type_oid"`
Expand Down Expand Up @@ -71,6 +72,7 @@ func (dst *RowDescription) Decode(src []byte) error {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.Name = src[rp : rp+idx]
fd.FieldName = string(fd.Name)
rp += idx + 1

// Since buf.Next() doesn't return an error if we hit the end of the buffer
Expand Down Expand Up @@ -107,6 +109,8 @@ func (src *RowDescription) Encode(dst []byte) []byte {

dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {

fd.Name = []byte(fd.FieldName)
dst = append(dst, fd.Name...)
dst = append(dst, 0)

Expand Down

0 comments on commit 51d7ac4

Please sign in to comment.