From fbbafce9003526517843eed7d24e22074e3634de Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 5 Sep 2023 17:59:03 -0600 Subject: [PATCH 01/22] GODRIVER-2810 Switch to polling monitoring when running within a FaaS environment --- internal/driverutil/hello.go | 101 ++++++++++++++ .../driverutil/{const.go => operation.go} | 0 internal/test/faas/awslambda/mongodb/main.go | 25 +++- mongo/integration/unified/client_entity.go | 2 + mongo/options/clientoptions.go | 36 +++++ mongo/options/clientoptions_test.go | 46 +++++++ .../unified/serverMonitoringMode.json | 129 ++++++++++++++++++ .../unified/serverMonitoringMode.yml | 67 +++++++++ testdata/uri-options/sdam-options.json | 46 +++++++ testdata/uri-options/sdam-options.yml | 35 +++++ x/mongo/driver/connstring/connstring.go | 19 +++ .../driver/connstring/connstring_spec_test.go | 2 + x/mongo/driver/operation/hello.go | 116 ++-------------- x/mongo/driver/operation/hello_test.go | 125 ++++++++--------- x/mongo/driver/topology/server.go | 32 ++++- x/mongo/driver/topology/server_options.go | 38 ++++-- x/mongo/driver/topology/topology_options.go | 1 + 17 files changed, 634 insertions(+), 186 deletions(-) create mode 100644 internal/driverutil/hello.go rename internal/driverutil/{const.go => operation.go} (100%) create mode 100644 testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json create mode 100644 testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml create mode 100644 testdata/uri-options/sdam-options.json create mode 100644 testdata/uri-options/sdam-options.yml diff --git a/internal/driverutil/hello.go b/internal/driverutil/hello.go new file mode 100644 index 0000000000..782f1f31d7 --- /dev/null +++ b/internal/driverutil/hello.go @@ -0,0 +1,101 @@ +package driverutil + +import ( + "os" + "strings" +) + +const AwsLambdaPrefix = "AWS_Lambda_" + +const ( + // FaaS environment variable names + EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" + EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + EnvVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" + EnvVarKService = "K_SERVICE" + EnvVarFunctionName = "FUNCTION_NAME" + EnvVarVercel = "VERCEL" +) + +const ( + // FaaS environment variable names + EnvVarAWSRegion = "AWS_REGION" + EnvVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" + EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" + EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" + EnvVarFunctionRegion = "FUNCTION_REGION" + EnvVarVercelRegion = "VERCEL_REGION" +) + +const ( + // FaaS environment names used by the client + EnvNameAWSLambda = "aws.lambda" + EnvNameAzureFunc = "azure.func" + EnvNameGCPFunc = "gcp.func" + EnvNameVercel = "vercel" +) + +// GetFaasEnvName parses the FaaS environment variable name and returns the +// corresponding name used by the client. If none of the variables or variables +// for multiple names are populated the client.env value MUST be entirely +// omitted. When variables for multiple "client.env.name" values are present, +// "vercel" takes precedence over "aws.lambda"; any other combination MUST cause +// "client.env" to be entirely omitted. +func GetFaasEnvName() string { + envVars := []string{ + EnvVarAWSExecutionEnv, + EnvVarAWSLambdaRuntimeAPI, + EnvVarFunctionsWorkerRuntime, + EnvVarKService, + EnvVarFunctionName, + EnvVarVercel, + } + + // If none of the variables are populated the client.env value MUST be + // entirely omitted. + names := make(map[string]struct{}) + + for _, envVar := range envVars { + val := os.Getenv(envVar) + if val == "" { + continue + } + + var name string + + switch envVar { + case EnvVarAWSExecutionEnv: + if !strings.HasPrefix(val, AwsLambdaPrefix) { + continue + } + + name = EnvNameAWSLambda + case EnvVarAWSLambdaRuntimeAPI: + name = EnvNameAWSLambda + case EnvVarFunctionsWorkerRuntime: + name = EnvNameAzureFunc + case EnvVarKService, EnvVarFunctionName: + name = EnvNameGCPFunc + case EnvVarVercel: + // "vercel" takes precedence over "aws.lambda". + delete(names, EnvNameAWSLambda) + + name = EnvNameVercel + } + + names[name] = struct{}{} + if len(names) > 1 { + // If multiple names are populated the client.env value + // MUST be entirely omitted. + names = nil + + break + } + } + + for name := range names { + return name + } + + return "" +} diff --git a/internal/driverutil/const.go b/internal/driverutil/operation.go similarity index 100% rename from internal/driverutil/const.go rename to internal/driverutil/operation.go diff --git a/internal/test/faas/awslambda/mongodb/main.go b/internal/test/faas/awslambda/mongodb/main.go index a0c55f9085..f9d8765550 100644 --- a/internal/test/faas/awslambda/mongodb/main.go +++ b/internal/test/faas/awslambda/mongodb/main.go @@ -27,11 +27,12 @@ const timeout = 60 * time.Second // event durations, as well as the number of heartbeats, commands, and open // conections. type eventListener struct { - commandCount int - commandDuration int64 - heartbeatCount int - heartbeatDuration int64 - openConnections int + commandCount int + commandDuration int64 + heartbeatAwaitedCount int + heartbeatCount int + heartbeatDuration int64 + openConnections int } // commandMonitor initializes an event.CommandMonitor that will count the number @@ -61,11 +62,19 @@ func (listener *eventListener) serverMonitor() *event.ServerMonitor { succeeded := func(e *event.ServerHeartbeatSucceededEvent) { listener.heartbeatCount++ listener.heartbeatDuration += e.DurationNanos + + if e.Awaited { + listener.heartbeatAwaitedCount++ + } } failed := func(e *event.ServerHeartbeatFailedEvent) { listener.heartbeatCount++ listener.heartbeatDuration += e.DurationNanos + + if e.Awaited { + listener.heartbeatAwaitedCount++ + } } return &event.ServerMonitor{ @@ -150,6 +159,12 @@ func handler(ctx context.Context, request events.APIGatewayProxyRequest) (events return gateway500(), fmt.Errorf("failed to delete: %w", err) } + // Driver must switch to polling monitoring when running within a FaaS + // environment. + if listener.heartbeatAwaitedCount > 0 { + return gateway500(), fmt.Errorf("FaaS environment fialed to switch to polling") + } + var avgCmdDur float64 if count := listener.commandCount; count != 0 { avgCmdDur = float64(listener.commandDuration) / float64(count) diff --git a/mongo/integration/unified/client_entity.go b/mongo/integration/unified/client_entity.go index e63c891039..41fd95b3a1 100644 --- a/mongo/integration/unified/client_entity.go +++ b/mongo/integration/unified/client_entity.go @@ -583,6 +583,8 @@ func setClientOptionsFromURIOptions(clientOpts *options.ClientOptions, uriOpts b clientOpts.SetTimeout(time.Duration(value.(int32)) * time.Millisecond) case "serverselectiontimeoutms": clientOpts.SetServerSelectionTimeout(time.Duration(value.(int32)) * time.Millisecond) + case "servermonitoringmode": + clientOpts.SetServerMonitoringMode(value.(string)) default: return fmt.Errorf("unrecognized URI option %s", key) } diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index f014da418b..21b5395183 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -33,6 +33,12 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + ServerMonitoringModeAuto = connstring.ServerMonitoringModeAuto + ServerMonitoringModePoll = connstring.ServerMonitoringModePoll + ServerMonitoringModeStream = connstring.ServerMonitoringModeStream +) + // ContextDialer is an interface that can be implemented by types that can create connections. It should be used to // provide a custom dialer when configuring a Client. // @@ -206,6 +212,7 @@ type ClientOptions struct { RetryReads *bool RetryWrites *bool ServerAPIOptions *ServerAPIOptions + ServerMonitoringMode *string ServerSelectionTimeout *time.Duration SRVMaxHosts *int SRVServiceName *string @@ -306,6 +313,11 @@ func (c *ClientOptions) validate() error { return connstring.ErrSRVMaxHostsWithLoadBalanced } } + + if mode := c.ServerMonitoringMode; mode != nil && !connstring.IsValidServerMonitoringMode(*mode) { + return fmt.Errorf("invalid server monitoring mode: %q", *mode) + } + return nil } @@ -945,6 +957,27 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio return c } +// SetServerMonitoringMode specifies the server monitoring protocol to use. +// +// Valid modes are: +// - "stream": The client will use a streaming protocol when the server +// supports it. The streaming protocol optimally reduces the time it takes +// for a client to discover server state changes. +// - "poll": The client will periodically check the server using a hello or +// legacy hello command and then sleep for heartbeatFrequencyMS milliseconds +// before running another check. +// - "auto": The client will behave like "poll" mode when running on a FaaS +// (Function as a Service) platform, or like "stream" mode otherwise. The +// client detects its execution environment by following the rules for +// generating the "client.env" handshake metadata field as specified in the +// MongoDB Handshake specification. This is the deafult mode. +func (c *ClientOptions) SetServerMonitoringMode(mode string) *ClientOptions { + fmt.Println("mode: ", mode) + c.ServerMonitoringMode = &mode + + return c +} + // SetSRVMaxHosts specifies the maximum number of SRV results to randomly select during polling. To limit the number // of hosts selected in SRV discovery, this function must be called before ApplyURI. This can also be set through // the "srvMaxHosts" URI option. @@ -1107,6 +1140,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.LoggerOptions != nil { c.LoggerOptions = opt.LoggerOptions } + if opt.ServerMonitoringMode != nil { + c.ServerMonitoringMode = opt.ServerMonitoringMode + } } return c diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 7c148ca0bd..e729f5f569 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -760,6 +760,52 @@ func TestClientOptions(t *testing.T) { }) } }) + t.Run("server monitoring mode validation", func(t *testing.T) { + t.Parallel() + + // := fmt.Errorf("invalid server monitoring mode: %q", *mode) + + testCases := []struct { + name string + opts *ClientOptions + err error + }{ + { + name: "undefined", + opts: Client(), + err: nil, + }, + { + name: "auto", + opts: Client().SetServerMonitoringMode(ServerMonitoringModeAuto), + err: nil, + }, + { + name: "poll", + opts: Client().SetServerMonitoringMode(ServerMonitoringModePoll), + err: nil, + }, + { + name: "stream", + opts: Client().SetServerMonitoringMode(ServerMonitoringModeStream), + err: nil, + }, + { + name: "invalid", + opts: Client().SetServerMonitoringMode("invalid"), + err: errors.New("invalid server monitoring mode: \"invalid\""), + }, + } + + for _, tc := range testCases { + tc := tc // Capture the range variable + + t.Run(tc.name, func(t *testing.T) { + err := tc.opts.Validate() + assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) + }) + } + }) } func createCertPool(t *testing.T, paths ...string) *x509.CertPool { diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json new file mode 100644 index 0000000000..520635ba2c --- /dev/null +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json @@ -0,0 +1,129 @@ +{ + "description": "serverMonitoringMode", + "schemaVersion": "1.3", + "tests": [ + { + "description": "connect with serverMonitoringMode=auto", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "serverMonitoringMode": "auto" + } + } + }, + { + "database": { + "id": "dbSdamModeAuto", + "client": "client0", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "dbSdamModeAuto", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + } + ] + }, + { + "description": "connect with serverMonitoringMode=stream", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client1", + "uriOptions": { + "serverMonitoringMode": "stream" + } + } + }, + { + "database": { + "id": "dbSdamModeStream", + "client": "client1", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "dbSdamModeStream", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + } + ] + }, + { + "description": "connect with serverMonitoringMode=poll", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client2", + "uriOptions": { + "serverMonitoringMode": "poll" + } + } + }, + { + "database": { + "id": "dbSdamModePoll", + "client": "client2", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "dbSdamModePoll", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + } + ] + } + ] +} diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml new file mode 100644 index 0000000000..bb2b2053b7 --- /dev/null +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml @@ -0,0 +1,67 @@ +description: serverMonitoringMode + +schemaVersion: "1.3" + +tests: + - description: "connect with serverMonitoringMode=auto" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client0 client0 + uriOptions: + serverMonitoringMode: "auto" + - database: + id: &dbSdamModeAuto dbSdamModeAuto + client: *client0 + databaseName: sdam-tests + - name: runCommand + object: *dbSdamModeAuto + arguments: + commandName: ping + command: { ping: 1 } + expectResult: { ok: 1 } + + - description: "connect with serverMonitoringMode=stream" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client1 client1 + uriOptions: + serverMonitoringMode: "stream" + - database: + id: &dbSdamModeStream dbSdamModeStream + client: *client1 + databaseName: sdam-tests + - name: runCommand + object: *dbSdamModeStream + arguments: + commandName: ping + command: { ping: 1 } + expectResult: { ok: 1 } + + - description: "connect with serverMonitoringMode=poll" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client2 client2 + uriOptions: + serverMonitoringMode: "poll" + - database: + id: &dbSdamModePoll dbSdamModePoll + client: *client2 + databaseName: sdam-tests + - name: runCommand + object: *dbSdamModePoll + arguments: + commandName: ping + command: { ping: 1 } + expectResult: { ok: 1 } diff --git a/testdata/uri-options/sdam-options.json b/testdata/uri-options/sdam-options.json new file mode 100644 index 0000000000..673f5607ee --- /dev/null +++ b/testdata/uri-options/sdam-options.json @@ -0,0 +1,46 @@ +{ + "tests": [ + { + "description": "serverMonitoringMode=auto", + "uri": "mongodb://example.com/?serverMonitoringMode=auto", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "auto" + } + }, + { + "description": "serverMonitoringMode=stream", + "uri": "mongodb://example.com/?serverMonitoringMode=stream", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "stream" + } + }, + { + "description": "serverMonitoringMode=poll", + "uri": "mongodb://example.com/?serverMonitoringMode=poll", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "poll" + } + }, + { + "description": "invalid serverMonitoringMode", + "uri": "mongodb://example.com/?serverMonitoringMode=invalid", + "valid": true, + "warning": true, + "hosts": null, + "auth": null, + "options": {} + } + ] +} diff --git a/testdata/uri-options/sdam-options.yml b/testdata/uri-options/sdam-options.yml new file mode 100644 index 0000000000..8f72ff4098 --- /dev/null +++ b/testdata/uri-options/sdam-options.yml @@ -0,0 +1,35 @@ +tests: + - description: "serverMonitoringMode=auto" + uri: "mongodb://example.com/?serverMonitoringMode=auto" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "auto" + + - description: "serverMonitoringMode=stream" + uri: "mongodb://example.com/?serverMonitoringMode=stream" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "stream" + + - description: "serverMonitoringMode=poll" + uri: "mongodb://example.com/?serverMonitoringMode=poll" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "poll" + + - description: "invalid serverMonitoringMode" + uri: "mongodb://example.com/?serverMonitoringMode=invalid" + valid: true + warning: true + hosts: ~ + auth: ~ + options: {} diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 983c1dab22..136a4f6730 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -21,6 +21,12 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + ServerMonitoringModeAuto = "auto" + ServerMonitoringModePoll = "poll" + ServerMonitoringModeStream = "stream" +) + var ( // ErrLoadBalancedWithMultipleHosts is returned when loadBalanced=true is // specified in a URI with multiple hosts. @@ -125,6 +131,7 @@ type ConnString struct { MaxStalenessSet bool ReplicaSet string Scheme string + ServerMonitoringMode string ServerSelectionTimeout time.Duration ServerSelectionTimeoutSet bool SocketTimeout time.Duration @@ -621,6 +628,12 @@ func (p *parser) addHost(host string) error { return nil } +func IsValidServerMonitoringMode(mode string) bool { + return mode == ServerMonitoringModeAuto || + mode == ServerMonitoringModeStream || + mode == ServerMonitoringModePoll +} + func (p *parser) addOption(pair string) error { kv := strings.SplitN(pair, "=", 2) if len(kv) != 2 || kv[0] == "" { @@ -823,6 +836,12 @@ func (p *parser) addOption(pair string) error { } p.RetryReadsSet = true + case "servermonitoringmode": + if !IsValidServerMonitoringMode(value) { + return fmt.Errorf("invalid value for %q: %q", key, value) + } + + p.ServerMonitoringMode = value case "serverselectiontimeoutms": n, err := strconv.Atoi(value) if err != nil || n < 0 { diff --git a/x/mongo/driver/connstring/connstring_spec_test.go b/x/mongo/driver/connstring/connstring_spec_test.go index a5f646297c..699ae16bdb 100644 --- a/x/mongo/driver/connstring/connstring_spec_test.go +++ b/x/mongo/driver/connstring/connstring_spec_test.go @@ -286,6 +286,8 @@ func verifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map require.Equal(t, value, float64(cs.ZstdLevel)) case "tlsdisableocspendpointcheck": require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck) + case "servermonitoringmode": + require.Equal(t, value, cs.ServerMonitoringMode) default: opt, ok := cs.UnknownOptions[key] require.True(t, ok) diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 3cfa2d450a..4d20b3239e 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -12,10 +12,10 @@ import ( "os" "runtime" "strconv" - "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/bsonutil" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -31,7 +31,6 @@ import ( // sharded clusters is 512. const maxClientMetadataSize = 512 -const awsLambdaPrefix = "AWS_Lambda_" const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. @@ -125,99 +124,6 @@ func (h *Hello) Result(addr address.Address) description.Server { return description.NewServer(addr, bson.Raw(h.res)) } -const ( - // FaaS environment variable names - envVarAWSExecutionEnv = "AWS_EXECUTION_ENV" - envVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" - envVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" - envVarKService = "K_SERVICE" - envVarFunctionName = "FUNCTION_NAME" - envVarVercel = "VERCEL" -) - -const ( - // FaaS environment variable names - envVarAWSRegion = "AWS_REGION" - envVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - envVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" - envVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" - envVarFunctionRegion = "FUNCTION_REGION" - envVarVercelRegion = "VERCEL_REGION" -) - -const ( - // FaaS environment names used by the client - envNameAWSLambda = "aws.lambda" - envNameAzureFunc = "azure.func" - envNameGCPFunc = "gcp.func" - envNameVercel = "vercel" -) - -// getFaasEnvName parses the FaaS environment variable name and returns the -// corresponding name used by the client. If none of the variables or variables -// for multiple names are populated the client.env value MUST be entirely -// omitted. When variables for multiple "client.env.name" values are present, -// "vercel" takes precedence over "aws.lambda"; any other combination MUST cause -// "client.env" to be entirely omitted. -func getFaasEnvName() string { - envVars := []string{ - envVarAWSExecutionEnv, - envVarAWSLambdaRuntimeAPI, - envVarFunctionsWorkerRuntime, - envVarKService, - envVarFunctionName, - envVarVercel, - } - - // If none of the variables are populated the client.env value MUST be - // entirely omitted. - names := make(map[string]struct{}) - - for _, envVar := range envVars { - val := os.Getenv(envVar) - if val == "" { - continue - } - - var name string - - switch envVar { - case envVarAWSExecutionEnv: - if !strings.HasPrefix(val, awsLambdaPrefix) { - continue - } - - name = envNameAWSLambda - case envVarAWSLambdaRuntimeAPI: - name = envNameAWSLambda - case envVarFunctionsWorkerRuntime: - name = envNameAzureFunc - case envVarKService, envVarFunctionName: - name = envNameGCPFunc - case envVarVercel: - // "vercel" takes precedence over "aws.lambda". - delete(names, envNameAWSLambda) - - name = envNameVercel - } - - names[name] = struct{}{} - if len(names) > 1 { - // If multiple names are populated the client.env value - // MUST be entirely omitted. - names = nil - - break - } - } - - for name := range names { - return name - } - - return "" -} - // appendClientAppName appends the application metadata to the dst. It is the // responsibility of the caller to check that this appending does not cause dst // to exceed any size limitations. @@ -255,7 +161,7 @@ func appendClientEnv(dst []byte, omitNonName, omitDoc bool) ([]byte, error) { return dst, nil } - name := getFaasEnvName() + name := driverutil.GetFaasEnvName() if name == "" { return dst, nil } @@ -307,15 +213,15 @@ func appendClientEnv(dst []byte, omitNonName, omitDoc bool) ([]byte, error) { if !omitNonName { switch name { - case envNameAWSLambda: - dst = addMem(envVarAWSLambdaFunctionMemorySize) - dst = addRegion(envVarAWSRegion) - case envNameGCPFunc: - dst = addMem(envVarFunctionMemoryMB) - dst = addRegion(envVarFunctionRegion) - dst = addTimeout(envVarFunctionTimeoutSec) - case envNameVercel: - dst = addRegion(envVarVercelRegion) + case driverutil.EnvNameAWSLambda: + dst = addMem(driverutil.EnvVarAWSLambdaFunctionMemorySize) + dst = addRegion(driverutil.EnvVarAWSRegion) + case driverutil.EnvNameGCPFunc: + dst = addMem(driverutil.EnvVarFunctionMemoryMB) + dst = addRegion(driverutil.EnvVarFunctionRegion) + dst = addTimeout(driverutil.EnvVarFunctionTimeoutSec) + case driverutil.EnvNameVercel: + dst = addRegion(driverutil.EnvVarVercelRegion) } } diff --git a/x/mongo/driver/operation/hello_test.go b/x/mongo/driver/operation/hello_test.go index 61ba2fde01..69bfc020d4 100644 --- a/x/mongo/driver/operation/hello_test.go +++ b/x/mongo/driver/operation/hello_test.go @@ -13,6 +13,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -54,18 +55,18 @@ func encodeWithCallback(t *testing.T, cb func(int, []byte) ([]byte, error)) bson // ensure that the local environment does not effect the outcome of a unit // test. func clearTestEnv(t *testing.T) { - t.Setenv(envVarAWSExecutionEnv, "") - t.Setenv(envVarAWSLambdaRuntimeAPI, "") - t.Setenv(envVarFunctionsWorkerRuntime, "") - t.Setenv(envVarKService, "") - t.Setenv(envVarFunctionName, "") - t.Setenv(envVarVercel, "") - t.Setenv(envVarAWSRegion, "") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "") - t.Setenv(envVarFunctionMemoryMB, "") - t.Setenv(envVarFunctionTimeoutSec, "") - t.Setenv(envVarFunctionRegion, "") - t.Setenv(envVarVercelRegion, "") + t.Setenv(driverutil.EnvVarAWSExecutionEnv, "") + t.Setenv(driverutil.EnvVarAWSLambdaRuntimeAPI, "") + t.Setenv(driverutil.EnvVarFunctionsWorkerRuntime, "") + t.Setenv(driverutil.EnvVarKService, "") + t.Setenv(driverutil.EnvVarFunctionName, "") + t.Setenv(driverutil.EnvVarVercel, "") + t.Setenv(driverutil.EnvVarAWSRegion, "") + t.Setenv(driverutil.EnvVarAWSLambdaFunctionMemorySize, "") + t.Setenv(driverutil.EnvVarFunctionMemoryMB, "") + t.Setenv(driverutil.EnvVarFunctionTimeoutSec, "") + t.Setenv(driverutil.EnvVarFunctionRegion, "") + t.Setenv(driverutil.EnvVarVercelRegion, "") } func TestAppendClientName(t *testing.T) { @@ -159,32 +160,32 @@ func TestAppendClientEnv(t *testing.T) { { name: "aws only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", }, want: []byte(`{"env":{"name":"aws.lambda"}}`), }, { name: "aws mem only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSLambdaFunctionMemorySize: "1024", }, want: []byte(`{"env":{"name":"aws.lambda","memory_mb":1024}}`), }, { name: "aws region only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSRegion: "us-east-2", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSRegion: "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda","region":"us-east-2"}}`), }, { name: "aws mem and region", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", - envVarAWSRegion: "us-east-2", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSLambdaFunctionMemorySize: "1024", + driverutil.EnvVarAWSRegion: "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda","memory_mb":1024,"region":"us-east-2"}}`), }, @@ -192,50 +193,50 @@ func TestAppendClientEnv(t *testing.T) { name: "aws mem and region with omit fields", omitEnvFields: true, env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", - envVarAWSRegion: "us-east-2", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSLambdaFunctionMemorySize: "1024", + driverutil.EnvVarAWSRegion: "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda"}}`), }, { name: "gcp only", env: map[string]string{ - envVarKService: "servicename", + driverutil.EnvVarKService: "servicename", }, want: []byte(`{"env":{"name":"gcp.func"}}`), }, { name: "gcp mem", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionMemoryMB: "1024", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionMemoryMB: "1024", }, want: []byte(`{"env":{"name":"gcp.func","memory_mb":1024}}`), }, { name: "gcp region", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionRegion: "us-east-2", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionRegion: "us-east-2", }, want: []byte(`{"env":{"name":"gcp.func","region":"us-east-2"}}`), }, { name: "gcp timeout", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionTimeoutSec: "1", }, want: []byte(`{"env":{"name":"gcp.func","timeout_sec":1}}`), }, { name: "gcp mem, region, and timeout", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", - envVarFunctionRegion: "us-east-2", - envVarFunctionMemoryMB: "1024", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionTimeoutSec: "1", + driverutil.EnvVarFunctionRegion: "us-east-2", + driverutil.EnvVarFunctionMemoryMB: "1024", }, want: []byte(`{"env":{"name":"gcp.func","memory_mb":1024,"region":"us-east-2","timeout_sec":1}}`), }, @@ -243,32 +244,32 @@ func TestAppendClientEnv(t *testing.T) { name: "gcp mem, region, and timeout with omit fields", omitEnvFields: true, env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", - envVarFunctionRegion: "us-east-2", - envVarFunctionMemoryMB: "1024", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionTimeoutSec: "1", + driverutil.EnvVarFunctionRegion: "us-east-2", + driverutil.EnvVarFunctionMemoryMB: "1024", }, want: []byte(`{"env":{"name":"gcp.func"}}`), }, { name: "vercel only", env: map[string]string{ - envVarVercel: "1", + driverutil.EnvVarVercel: "1", }, want: []byte(`{"env":{"name":"vercel"}}`), }, { name: "vercel region", env: map[string]string{ - envVarVercel: "1", - envVarVercelRegion: "us-east-2", + driverutil.EnvVarVercel: "1", + driverutil.EnvVarVercelRegion: "us-east-2", }, want: []byte(`{"env":{"name":"vercel","region":"us-east-2"}}`), }, { name: "azure only", env: map[string]string{ - envVarFunctionsWorkerRuntime: "go1.x", + driverutil.EnvVarFunctionsWorkerRuntime: "go1.x", }, want: []byte(`{"env":{"name":"azure.func"}}`), }, @@ -405,9 +406,9 @@ func TestEncodeClientMetadata(t *testing.T) { } // Set environment variables to add `env` field to handshake. - t.Setenv(envVarAWSLambdaRuntimeAPI, "lambda") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "123") - t.Setenv(envVarAWSRegion, "us-east-2") + t.Setenv(driverutil.EnvVarAWSLambdaRuntimeAPI, "lambda") + t.Setenv(driverutil.EnvVarAWSLambdaFunctionMemorySize, "123") + t.Setenv(driverutil.EnvVarAWSRegion, "us-east-2") t.Run("nothing is omitted", func(t *testing.T) { got, err := encodeClientMetadata("foo", maxClientMetadataSize) @@ -418,7 +419,7 @@ func TestEncodeClientMetadata(t *testing.T) { Driver: &driver{Name: driverName, Version: version.Driver}, OS: &dist{Type: runtime.GOOS, Architecture: runtime.GOARCH}, Platform: runtime.Version(), - Env: &env{Name: envNameAWSLambda, MemoryMB: 123, Region: "us-east-2"}, + Env: &env{Name: driverutil.EnvNameAWSLambda, MemoryMB: 123, Region: "us-east-2"}, }) assertDocsEqual(t, got, want) @@ -437,7 +438,7 @@ func TestEncodeClientMetadata(t *testing.T) { Driver: &driver{Name: driverName, Version: version.Driver}, OS: &dist{Type: runtime.GOOS, Architecture: runtime.GOARCH}, Platform: runtime.Version(), - Env: &env{Name: envNameAWSLambda}, + Env: &env{Name: driverutil.EnvNameAWSLambda}, }) assertDocsEqual(t, got, want) @@ -453,7 +454,7 @@ func TestEncodeClientMetadata(t *testing.T) { require.NoError(t, err, "error constructing env template: %v", err) // Calculate what the env.name costs. - ndst := bsoncore.AppendStringElement(nil, "name", envNameAWSLambda) + ndst := bsoncore.AppendStringElement(nil, "name", driverutil.EnvNameAWSLambda) // Environment sub name. envSubName := len(edst) - len(ndst) @@ -466,7 +467,7 @@ func TestEncodeClientMetadata(t *testing.T) { Driver: &driver{Name: driverName, Version: version.Driver}, OS: &dist{Type: runtime.GOOS}, Platform: runtime.Version(), - Env: &env{Name: envNameAWSLambda}, + Env: &env{Name: driverutil.EnvNameAWSLambda}, }) assertDocsEqual(t, got, want) @@ -552,38 +553,38 @@ func TestParseFaasEnvName(t *testing.T) { { name: "one aws", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", }, - want: envNameAWSLambda, + want: driverutil.EnvNameAWSLambda, }, { name: "both aws options", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaRuntimeAPI: "hello", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarAWSLambdaRuntimeAPI: "hello", }, - want: envNameAWSLambda, + want: driverutil.EnvNameAWSLambda, }, { name: "multiple variables", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarFunctionsWorkerRuntime: "hello", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarFunctionsWorkerRuntime: "hello", }, want: "", }, { name: "vercel and aws lambda", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarVercel: "hello", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_foo", + driverutil.EnvVarVercel: "hello", }, - want: envNameVercel, + want: driverutil.EnvNameVercel, }, { name: "invalid aws prefix", env: map[string]string{ - envVarAWSExecutionEnv: "foo", + driverutil.EnvVarAWSExecutionEnv: "foo", }, want: "", }, @@ -623,14 +624,14 @@ func BenchmarkClientMetadtaLargeEnv(b *testing.B) { b.ReportAllocs() b.ResetTimer() - b.Setenv(envNameAWSLambda, "foo") + b.Setenv(driverutil.EnvNameAWSLambda, "foo") str := "" for i := 0; i < 512; i++ { str += "a" } - b.Setenv(envVarAWSLambdaRuntimeAPI, str) + b.Setenv(driverutil.EnvVarAWSLambdaRuntimeAPI, str) b.RunParallel(func(pb *testing.PB) { for pb.Next() { diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 88b93b15e6..85dcf589d4 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -17,10 +17,12 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -131,7 +133,12 @@ type updateTopologyCallback func(description.Server) description.Server // ConnectServer creates a new Server and then initializes it using the // Connect method. -func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) { +func ConnectServer( + addr address.Address, + updateCallback updateTopologyCallback, + topologyID primitive.ObjectID, + opts ...ServerOption, +) (*Server, error) { srvr := NewServer(addr, topologyID, opts...) err := srvr.Connect(updateCallback) if err != nil { @@ -785,10 +792,29 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { return operation. NewHello(). ClusterClock(s.cfg.clock). - Deployment(driver.SingleConnectionDeployment{conn}). + Deployment(driver.SingleConnectionDeployment{C: conn}). ServerAPI(s.cfg.serverAPI) } +// isStreamable returns whether or not we can use the streaming protocol to +// optimally reduces the time it takes for a client to discover server state +// changes. Streaming must be disabled if any of the following are true: +// +// - the client is configured with serverMonitoringMode=poll [P], or +// - the client is configured with serverMonitoringMode=auto [A] and a FaaS +// platform is detected [F], or +// - the server does not support streaming (eg MongoDB <4.4) [S]. +// +// I.e, streaming must be disabled if: P ∨ (A ∧ F) ∨ (~S) ≡ ~P ∧ (~A ∨ ~F) ∧ S +func isStreamable(previousDesc description.Server, srv *Server) bool { + srvMonitoringMode := srv.cfg.serverMonitoringMode + faas := driverutil.GetFaasEnvName() + + return srvMonitoringMode != connstring.ServerMonitoringModePoll && // P + (srvMonitoringMode != connstring.ServerMonitoringModeAuto || faas == "") && // (~A ∨ ~F) + previousDesc.TopologyVersion != nil // S +} + func (s *Server) check() (description.Server, error) { var descPtr *description.Server var err error @@ -824,7 +850,7 @@ func (s *Server) check() (description.Server, error) { heartbeatConn := initConnection{s.conn} baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() - streamable := previousDescription.TopologyVersion != nil + streamable := isStreamable(previousDescription, s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) switch { diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 4272b3f751..4504a25355 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -14,23 +14,25 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { - clock *session.ClusterClock - compressionOpts []string - connectionOpts []ConnectionOption - appname string - heartbeatInterval time.Duration - heartbeatTimeout time.Duration - serverMonitor *event.ServerMonitor - registry *bsoncodec.Registry - monitoringDisabled bool - serverAPI *driver.ServerAPIOptions - loadBalanced bool + clock *session.ClusterClock + compressionOpts []string + connectionOpts []ConnectionOption + appname string + heartbeatInterval time.Duration + heartbeatTimeout time.Duration + serverMonitoringMode string + serverMonitor *event.ServerMonitor + registry *bsoncodec.Registry + monitoringDisabled bool + serverAPI *driver.ServerAPIOptions + loadBalanced bool // Connection pool options. maxConns uint64 @@ -202,3 +204,17 @@ func withLogger(fn func() *logger.Logger) ServerOption { cfg.logger = fn() } } + +// withServerMonitoringMode configures the mode (stream, poll, or auto) to use +// for monitoring. +func withServerMonitoringMode(mode *string) ServerOption { + return func(cfg *serverConfig) { + if mode != nil { + cfg.serverMonitoringMode = *mode + + return + } + + cfg.serverMonitoringMode = connstring.ServerMonitoringModeAuto + } +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 8deb614815..7858643dfd 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -362,6 +362,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, serverOpts = append( serverOpts, withLogger(func() *logger.Logger { return lgr }), + withServerMonitoringMode(co.ServerMonitoringMode), ) cfgp.logger = lgr From 2a886726cb4fb72aabae4dea493b0375a9fc4ac0 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 5 Sep 2023 20:32:49 -0600 Subject: [PATCH 02/22] GODRIVER-2972 Fix wiremessage RequestID race in operation.Execute --- x/mongo/driver/operation.go | 23 ++++++++++++++--------- x/mongo/driver/operation_test.go | 6 +++--- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 90573daa53..8e52773503 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -322,7 +322,7 @@ func (op Operation) shouldEncrypt() bool { } // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context) (Server, error) { +func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -340,14 +340,14 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) { } ctx = logger.WithOperationName(ctx, op.Name) - ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) + ctx = logger.WithOperationID(ctx, requestID) return op.Deployment.SelectServer(ctx, selector) } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) { - server, err := op.selectServer(ctx) +func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) { + server, err := op.selectServer(ctx, requestID) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -530,11 +530,11 @@ func (op Operation) Execute(ctx context.Context) error { } }() for { - wiremessage.NextRequestID() + requestID := wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx) + srvr, conn, err = op.getServerAndConnection(ctx, requestID) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server @@ -629,7 +629,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn) + *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -1103,8 +1103,13 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } -func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, +func (op Operation) createMsgWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, conn Connection, + requestID int32, ) ([]byte, startedInformation, error) { var info startedInformation var flags wiremessage.MsgFlag @@ -1120,7 +1125,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, flags |= wiremessage.ExhaustAllowed } - info.requestID = wiremessage.CurrentRequestID() + info.requestID = requestID wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index d4c5a1b6a0..20a5c2066d 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -62,7 +62,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -76,7 +76,7 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) noerr(t, err) got := d.params.selector if !cmp.Equal(got, want) { @@ -90,7 +90,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") From 7a61bdb482f2c0f6e9090d7eaf0522024b085104 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:09:16 -0600 Subject: [PATCH 03/22] GODRIVER-2810 Correct expected order for test assertion --- mongo/integration/handshake_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongo/integration/handshake_test.go b/mongo/integration/handshake_test.go index 3e9989158e..8d706062d2 100644 --- a/mongo/integration/handshake_test.go +++ b/mongo/integration/handshake_test.go @@ -193,7 +193,7 @@ func TestHandshakeProse(t *testing.T) { hello = "hello" } - assert.Equal(mt, pair.CommandName, hello, "expected and actual command name at index %d are different", idx) + assert.Equal(mt, hello, pair.CommandName, "expected and actual command name at index %d are different", idx) sent := pair.Sent From d5a6c9991323c27b97021b27bc11e62ad9a24343 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:28:41 -0600 Subject: [PATCH 04/22] GODRIVER-2810 Fix hello test faas getter reference --- x/mongo/driver/operation/hello_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/mongo/driver/operation/hello_test.go b/x/mongo/driver/operation/hello_test.go index 69bfc020d4..3def602086 100644 --- a/x/mongo/driver/operation/hello_test.go +++ b/x/mongo/driver/operation/hello_test.go @@ -598,7 +598,7 @@ func TestParseFaasEnvName(t *testing.T) { t.Setenv(key, value) } - got := getFaasEnvName() + got := driverutil.GetFaasEnvName() if got != test.want { t.Errorf("parseFaasEnvName(%s) = %s, want %s", test.name, got, test.want) } From 8277d886d4f16c203da6a1e301127635b106cb82 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 12:04:28 -0600 Subject: [PATCH 05/22] GODRIVER-2935 Use OP_QUERY in connection handshakes --- x/mongo/driver/operation.go | 111 ++++++++++++++++++++++++- x/mongo/driver/topology/server_test.go | 39 ++++++++- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 90573daa53..5a26d17ef4 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -19,6 +19,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/csot" @@ -629,7 +630,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn) + *wm, startedInfo, err = op.createWireMessage(ctx, (*wm)[:0], desc, maxTimeMS, conn) if err != nil { return err @@ -1103,6 +1104,85 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } +func (op Operation) createLegacyHandshakeWireMessage( + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, +) ([]byte, startedInformation, error) { + var info startedInformation + flags := op.secondaryOK(desc) + var wmindex int32 + info.requestID = wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) + dst = wiremessage.AppendQueryFlags(dst, flags) + + dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} + + // FullCollectionName + dst = append(dst, op.Database...) + dst = append(dst, dollarCmd[:]...) + dst = append(dst, 0x00) + dst = wiremessage.AppendQueryNumberToSkip(dst, 0) + dst = wiremessage.AppendQueryNumberToReturn(dst, -1) + + wrapper := int32(-1) + rp, err := op.createReadPref(desc, true) + if err != nil { + return dst, info, err + } + if len(rp) > 0 { + wrapper, dst = bsoncore.AppendDocumentStart(dst) + dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") + } + idx, dst := bsoncore.AppendDocumentStart(dst) + dst, err = op.CommandFn(dst, desc) + if err != nil { + return dst, info, err + } + + if op.Batches != nil && len(op.Batches.Current) > 0 { + dst = op.addBatchArray(dst) + } + + dst, err = op.addReadConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addWriteConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addSession(dst, desc) + if err != nil { + return dst, info, err + } + + dst = op.addClusterTime(dst, desc) + dst = op.addServerAPI(dst) + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + } + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + // Command monitoring only reports the document inside $query + info.cmd = dst[idx:] + + if len(rp) > 0 { + var err error + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + if err != nil { + return dst, info, err + } + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil +} + func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, conn Connection, ) ([]byte, startedInformation, error) { @@ -1186,6 +1266,33 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } +// isLegacyHandshake returns "true" if the operation is the first message of +// the initial handshake and should use a legacy hello. The requirement for +// using a legacy hello as defined by the specifications is as follows: +// +// > If server API version is not requested and loadBalanced: False, drivers +// > MUST use legacy hello for the first message of the initial handshake with +// > the OP_QUERY protocol +func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { + isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0 + + return desc.Kind != description.LoadBalanced && op.ServerAPI == nil && isInitialHandshake +} + +func (op Operation) createWireMessage( + ctx context.Context, + dst []byte, + desc description.SelectedServer, + maxTimeMS uint64, + conn Connection, +) ([]byte, startedInformation, error) { + if isLegacyHandshake(op, desc) { + return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + } + + return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn) +} + // addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document // has already been added and does not add the final 0 byte. func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { @@ -1830,6 +1937,8 @@ func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInfor logger.KeyFailure, formattedReply)...) } + //fmt.Println("->", redactFinishedInformationResponse(info)) + // If the finished event cannot be published, return early. if !op.canPublishFinishedEvent(info) { return diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index a2abd1fb1f..ba92b6dd94 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -31,9 +31,11 @@ import ( "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) type channelNetConnDialer struct{} @@ -1207,12 +1209,41 @@ func TestServer_ProcessError(t *testing.T) { func includesClientMetadata(t *testing.T, wm []byte) bool { t.Helper() - doc, err := drivertest.GetCommandFromMsgWireMessage(wm) - assert.NoError(t, err) + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + t.Fatal("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + t.Fatal("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + t.Fatal("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + t.Fatal("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + t.Fatal("could not read numberToReturn") + } + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + t.Fatal("could not read query") + } - _, err = doc.LookupErr("client") + if _, err := query.LookupErr("client"); err == nil { + return true + } + if _, err := query.LookupErr("$query", "client"); err == nil { + return true + } - return err == nil + return false } // processErrorTestConn is a driver.Connection implementation used by tests From 728f27729f77acc92a130a935d164bab04604d66 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 12:38:15 -0600 Subject: [PATCH 06/22] GODRIVER-2935 Update sent_message logic to include OP_QUERY for hello --- mongo/integration/mtest/sent_message.go | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/mongo/integration/mtest/sent_message.go b/mongo/integration/mtest/sent_message.go index d36075bf81..6b96e061bc 100644 --- a/mongo/integration/mtest/sent_message.go +++ b/mongo/integration/mtest/sent_message.go @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error) func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { switch opcode { + case wiremessage.OpQuery: + return parseOpQuery, true case wiremessage.OpMsg: return parseSentOpMsg, true case wiremessage.OpCompressed: @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { } } +func parseOpQuery(wm []byte) (*SentMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok { + return nil, errors.New("failed to read query flags") + } + if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok { + return nil, errors.New("failed to read full collection name") + } + if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok { + return nil, errors.New("failed to read number to skip") + } + if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok { + return nil, errors.New("failed to read number to return") + } + + query, wm, ok := wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("failed to read query") + } + + // If there is no read preference document, the command document is query. + // Otherwise, query is in the format {$query: , $readPreference: }. + commandDoc := query + var rpDoc bsoncore.Document + + dollarQueryVal, err := query.LookupErr("$query") + if err == nil { + commandDoc = dollarQueryVal.Document() + + rpVal, err := query.LookupErr("$readPreference") + if err != nil { + return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query) + } + rpDoc = rpVal.Document() + } + + // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command + // document. Pull these sequences out into an ArrayStyle DocumentSequence. + var docSequence *bsoncore.DocumentSequence + cmdElems, _ := commandDoc.Elements() + for _, elem := range cmdElems { + switch elem.Key() { + case "documents", "updates", "deletes": + docSequence = &bsoncore.DocumentSequence{ + Style: bsoncore.ArrayStyle, + Data: elem.Value().Array(), + } + } + if docSequence != nil { + // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. + break + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + DocumentSequence: docSequence, + } + return sm, nil +} + func parseSentMessage(wm []byte) (*SentMessage, error) { // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. _, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm) From 9fec743e3f867071ad82ef40f267d5ab82b678cd Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:16:01 -0600 Subject: [PATCH 07/22] GODRIVER-2935 Add LegacyHandshake operation logic --- x/mongo/driver/legacy.go | 1 + x/mongo/driver/operation.go | 13 +++---------- x/mongo/driver/operation/hello.go | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/x/mongo/driver/legacy.go b/x/mongo/driver/legacy.go index 9f3b8a39ac..c40f1f8091 100644 --- a/x/mongo/driver/legacy.go +++ b/x/mongo/driver/legacy.go @@ -19,4 +19,5 @@ const ( LegacyKillCursors LegacyListCollections LegacyListIndexes + LegacyHandshake ) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 5a26d17ef4..7e79ca485f 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1266,17 +1266,12 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } -// isLegacyHandshake returns "true" if the operation is the first message of -// the initial handshake and should use a legacy hello. The requirement for -// using a legacy hello as defined by the specifications is as follows: -// -// > If server API version is not requested and loadBalanced: False, drivers -// > MUST use legacy hello for the first message of the initial handshake with -// > the OP_QUERY protocol +// isLegacyHandshake returns True if the operation is the first message of +// the initial handshake and should use a legacy hello. func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0 - return desc.Kind != description.LoadBalanced && op.ServerAPI == nil && isInitialHandshake + return op.Legacy == LegacyHandshake && isInitialHandshake } func (op Operation) createWireMessage( @@ -1937,8 +1932,6 @@ func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInfor logger.KeyFailure, formattedReply)...) } - //fmt.Println("->", redactFinishedInformationResponse(info)) - // If the finished event cannot be published, return early. if !op.canPublishFinishedEvent(info) { return diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 3cfa2d450a..f9200510e1 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -537,8 +537,16 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti return h.createOperation().ExecuteExhaust(ctx, conn) } +// isLegacyHandshake returns True if server API version is not requested and +// loadBalanced is False. If this is the case, then the drivers MUST use legacy +// hello for the first message of the initial handshake with the OP_QUERY +// protocol +func isLegacyHandshake(h *Hello) bool { + return h.serverAPI == nil && h.d.Kind() != description.LoadBalanced +} + func (h *Hello) createOperation() driver.Operation { - return driver.Operation{ + op := driver.Operation{ Clock: h.clock, CommandFn: h.command, Database: "admin", @@ -549,6 +557,12 @@ func (h *Hello) createOperation() driver.Operation { }, ServerAPI: h.serverAPI, } + + if isLegacyHandshake(h) { + op.Legacy = driver.LegacyHandshake + } + + return op } // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant From fcc4db7d039a1cfd89e5585322d21b0dee62b7d6 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:17:22 -0600 Subject: [PATCH 08/22] GODRIVER-2810 Clean up handshake prose test --- internal/driverutil/hello.go | 43 +++++++--- mongo/integration/handshake_test.go | 103 ++++++++++-------------- mongo/options/clientoptions.go | 36 +++++---- x/mongo/driver/connstring/connstring.go | 20 ++++- x/mongo/driver/topology/server.go | 3 +- 5 files changed, 113 insertions(+), 92 deletions(-) diff --git a/internal/driverutil/hello.go b/internal/driverutil/hello.go index 782f1f31d7..b252d43b1f 100644 --- a/internal/driverutil/hello.go +++ b/internal/driverutil/hello.go @@ -9,30 +9,49 @@ const AwsLambdaPrefix = "AWS_Lambda_" const ( // FaaS environment variable names - EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" - EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + + // EnvVarAWSExecutionEnv is the AWS Execution environment variable. + EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" + // EnvVarAWSLambdaRuntimeAPI is the AWS Lambda runtime API variable. + EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + // EnvVarFunctionsWorkerRuntime is the functions worker runtime variable. EnvVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" - EnvVarKService = "K_SERVICE" - EnvVarFunctionName = "FUNCTION_NAME" - EnvVarVercel = "VERCEL" + // EnvVarKService is the K Service variable. + EnvVarKService = "K_SERVICE" + // EnvVarFunctionName is the function name variable. + EnvVarFunctionName = "FUNCTION_NAME" + // EnvVarVercel is the Vercel variable. + EnvVarVercel = "VERCEL" ) const ( // FaaS environment variable names - EnvVarAWSRegion = "AWS_REGION" + + // EnvVarAWSRegion is the AWS region variable. + EnvVarAWSRegion = "AWS_REGION" + // EnvVarAWSLambdaFunctionMemorySize is the AWS Lambda function memory size variable. EnvVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" - EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" - EnvVarFunctionRegion = "FUNCTION_REGION" - EnvVarVercelRegion = "VERCEL_REGION" + // EnvVarFunctionMemoryMB is the function memory in megabytes variable. + EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" + // EnvVarFunctionTimeoutSec is the function timeout in seconds variable. + EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" + // EnvVarFunctionRegion is the function region variable. + EnvVarFunctionRegion = "FUNCTION_REGION" + // EnvVarVercelRegion is the Vercel region variable. + EnvVarVercelRegion = "VERCEL_REGION" ) const ( // FaaS environment names used by the client + + // EnvNameAWSLambda is the AWS Lambda environment name. EnvNameAWSLambda = "aws.lambda" + // EnvNameAzureFunc is the Azure Function environment name. EnvNameAzureFunc = "azure.func" - EnvNameGCPFunc = "gcp.func" - EnvNameVercel = "vercel" + // EnvNameGCPFunc is the Google Cloud Function environment name. + EnvNameGCPFunc = "gcp.func" + // EnvNameVercel is the Vercel environment name. + EnvNameVercel = "vercel" ) // GetFaasEnvName parses the FaaS environment variable name and returns the diff --git a/mongo/integration/handshake_test.go b/mongo/integration/handshake_test.go index 8d706062d2..bb508bd900 100644 --- a/mongo/integration/handshake_test.go +++ b/mongo/integration/handshake_test.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/integration/mtest" @@ -51,31 +52,18 @@ func TestHandshakeProse(t *testing.T) { return elems } - const ( - envVarAWSExecutionEnv = "AWS_EXECUTION_ENV" - envVarAWSRegion = "AWS_REGION" - envVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - envVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" - envVarKService = "K_SERVICE" - envVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" - envVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" - envVarFunctionRegion = "FUNCTION_REGION" - envVarVercel = "VERCEL" - envVarVercelRegion = "VERCEL_REGION" - ) - // Reset the environment variables to avoid environment namespace // collision. - t.Setenv(envVarAWSExecutionEnv, "") - t.Setenv(envVarFunctionsWorkerRuntime, "") - t.Setenv(envVarKService, "") - t.Setenv(envVarVercel, "") - t.Setenv(envVarAWSRegion, "") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "") - t.Setenv(envVarFunctionMemoryMB, "") - t.Setenv(envVarFunctionTimeoutSec, "") - t.Setenv(envVarFunctionRegion, "") - t.Setenv(envVarVercelRegion, "") + t.Setenv(driverutil.EnvVarAWSExecutionEnv, "") + t.Setenv(driverutil.EnvVarFunctionsWorkerRuntime, "") + t.Setenv(driverutil.EnvVarKService, "") + t.Setenv(driverutil.EnvVarVercel, "") + t.Setenv(driverutil.EnvVarAWSRegion, "") + t.Setenv(driverutil.EnvVarAWSLambdaFunctionMemorySize, "") + t.Setenv(driverutil.EnvVarFunctionMemoryMB, "") + t.Setenv(driverutil.EnvVarFunctionTimeoutSec, "") + t.Setenv(driverutil.EnvVarFunctionRegion, "") + t.Setenv(driverutil.EnvVarVercelRegion, "") for _, test := range []struct { name string @@ -85,9 +73,9 @@ func TestHandshakeProse(t *testing.T) { { name: "1. valid AWS", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSRegion: "us-east-2", - envVarAWSLambdaFunctionMemorySize: "1024", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_java8", + driverutil.EnvVarAWSRegion: "us-east-2", + driverutil.EnvVarAWSLambdaFunctionMemorySize: "1024", }, want: clientMetadata(bson.D{ {Key: "name", Value: "aws.lambda"}, @@ -98,7 +86,7 @@ func TestHandshakeProse(t *testing.T) { { name: "2. valid Azure", env: map[string]string{ - envVarFunctionsWorkerRuntime: "node", + driverutil.EnvVarFunctionsWorkerRuntime: "node", }, want: clientMetadata(bson.D{ {Key: "name", Value: "azure.func"}, @@ -107,10 +95,10 @@ func TestHandshakeProse(t *testing.T) { { name: "3. valid GCP", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionMemoryMB: "1024", - envVarFunctionTimeoutSec: "60", - envVarFunctionRegion: "us-central1", + driverutil.EnvVarKService: "servicename", + driverutil.EnvVarFunctionMemoryMB: "1024", + driverutil.EnvVarFunctionTimeoutSec: "60", + driverutil.EnvVarFunctionRegion: "us-central1", }, want: clientMetadata(bson.D{ {Key: "name", Value: "gcp.func"}, @@ -122,8 +110,8 @@ func TestHandshakeProse(t *testing.T) { { name: "4. valid Vercel", env: map[string]string{ - envVarVercel: "1", - envVarVercelRegion: "cdg1", + driverutil.EnvVarVercel: "1", + driverutil.EnvVarVercelRegion: "cdg1", }, want: clientMetadata(bson.D{ {Key: "name", Value: "vercel"}, @@ -133,16 +121,16 @@ func TestHandshakeProse(t *testing.T) { { name: "5. invalid multiple providers", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarFunctionsWorkerRuntime: "node", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_java8", + driverutil.EnvVarFunctionsWorkerRuntime: "node", }, want: clientMetadata(nil), }, { name: "6. invalid long string", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSRegion: func() string { + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_java8", + driverutil.EnvVarAWSRegion: func() string { var s string for i := 0; i < 512; i++ { s += "a" @@ -157,8 +145,8 @@ func TestHandshakeProse(t *testing.T) { { name: "7. invalid wrong types", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSLambdaFunctionMemorySize: "big", + driverutil.EnvVarAWSExecutionEnv: "AWS_Lambda_java8", + driverutil.EnvVarAWSLambdaFunctionMemorySize: "big", }, want: clientMetadata(bson.D{ {Key: "name", Value: "aws.lambda"}, @@ -167,7 +155,7 @@ func TestHandshakeProse(t *testing.T) { { name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"", env: map[string]string{ - envVarAWSExecutionEnv: "EC2", + driverutil.EnvVarAWSExecutionEnv: "EC2", }, want: clientMetadata(nil), }, @@ -184,32 +172,27 @@ func TestHandshakeProse(t *testing.T) { require.NoError(mt, err, "Ping error: %v", err) messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] - // First two messages are handshake messages - for idx, pair := range messages[:2] { - hello := handshake.LegacyHello - // Expect "hello" command name with API version. - if os.Getenv("REQUIRE_API_VERSION") == "true" { - hello = "hello" - } - - assert.Equal(mt, hello, pair.CommandName, "expected and actual command name at index %d are different", idx) + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } - sent := pair.Sent + assert.Equal(mt, hello, handshakeMessage.CommandName) - // Lookup the "client" field in the command document. - clientVal, err := sent.Command.LookupErr("client") - require.NoError(mt, err, "expected command %s at index %d to contain client field", sent.Command, idx) + // Lookup the "client" field in the command document. + clientVal, err := handshakeMessage.Sent.Command.LookupErr("client") + require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command) - got, ok := clientVal.DocumentOK() - require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) + got, ok := clientVal.DocumentOK() + require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) - wantBytes, err := bson.Marshal(test.want) - require.NoError(mt, err, "error marshaling want document: %v", err) + wantBytes, err := bson.Marshal(test.want) + require.NoError(mt, err, "error marshaling want document: %v", err) - want := bsoncore.Document(wantBytes) - assert.Equal(mt, want, got, "want: %v, got: %v", want, got) - } + want := bsoncore.Document(wantBytes) + assert.Equal(mt, want, got, "want: %v, got: %v", want, got) }) } } diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 21b5395183..b8ac83efce 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -34,8 +34,22 @@ import ( ) const ( - ServerMonitoringModeAuto = connstring.ServerMonitoringModeAuto - ServerMonitoringModePoll = connstring.ServerMonitoringModePoll + // ServerMonitoringModeAuto indicates that the client will behave like "poll" + // mode when running on a FaaS (Function as a Service) platform, or like + // "stream" mode otherwise. The client detects its execution environment by + // following the rules for generating the "client.env" handshake metadata field + // as specified in the MongoDB Handshake specification. This is the default + // mode. + ServerMonitoringModeAuto = connstring.ServerMonitoringModeAuto + + // ServerMonitoringModePoll indicates that the client will periodically check + // the server using a hello or legacy hello command and then sleep for + // heartbeatFrequencyMS milliseconds before running another check. + ServerMonitoringModePoll = connstring.ServerMonitoringModePoll + + // ServerMonitoringModeStream indicates that the client will use a streaming + // protocol when the server supports it. The streaming protocol optimally + // reduces the time it takes for a client to discover server state changes. ServerMonitoringModeStream = connstring.ServerMonitoringModeStream ) @@ -957,20 +971,10 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio return c } -// SetServerMonitoringMode specifies the server monitoring protocol to use. -// -// Valid modes are: -// - "stream": The client will use a streaming protocol when the server -// supports it. The streaming protocol optimally reduces the time it takes -// for a client to discover server state changes. -// - "poll": The client will periodically check the server using a hello or -// legacy hello command and then sleep for heartbeatFrequencyMS milliseconds -// before running another check. -// - "auto": The client will behave like "poll" mode when running on a FaaS -// (Function as a Service) platform, or like "stream" mode otherwise. The -// client detects its execution environment by following the rules for -// generating the "client.env" handshake metadata field as specified in the -// MongoDB Handshake specification. This is the deafult mode. +// SetServerMonitoringMode specifies the server monitoring protocol to use. See +// the helper constants ServerMonitoringModeAuto, ServerMonitoringModePoll, and +// ServerMonitoringModeStream for more information about valid server +// monitoring modes. func (c *ClientOptions) SetServerMonitoringMode(mode string) *ClientOptions { fmt.Println("mode: ", mode) c.ServerMonitoringMode = &mode diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 136a4f6730..cd43136471 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -22,8 +22,22 @@ import ( ) const ( - ServerMonitoringModeAuto = "auto" - ServerMonitoringModePoll = "poll" + // ServerMonitoringModeAuto indicates that the client will behave like "poll" + // mode when running on a FaaS (Function as a Service) platform, or like + // "stream" mode otherwise. The client detects its execution environment by + // following the rules for generating the "client.env" handshake metadata field + // as specified in the MongoDB Handshake specification. This is the default + // mode. + ServerMonitoringModeAuto = "auto" + + // ServerMonitoringModePoll indicates that the client will periodically check + // the server using a hello or legacy hello command and then sleep for + // heartbeatFrequencyMS milliseconds before running another check. + ServerMonitoringModePoll = "poll" + + // ServerMonitoringModeStream indicates that the client will use a streaming + // protocol when the server supports it. The streaming protocol optimally + // reduces the time it takes for a client to discover server state changes. ServerMonitoringModeStream = "stream" ) @@ -628,6 +642,8 @@ func (p *parser) addHost(host string) error { return nil } +// IsValidServerMonitoringMode will return true if the given string matches a +// valid server monitoring mode. func IsValidServerMonitoringMode(mode string) bool { return mode == ServerMonitoringModeAuto || mode == ServerMonitoringModeStream || diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 5451e2bebb..ee0260290b 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -1,4 +1,4 @@ -// Copyright (C) MongoDB, Inc. 2017-present. +/// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain @@ -807,7 +807,6 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { // // I.e, streaming must be disabled if: P ∨ (A ∧ F) ∨ (~S) ≡ ~P ∧ (~A ∨ ~F) ∧ S func isStreamable(previousDesc description.Server, srv *Server) bool { - return previousDesc.TopologyVersion != nil srvMonitoringMode := srv.cfg.serverMonitoringMode faas := driverutil.GetFaasEnvName() From fe237bb0079574fb76c859332218d7c49921210d Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:18:30 -0600 Subject: [PATCH 09/22] GODRIVER-2810 Remove debugging tools --- mongo/options/clientoptions.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index b8ac83efce..85c969d5c0 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -976,7 +976,6 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio // ServerMonitoringModeStream for more information about valid server // monitoring modes. func (c *ClientOptions) SetServerMonitoringMode(mode string) *ClientOptions { - fmt.Println("mode: ", mode) c.ServerMonitoringMode = &mode return c From c343f1258bb5e8b425be5a9cdefa565e6915ba36 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Wed, 6 Sep 2023 18:52:46 -0600 Subject: [PATCH 10/22] GODRIVER-2935 Extend legacy check to GetHandshakeInformation --- x/mongo/driver/operation/hello.go | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index f9200510e1..16d5809130 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -541,8 +541,8 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti // loadBalanced is False. If this is the case, then the drivers MUST use legacy // hello for the first message of the initial handshake with the OP_QUERY // protocol -func isLegacyHandshake(h *Hello) bool { - return h.serverAPI == nil && h.d.Kind() != description.LoadBalanced +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { + return srvAPI == nil && deployment.Kind() != description.LoadBalanced } func (h *Hello) createOperation() driver.Operation { @@ -558,7 +558,7 @@ func (h *Hello) createOperation() driver.Operation { ServerAPI: h.serverAPI, } - if isLegacyHandshake(h) { + if isLegacyHandshake(h.serverAPI, h.d) { op.Legacy = driver.LegacyHandshake } @@ -568,18 +568,25 @@ func (h *Hello) createOperation() driver.Operation { // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant // information about the server. This function implements the driver.Handshaker interface. func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { - err := driver.Operation{ + deployment := driver.SingleConnectionDeployment{C: c} + + op := driver.Operation{ Clock: h.clock, CommandFn: h.handshakeCommand, - Deployment: driver.SingleConnectionDeployment{C: c}, + Deployment: deployment, Database: "admin", ProcessResponseFn: func(info driver.ResponseInfo) error { h.res = info.ServerResponse return nil }, ServerAPI: h.serverAPI, - }.Execute(ctx) - if err != nil { + } + + if isLegacyHandshake(h.serverAPI, deployment) { + op.Legacy = driver.LegacyHandshake + } + + if err := op.Execute(ctx); err != nil { return driver.HandshakeInformation{}, err } @@ -592,6 +599,9 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok { info.ServerConnectionID = &serverConnectionID } + + var err error + // Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the // StringSliceFromRawValue call. if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil { From 25b76854724e947a013f160dcb3db88e9ccd0da2 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 7 Sep 2023 09:00:54 -0600 Subject: [PATCH 11/22] GODRIVER-2935 Add legacy tests back to auth and client --- mongo/integration/client_test.go | 4 +-- x/mongo/driver/auth/speculative_scram_test.go | 4 +-- x/mongo/driver/auth/speculative_x509_test.go | 4 +-- x/mongo/driver/drivertest/channel_conn.go | 32 +++++++++++++++++++ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 914ca863b7..007427824b 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -733,8 +733,8 @@ func TestClient(t *testing.T) { pair := msgPairs[0] assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, + "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire // version is now known to be >= 6. diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index f2234e227c..a159891adc 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) { // Assert that the driver sent hello with the speculative authentication message. assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d", len(tc.payloads), (conn.Written)) - helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, helloCmd, handshake.LegacyHello) @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 13fdf2b185..cf46de6ffd 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index d2ae8df248..27be4c264d 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte { return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) } +// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. +func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) { + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + return nil, errors.New("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + return nil, errors.New("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + return nil, errors.New("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + return nil, errors.New("could not read numberToReturn") + } + + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("could not read query") + } + return query, nil +} + // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message. func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) { var ok bool From 3672bb405b9e3f519286a92dda8c93361908e389 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:04:35 -0600 Subject: [PATCH 12/22] GODRIVER-2810 Add licenses --- internal/driverutil/hello.go | 6 ++++++ x/mongo/driver/topology/server.go | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/internal/driverutil/hello.go b/internal/driverutil/hello.go index b252d43b1f..25e684c2c3 100644 --- a/internal/driverutil/hello.go +++ b/internal/driverutil/hello.go @@ -1,3 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + package driverutil import ( diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index ee0260290b..737ce82459 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -1,3 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + /// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may From b8cba730f0cd2c958141e24a618c8e31bdf64c02 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:35:02 -0600 Subject: [PATCH 13/22] GODRIVER-2810 Cleanup code --- mongo/options/clientoptions_test.go | 4 ++-- x/mongo/driver/operation.go | 3 ++- x/mongo/driver/topology/server.go | 6 ------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index e729f5f569..abd42ae068 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -763,8 +763,6 @@ func TestClientOptions(t *testing.T) { t.Run("server monitoring mode validation", func(t *testing.T) { t.Parallel() - // := fmt.Errorf("invalid server monitoring mode: %q", *mode) - testCases := []struct { name string opts *ClientOptions @@ -801,6 +799,8 @@ func TestClientOptions(t *testing.T) { tc := tc // Capture the range variable t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.opts.Validate() assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 76524bac8f..28e874947b 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1109,7 +1109,6 @@ func (op Operation) createLegacyHandshakeWireMessage( dst []byte, desc description.SelectedServer, ) ([]byte, startedInformation, error) { - var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'} var info startedInformation flags := op.secondaryOK(desc) var wmindex int32 @@ -1117,6 +1116,8 @@ func (op Operation) createLegacyHandshakeWireMessage( wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) dst = wiremessage.AppendQueryFlags(dst, flags) + var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'} + // FullCollectionName dst = append(dst, op.Database...) dst = append(dst, dollarCmd[:]...) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 737ce82459..0c9144f5f4 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -4,12 +4,6 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -/// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - package topology import ( From cfb1c661c824fb8de8df1baeac731df61413bc1f Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:25:01 -0600 Subject: [PATCH 14/22] GODRIVER-2810 Re-organize rtt monitor --- x/mongo/driver/topology/rtt_monitor.go | 4 +++ x/mongo/driver/topology/server.go | 43 +++++++++++++------------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 998d2a0253..1a1c37b296 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -51,6 +51,7 @@ type rttMonitor struct { cfg *rttConfig ctx context.Context cancelFn context.CancelFunc + started bool } var _ driver.RTTMonitor = &rttMonitor{} @@ -75,6 +76,7 @@ func newRTTMonitor(cfg *rttConfig) *rttMonitor { func (r *rttMonitor) connect() { r.closeWg.Add(1) + r.started = true go r.start() } @@ -89,6 +91,8 @@ func (r *rttMonitor) start() { var conn *connection defer func() { + r.started = false + if conn != nil { // If the connection exists, we need to wait for it to be connected because // conn.connect() and conn.close() cannot be called concurrently. If the connection diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 0c9144f5f4..e88cff8084 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -246,7 +246,6 @@ func (s *Server) Connect(updateCallback updateTopologyCallback) error { s.updateTopologyCallback.Store(updateCallback) if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced { - s.rttMonitor.connect() s.closewg.Add(1) go s.update() } @@ -655,12 +654,15 @@ func (s *Server) update() { // If the server supports streaming or we're already streaming, we want to move to streaming the next response // without waiting. If the server has transitioned to Unknown from a network error, we want to do another // check without waiting in case it was a transient error and the server isn't actually down. - serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming() transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil && previousDescription.Kind != description.Unknown - if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError { + if isStreamingEnabled(s) && isStreamable(s) && !s.rttMonitor.started { + s.rttMonitor.connect() + } + + if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError { continue } @@ -796,23 +798,22 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { ServerAPI(s.cfg.serverAPI) } -// isStreamable returns whether or not we can use the streaming protocol to -// optimally reduces the time it takes for a client to discover server state -// changes. Streaming must be disabled if any of the following are true: -// -// - the client is configured with serverMonitoringMode=poll [P], or -// - the client is configured with serverMonitoringMode=auto [A] and a FaaS -// platform is detected [F], or -// - the server does not support streaming (eg MongoDB <4.4) [S]. -// -// I.e, streaming must be disabled if: P ∨ (A ∧ F) ∨ (~S) ≡ ~P ∧ (~A ∨ ~F) ∧ S -func isStreamable(previousDesc description.Server, srv *Server) bool { - srvMonitoringMode := srv.cfg.serverMonitoringMode - faas := driverutil.GetFaasEnvName() - - return srvMonitoringMode != connstring.ServerMonitoringModePoll && // P - (srvMonitoringMode != connstring.ServerMonitoringModeAuto || faas == "") && // (~A ∨ ~F) - previousDesc.TopologyVersion != nil // S +func isStreamingEnabled(srv *Server) bool { + mode := srv.cfg.serverMonitoringMode + + if mode == connstring.ServerMonitoringModeStream { + return true + } + + if mode == connstring.ServerMonitoringModeAuto { + return driverutil.GetFaasEnvName() == "" + } + + return false +} + +func isStreamable(srv *Server) bool { + return srv.Description().Kind != description.Unknown && srv.Description().TopologyVersion != nil } func (s *Server) check() (description.Server, error) { @@ -850,7 +851,7 @@ func (s *Server) check() (description.Server, error) { heartbeatConn := initConnection{s.conn} baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() - streamable := isStreamable(previousDescription, s) + streamable := isStreamingEnabled(s) && isStreamable(s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) switch { From 2d54de4f548f3a733daf7cbf809e1d53024898ca Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:50:41 -0600 Subject: [PATCH 15/22] GODRIVER-2810 Bump RTT tests to min 4.4 --- mongo/integration/client_test.go | 9 +++++---- mongo/integration/sdam_prose_test.go | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 038ed25d72..45b18b6537 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -515,7 +515,8 @@ func TestClient(t *testing.T) { assert.Nil(t, err, "unexpected error calling Ping: %v", err) }) - mt.Run("minimum RTT is monitored", func(mt *mtest.T) { + rtt90Opts := mtest.NewOptions().MinServerVersion("4.4") + mt.RunOpts("minimum RTT is monitored", rtt90Opts, func(mt *mtest.T) { mt.Parallel() // Reset the client with a dialer that delays all network round trips by 300ms and set the @@ -555,7 +556,7 @@ func TestClient(t *testing.T) { // Test that if the minimum RTT is greater than the remaining timeout for an operation, the // operation is not sent to the server and no connections are closed. - mt.Run("minimum RTT used to prevent sending requests", func(mt *mtest.T) { + mt.RunOpts("minimum RTT used to prevent sending requests", rtt90Opts, func(mt *mtest.T) { mt.Parallel() // Assert that we can call Ping with a 250ms timeout. @@ -614,7 +615,7 @@ func TestClient(t *testing.T) { assert.Equal(t, 0, closed, "expected no connections to be closed") }) - mt.Run("RTT90 is monitored", func(mt *mtest.T) { + mt.RunOpts("RTT90 is monitored", rtt90Opts, func(mt *mtest.T) { mt.Parallel() // Reset the client with a dialer that delays all network round trips by 300ms and set the @@ -654,7 +655,7 @@ func TestClient(t *testing.T) { // Test that if Timeout is set and the RTT90 is greater than the remaining timeout for an operation, the // operation is not sent to the server, fails with a timeout error, and no connections are closed. - mt.Run("RTT90 used to prevent sending requests", func(mt *mtest.T) { + mt.RunOpts("RTT90 used to prevent sending requests", rtt90Opts, func(mt *mtest.T) { mt.Parallel() // Assert that we can call Ping with a 250ms timeout. diff --git a/mongo/integration/sdam_prose_test.go b/mongo/integration/sdam_prose_test.go index 21b1fea4ba..69b7fdbd5c 100644 --- a/mongo/integration/sdam_prose_test.go +++ b/mongo/integration/sdam_prose_test.go @@ -31,7 +31,8 @@ func TestSDAMProse(t *testing.T) { heartbeatIntervalMtOpts := mtest.NewOptions(). ClientOptions(heartbeatIntervalClientOpts). CreateCollection(false). - ClientType(mtest.Proxy) + ClientType(mtest.Proxy). + MinServerVersion("4.4") // RTT Monitor / Streaming protocol is not supported for versions < 4.4. mt.RunOpts("heartbeats processed more frequently", heartbeatIntervalMtOpts, func(mt *mtest.T) { // Test that setting heartbeat interval to 500ms causes the client to process heartbeats // approximately every 500ms instead of the default 10s. Note that a Client doesn't From 2d5ebd976addec091bbb8e44df722323dffe0591 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 12 Sep 2023 11:34:45 -0600 Subject: [PATCH 16/22] GODRIVER-2810 Revert unecessary changes --- x/mongo/driver/operation.go | 2 +- x/mongo/driver/topology/server.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 28e874947b..229988e133 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1116,7 +1116,7 @@ func (op Operation) createLegacyHandshakeWireMessage( wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) dst = wiremessage.AppendQueryFlags(dst, flags) - var dollarCmd = [...]byte{'.', '$', 'c', 'm', 'd'} + dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} // FullCollectionName dst = append(dst, op.Database...) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index e88cff8084..0b278cc8dd 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -1,4 +1,4 @@ -// Copyright (C) MongoDB, Inc. 2023-present. +// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain From 1f32ccd23655b17b22b857ff58386be0e7d871e2 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:57:33 -0600 Subject: [PATCH 17/22] GODRIVER-2810 Make code more dev-friendly --- x/mongo/driver/topology/rtt_monitor.go | 2 -- x/mongo/driver/topology/server.go | 13 +++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 1a1c37b296..e3e2e04683 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -91,8 +91,6 @@ func (r *rttMonitor) start() { var conn *connection defer func() { - r.started = false - if conn != nil { // If the connection exists, we need to wait for it to be connected because // conn.connect() and conn.close() cannot be called concurrently. If the connection diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 0b278cc8dd..a4aa703e97 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -799,17 +799,14 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { } func isStreamingEnabled(srv *Server) bool { - mode := srv.cfg.serverMonitoringMode - - if mode == connstring.ServerMonitoringModeStream { + switch srv.cfg.serverMonitoringMode { + case connstring.ServerMonitoringModeStream: return true - } - - if mode == connstring.ServerMonitoringModeAuto { + case connstring.ServerMonitoringModePoll: + return false + default: return driverutil.GetFaasEnvName() == "" } - - return false } func isStreamable(srv *Server) bool { From 6287bab51461e3c760f3b916a50b390b2536b86b Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Thu, 12 Oct 2023 13:57:03 -0600 Subject: [PATCH 18/22] GODRIVER-2810 Ensure polling does not disable CSOT --- mongo/integration/client_test.go | 2 +- x/mongo/driver/topology/server.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 45b18b6537..532cbc5ab1 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -515,7 +515,7 @@ func TestClient(t *testing.T) { assert.Nil(t, err, "unexpected error calling Ping: %v", err) }) - rtt90Opts := mtest.NewOptions().MinServerVersion("4.4") + rtt90Opts := mtest.NewOptions() mt.RunOpts("minimum RTT is monitored", rtt90Opts, func(mt *mtest.T) { mt.Parallel() diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index a4aa703e97..05c290d701 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -851,6 +851,12 @@ func (s *Server) check() (description.Server, error) { streamable := isStreamingEnabled(s) && isStreamable(s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) + + // We need to record an RTT sample in the polling case so that if the server + // is < 4.4, or if polling is specified by the user, then the + // RTT-short-circuit feature of CSOT is not disabled. + var addPollingRTTSample bool + switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. @@ -880,8 +886,14 @@ func (s *Server) check() (description.Server, error) { s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) err = baseOperation.Execute(s.heartbeatCtx) + + addPollingRTTSample = true } + duration = time.Since(start) + if addPollingRTTSample { + s.rttMonitor.addSample(duration) + } if err == nil { tempDesc := baseOperation.Result(s.address) From d2798efc2f0cb06cd72685f5c462f83898126b50 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:13:51 -0600 Subject: [PATCH 19/22] GODRIVER-2810 Bump schema to 1.17 --- .../integration/unified/event_verification.go | 174 +++++++++ mongo/integration/unified/schema_version.go | 2 +- .../unified/testrunner_operation.go | 7 +- .../unified/serverMonitoringMode.json | 358 +++++++++++++++++- .../unified/serverMonitoringMode.yml | 158 ++++++-- 5 files changed, 652 insertions(+), 47 deletions(-) diff --git a/mongo/integration/unified/event_verification.go b/mongo/integration/unified/event_verification.go index 91f7452907..1d54e3fb2a 100644 --- a/mongo/integration/unified/event_verification.go +++ b/mongo/integration/unified/event_verification.go @@ -9,6 +9,7 @@ package unified import ( "bytes" "context" + "errors" "fmt" "go.mongodb.org/mongo-driver/bson" @@ -64,10 +65,37 @@ type cmapEvent struct { } `bson:"poolClearedEvent"` } +type sdamEvent struct { + ServerDescriptionChangedEvent *struct { + NewDescription *struct { + Type *string `bson:"type"` + } `bson:"newDescription"` + + PreviousDescription *struct { + Type *string `bson:"type"` + } `bson:"previousDescription"` + } `bson:"serverDescriptionChangedEvent"` + + ServerHeartbeatStartedEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatStartedEvent"` + + ServerHeartbeatSucceededEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatSucceededEvent"` + + ServerHeartbeatFailedEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatFailedEvent"` + + TopologyDescriptionChangedEvent *struct{} `bson:"topologyDescriptionChangedEvent"` +} + type expectedEvents struct { ClientID string `bson:"client"` CommandEvents []commandMonitoringEvent CMAPEvents []cmapEvent + SDAMEvents []sdamEvent IgnoreExtraEvents *bool } @@ -102,6 +130,8 @@ func (e *expectedEvents) UnmarshalBSON(data []byte) error { target = &e.CommandEvents case "cmap": target = &e.CMAPEvents + case "sdam": + target = &e.SDAMEvents default: return fmt.Errorf("unrecognized 'eventType' value for expectedEvents: %q", temp.EventType) } @@ -127,6 +157,8 @@ func verifyEvents(ctx context.Context, expectedEvents *expectedEvents) error { return verifyCommandEvents(ctx, client, expectedEvents) case expectedEvents.CMAPEvents != nil: return verifyCMAPEvents(client, expectedEvents) + case expectedEvents.SDAMEvents != nil: + return verifySDAMEvents(client, expectedEvents) } return nil } @@ -405,3 +437,145 @@ func stringifyEventsForClient(client *clientEntity) string { return str.String() } + +func getNextServerDescriptionChangedEvent( + events []*event.ServerDescriptionChangedEvent, +) (*event.ServerDescriptionChangedEvent, []*event.ServerDescriptionChangedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no server changed event published") + } + + return events[0], events[1:], nil +} + +func getNextServerHeartbeatStartedEvent( + events []*event.ServerHeartbeatStartedEvent, +) (*event.ServerHeartbeatStartedEvent, []*event.ServerHeartbeatStartedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat started event published") + } + + return events[0], events[1:], nil +} + +func getNextServerHeartbeatSucceededEvent( + events []*event.ServerHeartbeatSucceededEvent, +) (*event.ServerHeartbeatSucceededEvent, []*event.ServerHeartbeatSucceededEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat succeeded event published") + } + + return events[0], events[:1], nil +} + +func getNextServerHeartbeatFailedEvent( + events []*event.ServerHeartbeatFailedEvent, +) (*event.ServerHeartbeatFailedEvent, []*event.ServerHeartbeatFailedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat failed event published") + } + + return events[0], events[:1], nil +} + +func getNextTopologyDescriptionChangedEvent( + events []*event.TopologyDescriptionChangedEvent, +) (*event.TopologyDescriptionChangedEvent, []*event.TopologyDescriptionChangedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no topology description changed event published") + } + + return events[0], events[:1], nil +} + +func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error { + var ( + changed = client.serverDescriptionChanged + started = client.serverHeartbeatStartedEvent + succeeded = client.serverHeartbeatSucceeded + failed = client.serverHeartbeatFailedEvent + tchanged = client.topologyDescriptionChanged + ) + + vol := func() int { return len(changed) + len(started) + len(succeeded) + len(failed) + len(tchanged) } + + if len(expectedEvents.SDAMEvents) == 0 && vol() != 0 { + return fmt.Errorf("expected no sdam events to be sent but got %s", stringifyEventsForClient(client)) + } + + for idx, evt := range expectedEvents.SDAMEvents { + var err error + + switch { + case evt.ServerDescriptionChangedEvent != nil: + var got *event.ServerDescriptionChangedEvent + if got, changed, err = getNextServerDescriptionChangedEvent(changed); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + prevDesc := evt.ServerDescriptionChangedEvent.NewDescription + + var wantPrevDesc string + if prevDesc != nil && prevDesc.Type != nil { + wantPrevDesc = *prevDesc.Type + } + + gotPrevDesc := got.PreviousDescription.Kind.String() + if gotPrevDesc != wantPrevDesc { + return newEventVerificationError(idx, client, + "expected previous server description %q, got %q", wantPrevDesc, gotPrevDesc) + } + + newDesc := evt.ServerDescriptionChangedEvent.PreviousDescription + + var wantNewDesc string + if newDesc != nil && newDesc.Type != nil { + wantNewDesc = *newDesc.Type + } + + gotNewDesc := got.NewDescription.Kind.String() + if gotNewDesc != wantNewDesc { + return newEventVerificationError(idx, client, + "expected new server description %q, got %q", wantNewDesc, gotNewDesc) + } + case evt.ServerHeartbeatStartedEvent != nil: + var got *event.ServerHeartbeatStartedEvent + if got, started, err = getNextServerHeartbeatStartedEvent(started); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatStartedEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.ServerHeartbeatSucceededEvent != nil: + var got *event.ServerHeartbeatSucceededEvent + if got, succeeded, err = getNextServerHeartbeatSucceededEvent(succeeded); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatSucceededEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.ServerHeartbeatFailedEvent != nil: + var got *event.ServerHeartbeatFailedEvent + if got, failed, err = getNextServerHeartbeatFailedEvent(failed); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatFailedEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.TopologyDescriptionChangedEvent != nil: + if _, tchanged, err = getNextTopologyDescriptionChangedEvent(tchanged); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + } + } + + // Verify that there are no remaining events if ignoreExtraEvents is unset or false. + ignoreExtraEvents := expectedEvents.IgnoreExtraEvents != nil && *expectedEvents.IgnoreExtraEvents + if !ignoreExtraEvents && vol() > 0 { + return fmt.Errorf("extra sdam events published; all events for client: %s", stringifyEventsForClient(client)) + } + return nil +} diff --git a/mongo/integration/unified/schema_version.go b/mongo/integration/unified/schema_version.go index c85a2efa79..9aec89a18d 100644 --- a/mongo/integration/unified/schema_version.go +++ b/mongo/integration/unified/schema_version.go @@ -16,7 +16,7 @@ import ( var ( supportedSchemaVersions = map[int]string{ - 1: "1.16", + 1: "1.17", } ) diff --git a/mongo/integration/unified/testrunner_operation.go b/mongo/integration/unified/testrunner_operation.go index 474c01c88a..297ebbdf5d 100644 --- a/mongo/integration/unified/testrunner_operation.go +++ b/mongo/integration/unified/testrunner_operation.go @@ -19,7 +19,12 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var waitForEventTimeout = 10 * time.Second +// waitForEventTimeout is the amount of time to wait for an event to occur. The +// maximum amount of time expected for this value is currently 10 seconds, which +// is the amoutn of time that the driver will attempt to wait between streamable +// heartbeats. Increase this value if a new maximum time is expected in another +// operation. +var waitForEventTimeout = 11 * time.Second type loopArgs struct { Operations []*operation `bson:"operations"` diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json index 520635ba2c..7d681b4f9e 100644 --- a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json @@ -1,9 +1,24 @@ { "description": "serverMonitoringMode", - "schemaVersion": "1.3", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "topologies": [ + "single", + "sharded", + "sharded-replicaset" + ], + "serverless": "forbid" + } + ], "tests": [ { - "description": "connect with serverMonitoringMode=auto", + "description": "connect with serverMonitoringMode=auto >=4.4", + "runOnRequirements": [ + { + "minServerVersion": "4.4.0" + } + ], "operations": [ { "name": "createEntities", @@ -12,16 +27,22 @@ "entities": [ { "client": { - "id": "client0", + "id": "client", "uriOptions": { "serverMonitoringMode": "auto" - } + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] } }, { "database": { - "id": "dbSdamModeAuto", - "client": "client0", + "id": "db", + "client": "client", "databaseName": "sdam-tests" } } @@ -30,7 +51,7 @@ }, { "name": "runCommand", - "object": "dbSdamModeAuto", + "object": "db", "arguments": { "commandName": "ping", "command": { @@ -40,11 +61,51 @@ "expectResult": { "ok": 1 } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": true + } + } + ] } ] }, { - "description": "connect with serverMonitoringMode=stream", + "description": "connect with serverMonitoringMode=auto <4.4", + "runOnRequirements": [ + { + "maxServerVersion": "4.2.99" + } + ], "operations": [ { "name": "createEntities", @@ -53,16 +114,110 @@ "entities": [ { "client": { - "id": "client1", + "id": "client", + "uriOptions": { + "serverMonitoringMode": "auto", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=stream >=4.4", + "runOnRequirements": [ + { + "minServerVersion": "4.4.0" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", "uriOptions": { "serverMonitoringMode": "stream" - } + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] } }, { "database": { - "id": "dbSdamModeStream", - "client": "client1", + "id": "db", + "client": "client", "databaseName": "sdam-tests" } } @@ -71,7 +226,7 @@ }, { "name": "runCommand", - "object": "dbSdamModeStream", + "object": "db", "arguments": { "commandName": "ping", "command": { @@ -81,6 +236,129 @@ "expectResult": { "ok": 1 } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": true + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=stream <4.4", + "runOnRequirements": [ + { + "maxServerVersion": "4.2.99" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "stream", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] } ] }, @@ -94,16 +372,23 @@ "entities": [ { "client": { - "id": "client2", + "id": "client", "uriOptions": { - "serverMonitoringMode": "poll" - } + "serverMonitoringMode": "poll", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] } }, { "database": { - "id": "dbSdamModePoll", - "client": "client2", + "id": "db", + "client": "client", "databaseName": "sdam-tests" } } @@ -112,7 +397,7 @@ }, { "name": "runCommand", - "object": "dbSdamModePoll", + "object": "db", "arguments": { "commandName": "ping", "command": { @@ -122,6 +407,41 @@ "expectResult": { "ok": 1 } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] } ] } diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml index bb2b2053b7..28c7853d04 100644 --- a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml @@ -1,49 +1,151 @@ description: serverMonitoringMode -schemaVersion: "1.3" - +schemaVersion: "1.17" +# These tests cannot run on replica sets because the order of the expected +# SDAM events are non-deterministic when monitoring multiple servers. +# They also cannot run on Serverless or load balanced clusters where SDAM is disabled. +runOnRequirements: + - topologies: [single, sharded, sharded-replicaset] + serverless: forbid tests: - - description: "connect with serverMonitoringMode=auto" + - description: "connect with serverMonitoringMode=auto >=4.4" + runOnRequirements: + - minServerVersion: "4.4.0" operations: - name: createEntities object: testRunner arguments: entities: - client: - id: &client0 client0 + id: client uriOptions: serverMonitoringMode: "auto" + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent - database: - id: &dbSdamModeAuto dbSdamModeAuto - client: *client0 + id: db + client: client databaseName: sdam-tests - - name: runCommand - object: *dbSdamModeAuto + - &ping + name: runCommand + object: db arguments: commandName: ping command: { ping: 1 } expectResult: { ok: 1 } + # Wait for the second serverHeartbeatStartedEvent to ensure we start streaming. + - &waitForSecondHeartbeatStarted + name: waitForEvent + object: testRunner + arguments: + client: client + event: + serverHeartbeatStartedEvent: {} + count: 2 + expectEvents: &streamingStartedEvents + - client: client + eventType: sdam + ignoreExtraEvents: true + events: + - serverHeartbeatStartedEvent: + awaited: False + - serverHeartbeatSucceededEvent: + awaited: False + - serverHeartbeatStartedEvent: + awaited: True + + - description: "connect with serverMonitoringMode=auto <4.4" + runOnRequirements: + - maxServerVersion: "4.2.99" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "auto" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: &pollingStartedEvents + - client: client + eventType: sdam + ignoreExtraEvents: true + events: + - serverHeartbeatStartedEvent: + awaited: False + - serverHeartbeatSucceededEvent: + awaited: False + - serverHeartbeatStartedEvent: + awaited: False - - description: "connect with serverMonitoringMode=stream" + - description: "connect with serverMonitoringMode=stream >=4.4" + runOnRequirements: + - minServerVersion: "4.4.0" operations: - name: createEntities object: testRunner arguments: entities: - client: - id: &client1 client1 + id: client uriOptions: serverMonitoringMode: "stream" + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent - database: - id: &dbSdamModeStream dbSdamModeStream - client: *client1 + id: db + client: client databaseName: sdam-tests - - name: runCommand - object: *dbSdamModeStream + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we start streaming. + - *waitForSecondHeartbeatStarted + expectEvents: *streamingStartedEvents + + - description: "connect with serverMonitoringMode=stream <4.4" + runOnRequirements: + - maxServerVersion: "4.2.99" + operations: + - name: createEntities + object: testRunner arguments: - commandName: ping - command: { ping: 1 } - expectResult: { ok: 1 } + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "stream" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: *pollingStartedEvents - description: "connect with serverMonitoringMode=poll" operations: @@ -52,16 +154,20 @@ tests: arguments: entities: - client: - id: &client2 client2 + id: client uriOptions: serverMonitoringMode: "poll" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent - database: - id: &dbSdamModePoll dbSdamModePoll - client: *client2 + id: db + client: client databaseName: sdam-tests - - name: runCommand - object: *dbSdamModePoll - arguments: - commandName: ping - command: { ping: 1 } - expectResult: { ok: 1 } + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: *pollingStartedEvents From aaa1b6a130cc74ea290ca0240fa939d90759211c Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:41:51 -0600 Subject: [PATCH 20/22] GODRIVER-2810 Remove unecessary bool logic --- mongo/integration/client_test.go | 9 ++++----- x/mongo/driver/topology/server.go | 13 +++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index fe8e4cc6b0..e3a5b26241 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -515,8 +515,7 @@ func TestClient(t *testing.T) { assert.Nil(t, err, "unexpected error calling Ping: %v", err) }) - rtt90Opts := mtest.NewOptions() - mt.RunOpts("minimum RTT is monitored", rtt90Opts, func(mt *mtest.T) { + mt.Run("minimum RTT is monitored", func(mt *mtest.T) { mt.Parallel() // Reset the client with a dialer that delays all network round trips by 300ms and set the @@ -556,7 +555,7 @@ func TestClient(t *testing.T) { // Test that if the minimum RTT is greater than the remaining timeout for an operation, the // operation is not sent to the server and no connections are closed. - mt.RunOpts("minimum RTT used to prevent sending requests", rtt90Opts, func(mt *mtest.T) { + mt.Run("minimum RTT used to prevent sending requests", func(mt *mtest.T) { mt.Parallel() // Assert that we can call Ping with a 250ms timeout. @@ -615,7 +614,7 @@ func TestClient(t *testing.T) { assert.Equal(t, 0, closed, "expected no connections to be closed") }) - mt.RunOpts("RTT90 is monitored", rtt90Opts, func(mt *mtest.T) { + mt.Run("RTT90 is monitored", func(mt *mtest.T) { mt.Parallel() // Reset the client with a dialer that delays all network round trips by 300ms and set the @@ -655,7 +654,7 @@ func TestClient(t *testing.T) { // Test that if Timeout is set and the RTT90 is greater than the remaining timeout for an operation, the // operation is not sent to the server, fails with a timeout error, and no connections are closed. - mt.RunOpts("RTT90 used to prevent sending requests", rtt90Opts, func(mt *mtest.T) { + mt.Run("RTT90 used to prevent sending requests", func(mt *mtest.T) { mt.Parallel() // Assert that we can call Ping with a 250ms timeout. diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index b945d149fa..47b6ed1c88 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -855,11 +855,6 @@ func (s *Server) check() (description.Server, error) { s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) - // We need to record an RTT sample in the polling case so that if the server - // is < 4.4, or if polling is specified by the user, then the - // RTT-short-circuit feature of CSOT is not disabled. - var addPollingRTTSample bool - switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. @@ -889,12 +884,14 @@ func (s *Server) check() (description.Server, error) { s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) err = baseOperation.Execute(s.heartbeatCtx) - - addPollingRTTSample = true } duration = time.Since(start) - if addPollingRTTSample { + + // We need to record an RTT sample in the polling case so that if the server + // is < 4.4, or if polling is specified by the user, then the + // RTT-short-circuit feature of CSOT is not disabled. + if !streamable { s.rttMonitor.addSample(duration) } From b686f7cc6b66151c3385fc5b3f44e204a69b2366 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 27 Oct 2023 14:55:06 -0600 Subject: [PATCH 21/22] GODRIVER-2810 Guard rttMonitor connection --- Dockerfile | 2 ++ x/mongo/driver/topology/rtt_monitor.go | 32 +++++++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index 446ab1d5f2..077967d039 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,6 +28,8 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ tzdata \ gpg \ apt-utils \ + libc6-dev \ + gcc \ make && \ apt-add-repository ppa:longsleep/golang-backports && \ apt-get -qq update && \ diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index e3e2e04683..7308c88f15 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -39,7 +39,12 @@ type rttConfig struct { } type rttMonitor struct { - mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet + mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet + + // connMu guards connecting and disconnecting. This is necessary since + // disconnecting will await the cancelation of a started connection. The + // use case for rttMonitor.connect needs to be goroutine safe. + connMu sync.Mutex samples []time.Duration offset int minRTT time.Duration @@ -52,6 +57,7 @@ type rttMonitor struct { ctx context.Context cancelFn context.CancelFunc started bool + done chan struct{} } var _ driver.RTTMonitor = &rttMonitor{} @@ -75,20 +81,34 @@ func newRTTMonitor(cfg *rttConfig) *rttMonitor { } func (r *rttMonitor) connect() { - r.closeWg.Add(1) + r.connMu.Lock() + defer r.connMu.Unlock() + r.started = true - go r.start() + r.closeWg.Add(1) + + go func() { + defer r.closeWg.Done() + + r.start() + }() } func (r *rttMonitor) disconnect() { - // Signal for the routine to stop. + r.connMu.Lock() + defer r.connMu.Unlock() + + if !r.started { + return + } + r.cancelFn() + + // Wait for the existing connection to complete. r.closeWg.Wait() } func (r *rttMonitor) start() { - defer r.closeWg.Done() - var conn *connection defer func() { if conn != nil { From 34a8970a7b4bafbaddf431697c54187563f3cad0 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:15:34 -0600 Subject: [PATCH 22/22] GODRIVER-2810 Remove unused rttMonitor field --- x/mongo/driver/topology/rtt_monitor.go | 1 - 1 file changed, 1 deletion(-) diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 7308c88f15..eacc6bf6d3 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -57,7 +57,6 @@ type rttMonitor struct { ctx context.Context cancelFn context.CancelFunc started bool - done chan struct{} } var _ driver.RTTMonitor = &rttMonitor{}