Skip to content

Commit

Permalink
Merge pull request #5335 from oasisprotocol/kostko/feature/cbor-rpc-r…
Browse files Browse the repository at this point in the history
…elax

go/common/cbor: Relax CBOR decoding for gRPC/RHP endpoints
  • Loading branch information
kostko authored Aug 3, 2023
2 parents 3293d29 + f76e2b5 commit 11b1662
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions .changelog/5335.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
go/common/cbor: Relax CBOR decoding for gRPC/RHP endpoints
30 changes: 30 additions & 0 deletions go/common/cbor/cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,20 @@ var (
MaxMapPairs: 2147483647, // Maximum allowed.
}

// decOptionsRPC are decoding options for gRPC endpoints. They are only used when explicitly
// requested by using the UnmarshalRPC method.
decOptionsRPC = cbor.DecOptions{
DupMapKey: cbor.DupMapKeyEnforcedAPF,
IndefLength: cbor.IndefLengthForbidden,
TagsMd: cbor.TagsForbidden,
MaxArrayElements: 10_000_000, // Usually limited by blob size limits anyway.
MaxMapPairs: 10_000_000, // Usually limited by blob size limits anyway.
}

encMode cbor.EncMode
decMode cbor.DecMode
decModeTrusted cbor.DecMode
decModeRPC cbor.DecMode
)

func init() {
Expand All @@ -69,6 +80,9 @@ func init() {
if decModeTrusted, err = decOptionsTrusted.DecMode(); err != nil {
panic(err)
}
if decModeRPC, err = decOptionsRPC.DecMode(); err != nil {
panic(err)
}
}

// Marshal serializes a given type into a CBOR byte vector.
Expand Down Expand Up @@ -100,6 +114,17 @@ func UnmarshalTrusted(data []byte, dst interface{}) error {
return decModeTrusted.Unmarshal(data, dst)
}

// UnmarshalRPC deserializes a CBOR byte vector into a given type.
//
// This method is suitable for RPC endpoints as it relaxes some decoding restrictions.
func UnmarshalRPC(data []byte, dst interface{}) error {
if data == nil {
return nil
}

return decModeRPC.Unmarshal(data, dst)
}

// MustUnmarshal deserializes a CBOR byte vector into a given type.
// Panics if unmarshal fails.
func MustUnmarshal(data []byte, dst interface{}) {
Expand All @@ -117,3 +142,8 @@ func NewEncoder(w io.Writer) *cbor.Encoder {
func NewDecoder(r io.Reader) *cbor.Decoder {
return decMode.NewDecoder(r)
}

// NewDecoderRPC creates a new CBOR decoder with relaxed decoding restrictions.
func NewDecoderRPC(r io.Reader) *cbor.Decoder {
return decModeRPC.NewDecoder(r)
}
13 changes: 12 additions & 1 deletion go/common/cbor/cbor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ func TestEncoderDecoder(t *testing.T) {
err = dec.Decode(&x)
require.NoError(err, "Decode")
require.EqualValues(42, x, "decoded value should be correct")

err = enc.Encode(32)
require.NoError(err, "Encode")

dec = NewDecoderRPC(&buf)
err = dec.Decode(&x)
require.NoError(err, "Decode")
require.EqualValues(32, x, "decoded value should be correct")
}

func TestDecodeUnknowField(t *testing.T) {
func TestDecodeUnknownField(t *testing.T) {
require := require.New(t)

type a struct {
Expand All @@ -69,4 +77,7 @@ func TestDecodeUnknowField(t *testing.T) {

err = UnmarshalTrusted(raw, &dec)
require.NoError(err, "unknown fields from trusted sources should pass")

err = UnmarshalRPC(raw, &dec)
require.NoError(err, "unknown fields from RPC should pass")
}
2 changes: 1 addition & 1 deletion go/common/cbor/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (c *MessageReader) Read(msg interface{}) error {

// Decode message bytes.
r := io.LimitReader(c.reader, int64(length))
dec := NewDecoder(r)
dec := NewDecoderRPC(r)
if err := dec.Decode(msg); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go/common/grpc/cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (c *CBORCodec) Marshal(v interface{}) ([]byte, error) {
}

func (c *CBORCodec) Unmarshal(data []byte, v interface{}) error {
return cbor.Unmarshal(data, v)
return cbor.UnmarshalRPC(data, v)
}

func (c *CBORCodec) Name() string {
Expand Down

0 comments on commit 11b1662

Please sign in to comment.