Skip to content

Commit

Permalink
Merge branch 'master' into GODRIVER-2935
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Sep 7, 2023
2 parents 25b7685 + 84a4385 commit c0085e1
Show file tree
Hide file tree
Showing 13 changed files with 421 additions and 98 deletions.
35 changes: 8 additions & 27 deletions .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -437,37 +437,18 @@ functions:
make -s evg-test-enterprise-auth
run-atlas-test:
- command: ec2.assume_role
params:
role_arn: "${aws_test_secrets_role}"
- command: shell.exec
type: test
params:
shell: "bash"
working_dir: src/go.mongodb.org/mongo-driver
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
script: |
# DO NOT ECHO WITH XTRACE
if [ "Windows_NT" = "$OS" ]; then
export GOPATH=$(cygpath -w $(dirname $(dirname $(dirname `pwd`))))
export GOCACHE=$(cygpath -w "$(pwd)/.cache")
else
export GOPATH=$(dirname $(dirname $(dirname `pwd`)))
export GOCACHE="$(pwd)/.cache"
fi;
export GOPATH="$GOPATH"
export GOROOT="${GO_DIST}"
export GOCACHE="$GOCACHE"
export PATH="${GCC_PATH}:${GO_DIST}/bin:$PATH"
export ATLAS_FREE="${atlas_free_tier_uri}"
export ATLAS_REPLSET="${atlas_replica_set_uri}"
export ATLAS_SHARD="${atlas_sharded_uri}"
export ATLAS_TLS11="${atlas_tls_v11_uri}"
export ATLAS_TLS12="${atlas_tls_v12_uri}"
export ATLAS_FREE_SRV="${atlas_free_tier_uri_srv}"
export ATLAS_REPLSET_SRV="${atlas_replica_set_uri_srv}"
export ATLAS_SHARD_SRV="${atlas_sharded_uri_srv}"
export ATLAS_TLS11_SRV="${atlas_tls_v11_uri_srv}"
export ATLAS_TLS12_SRV="${atlas_tls_v12_uri_srv}"
export ATLAS_SERVERLESS="${atlas_serverless_uri}"
export ATLAS_SERVERLESS_SRV="${atlas_serverless_uri_srv}"
make -s evg-test-atlas
${PREPARE_SHELL}
bash etc/run-atlas-test.sh
run-ocsp-test:
- command: shell.exec
Expand Down Expand Up @@ -2228,7 +2209,7 @@ tasks:
export AZUREKMS_VMNAME=${AZUREKMS_VMNAME}
echo '${testazurekms_privatekey}' > /tmp/testazurekms.prikey
export AZUREKMS_PRIVATEKEYPATH=/tmp/testazurekms.prikey
AZUREKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib MONGODB_URI='mongodb://localhost:27017' PROVIDER='azure' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh
AZUREKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib MONGODB_URI='mongodb://localhost:27017' PROVIDER='azure' AZUREKMS_KEY_NAME='${AZUREKMS_KEY_NAME}' AZUREKMS_KEY_VAULT_ENDPOINT='${AZUREKMS_KEY_VAULT_ENDPOINT}' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh
- name: "testazurekms-fail-task"
# testazurekms-fail-task runs without environment variables.
Expand All @@ -2250,7 +2231,7 @@ tasks:
LD_LIBRARY_PATH=./install/libmongocrypt/lib \
MONGODB_URI='mongodb://localhost:27017' \
EXPECT_ERROR='unable to retrieve azure credentials' \
PROVIDER='azure' \
PROVIDER='azure' AZUREKMS_KEY_NAME='${AZUREKMS_KEY_NAME}' AZUREKMS_KEY_VAULT_ENDPOINT='${AZUREKMS_KEY_VAULT_ENDPOINT}' \
./testkms
- name: "test-fuzz"
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ internal/test/compilecheck/compilecheck.so

# Ignore api report files
api-report.md
api-report.txt
api-report.txt

# Ignore secrets files
secrets-expansion.yml
secrets-export.sh
5 changes: 0 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)"
TEST_TIMEOUT = 1800

### Utility targets. ###
Expand Down Expand Up @@ -128,10 +127,6 @@ build-aws-ecs-test:
evg-test:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite

.PHONY: evg-test-atlas
evg-test-atlas:
go run ./cmd/testatlas/main.go $(ATLAS_URIS)

.PHONY: evg-test-atlas-data-lake
evg-test-atlas-data-lake:
ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite
Expand Down
6 changes: 6 additions & 0 deletions cmd/testatlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ func main() {
uris := flag.Args()
ctx := context.Background()

fmt.Printf("Running atlas tests for %d uris\n", len(uris))

for idx, uri := range uris {
fmt.Printf("Running test %d\n", idx)

// Set a low server selection timeout so we fail fast if there are errors.
clientOpts := options.Client().
ApplyURI(uri).
Expand All @@ -41,6 +45,8 @@ func main() {
panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err))
}
}

fmt.Println("Finished!")
}

