diff --git a/CHANGELOG.md b/CHANGELOG.md index cf02cf0b9..158b000c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Add a LatencyAwarePolicy that prioritizes the host with smaller average latency ### Changed diff --git a/policies.go b/policies.go index 1157da87b..f206f4199 100644 --- a/policies.go +++ b/policies.go @@ -313,7 +313,7 @@ type HostSelectionPolicy interface { // selection policy. type SelectedHost interface { Info() *HostInfo - Mark(error) + Mark(error, uint64) } type selectedHost HostInfo @@ -322,7 +322,7 @@ func (host *selectedHost) Info() *HostInfo { return (*HostInfo)(host) } -func (host *selectedHost) Mark(err error) {} +func (host *selectedHost) Mark(err error, latency uint64) {} // NextHost is an iteration function over picked hosts type NextHost func() SelectedHost @@ -817,7 +817,7 @@ func (host selectedHostPoolHost) Info() *HostInfo { return host.info } -func (host selectedHostPoolHost) Mark(err error) { +func (host selectedHostPoolHost) Mark(err error, latency uint64) { ip := host.info.ConnectAddress().String() host.policy.mu.RLock() @@ -981,6 +981,284 @@ func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) } +// LatencyAwarePolicy is a host selection policy which will prioritize and return hosts with smaller latencies. +// It collects the latencies of the queries to each Cassandra node and maintains a per-node average latency score. +// Nodes that are slower than the best performing nodes by more than a configurable threshold will be deprioritized +// in the query plan +const ( + minMeasures = 50 + exclusionThreshold = 2 +) + +var ( + scaleInNano = uint64(time.Millisecond.Nanoseconds() * 100) + retryPeriodNanos = uint64(time.Second.Nanoseconds() * 10) + updateMinAvgRate = time.Millisecond * 100 + startUpdateMinAvgLatencyOnce sync.Once + + hostLatencyThresholdToAccount = uint64(30 * minMeasures / 100) +) + +func LatencyAwarePolicy(fallback HostSelectionPolicy) HostSelectionPolicy { + if fallback == nil { + panic("LatencyAwarePolicy should have a fallback HostSelectionPolicy") + } + return &latencyAwarePolicy{latencies: make(map[string]*hostLatencyStat), fallback: fallback, + stopUpdateMinAvgChan: make(chan struct{})} +} + +type latencyAwarePolicy struct { + latencies map[string]*hostLatencyStat + mu sync.RWMutex + + minAvg uint64 + maMu sync.RWMutex + + fallback HostSelectionPolicy + + stopUpdateMinAvgChan chan struct{} +} + +func (l *latencyAwarePolicy) IsLocal(host *HostInfo) bool { return l.fallback.IsLocal(host) } +func (l *latencyAwarePolicy) KeyspaceChanged(update KeyspaceUpdateEvent) { + l.fallback.KeyspaceChanged(update) +} +func (l *latencyAwarePolicy) Init(session *Session) { + l.fallback.Init(session) +} +func (l *latencyAwarePolicy) SetPartitioner(partitioner string) { + l.fallback.SetPartitioner(partitioner) +} + +type hostLatencyStat struct { + mu sync.Mutex + + thresholdToAccount uint64 + scale uint64 + timestamp int64 + average uint64 + numMeasure uint64 + + host *HostInfo +} + +func (h *hostLatencyStat) addNewLatency(latency uint64) { + h.mu.Lock() + defer h.mu.Unlock() + + curTimestamp := time.Now().UnixNano() + + h.numMeasure = h.numMeasure + 1 + if h.numMeasure < h.thresholdToAccount { + h.timestamp = curTimestamp + h.average = 0 + return + } + + if h.average <= 0 { + h.timestamp = curTimestamp + h.average = latency + return + } + + delay := curTimestamp - h.timestamp + // This should be rare and just discard the new latency + if delay <= 0 { + return + } + + h.timestamp = curTimestamp + scaledDelay := float64(delay) / float64(h.scale) + prevWeight := math.Log(scaledDelay+1) / scaledDelay + newAverage := (1.0-prevWeight)*float64(latency) + prevWeight*float64(h.average) + h.average = uint64(newAverage) +} + +func (l *latencyAwarePolicy) AddHost(host *HostInfo) { + ip := host.ConnectAddress().String() + + l.mu.RLock() + if h, ok := l.latencies[ip]; ok && h != nil { + l.mu.RUnlock() + return + } + l.mu.RUnlock() + + l.mu.Lock() + l.latencies[ip] = &hostLatencyStat{thresholdToAccount: hostLatencyThresholdToAccount, + scale: scaleInNano, timestamp: time.Now().UnixNano(), average: 0, host: host} + l.mu.Unlock() + + l.fallback.AddHost(host) +} + +func (l *latencyAwarePolicy) RemoveHost(host *HostInfo) { + ip := host.ConnectAddress().String() + + l.mu.RLock() + if _, ok := l.latencies[ip]; !ok { + l.mu.RUnlock() + return + } + l.mu.RUnlock() + + l.mu.Lock() + delete(l.latencies, ip) + l.mu.Unlock() + + l.fallback.RemoveHost(host) +} + +func (l *latencyAwarePolicy) HostUp(host *HostInfo) { + l.AddHost(host) +} + +func (l *latencyAwarePolicy) HostDown(host *HostInfo) { + l.RemoveHost(host) +} + +func (l *latencyAwarePolicy) Pick(qry ExecutableQuery) NextHost { + if l.latencies == nil || len(l.latencies) == 0 { + return l.fallback.Pick(qry) + } + // Start the thread to update the minAvg + startUpdateMinAvgLatencyOnce.Do(func() { + go l.updateMinAvgLatency() + }) + + return func() SelectedHost { + fallbackIter := l.fallback.Pick(qry) + + l.mu.RLock() + defer l.mu.RUnlock() + + skipped := make([]*HostInfo, len(l.latencies)) + now := uint64(time.Now().UnixNano()) + + fallbackHost := fallbackIter() + for { + if fallbackHost != nil { + ip := fallbackHost.Info().connectAddress.String() + + stat, ok := l.latencies[ip] + l.maMu.RLock() + elapsedTime := now - uint64(stat.timestamp) + if !ok || l.minAvg <= 0 || stat.numMeasure < minMeasures || elapsedTime > retryPeriodNanos { + l.maMu.RUnlock() + return selectedLatencyAwareHost{ + policy: l, + info: l.latencies[ip].host, + } + } + l.maMu.RUnlock() + exclusionThresholdTime := l.minAvg * exclusionThreshold + if stat.average > 0 && stat.average <= exclusionThresholdTime { + return selectedLatencyAwareHost{ + policy: l, + info: l.latencies[ip].host, + } + } + skipped = append(skipped, l.latencies[ip].host) + fallbackHost = fallbackIter() + } else { + break + } + } + randomIdx := rand.Intn(len(skipped)) + return selectedLatencyAwareHost{policy: l, info: skipped[randomIdx]} + } +} + +func (l *latencyAwarePolicy) updateHostLatency(ip string, latency uint64) { + l.mu.RLock() + defer l.mu.RUnlock() + + if _, ok := l.latencies[ip]; !ok { + return + } + + l.latencies[ip].addNewLatency(latency) +} + +func (l *latencyAwarePolicy) updateMinAvgLatency() { + // Update the min avg latency with rate + updateTicker := time.NewTicker(updateMinAvgRate) + defer updateTicker.Stop() + + for { + select { + case <-updateTicker.C: + maxUint64 := uint64(math.MaxUint64) + minLatency := maxUint64 + l.mu.RLock() + now := time.Now().UnixNano() + for _, stat := range l.latencies { + elapsedTime := uint64(now - stat.timestamp) + if stat != nil && stat.average > 0 && stat.numMeasure >= minMeasures && minLatency > stat.average && + elapsedTime <= retryPeriodNanos { + minLatency = stat.average + } + } + l.mu.RUnlock() + + if minLatency != maxUint64 { + l.maMu.Lock() + l.minAvg = minLatency + l.maMu.Unlock() + } + case <-l.stopUpdateMinAvgChan: + return + } + } +} + +func (l *latencyAwarePolicy) stopUpdateMinAvgLatency() { + l.stopUpdateMinAvgChan <- struct{}{} + close(l.stopUpdateMinAvgChan) +} + +// selectedLatencyAwareHost is a host returned by the latencyAwarePolicy and +// implements the SelectedHost interface +type selectedLatencyAwareHost struct { + policy *latencyAwarePolicy + info *HostInfo +} + +func (host selectedLatencyAwareHost) Info() *HostInfo { + return host.info +} + +func (host selectedLatencyAwareHost) Mark(err error, latency uint64) { + ip := host.info.ConnectAddress().String() + + host.policy.mu.RLock() + defer host.policy.mu.RUnlock() + + // host was removed between pick and mark + if _, ok := host.policy.latencies[ip]; !ok { + return + } + + if ok := host.shouldConsiderNewLatency(err); ok { + host.policy.updateHostLatency(ip, latency) + } +} + +func (host selectedLatencyAwareHost) shouldConsiderNewLatency(err error) bool { + if err == nil { + return true + } + var errFrame errorFrame + if errors.As(err, &errFrame) { + if errFrame.code == ErrCodeServer || errFrame.code == ErrCodeOverloaded || + errFrame.code == ErrCodeBootstrapping || errFrame.code == ErrCodeUnprepared || + errFrame.code == ErrCodeInvalid { + return false + } + } + return true +} + // ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After // each host connects during session initialization, the Ready method will be // called. If you only need a single Host to be up you can wrap a diff --git a/policies_test.go b/policies_test.go index 231c2a7e2..bc6c848e5 100644 --- a/policies_test.go +++ b/policies_test.go @@ -160,25 +160,25 @@ func TestHostPolicy_HostPool(t *testing.T) { if actualA.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualA.Info().HostID()) } - actualA.Mark(nil) + actualA.Mark(nil, 0) actualB := iter() if actualB.Info().HostID() != "1" { t.Errorf("Expected hosts[1] but was hosts[%s]", actualB.Info().HostID()) } - actualB.Mark(fmt.Errorf("error")) + actualB.Mark(fmt.Errorf("error"), 0) actualC := iter() if actualC.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualC.Info().HostID()) } - actualC.Mark(nil) + actualC.Mark(nil, 0) actualD := iter() if actualD.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualD.Info().HostID()) } - actualD.Mark(nil) + actualD.Mark(nil, 0) } func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) { @@ -847,3 +847,119 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { expectHosts(t, "non-local DC", iter, "0", "1", "4", "5", "8", "9") expectNoMoreHosts(t, iter) } + +func TestHostPolicy_LatencyAware(t *testing.T) { + lap := LatencyAwarePolicy(RoundRobinHostPolicy()) + if lap == nil { + t.Errorf("Expected non error when creating the LatencyAwarePolicy") + return + } + + hosts := [...]*HostInfo{ + {hostId: "0", connectAddress: net.ParseIP("10.0.0.1")}, + {hostId: "1", connectAddress: net.ParseIP("10.0.0.2")}, + {hostId: "2", connectAddress: net.ParseIP("10.0.0.3")}, + {hostId: "3", connectAddress: net.ParseIP("10.0.0.4")}, + } + + for _, host := range hosts { + lap.AddHost(host) + } + + lap.SetPartitioner("OrderedPartitioner") + iterA := lap.Pick(nil) + ha := iterA() + if ha == nil { + t.Errorf("Expected non-nil host when pick from the query plan") + } + + targetIP := "10.0.0.2" + targetIPLatency := uint64(1000) + hostsLatencies := map[string]uint64{"10.0.0.1": 3100, targetIP: targetIPLatency, "10.0.0.3": 3000, "10.0.0.4": 4000} + if l, ok := lap.(*latencyAwarePolicy); ok { + for h, stat := range l.latencies { + stat.mu.Lock() + stat.numMeasure = minMeasures + 1 + stat.timestamp = time.Now().UnixNano() - 10 + stat.mu.Unlock() + l.updateHostLatency(h, hostsLatencies[h]) + if stat.average <= 0 { + t.Errorf("Expected host [%s] average latency to be positve after the 1st update", h) + } + l.updateHostLatency(h, hostsLatencies[h]) + if stat.average <= 0 { + t.Errorf("Expected host [%s] average latency to be positve after the 2rd update", h) + } + } + l.maMu.Lock() + l.minAvg = targetIPLatency + l.maMu.Unlock() + iterB := lap.Pick(nil) + h := iterB() + hIP := h.Info().ConnectAddress().String() + if hIP != targetIP { + t.Errorf("Expected the host with smallest latency but got [%s]", hIP) + } + + unstableHost := hosts[0] + unstableHostIP := unstableHost.ConnectAddress().String() + lap.HostDown(unstableHost) + if _, ok := l.latencies[unstableHostIP]; ok { + t.Errorf("Expected the host %s to be gone after being down", unstableHostIP) + } + lap.HostUp(unstableHost) + if _, ok := l.latencies[unstableHostIP]; !ok { + t.Errorf("Expected the host %s to be there after being up", unstableHostIP) + } + // wait for the min avg update + time.Sleep(updateMinAvgRate * 2) + l.stopUpdateMinAvgLatency() + } else { + t.Errorf("Expected latencyAwarePolicy type") + } + + for _, host := range hosts { + lap.RemoveHost(host) + } +} + +func TestHostPolicy_LatencyAware_Panic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic when creating the LatencyAwarePolicy without a fallback policy") + } + }() + _ = LatencyAwarePolicy(nil) +} + +func TestHostPolicy_LatencyAware_SelectedLatencyAwareHost(t *testing.T) { + hostIP := "10.0.0.1" + hostInfo := &HostInfo{hostId: "0", connectAddress: net.ParseIP(hostIP)} + hostLatencies := map[string]*hostLatencyStat{hostIP: {host: hostInfo}} + lap := &latencyAwarePolicy{latencies: hostLatencies, fallback: RoundRobinHostPolicy(), + stopUpdateMinAvgChan: make(chan struct{})} + h := selectedLatencyAwareHost{ + policy: lap, + info: hostInfo} + if h.Info() == nil { + t.Errorf("Expected the host info to be non-nil") + } + + if ok := h.shouldConsiderNewLatency(nil); !ok { + t.Errorf("Expected the shouldConsiderNewLatency returns true for no error") + } + h.Mark(errorFrame{code: ErrCodeServer}, 1000) + if lap.latencies[hostIP].timestamp != 0 { + t.Errorf("Expected the timestamp is empty for ErrCodeServer") + } + h.Mark(errorFrame{code: ErrCodeReadTimeout}, 1000) + if lap.latencies[hostIP].timestamp == 0 { + t.Errorf("Expected the timestamp is not empty for ErrCodeReadTimeout") + } + + lap.RemoveHost(hostInfo) + h.Mark(errorFrame{code: ErrCodeServer}, 1000) + if _, ok := lap.latencies[hostIP]; ok { + t.Errorf("Expected the host is gone") + } +} diff --git a/query_executor.go b/query_executor.go index fb68b07f2..71a3a281c 100644 --- a/query_executor.go +++ b/query_executor.go @@ -151,17 +151,20 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne continue } + start := time.Now() iter = q.attemptQuery(ctx, qry, conn) + end := time.Now() + latency := uint64(end.Sub(start).Nanoseconds()) iter.host = selectedHost.Info() // Update host switch iter.err { case context.Canceled, context.DeadlineExceeded, ErrNotFound: // those errors represents logical errors, they should not count // toward removing a node from the pool - selectedHost.Mark(nil) + selectedHost.Mark(nil, latency) return iter default: - selectedHost.Mark(iter.err) + selectedHost.Mark(iter.err, latency) } // Exit if the query was successful