Skip to content

Commit

Permalink
GODRIVER-3009 Fix concurrent panic in struct codec. (#1477)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu authored Nov 30, 2023
1 parent d33301f commit 868e9c0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 6 deletions.
9 changes: 3 additions & 6 deletions bson/bsoncodec/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) {
// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for
// concurrent use by multiple goroutines after all codecs and encoders are registered.
func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
if valueType == nil {
return nil, ErrNoEncoder{Type: valueType}
}
enc, found := r.lookupTypeEncoder(valueType)
if found {
if enc == nil {
Expand All @@ -400,15 +403,10 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
if found {
return r.typeEncoders.LoadOrStore(valueType, enc), nil
}
if valueType == nil {
r.storeTypeEncoder(valueType, nil)
return nil, ErrNoEncoder{Type: valueType}
}

if v, ok := r.kindEncoders.Load(valueType.Kind()); ok {
return r.storeTypeEncoder(valueType, v), nil
}
r.storeTypeEncoder(valueType, nil)
return nil, ErrNoEncoder{Type: valueType}
}

Expand Down Expand Up @@ -474,7 +472,6 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) {
if v, ok := r.kindDecoders.Load(valueType.Kind()); ok {
return r.storeTypeDecoder(valueType, v), nil
}
r.storeTypeDecoder(valueType, nil)
return nil, ErrNoDecoder{Type: valueType}
}

Expand Down
30 changes: 30 additions & 0 deletions bson/bsoncodec/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,36 @@ func TestRegistry(t *testing.T) {
})
})
}
t.Run("nil type", func(t *testing.T) {
t.Parallel()

t.Run("Encoder", func(t *testing.T) {
t.Parallel()

wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)}

gotcodec, goterr := reg.LookupEncoder(nil)
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
}
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("codecs did not match: got %#v, want nil", gotcodec)
}
})
t.Run("Decoder", func(t *testing.T) {
t.Parallel()

wanterr := ErrNilType

gotcodec, goterr := reg.LookupDecoder(nil)
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
}
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("codecs did not match: got %v: want nil", gotcodec)
}
})
})
// lookup a type whose pointer implements an interface and expect that the registered hook is
// returned
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions bson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -380,3 +381,19 @@ func TestMarshalExtJSONIndent(t *testing.T) {
})
}
}

func TestMarshalConcurrently(t *testing.T) {
t.Parallel()

const size = 10_000

wg := sync.WaitGroup{}
wg.Add(size)
for i := 0; i < size; i++ {
go func() {
defer wg.Done()
_, _ = Marshal(struct{ LastError error }{})
}()
}
wg.Wait()
}
19 changes: 19 additions & 0 deletions bson/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package bson
import (
"math/rand"
"reflect"
"sync"
"testing"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
Expand Down Expand Up @@ -773,3 +774,21 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {
})
}
}

func TestUnmarshalConcurrently(t *testing.T) {
t.Parallel()

const size = 10_000

data := []byte{16, 0, 0, 0, 10, 108, 97, 115, 116, 101, 114, 114, 111, 114, 0, 0}
wg := sync.WaitGroup{}
wg.Add(size)
for i := 0; i < size; i++ {
go func() {
defer wg.Done()
var res struct{ LastError error }
_ = Unmarshal(data, &res)
}()
}
wg.Wait()
}

0 comments on commit 868e9c0

Please sign in to comment.