func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
Expand Down
20 changes: 18 additions & 2 deletions cmd/testkms/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ var datakeyopts = map[string]primitive.M{
"key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
},
"azure": bson.M{
"keyVaultEndpoint": "https://keyvault-drivers-2411.vault.azure.net/keys/",
"keyName": "KEY-NAME",
"keyVaultEndpoint": "",
"keyName": "",
},
"gcp": bson.M{
"projectId": "devprod-drivers",
Expand Down Expand Up @@ -53,6 +53,20 @@ func main() {
default:
ok = true
}
if provider == "azure" {
azureKmsKeyName := os.Getenv("AZUREKMS_KEY_NAME")
azureKmsKeyVaultEndpoint := os.Getenv("AZUREKMS_KEY_VAULT_ENDPOINT")
if azureKmsKeyName == "" {
fmt.Println("ERROR: Please set required AZUREKMS_KEY_NAME environment variable.")
ok = false
}
if azureKmsKeyVaultEndpoint == "" {
fmt.Println("ERROR: Please set required AZUREKMS_KEY_VAULT_ENDPOINT environment variable.")
ok = false
}
datakeyopts["azure"]["keyName"] = azureKmsKeyName
datakeyopts["azure"]["keyVaultEndpoint"] = azureKmsKeyVaultEndpoint
}
if !ok {
providers := make([]string, 0, len(datakeyopts))
for p := range datakeyopts {
Expand All @@ -63,6 +77,8 @@ func main() {
fmt.Println("- MONGODB_URI as a MongoDB URI. Example: 'mongodb://localhost:27017'")
fmt.Println("- EXPECT_ERROR as an optional expected error substring.")
fmt.Println("- PROVIDER as a KMS provider, which supports:", strings.Join(providers, ", "))
fmt.Println("- AZUREKMS_KEY_NAME as the Azure key name. Required if PROVIDER=azure.")
fmt.Println("- AZUREKMS_KEY_VAULT_ENDPOINT as the Azure key name. Required if PROVIDER=azure.")
os.Exit(1)
}

Expand Down
12 changes: 12 additions & 0 deletions etc/get_aws_secrets.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env bash
# get-aws-secrets
# Gets AWS secrets from the vault
set -eu

if [ -z "$DRIVERS_TOOLS" ]; then
echo "Please define DRIVERS_TOOLS variable"
exit 1
fi

bash $DRIVERS_TOOLS/.evergreen/auth_aws/setup_secrets.sh $@
. ./secrets-export.sh
11 changes: 11 additions & 0 deletions etc/run-atlas-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
# run-atlas-test
# Run atlas connectivity tests.
set -eu
set +x

# Get the atlas secrets.
. etc/get_aws_secrets.sh drivers/atlas_connect

echo "Running cmd/testatlas/main.go"
go run ./cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS"
3 changes: 0 additions & 3 deletions mongo/integration/client_side_encryption_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ func TestClientSideEncryptionProse(t *testing.T) {
}
})
mt.Run("4. bson size limits", func(mt *mtest.T) {
// TODO(GODRIVER-2872): Fix and unskip this test case.
mt.Skip("Test fails frequently, skipping. See GODRIVER-2872")

kmsProviders := map[string]map[string]interface{}{
"local": {
"key": localMasterKey,
Expand Down
2 changes: 0 additions & 2 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ func TestClient(t *testing.T) {
"expected security field to be type %v, got %v", bson.TypeMaxKey, security.Type)
_, found := security.Document().LookupErr("SSLServerSubjectName")
assert.Nil(mt, found, "SSLServerSubjectName not found in result")
_, found = security.Document().LookupErr("SSLServerHasCertificateAuthority")
assert.Nil(mt, found, "SSLServerHasCertificateAuthority not found in result")
})
mt.RunOpts("x509", mtest.NewOptions().Auth(true).SSL(true), func(mt *mtest.T) {
testCases := []struct {
Expand Down
114 changes: 70 additions & 44 deletions x/mongo/driver/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,72 @@ type CompressionOpts struct {
UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
// destination writer. It panics on any errors and should only be used at
// package initialization time.
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
if err != nil {
panic(err)
}
return enc
}

var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
}

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
zstdEncoders.Store(level, encoder)
return encoder, nil
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
// compression level.
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2

var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool

func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
return enc, nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
}

return encoder, nil
func putZlibEncoder(enc *zlibEncoder) {
if enc != nil {
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
}
}

type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
buf bytes.Buffer
level int
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
defer putZlibEncoder(e)

e.buf.Reset()
e.writer.Reset(e.buf)
e.writer.Reset(&e.buf)

_, err := e.writer.Write(src)
if err != nil {
Expand Down Expand Up @@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
}
}

var zstdReaderPool = sync.Pool{
New: func() interface{} {
r, _ := zstd.NewReader(nil)
return r
},
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
Expand All @@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
out := make([]byte, opts.UncompressedSize)
return snappy.Decode(out, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
out := make([]byte, opts.UncompressedSize)
if _, err := io.ReadFull(r, out); err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
if err := r.Close(); err != nil {
return nil, err
}
return uncompressed, nil
return out, nil
case wiremessage.CompressorZstd:
buf := make([]byte, 0, opts.UncompressedSize)
// Using a pool here is about ~20% faster
// than using a single global zstd.Reader
r := zstdReaderPool.Get().(*zstd.Decoder)
out, err := r.DecodeAll(in, buf)
zstdReaderPool.Put(r)
return out, err
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
Expand Down
Loading

0 comments on commit c0085e1

Please sign in to comment.