From 3da990232580ffaa0a3082b1f11639220e53d490 Mon Sep 17 00:00:00 2001 From: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Date: Mon, 14 Aug 2023 14:33:38 -0600 Subject: [PATCH] GODRIVER-2572 Clean up code per review --- mongo/collection.go | 50 +++++++------ mongo/description/server_selector.go | 7 +- mongo/integration/unified/client_entity.go | 14 ++-- mongo/integration/unified/event.go | 16 ++--- .../unified/testrunner_operation.go | 6 +- x/mongo/driver/topology/topology.go | 72 ++++++++++--------- 6 files changed, 82 insertions(+), 83 deletions(-) diff --git a/mongo/collection.go b/mongo/collection.go index bb5399e16b..6abbea9792 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1870,49 +1870,47 @@ func (coll *Collection) drop(ctx context.Context) error { type pinnedServerSelector struct { stringer fmt.Stringer - fn description.ServerSelectorFunc + fallback description.ServerSelector + session *session.Client } func (pss pinnedServerSelector) String() string { + if pss.stringer == nil { + return "" + } + return pss.stringer.String() } func (pss pinnedServerSelector) SelectServer( t description.Topology, - s []description.Server, + svrs []description.Server, ) ([]description.Server, error) { - return pss.fn(t, s) -} - -// pinnedSelectorFunc makes a selector for a pinned session with a pinned -// server. It attempts to do server selection on the pinned server, but if that -// fails, it will go through a list of default selectors. -func pinnedSelectorFunc(sess *session.Client, defaultSelector description.ServerSelector) description.ServerSelectorFunc { - return func(t description.Topology, svrs []description.Server) ([]description.Server, error) { - if sess != nil && sess.PinnedServer != nil { - // If there is a pinned server, try to find it in the list of candidates. - for _, candidate := range svrs { - if candidate.Addr == sess.PinnedServer.Addr { - return []description.Server{candidate}, nil - } + if pss.session != nil && pss.session.PinnedServer != nil { + // If there is a pinned server, try to find it in the list of candidates. + for _, candidate := range svrs { + if candidate.Addr == pss.session.PinnedServer.Addr { + return []description.Server{candidate}, nil } - - return nil, nil } - return defaultSelector.SelectServer(t, svrs) + return nil, nil } + + return pss.fallback.SelectServer(t, svrs) } -func makePinnedSelector(sess *session.Client, defaultSelector description.ServerSelector) description.ServerSelector { - if srvSelectorStringer, ok := defaultSelector.(fmt.Stringer); ok { - return pinnedServerSelector{ - stringer: srvSelectorStringer, - fn: pinnedSelectorFunc(sess, defaultSelector), - } +func makePinnedSelector(sess *session.Client, fallback description.ServerSelector) description.ServerSelector { + pss := pinnedServerSelector{ + session: sess, + fallback: fallback, + } + + if srvSelectorStringer, ok := fallback.(fmt.Stringer); ok { + pss.stringer = srvSelectorStringer } - return pinnedSelectorFunc(sess, defaultSelector) + return pss } func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelector { diff --git a/mongo/description/server_selector.go b/mongo/description/server_selector.go index 76f14de769..aee1f050cb 100644 --- a/mongo/description/server_selector.go +++ b/mongo/description/server_selector.go @@ -57,15 +57,15 @@ type compositeSelector struct { } func (cs *compositeSelector) info() serverSelectorInfo { - cssInfo := &serverSelectorInfo{Type: "compositeSelector"} + csInfo := serverSelectorInfo{Type: "compositeSelector"} for _, sel := range cs.selectors { if getter, ok := sel.(serverSelectorInfoGetter); ok { - cssInfo.Selectors = append(cssInfo.Selectors, getter.info()) + csInfo.Selectors = append(csInfo.Selectors, getter.info()) } } - return *cssInfo + return csInfo } // String returns the JSON string representation of the compositeSelector. @@ -204,7 +204,6 @@ func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { rp: rp, isOutputAggregate: false, } - } func (selector readPrefServerSelector) info() serverSelectorInfo { diff --git a/mongo/integration/unified/client_entity.go b/mongo/integration/unified/client_entity.go index 1bab1951ec..e63c891039 100644 --- a/mongo/integration/unified/client_entity.go +++ b/mongo/integration/unified/client_entity.go @@ -62,7 +62,7 @@ type clientEntity struct { observedEvents map[monitoringEventType]struct{} storedEvents map[monitoringEventType][]string eventsCount map[monitoringEventType]int32 - serverDescriptionChangedEventsCount map[serverDescriptionChangedEvent]int32 + serverDescriptionChangedEventsCount map[serverDescriptionChangedEventInfo]int32 eventsCountLock sync.RWMutex serverDescriptionChangedEventsCountLock sync.RWMutex @@ -89,7 +89,7 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp observedEvents: make(map[monitoringEventType]struct{}), storedEvents: make(map[monitoringEventType][]string), eventsCount: make(map[monitoringEventType]int32), - serverDescriptionChangedEventsCount: make(map[serverDescriptionChangedEvent]int32), + serverDescriptionChangedEventsCount: make(map[serverDescriptionChangedEventInfo]int32), entityMap: em, observeSensitiveCommands: entityOptions.ObserveSensitiveCommands, } @@ -302,7 +302,7 @@ func (c *clientEntity) addEventsCount(eventType monitoringEventType) { c.eventsCount[eventType]++ } -func (c *clientEntity) addServerDescriptionChangedEventCount(evt serverDescriptionChangedEvent) { +func (c *clientEntity) addServerDescriptionChangedEventCount(evt serverDescriptionChangedEventInfo) { c.serverDescriptionChangedEventsCountLock.Lock() defer c.serverDescriptionChangedEventsCountLock.Unlock() @@ -316,7 +316,7 @@ func (c *clientEntity) getEventCount(eventType monitoringEventType) int32 { return c.eventsCount[eventType] } -func (c *clientEntity) getServerDescriptionChangedEventCount(evt serverDescriptionChangedEvent) int32 { +func (c *clientEntity) getServerDescriptionChangedEventCount(evt serverDescriptionChangedEventInfo) int32 { c.serverDescriptionChangedEventsCountLock.Lock() defer c.serverDescriptionChangedEventsCountLock.Unlock() @@ -475,15 +475,15 @@ func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDes return } - if _, ok := c.observedEvents[serverDescriptionChangedInfo]; ok { + if _, ok := c.observedEvents[serverDescriptionChangedEvent]; ok { c.serverDescriptionChanged = append(c.serverDescriptionChanged, evt) } // Record object-specific unified spec test data on an event. - c.addServerDescriptionChangedEventCount(*newServerDescriptionChangedEvent(evt)) + c.addServerDescriptionChangedEventCount(*newServerDescriptionChangedEventInfo(evt)) // Record the event generally. - c.addEventsCount(serverDescriptionChangedInfo) + c.addEventsCount(serverDescriptionChangedEvent) } func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartbeatFailedEvent) { diff --git a/mongo/integration/unified/event.go b/mongo/integration/unified/event.go index 799be0361d..a6da6bf327 100644 --- a/mongo/integration/unified/event.go +++ b/mongo/integration/unified/event.go @@ -31,7 +31,7 @@ const ( connectionCheckOutFailedEvent monitoringEventType = "ConnectionCheckOutFailedEvent" connectionCheckedOutEvent monitoringEventType = "ConnectionCheckedOutEvent" connectionCheckedInEvent monitoringEventType = "ConnectionCheckedInEvent" - serverDescriptionChangedInfo monitoringEventType = "ServerDescriptionChangedEvent" + serverDescriptionChangedEvent monitoringEventType = "ServerDescriptionChangedEvent" serverHeartbeatFailedEvent monitoringEventType = "ServerHeartbeatFailedEvent" serverHeartbeatStartedEvent monitoringEventType = "ServerHeartbeatStartedEvent" serverHeartbeatSucceededEvent monitoringEventType = "ServerHeartbeatSucceededEvent" @@ -69,7 +69,7 @@ func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) case "connectioncheckedinevent": return connectionCheckedInEvent, true case "serverdescriptionchangedevent": - return serverDescriptionChangedInfo, true + return serverDescriptionChangedEvent, true case "serverheartbeatfailedevent": return serverHeartbeatFailedEvent, true case "serverheartbeatstartedevent": @@ -119,9 +119,9 @@ type serverDescription struct { Type string } -// serverDescriptionChangedEvent represents an event generated when the server +// serverDescriptionChangedEventInfo represents an event generated when the server // description changes. -type serverDescriptionChangedEvent struct { +type serverDescriptionChangedEventInfo struct { // NewDescription corresponds to the server description as it was after // the change that triggered this event. NewDescription serverDescription @@ -131,10 +131,10 @@ type serverDescriptionChangedEvent struct { PreviousDescription serverDescription } -// newServerDescriptionChangedEvent returns a new serverDescriptionChangedEvent +// newServerDescriptionChangedEventInfo returns a new serverDescriptionChangedEvent // instance for the given event. -func newServerDescriptionChangedEvent(evt *event.ServerDescriptionChangedEvent) *serverDescriptionChangedEvent { - return &serverDescriptionChangedEvent{ +func newServerDescriptionChangedEventInfo(evt *event.ServerDescriptionChangedEvent) *serverDescriptionChangedEventInfo { + return &serverDescriptionChangedEventInfo{ NewDescription: serverDescription{ Type: evt.NewDescription.Kind.String(), }, @@ -146,7 +146,7 @@ func newServerDescriptionChangedEvent(evt *event.ServerDescriptionChangedEvent) // UnmarshalBSON unmarshals the event from BSON, used when trying to create the // expected event from a unified spec test. -func (evt *serverDescriptionChangedEvent) UnmarshalBSON(data []byte) error { +func (evt *serverDescriptionChangedEventInfo) UnmarshalBSON(data []byte) error { if len(data) == 0 { return nil } diff --git a/mongo/integration/unified/testrunner_operation.go b/mongo/integration/unified/testrunner_operation.go index 5a62142ee7..474c01c88a 100644 --- a/mongo/integration/unified/testrunner_operation.go +++ b/mongo/integration/unified/testrunner_operation.go @@ -302,10 +302,10 @@ func getServerDescriptionChangedEventCount(client *clientEntity, raw bson.Raw) i // If the document has no values, then we assume that the UST only // intends to check that the event happened. if values, _ := raw.Values(); len(values) == 0 { - return client.getEventCount(serverDescriptionChangedInfo) + return client.getEventCount(serverDescriptionChangedEvent) } - var expectedEvt serverDescriptionChangedEvent + var expectedEvt serverDescriptionChangedEventInfo if err := bson.Unmarshal(raw, &expectedEvt); err != nil { return 0 } @@ -323,7 +323,7 @@ func (args waitForEventArguments) eventCompleted(client *clientEntity) bool { } switch eventType { - case serverDescriptionChangedInfo: + case serverDescriptionChangedEvent: if getServerDescriptionChangedEventCount(client, eventDoc) < args.Count { return false } diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 5d2102646b..b0683021ee 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -109,8 +109,10 @@ type Topology struct { id primitive.ObjectID } -var _ driver.Deployment = &Topology{} -var _ driver.Subscriber = &Topology{} +var ( + _ driver.Deployment = &Topology{} + _ driver.Subscriber = &Topology{} +) type serverSelectionState struct { selector description.ServerSelector @@ -174,12 +176,12 @@ func logTopologyMessage(topo *Topology, msg string, keysAndValues ...interface{} }, keysAndValues...)...) } -func mustLogServerSelectionMessage(topo *Topology, level logger.Level) bool { +func mustLogServerSelection(topo *Topology, level logger.Level) bool { return topo.cfg.logger != nil && topo.cfg.logger.LevelComponentEnabled( level, logger.ComponentServerSelection) } -func logServerSelectionMessage( +func logServerSelection( ctx context.Context, level logger.Level, topo *Topology, @@ -208,7 +210,7 @@ func logServerSelectionMessage( }, keysAndValues...)...) } -func logServerSelectionSucceededMessage( +func logServerSelectionSucceeded( ctx context.Context, topo *Topology, srvSelector description.ServerSelector, @@ -222,18 +224,18 @@ func logServerSelectionSucceededMessage( portInt64, _ := strconv.ParseInt(port, 10, 32) - logServerSelectionMessage(ctx, logger.LevelDebug, topo, logger.ServerSelectionSucceeded, srvSelector, + logServerSelection(ctx, logger.LevelDebug, topo, logger.ServerSelectionSucceeded, srvSelector, logger.KeyServerHost, host, logger.KeyServerPort, portInt64) } -func logServerSelectionFailedMessage( +func logServerSelectionFailed( ctx context.Context, topo *Topology, srvSelector description.ServerSelector, err error, ) { - logServerSelectionMessage(ctx, logger.LevelDebug, topo, logger.ServerSelectionFailed, srvSelector, + logServerSelection(ctx, logger.LevelDebug, topo, logger.ServerSelectionFailed, srvSelector, logger.KeyFailure, err.Error()) } @@ -464,8 +466,8 @@ func (t *Topology) RequestImmediateCheck() { // parent context is done. func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) { if atomic.LoadInt64(&t.state) != topologyConnected { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, ErrTopologyClosed) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, ErrTopologyClosed) } return nil, ErrTopologyClosed @@ -489,8 +491,8 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect var selectErr error if !doneOnce { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionMessage(ctx, logger.LevelDebug, t, logger.ServerSelectionStarted, ss) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelection(ctx, logger.LevelDebug, t, logger.ServerSelectionStarted, ss) } // for the first pass, select a server from the current description. @@ -504,8 +506,8 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect var err error sub, err = t.Subscribe() if err != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, err) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, err) } return nil, err @@ -516,8 +518,8 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState) } if selectErr != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, selectErr) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, selectErr) } return nil, selectErr @@ -525,11 +527,11 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect if len(suitable) == 0 { // try again if there are no servers available - if mustLogServerSelectionMessage(t, logger.LevelInfo) { + if mustLogServerSelection(t, logger.LevelInfo) { elapsed := time.Since(startTime) remainingTimeMS := t.cfg.ServerSelectionTimeout - elapsed - logServerSelectionMessage(ctx, logger.LevelInfo, t, logger.ServerSelectionWaiting, ss, + logServerSelection(ctx, logger.LevelInfo, t, logger.ServerSelectionWaiting, ss, logger.KeyRemainingTimeMS, remainingTimeMS.Milliseconds()) } @@ -541,8 +543,8 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect if len(suitable) == 1 { server, err := t.FindServer(suitable[0]) if err != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, err) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, err) } return nil, err @@ -551,8 +553,8 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect continue } - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionSucceededMessage(ctx, t, ss, server) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionSucceeded(ctx, t, ss, server) } return server, nil @@ -563,16 +565,16 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect desc1, desc2 := pick2(suitable) server1, err := t.FindServer(desc1) if err != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, err) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, err) } return nil, err } server2, err := t.FindServer(desc2) if err != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionFailedMessage(ctx, t, ss, err) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, err) } return nil, err @@ -588,14 +590,14 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect } if server1 != nil { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionSucceededMessage(ctx, t, ss, server1) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionSucceeded(ctx, t, ss, server1) } return server1, nil } - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionSucceededMessage(ctx, t, ss, server2) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionSucceeded(ctx, t, ss, server2) } return server2, nil @@ -605,15 +607,15 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect // We use in-use connections as an analog for in-progress operations because they are almost // always the same value for a given server. if server1.OperationCount() < server2.OperationCount() { - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionSucceededMessage(ctx, t, ss, server1) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionSucceeded(ctx, t, ss, server1) } return server1, nil } - if mustLogServerSelectionMessage(t, logger.LevelDebug) { - logServerSelectionSucceededMessage(ctx, t, ss, server2) + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionSucceeded(ctx, t, ss, server2) } return server2, nil } @@ -824,7 +826,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { t.fsm.addServer(addr) } - //store new description + // store new description newDesc := description.Topology{ Kind: t.fsm.Kind, Servers: t.fsm.Servers,