diff --git a/CHANGELOG.md b/CHANGELOG.md index 53a23358e..f9cb11448 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,7 @@ * [CHANGE] Changed `ShouldLog()` function signature in `middleware.OptionalLogging` interface to `ShouldLog(context.Context) (bool, string)`: the returned `string` contains an optional reason. When reason is valued, `GRPCServerLog` adds `()` suffix to the error. #514 * [CHANGE] Cache: Remove superfluous `cache.RemoteCacheClient` interface and unify all caches using the `cache.Cache` interface. #520 * [CHANGE] Updated the minimum required Go version to 1.21. #540 +* [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 * [CHANGE] memberlist: Metric `memberlist_client_messages_in_broadcast_queue` is now split into `queue="local"` and `queue="gossip"` values. #539 * [CHANGE] memberlist: Failure to fast-join a cluster via contacting a node is now logged at `info` instead of `debug`. #585 * [CHANGE] `Service.AddListener` and `Manager.AddListener` now return function for stopping the listener. #564 @@ -232,7 +233,9 @@ * [ENHANCEMENT] grpcclient: Support custom gRPC compressors. #583 * [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584 * [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591 -* [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 +* [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601 +* [ENHANCEMENT] Memberlist: Add concurrency to the transport's WriteTo method. #525 +* [ENHANCEMENT] Memberlist: Notifications can now be processed once per interval specified by `-memberlist.notify-interval` to reduce notify storms in large clusters. #592 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 * [BUGFIX] Ring: `ring_member_ownership_percent` and `ring_tokens_owned` metrics are not updated on scale down. #109 diff --git a/cache/mock.go b/cache/mock.go index 4a5dae962..15d95419a 100644 --- a/cache/mock.go +++ b/cache/mock.go @@ -19,10 +19,11 @@ var ( type MockCache struct { mu sync.Mutex cache map[string]Item + now time.Time } func NewMockCache() *MockCache { - c := &MockCache{} + c := &MockCache{now: time.Now()} c.Flush() return c } @@ -30,14 +31,14 @@ func NewMockCache() *MockCache { func (m *MockCache) SetAsync(key string, value []byte, ttl time.Duration) { m.mu.Lock() defer m.mu.Unlock() - m.cache[key] = Item{Data: value, ExpiresAt: time.Now().Add(ttl)} + m.cache[key] = Item{Data: value, ExpiresAt: m.now.Add(ttl)} } func (m *MockCache) SetMultiAsync(data map[string][]byte, ttl time.Duration) { m.mu.Lock() defer m.mu.Unlock() - exp := time.Now().Add(ttl) + exp := m.now.Add(ttl) for key, val := range data { m.cache[key] = Item{Data: val, ExpiresAt: exp} } @@ -46,7 +47,7 @@ func (m *MockCache) SetMultiAsync(data map[string][]byte, ttl time.Duration) { func (m *MockCache) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { m.mu.Lock() defer m.mu.Unlock() - m.cache[key] = Item{Data: value, ExpiresAt: time.Now().Add(ttl)} + m.cache[key] = Item{Data: value, ExpiresAt: m.now.Add(ttl)} return nil } @@ -54,11 +55,11 @@ func (m *MockCache) Add(_ context.Context, key string, value []byte, ttl time.Du m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.cache[key]; ok { + if i, ok := m.cache[key]; ok && i.ExpiresAt.After(m.now) { return ErrNotStored } - m.cache[key] = Item{Data: value, ExpiresAt: time.Now().Add(ttl)} + m.cache[key] = Item{Data: value, ExpiresAt: m.now.Add(ttl)} return nil } @@ -68,7 +69,7 @@ func (m *MockCache) GetMulti(_ context.Context, keys []string, _ ...Option) map[ found := make(map[string][]byte, len(keys)) - now := time.Now() + now := m.now for _, k := range keys { v, ok := m.cache[k] if ok && now.Before(v.ExpiresAt) { @@ -107,6 +108,7 @@ func (m *MockCache) Delete(_ context.Context, key string) error { return nil } +// Flush removes all entries from the cache func (m *MockCache) Flush() { m.mu.Lock() defer m.mu.Unlock() @@ -114,6 +116,14 @@ func (m *MockCache) Flush() { m.cache = map[string]Item{} } +// Advance changes "now" by the given duration +func (m *MockCache) Advance(d time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + m.now = m.now.Add(d) +} + // InstrumentedMockCache is a mocked cache implementation which also tracks the number // of times its functions are called. type InstrumentedMockCache struct { @@ -172,10 +182,16 @@ func (m *InstrumentedMockCache) GetItems() map[string]Item { return m.cache.GetItems() } +// Flush removes all entries from the cache func (m *InstrumentedMockCache) Flush() { m.cache.Flush() } +// Advance changes "now" by the given duration +func (m *InstrumentedMockCache) Advance(d time.Duration) { + m.cache.Advance(d) +} + func (m *InstrumentedMockCache) CountStoreCalls() int { return int(m.storeCount.Load()) } diff --git a/concurrency/worker.go b/concurrency/worker.go index f40f03348..10a59e600 100644 --- a/concurrency/worker.go +++ b/concurrency/worker.go @@ -5,12 +5,18 @@ package concurrency // If all workers are busy, Go() will spawn a new goroutine to run the workload. func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { p := &ReusableGoroutinesPool{ - jobs: make(chan func()), + jobs: make(chan func()), + closed: make(chan struct{}), } for i := 0; i < size; i++ { go func() { - for f := range p.jobs { - f() + for { + select { + case f := <-p.jobs: + f() + case <-p.closed: + return + } } }() } @@ -18,7 +24,8 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { } type ReusableGoroutinesPool struct { - jobs chan func() + jobs chan func() + closed chan struct{} } // Go will run the given function in a worker of the pool. @@ -32,7 +39,9 @@ func (p *ReusableGoroutinesPool) Go(f func()) { } // Close stops the workers of the pool. -// No new Do() calls should be performed after calling Close(). +// No new Go() calls should be performed after calling Close(). // Close does NOT wait for all jobs to finish, it is the caller's responsibility to ensure that in the provided workloads. // Close is intended to be used in tests to ensure that no goroutines are leaked. -func (p *ReusableGoroutinesPool) Close() { close(p.jobs) } +func (p *ReusableGoroutinesPool) Close() { + close(p.closed) +} diff --git a/concurrency/worker_test.go b/concurrency/worker_test.go index 338062055..c8ceef904 100644 --- a/concurrency/worker_test.go +++ b/concurrency/worker_test.go @@ -4,10 +4,12 @@ import ( "regexp" "runtime" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestReusableGoroutinesPool(t *testing.T) { @@ -59,3 +61,29 @@ func TestReusableGoroutinesPool(t *testing.T) { } t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines()) } + +// TestReusableGoroutinesPool_Race tests that Close() and Go() can be called concurrently. +func TestReusableGoroutinesPool_Race(t *testing.T) { + w := NewReusableGoroutinesPool(2) + + var runCountAtomic atomic.Int32 + const maxMsgCount = 10 + + var testWG sync.WaitGroup + testWG.Add(1) + go func() { + defer testWG.Done() + for i := 0; i < maxMsgCount; i++ { + w.Go(func() { + runCountAtomic.Add(1) + }) + time.Sleep(10 * time.Millisecond) + } + }() + time.Sleep(10 * time.Millisecond) + w.Close() // close the pool + testWG.Wait() // wait for the test to finish + + runCt := int(runCountAtomic.Load()) + require.Equal(t, runCt, 10, "expected all functions to run") +} diff --git a/crypto/tls/test/tls_integration_test.go b/crypto/tls/test/tls_integration_test.go index 941cd39e3..c656c7e07 100644 --- a/crypto/tls/test/tls_integration_test.go +++ b/crypto/tls/test/tls_integration_test.go @@ -13,10 +13,12 @@ import ( "path/filepath" "runtime" "strconv" + "strings" "testing" "time" "github.com/gogo/status" + "github.com/hashicorp/go-cleanhttp" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -103,6 +105,7 @@ func newIntegrationClientServer( serv, err := server.New(cfg) require.NoError(t, err) + defer serv.Shutdown() serv.HTTP.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, "OK") @@ -115,36 +118,42 @@ func newIntegrationClientServer( require.NoError(t, err) }() - // Wait until the server is up and running - assert.Eventually(t, func() bool { - conn, err := net.DialTimeout("tcp", httpAddr.String(), 1*time.Second) - if err != nil { - t.Logf("error dialing http: %v", err) - return false - } - defer conn.Close() - grpcConn, err := net.DialTimeout("tcp", grpcAddr.String(), 1*time.Second) - if err != nil { - t.Logf("error dialing grpc: %v", err) - return false - } - defer grpcConn.Close() - return true - }, 2500*time.Millisecond, 1*time.Second, "server is not up") - httpURL := fmt.Sprintf("https://localhost:%d/hello", httpAddr.Port) grpcHost := net.JoinHostPort("localhost", strconv.Itoa(grpcAddr.Port)) for _, tc := range tcs { - tlsClientConfig, err := tc.tlsConfig.GetTLSConfig() - require.NoError(t, err) - // HTTP t.Run("HTTP/"+tc.name, func(t *testing.T) { - transport := &http.Transport{TLSClientConfig: tlsClientConfig} + tlsClientConfig, err := tc.tlsConfig.GetTLSConfig() + require.NoError(t, err) + + transport := cleanhttp.DefaultTransport() + transport.TLSClientConfig = tlsClientConfig client := &http.Client{Transport: transport} - resp, err := client.Get(httpURL) + cancellableCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequestWithContext(cancellableCtx, http.MethodGet, httpURL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + // We retry the request a few times in case of a TCP reset (and we're expecting an error) + // Sometimes, the server resets the connection rather than sending the TLS error + // Seems that even Google have issues with RST flakiness: https://go-review.googlesource.com/c/go/+/527196 + isRST := func(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "connection reset by peer") || strings.Contains(err.Error(), "broken pipe") + } + for i := 0; i < 3 && isRST(err) && tc.httpExpectError != nil; i++ { + time.Sleep(100 * time.Millisecond) + resp, err = client.Do(req) + if err == nil { + defer resp.Body.Close() + } + } if err == nil { defer resp.Body.Close() } @@ -175,16 +184,18 @@ func newIntegrationClientServer( dialOptions = append([]grpc.DialOption{grpc.WithDefaultCallOptions(clientConfig.CallOptions()...)}, dialOptions...) conn, err := grpc.NewClient(grpcHost, dialOptions...) - assert.NoError(t, err, tc.name) - require.NoError(t, err, tc.name) require.NoError(t, err, tc.name) + defer conn.Close() + + cancellableCtx, cancel := context.WithCancel(context.Background()) + defer cancel() client := grpc_health_v1.NewHealthClient(conn) // TODO: Investigate why the client doesn't really receive the // error about the bad certificate from the server side and just // see connection closed instead - resp, err := client.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{}) + resp, err := client.Check(cancellableCtx, &grpc_health_v1.HealthCheckRequest{}) if tc.grpcExpectError != nil { tc.grpcExpectError(t, err) return @@ -194,10 +205,7 @@ func newIntegrationClientServer( assert.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, resp.Status) } }) - } - - serv.Shutdown() } func TestServerWithoutTlsEnabled(t *testing.T) { diff --git a/kv/memberlist/kv.pb.go b/kv/memberlist/kv.pb.go index 4c2eb9265..2080e9789 100644 --- a/kv/memberlist/kv.pb.go +++ b/kv/memberlist/kv.pb.go @@ -76,6 +76,10 @@ type KeyValuePair struct { Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` // ID of the codec used to write the value Codec string `protobuf:"bytes,3,opt,name=codec,proto3" json:"codec,omitempty"` + // Is this Key marked for deletion? + Deleted bool `protobuf:"varint,4,opt,name=deleted,proto3" json:"deleted,omitempty"` + // When was the key last updated? + UpdateTimeMillis int64 `protobuf:"varint,5,opt,name=update_time_millis,json=updateTimeMillis,proto3" json:"update_time_millis,omitempty"` } func (m *KeyValuePair) Reset() { *m = KeyValuePair{} } @@ -131,6 +135,20 @@ func (m *KeyValuePair) GetCodec() string { return "" } +func (m *KeyValuePair) GetDeleted() bool { + if m != nil { + return m.Deleted + } + return false +} + +func (m *KeyValuePair) GetUpdateTimeMillis() int64 { + if m != nil { + return m.UpdateTimeMillis + } + return 0 +} + func init() { proto.RegisterType((*KeyValueStore)(nil), "memberlist.KeyValueStore") proto.RegisterType((*KeyValuePair)(nil), "memberlist.KeyValuePair") @@ -139,22 +157,25 @@ func init() { func init() { proto.RegisterFile("kv.proto", fileDescriptor_2216fe83c9c12408) } var fileDescriptor_2216fe83c9c12408 = []byte{ - // 236 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xc8, 0x2e, 0xd3, 0x2b, - 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0xca, 0x4d, 0xcd, 0x4d, 0x4a, 0x2d, 0xca, 0xc9, 0x2c, 0x2e, - 0x91, 0xd2, 0x4d, 0xcf, 0x2c, 0xc9, 0x28, 0x4d, 0xd2, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xcf, 0x4f, - 0xcf, 0xd7, 0x07, 0x2b, 0x49, 0x2a, 0x4d, 0x03, 0xf3, 0xc0, 0x1c, 0x30, 0x0b, 0xa2, 0x55, 0xc9, - 0x9e, 0x8b, 0xd7, 0x3b, 0xb5, 0x32, 0x2c, 0x31, 0xa7, 0x34, 0x35, 0xb8, 0x24, 0xbf, 0x28, 0x55, - 0x48, 0x8f, 0x8b, 0xb5, 0x20, 0x31, 0xb3, 0xa8, 0x58, 0x82, 0x51, 0x81, 0x59, 0x83, 0xdb, 0x48, - 0x42, 0x0f, 0x61, 0xb6, 0x1e, 0x4c, 0x65, 0x40, 0x62, 0x66, 0x51, 0x10, 0x44, 0x99, 0x92, 0x0f, - 0x17, 0x0f, 0xb2, 0xb0, 0x90, 0x00, 0x17, 0x73, 0x76, 0x6a, 0xa5, 0x04, 0xa3, 0x02, 0xa3, 0x06, - 0x67, 0x10, 0x88, 0x29, 0x24, 0xc2, 0xc5, 0x5a, 0x06, 0x92, 0x96, 0x60, 0x52, 0x60, 0xd4, 0xe0, - 0x09, 0x82, 0x70, 0x40, 0xa2, 0xc9, 0xf9, 0x29, 0xa9, 0xc9, 0x12, 0xcc, 0x60, 0x95, 0x10, 0x8e, - 0x93, 0xc9, 0x85, 0x87, 0x72, 0x0c, 0x37, 0x1e, 0xca, 0x31, 0x7c, 0x78, 0x28, 0xc7, 0xd8, 0xf0, - 0x48, 0x8e, 0x71, 0xc5, 0x23, 0x39, 0xc6, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, - 0xf0, 0x48, 0x8e, 0xf1, 0xc5, 0x23, 0x39, 0x86, 0x0f, 0x8f, 0xe4, 0x18, 0x27, 0x3c, 0x96, 0x63, - 0xb8, 0xf0, 0x58, 0x8e, 0xe1, 0xc6, 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0x5f, 0x8c, 0x01, 0x01, - 0x00, 0x00, 0xff, 0xff, 0x7a, 0x22, 0xdf, 0xec, 0x12, 0x01, 0x00, 0x00, + // 288 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x4c, 0x90, 0x31, 0x4b, 0xc3, 0x40, + 0x18, 0x86, 0xef, 0x33, 0x56, 0xeb, 0x59, 0xa1, 0x1c, 0x0e, 0x87, 0xc3, 0x47, 0xe8, 0x94, 0x41, + 0x53, 0x50, 0x77, 0xc1, 0x55, 0x04, 0x89, 0xe2, 0x5a, 0x92, 0xe6, 0xb3, 0x1e, 0xcd, 0x71, 0x25, + 0xbd, 0x14, 0xba, 0xf9, 0x13, 0x1c, 0xfc, 0x11, 0xfe, 0x14, 0xc7, 0x8e, 0x1d, 0xed, 0x65, 0x71, + 0xec, 0x4f, 0x90, 0x5c, 0x28, 0xba, 0xbd, 0xcf, 0xfb, 0x3e, 0x77, 0xc3, 0xc7, 0xbb, 0xd3, 0x45, + 0x3c, 0x2b, 0x8d, 0x35, 0x82, 0x6b, 0xd2, 0x19, 0x95, 0x85, 0x9a, 0xdb, 0xb3, 0x8b, 0x89, 0xb2, + 0xaf, 0x55, 0x16, 0x8f, 0x8d, 0x1e, 0x4e, 0xcc, 0xc4, 0x0c, 0xbd, 0x92, 0x55, 0x2f, 0x9e, 0x3c, + 0xf8, 0xd4, 0x3e, 0x1d, 0xdc, 0xf0, 0x93, 0x3b, 0x5a, 0x3e, 0xa7, 0x45, 0x45, 0x8f, 0xd6, 0x94, + 0x24, 0x62, 0xde, 0x99, 0xa5, 0xaa, 0x9c, 0x4b, 0x08, 0x83, 0xe8, 0xf8, 0x52, 0xc6, 0x7f, 0x7f, + 0xc7, 0x3b, 0xf3, 0x21, 0x55, 0x65, 0xd2, 0x6a, 0x83, 0x0f, 0xe0, 0xbd, 0xff, 0xbd, 0xe8, 0xf3, + 0x60, 0x4a, 0x4b, 0x09, 0x21, 0x44, 0x47, 0x49, 0x13, 0xc5, 0x29, 0xef, 0x2c, 0x9a, 0x59, 0xee, + 0x85, 0x10, 0xf5, 0x92, 0x16, 0x9a, 0x76, 0x6c, 0x72, 0x1a, 0xcb, 0xc0, 0x9b, 0x2d, 0x08, 0xc9, + 0x0f, 0x73, 0x2a, 0xc8, 0x52, 0x2e, 0xf7, 0x43, 0x88, 0xba, 0xc9, 0x0e, 0xc5, 0x39, 0x17, 0xd5, + 0x2c, 0x4f, 0x2d, 0x8d, 0xac, 0xd2, 0x34, 0xd2, 0xaa, 0x28, 0xd4, 0x5c, 0x76, 0x42, 0x88, 0x82, + 0xa4, 0xdf, 0x2e, 0x4f, 0x4a, 0xd3, 0xbd, 0xef, 0x6f, 0xaf, 0x57, 0x1b, 0x64, 0xeb, 0x0d, 0xb2, + 0xed, 0x06, 0xe1, 0xcd, 0x21, 0x7c, 0x3a, 0x84, 0x2f, 0x87, 0xb0, 0x72, 0x08, 0xdf, 0x0e, 0xe1, + 0xc7, 0x21, 0xdb, 0x3a, 0x84, 0xf7, 0x1a, 0xd9, 0xaa, 0x46, 0xb6, 0xae, 0x91, 0x65, 0x07, 0xfe, + 0x28, 0x57, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0xe0, 0x1f, 0xee, 0xce, 0x5b, 0x01, 0x00, 0x00, } func (this *KeyValueStore) Equal(that interface{}) bool { @@ -214,6 +235,12 @@ func (this *KeyValuePair) Equal(that interface{}) bool { if this.Codec != that1.Codec { return false } + if this.Deleted != that1.Deleted { + return false + } + if this.UpdateTimeMillis != that1.UpdateTimeMillis { + return false + } return true } func (this *KeyValueStore) GoString() string { @@ -232,11 +259,13 @@ func (this *KeyValuePair) GoString() string { if this == nil { return "nil" } - s := make([]string, 0, 7) + s := make([]string, 0, 9) s = append(s, "&memberlist.KeyValuePair{") s = append(s, "Key: "+fmt.Sprintf("%#v", this.Key)+",\n") s = append(s, "Value: "+fmt.Sprintf("%#v", this.Value)+",\n") s = append(s, "Codec: "+fmt.Sprintf("%#v", this.Codec)+",\n") + s = append(s, "Deleted: "+fmt.Sprintf("%#v", this.Deleted)+",\n") + s = append(s, "UpdateTimeMillis: "+fmt.Sprintf("%#v", this.UpdateTimeMillis)+",\n") s = append(s, "}") return strings.Join(s, "") } @@ -305,6 +334,21 @@ func (m *KeyValuePair) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if m.UpdateTimeMillis != 0 { + i = encodeVarintKv(dAtA, i, uint64(m.UpdateTimeMillis)) + i-- + dAtA[i] = 0x28 + } + if m.Deleted { + i-- + if m.Deleted { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x20 + } if len(m.Codec) > 0 { i -= len(m.Codec) copy(dAtA[i:], m.Codec) @@ -373,6 +417,12 @@ func (m *KeyValuePair) Size() (n int) { if l > 0 { n += 1 + l + sovKv(uint64(l)) } + if m.Deleted { + n += 2 + } + if m.UpdateTimeMillis != 0 { + n += 1 + sovKv(uint64(m.UpdateTimeMillis)) + } return n } @@ -405,6 +455,8 @@ func (this *KeyValuePair) String() string { `Key:` + fmt.Sprintf("%v", this.Key) + `,`, `Value:` + fmt.Sprintf("%v", this.Value) + `,`, `Codec:` + fmt.Sprintf("%v", this.Codec) + `,`, + `Deleted:` + fmt.Sprintf("%v", this.Deleted) + `,`, + `UpdateTimeMillis:` + fmt.Sprintf("%v", this.UpdateTimeMillis) + `,`, `}`, }, "") return s @@ -631,6 +683,45 @@ func (m *KeyValuePair) Unmarshal(dAtA []byte) error { } m.Codec = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Deleted", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowKv + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.Deleted = bool(v != 0) + case 5: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field UpdateTimeMillis", wireType) + } + m.UpdateTimeMillis = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowKv + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.UpdateTimeMillis |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipKv(dAtA[iNdEx:]) diff --git a/kv/memberlist/kv.proto b/kv/memberlist/kv.proto index cc5f12463..b2e513b07 100644 --- a/kv/memberlist/kv.proto +++ b/kv/memberlist/kv.proto @@ -19,4 +19,9 @@ message KeyValuePair { // ID of the codec used to write the value string codec = 3; + + // Is this Key marked for deletion? + bool deleted = 4; + // When was the key last updated? + int64 update_time_millis = 5; } diff --git a/kv/memberlist/memberlist_client.go b/kv/memberlist/memberlist_client.go index 1d96363fe..a58b98412 100644 --- a/kv/memberlist/memberlist_client.go +++ b/kv/memberlist/memberlist_client.go @@ -72,8 +72,14 @@ func (c *Client) Get(ctx context.Context, key string) (interface{}, error) { } // Delete is part of kv.Client interface. -func (c *Client) Delete(_ context.Context, _ string) error { - return errors.New("memberlist does not support Delete") +func (c *Client) Delete(ctx context.Context, key string) error { + err := c.awaitKVRunningOrStopping(ctx) + if err != nil { + return err + } + + c.kv.Delete(key) + return nil } // CAS is part of kv.Client interface @@ -137,6 +143,7 @@ type KVConfig struct { GossipToTheDeadTime time.Duration `yaml:"gossip_to_dead_nodes_time" category:"advanced"` DeadNodeReclaimTime time.Duration `yaml:"dead_node_reclaim_time" category:"advanced"` EnableCompression bool `yaml:"compression_enabled" category:"advanced"` + NotifyInterval time.Duration `yaml:"notify_interval" category:"advanced"` // ip:port to advertise other cluster members. Used for NAT traversal AdvertiseAddr string `yaml:"advertise_addr"` @@ -154,7 +161,8 @@ type KVConfig struct { RejoinInterval time.Duration `yaml:"rejoin_interval" category:"advanced"` // Remove LEFT ingesters from ring after this timeout. - LeftIngestersTimeout time.Duration `yaml:"left_ingesters_timeout" category:"advanced"` + LeftIngestersTimeout time.Duration `yaml:"left_ingesters_timeout" category:"advanced"` + ObsoleteEntriesTimeout time.Duration `yaml:"obsolete_entries_timeout" category:"advanced"` // Timeout used when leaving the memberlist cluster. LeaveTimeout time.Duration `yaml:"leave_timeout" category:"advanced"` @@ -195,6 +203,7 @@ func (cfg *KVConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix string) { f.DurationVar(&cfg.DeadNodeReclaimTime, prefix+"memberlist.dead-node-reclaim-time", mlDefaults.DeadNodeReclaimTime, "How soon can dead node's name be reclaimed with new address. 0 to disable.") f.IntVar(&cfg.MessageHistoryBufferBytes, prefix+"memberlist.message-history-buffer-bytes", 0, "How much space to use for keeping received and sent messages in memory for troubleshooting (two buffers). 0 to disable.") f.BoolVar(&cfg.EnableCompression, prefix+"memberlist.compression-enabled", mlDefaults.EnableCompression, "Enable message compression. This can be used to reduce bandwidth usage at the cost of slightly more CPU utilization.") + f.DurationVar(&cfg.NotifyInterval, prefix+"memberlist.notify-interval", 0, "How frequently to notify watchers when a key changes. Can reduce CPU activity in large memberlist deployments. 0 to notify without delay.") f.StringVar(&cfg.AdvertiseAddr, prefix+"memberlist.advertise-addr", mlDefaults.AdvertiseAddr, "Gossip address to advertise to other members in the cluster. Used for NAT traversal.") f.IntVar(&cfg.AdvertisePort, prefix+"memberlist.advertise-port", mlDefaults.AdvertisePort, "Gossip port to advertise to other members in the cluster. Used for NAT traversal.") f.StringVar(&cfg.ClusterLabel, prefix+"memberlist.cluster-label", mlDefaults.Label, "The cluster label is an optional string to include in outbound packets and gossip streams. Other members in the memberlist cluster will discard any message whose label doesn't match the configured one, unless the 'cluster-label-verification-disabled' configuration option is set to true.") @@ -251,6 +260,10 @@ type KV struct { watchers map[string][]chan string prefixWatchers map[string][]chan string + // Delayed notifications for watchers + notifMu sync.Mutex + keyNotifications map[string]struct{} + // Buffers with sent and received messages. Used for troubleshooting only. // New messages are appended, old messages (based on configured size limit) removed from the front. messagesMu sync.Mutex @@ -324,6 +337,9 @@ type ValueDesc struct { // ID of codec used to write this value. Only used when sending full state. CodecID string + + Deleted bool + UpdateTime time.Time } func (v ValueDesc) Clone() (result ValueDesc) { @@ -338,6 +354,8 @@ type valueUpdate struct { value []byte codec codec.Codec messageSize int + deleted bool + updateTime time.Time } func (v ValueDesc) String() string { @@ -359,17 +377,18 @@ func NewKV(cfg KVConfig, logger log.Logger, dnsProvider DNSProvider, registerer cfg.TCPTransport.MetricsNamespace = cfg.MetricsNamespace mlkv := &KV{ - cfg: cfg, - logger: logger, - registerer: registerer, - provider: dnsProvider, - store: make(map[string]ValueDesc), - codecs: make(map[string]codec.Codec), - watchers: make(map[string][]chan string), - prefixWatchers: make(map[string][]chan string), - workersChannels: make(map[string]chan valueUpdate), - shutdown: make(chan struct{}), - maxCasRetries: maxCasRetries, + cfg: cfg, + logger: logger, + registerer: registerer, + provider: dnsProvider, + store: make(map[string]ValueDesc), + codecs: make(map[string]codec.Codec), + watchers: make(map[string][]chan string), + keyNotifications: make(map[string]struct{}), + prefixWatchers: make(map[string][]chan string), + workersChannels: make(map[string]chan valueUpdate), + shutdown: make(chan struct{}), + maxCasRetries: maxCasRetries, } mlkv.createAndRegisterMetrics() @@ -486,6 +505,13 @@ func (m *KV) running(ctx context.Context) error { return errFailedToJoinCluster } + if m.cfg.NotifyInterval > 0 { + // Start delayed key notifications. + notifTicker := time.NewTicker(m.cfg.NotifyInterval) + defer notifTicker.Stop() + go m.monitorKeyNotifications(ctx, notifTicker.C) + } + var tickerChan <-chan time.Time if m.cfg.RejoinInterval > 0 && len(m.cfg.JoinMembers) > 0 { t := time.NewTicker(m.cfg.RejoinInterval) @@ -494,6 +520,9 @@ func (m *KV) running(ctx context.Context) error { tickerChan = t.C } + obsoleteEntriesTicker := time.NewTicker(m.cfg.PushPullInterval) + defer obsoleteEntriesTicker.Stop() + logger := log.With(m.logger, "phase", "periodic_rejoin") for { select { @@ -507,6 +536,11 @@ func (m *KV) running(ctx context.Context) error { level.Warn(logger).Log("msg", "re-joining memberlist cluster failed", "err", err, "next_try_in", m.cfg.RejoinInterval) } + case <-obsoleteEntriesTicker.C: + // cleanupObsoleteEntries is normally called during push/pull, but if there are no other + // nodes to push/pull with, we can call it periodically to make sure we remove unused entries from memory. + m.cleanupObsoleteEntries() + case <-ctx.Done(): return nil } @@ -905,7 +939,59 @@ func removeWatcherChannel(k string, w chan string, watchers map[string][]chan st } } +// notifyWatchers sends notification to all watchers of given key. If delay is +// enabled, it accumulates them for later sending. func (m *KV) notifyWatchers(key string) { + if m.cfg.NotifyInterval <= 0 { + m.notifyWatchersSync(key) + return + } + + m.notifMu.Lock() + defer m.notifMu.Unlock() + m.keyNotifications[key] = struct{}{} +} + +// monitorKeyNotifications sends accumulated notifications to all watchers of +// respective keys when the given channel ticks. +func (m *KV) monitorKeyNotifications(ctx context.Context, tickChan <-chan time.Time) { + if m.cfg.NotifyInterval <= 0 { + panic("sendNotifications called with NotifyInterval <= 0") + } + + for { + select { + case <-tickChan: + m.sendKeyNotifications() + case <-ctx.Done(): + return + } + } +} + +// sendKeyNotifications sends accumulated notifications to watchers of respective keys. +func (m *KV) sendKeyNotifications() { + newNotifs := func() map[string]struct{} { + // Grab and clear accumulated notifications. + m.notifMu.Lock() + defer m.notifMu.Unlock() + + if len(m.keyNotifications) == 0 { + return nil + } + newMap := make(map[string]struct{}) + notifs := m.keyNotifications + m.keyNotifications = newMap + return notifs + } + + for key := range newNotifs() { + m.notifyWatchersSync(key) + } +} + +// notifyWatcherSync immediately sends notification to all watchers of given key. +func (m *KV) notifyWatchersSync(key string) { m.watchersMu.Lock() defer m.watchersMu.Unlock() @@ -939,6 +1025,32 @@ func (m *KV) notifyWatchers(key string) { } } +func (m *KV) Delete(key string) error { + m.storeMu.Lock() + defer m.storeMu.Unlock() + + val, ok := m.store[key] + if !ok || val.Deleted { + return nil + } + + c := m.GetCodec(val.CodecID) + if c == nil { + return fmt.Errorf("invalid codec: %s", val.CodecID) + } + + change, newver, deleted, updated, err := m.mergeValueForKey(key, nil, false, 0, val.CodecID, true, time.Now()) + if err != nil { + return err + } + + if newver > 0 { + m.notifyWatchers(key) + m.broadcastNewValue(key, change, newver, c, false, deleted, updated) + } + return nil +} + // CAS implements Compare-And-Set/Swap operation. // // CAS expects that value returned by 'f' function implements Mergeable interface. If it doesn't, CAS fails immediately. @@ -969,7 +1081,7 @@ outer: } } - change, newver, retry, err := m.trySingleCas(key, codec, f) + change, newver, retry, deleted, updated, err := m.trySingleCas(key, codec, f) if err != nil { level.Debug(m.logger).Log("msg", "CAS attempt failed", "err", err, "retry", retry) @@ -984,13 +1096,13 @@ outer: m.casSuccesses.Inc() m.notifyWatchers(key) - m.broadcastNewValue(key, change, newver, codec, true) + m.broadcastNewValue(key, change, newver, codec, true, deleted, updated) } return nil } - if lastError == errVersionMismatch { + if errors.Is(lastError, errVersionMismatch) { // this is more likely error than version mismatch. lastError = errTooManyRetries } @@ -1001,50 +1113,50 @@ outer: // returns change, error (or nil, if CAS succeeded), and whether to retry or not. // returns errNoChangeDetected if merge failed to detect change in f's output. -func (m *KV) trySingleCas(key string, codec codec.Codec, f func(in interface{}) (out interface{}, retry bool, err error)) (Mergeable, uint, bool, error) { +func (m *KV) trySingleCas(key string, codec codec.Codec, f func(in interface{}) (out interface{}, retry bool, err error)) (Mergeable, uint, bool, bool, time.Time, error) { val, ver, err := m.get(key, codec) if err != nil { - return nil, 0, false, fmt.Errorf("failed to get value: %v", err) + return nil, 0, false, false, time.Time{}, fmt.Errorf("failed to get value: %v", err) } out, retry, err := f(val) if err != nil { - return nil, 0, retry, fmt.Errorf("fn returned error: %v", err) + return nil, 0, retry, false, time.Time{}, fmt.Errorf("fn returned error: %v", err) } if out == nil { // no change to be done - return nil, 0, false, nil + return nil, 0, false, false, time.Time{}, nil } // Don't even try incomingValue, ok := out.(Mergeable) if !ok || incomingValue == nil { - return nil, 0, retry, fmt.Errorf("invalid type: %T, expected Mergeable", out) + return nil, 0, retry, false, time.Time{}, fmt.Errorf("invalid type: %T, expected Mergeable", out) } // To support detection of removed items from value, we will only allow CAS operation to // succeed if version hasn't changed, i.e. state hasn't changed since running 'f'. // Supplied function may have kept a reference to the returned "incoming value". // If KV store will keep this value as well, it needs to make a clone. - change, newver, err := m.mergeValueForKey(key, incomingValue, true, ver, codec) + change, newver, deleted, updated, err := m.mergeValueForKey(key, incomingValue, true, ver, codec.CodecID(), false, time.Now()) if err == errVersionMismatch { - return nil, 0, retry, err + return nil, 0, retry, false, time.Time{}, err } if err != nil { - return nil, 0, retry, fmt.Errorf("merge failed: %v", err) + return nil, 0, retry, false, time.Time{}, fmt.Errorf("merge failed: %v", err) } if newver == 0 { // CAS method reacts on this error - return nil, 0, retry, errNoChangeDetected + return nil, 0, retry, deleted, updated, errNoChangeDetected } - return change, newver, retry, nil + return change, newver, retry, deleted, updated, nil } -func (m *KV) broadcastNewValue(key string, change Mergeable, version uint, codec codec.Codec, locallyGenerated bool) { +func (m *KV) broadcastNewValue(key string, change Mergeable, version uint, codec codec.Codec, locallyGenerated bool, deleted bool, updateTime time.Time) { if locallyGenerated && m.State() != services.Running { level.Warn(m.logger).Log("msg", "skipped broadcasting of locally-generated update because memberlist KV is shutting down", "key", key) return @@ -1057,7 +1169,7 @@ func (m *KV) broadcastNewValue(key string, change Mergeable, version uint, codec return } - kvPair := KeyValuePair{Key: key, Value: data, Codec: codec.CodecID()} + kvPair := KeyValuePair{Key: key, Value: data, Codec: codec.CodecID(), Deleted: deleted, UpdateTimeMillis: updateTimeMillis(updateTime)} pairData, err := kvPair.Marshal() if err != nil { level.Error(m.logger).Log("msg", "failed to serialize KV pair", "key", key, "version", version, "err", err) @@ -1134,7 +1246,7 @@ func (m *KV) NotifyMsg(msg []byte) { ch := m.getKeyWorkerChannel(kvPair.Key) select { - case ch <- valueUpdate{value: kvPair.Value, codec: codec, messageSize: len(msg)}: + case ch <- valueUpdate{value: kvPair.Value, codec: codec, messageSize: len(msg), deleted: kvPair.Deleted, updateTime: updateTime(kvPair.UpdateTimeMillis)}: default: m.numberOfDroppedMessages.Inc() level.Warn(m.logger).Log("msg", "notify queue full, dropping message", "key", kvPair.Key) @@ -1161,7 +1273,7 @@ func (m *KV) processValueUpdate(workerCh <-chan valueUpdate, key string) { select { case update := <-workerCh: // we have a value update! Let's merge it with our current version for given key - mod, version, err := m.mergeBytesValueForKey(key, update.value, update.codec) + mod, version, deleted, updated, err := m.mergeBytesValueForKey(key, update.value, update.codec, update.deleted, update.updateTime) changes := []string(nil) if mod != nil { @@ -1185,8 +1297,8 @@ func (m *KV) processValueUpdate(workerCh <-chan valueUpdate, key string) { } else if version > 0 { m.notifyWatchers(key) - // Don't resend original message, but only changes. - m.broadcastNewValue(key, mod, version, update.codec, false) + // Don't resend original message, but only changes, if any. + m.broadcastNewValue(key, mod, version, update.codec, false, deleted, updated) } case <-m.shutdown: @@ -1229,6 +1341,8 @@ func (m *KV) LocalState(_ bool) []byte { m.numberOfPulls.Inc() + m.cleanupObsoleteEntries() + m.storeMu.Lock() defer m.storeMu.Unlock() @@ -1260,6 +1374,8 @@ func (m *KV) LocalState(_ bool) []byte { kvPair.Key = key kvPair.Value = encoded kvPair.Codec = val.CodecID + kvPair.Deleted = val.Deleted + kvPair.UpdateTimeMillis = updateTimeMillis(val.UpdateTime) ser, err := kvPair.Marshal() if err != nil { @@ -1342,7 +1458,7 @@ func (m *KV) MergeRemoteState(data []byte, _ bool) { } // we have both key and value, try to merge it with our state - change, newver, err := m.mergeBytesValueForKey(kvPair.Key, kvPair.Value, codec) + change, newver, deleted, updated, err := m.mergeBytesValueForKey(kvPair.Key, kvPair.Value, codec, kvPair.Deleted, updateTime(kvPair.UpdateTimeMillis)) changes := []string(nil) if change != nil { @@ -1361,7 +1477,7 @@ func (m *KV) MergeRemoteState(data []byte, _ bool) { level.Error(m.logger).Log("msg", "failed to store received value", "key", kvPair.Key, "err", err) } else if newver > 0 { m.notifyWatchers(kvPair.Key) - m.broadcastNewValue(kvPair.Key, change, newver, codec, false) + m.broadcastNewValue(kvPair.Key, change, newver, codec, false, deleted, updated) } } @@ -1370,26 +1486,26 @@ func (m *KV) MergeRemoteState(data []byte, _ bool) { } } -func (m *KV) mergeBytesValueForKey(key string, incomingData []byte, codec codec.Codec) (Mergeable, uint, error) { +func (m *KV) mergeBytesValueForKey(key string, incomingData []byte, codec codec.Codec, deleted bool, updateTime time.Time) (Mergeable, uint, bool, time.Time, error) { decodedValue, err := codec.Decode(incomingData) if err != nil { - return nil, 0, fmt.Errorf("failed to decode value: %v", err) + return nil, 0, false, time.Time{}, fmt.Errorf("failed to decode value: %v", err) } incomingValue, ok := decodedValue.(Mergeable) if !ok { - return nil, 0, fmt.Errorf("expected Mergeable, got: %T", decodedValue) + return nil, 0, false, time.Time{}, fmt.Errorf("expected Mergeable, got: %T", decodedValue) } // No need to clone this "incomingValue", since we have just decoded it from bytes, and won't be using it. - return m.mergeValueForKey(key, incomingValue, false, 0, codec) + return m.mergeValueForKey(key, incomingValue, false, 0, codec.CodecID(), deleted, updateTime) } // Merges incoming value with value we have in our store. Returns "a change" that can be sent to other // cluster members to update their state, and new version of the value. // If CAS version is specified, then merging will fail if state has changed already, and errVersionMismatch is reported. // If no modification occurred, new version is 0. -func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, incomingValueRequiresClone bool, casVersion uint, codec codec.Codec) (Mergeable, uint, error) { +func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, incomingValueRequiresClone bool, casVersion uint, codecID string, deleted bool, updateTime time.Time) (change Mergeable, newVersion uint, newDeleted bool, newUpdated time.Time, err error) { m.storeMu.Lock() defer m.storeMu.Unlock() @@ -1399,16 +1515,25 @@ func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, incomingValue curr := m.store[key] // if casVersion is 0, then there was no previous value, so we will just do normal merge, without localCAS flag set. if casVersion > 0 && curr.Version != casVersion { - return nil, 0, errVersionMismatch + return nil, 0, false, time.Time{}, errVersionMismatch } result, change, err := computeNewValue(incomingValue, incomingValueRequiresClone, curr.value, casVersion > 0) if err != nil { - return nil, 0, err + return nil, 0, false, time.Time{}, err + } + + newVersion = curr.Version + 1 + newUpdated = curr.UpdateTime + newDeleted = curr.Deleted + + if !updateTime.IsZero() && updateTime.After(newUpdated) { + newUpdated = updateTime + newDeleted = deleted } // No change, don't store it. - if change == nil || len(change.MergeContent()) == 0 { - return nil, 0, nil + if (change == nil || len(change.MergeContent()) == 0) && curr.Deleted == newDeleted { + return nil, 0, curr.Deleted, curr.UpdateTime, nil } if m.cfg.LeftIngestersTimeout > 0 { @@ -1425,22 +1550,23 @@ func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, incomingValue // RemoveTombstones twice with same limit should be noop. change.RemoveTombstones(limit) if len(change.MergeContent()) == 0 { - return nil, 0, nil + return nil, 0, curr.Deleted, curr.UpdateTime, nil } } - newVersion := curr.Version + 1 m.store[key] = ValueDesc{ - value: result, - Version: newVersion, - CodecID: codec.CodecID(), + value: result, + Version: newVersion, + CodecID: codecID, + Deleted: newDeleted, + UpdateTime: newUpdated, } // The "changes" returned by Merge() can contain references to the "result" // state. Therefore, make sure we clone it before releasing the lock. change = change.Clone() - return change, newVersion, nil + return change, newVersion, newDeleted, newUpdated, nil } // returns [result, change, error] @@ -1457,6 +1583,10 @@ func computeNewValue(incoming Mergeable, incomingValueRequiresClone bool, oldVal return incoming, incoming, nil } + if incoming == nil { + return oldVal, nil, nil + } + // otherwise we have two mergeables, so merge them change, err := oldVal.Merge(incoming, cas) return oldVal, change, err @@ -1518,6 +1648,17 @@ func (m *KV) deleteSentReceivedMessages() { m.receivedMessagesSize = 0 } +func (m *KV) cleanupObsoleteEntries() { + m.storeMu.Lock() + defer m.storeMu.Lock() + + for k, v := range m.store { + if v.Deleted && time.Since(v.UpdateTime) > m.cfg.ObsoleteEntriesTimeout { + delete(m.store, k) + } + } +} + func addMessageToBuffer(msgs []Message, size int, limit int, msg Message) ([]Message, int) { msgs = append(msgs, msg) size += msg.Size @@ -1529,3 +1670,17 @@ func addMessageToBuffer(msgs []Message, size int, limit int, msg Message) ([]Mes return msgs, size } + +func updateTime(val int64) time.Time { + if val == 0 { + return time.Time{} + } + return time.UnixMilli(val) +} + +func updateTimeMillis(ts time.Time) int64 { + if ts.IsZero() { + return 0 + } + return ts.UnixMilli() +} diff --git a/kv/memberlist/memberlist_client_test.go b/kv/memberlist/memberlist_client_test.go index 0d14c5f76..f1894e4ae 100644 --- a/kv/memberlist/memberlist_client_test.go +++ b/kv/memberlist/memberlist_client_test.go @@ -255,7 +255,6 @@ func getLocalhostAddrs() []string { func TestBasicGetAndCas(t *testing.T) { c := dataCodec{} - name := "Ing 1" var cfg KVConfig flagext.DefaultValues(&cfg) cfg.TCPTransport = TCPTransportConfig{ @@ -278,6 +277,7 @@ func TestBasicGetAndCas(t *testing.T) { } // Create member in PENDING state, with some tokens + name := "Ing 1" err = cas(kv, key, updateFn(name)) require.NoError(t, err) @@ -590,12 +590,16 @@ func TestMultipleClientsWithMixedLabelsAndExpectFailure(t *testing.T) { // 1) "" // 2) "label1" // 3) "label2" + // 4) "label3" + // 5) "label4" // // We expect that it won't be possible to build a memberlist cluster with mixed labels. var membersLabel = []string{ "", "label1", "label2", + "label3", + "label4", } configGen := func(i int) KVConfig { @@ -609,7 +613,7 @@ func TestMultipleClientsWithMixedLabelsAndExpectFailure(t *testing.T) { err := testMultipleClientsWithConfigGenerator(t, len(membersLabel), configGen) require.Error(t, err) - require.Contains(t, err.Error(), fmt.Sprintf("expected to see %d members, got", len(membersLabel))) + require.Contains(t, err.Error(), "expected to see at least 2 members, got 1") } func TestMultipleClientsWithMixedLabelsAndClusterLabelVerificationDisabled(t *testing.T) { @@ -658,6 +662,8 @@ func TestMultipleClientsWithSameLabelWithClusterLabelVerification(t *testing.T) } func testMultipleClientsWithConfigGenerator(t *testing.T, members int, configGen func(memberId int) KVConfig) error { + t.Helper() + c := dataCodec{} const key = "ring" var clients []*Client @@ -719,12 +725,10 @@ func testMultipleClientsWithConfigGenerator(t *testing.T, members int, configGen startTime := time.Now() firstKv := clients[0] ctx, cancel := context.WithTimeout(context.Background(), casInterval*3) // Watch for 3x cas intervals. - updates := 0 - gotMembers := 0 + joinedMembers := 0 firstKv.WatchKey(ctx, key, func(in interface{}) bool { - updates++ - r := in.(*data) + joinedMembers = len(r.Members) minTimestamp, maxTimestamp, avgTimestamp := getTimestamps(r.Members) @@ -733,22 +737,13 @@ func testMultipleClientsWithConfigGenerator(t *testing.T, members int, configGen "tokens, oldest timestamp:", now.Sub(time.Unix(minTimestamp, 0)).String(), "avg timestamp:", now.Sub(time.Unix(avgTimestamp, 0)).String(), "youngest timestamp:", now.Sub(time.Unix(maxTimestamp, 0)).String()) - gotMembers = len(r.Members) return true // yes, keep watching }) cancel() // make linter happy - t.Logf("Ring updates observed: %d", updates) - - // We expect that all members are in the ring - if gotMembers != members { - return fmt.Errorf("expected to see %d members, got %d", members, gotMembers) - } - - if updates < members { - // in general, at least one update from each node. (although that's not necessarily true... - // but typically we get more updates than that anyway) - return fmt.Errorf("expected to see at least %d updates, got %d", members, updates) + if joinedMembers <= 1 { + // expect at least 2 members. Otherwise, this means that the ring has failed to sync. + return fmt.Errorf("expected to see at least 2 members, got %d", joinedMembers) } if err := getClientErr(); err != nil { @@ -758,47 +753,69 @@ func testMultipleClientsWithConfigGenerator(t *testing.T, members int, configGen // Let's check all the clients to see if they have relatively up-to-date information // All of them should at least have all the clients // And same tokens. - allTokens := []uint32(nil) - - for i := 0; i < members; i++ { - kv := clients[i] + check := func() error { + allTokens := []uint32(nil) - r := getData(t, kv, key) - t.Logf("KV %d: number of known members: %d\n", i, len(r.Members)) - if len(r.Members) != members { - return fmt.Errorf("Member %d has only %d members in the ring", i, len(r.Members)) - } + for i := 0; i < members; i++ { + kv := clients[i] - minTimestamp, maxTimestamp, avgTimestamp := getTimestamps(r.Members) - for n, ing := range r.Members { - if ing.State != ACTIVE { - return fmt.Errorf("Member %d: invalid state of member %s in the ring: %v ", i, n, ing.State) + r := getData(t, kv, key) + t.Logf("KV %d: number of known members: %d\n", i, len(r.Members)) + if len(r.Members) != members { + return fmt.Errorf("Member %d has only %d members in the ring", i, len(r.Members)) } - } - now := time.Now() - t.Logf("Member %d: oldest: %v, avg: %v, youngest: %v", i, - now.Sub(time.Unix(minTimestamp, 0)).String(), - now.Sub(time.Unix(avgTimestamp, 0)).String(), - now.Sub(time.Unix(maxTimestamp, 0)).String()) - - tokens := r.getAllTokens() - if allTokens == nil { - allTokens = tokens - t.Logf("Found tokens: %d", len(allTokens)) - } else { - if len(allTokens) != len(tokens) { - return fmt.Errorf("Member %d: Expected %d tokens, got %d", i, len(allTokens), len(tokens)) + + minTimestamp, maxTimestamp, avgTimestamp := getTimestamps(r.Members) + for n, ing := range r.Members { + if ing.State != ACTIVE { + stateStr := "UNKNOWN" + switch ing.State { + case JOINING: + stateStr = "JOINING" + case LEFT: + stateStr = "LEFT" + } + return fmt.Errorf("Member %d: invalid state of member %s in the ring: %s (%v) ", i, n, stateStr, ing.State) + } } + now := time.Now() + t.Logf("Member %d: oldest: %v, avg: %v, youngest: %v", i, + now.Sub(time.Unix(minTimestamp, 0)).String(), + now.Sub(time.Unix(avgTimestamp, 0)).String(), + now.Sub(time.Unix(maxTimestamp, 0)).String()) + + tokens := r.getAllTokens() + if allTokens == nil { + allTokens = tokens + t.Logf("Found tokens: %d", len(allTokens)) + } else { + if len(allTokens) != len(tokens) { + return fmt.Errorf("Member %d: Expected %d tokens, got %d", i, len(allTokens), len(tokens)) + } - for ix, tok := range allTokens { - if tok != tokens[ix] { - return fmt.Errorf("Member %d: Tokens at position %d differ: %v, %v", i, ix, tok, tokens[ix]) + for ix, tok := range allTokens { + if tok != tokens[ix] { + return fmt.Errorf("Member %d: Tokens at position %d differ: %v, %v", i, ix, tok, tokens[ix]) + } } } } + + return getClientErr() } - return getClientErr() + // Try this for ~10 seconds. memberlist is eventually consistent, so we may need to wait a bit, especially with `-race`. + for timeout := time.After(10 * time.Second); ; { + select { + case <-timeout: + return check() // return last error + default: + if err := check(); err == nil { + return nil // it passed + } + time.Sleep(100 * time.Millisecond) + } + } } func TestJoinMembersWithRetryBackoff(t *testing.T) { @@ -1669,11 +1686,11 @@ func TestGetBroadcastsPrefersLocalUpdates(t *testing.T) { require.Equal(t, 0, len(kv.GetBroadcasts(0, math.MaxInt32))) // Check that locally-generated broadcast messages will be prioritized and sent out first, even if they are enqueued later or are smaller than other messages in the queue. - kv.broadcastNewValue("non-local", smallUpdate, 1, codec, false) - kv.broadcastNewValue("non-local", bigUpdate, 2, codec, false) - kv.broadcastNewValue("local", smallUpdate, 1, codec, true) - kv.broadcastNewValue("local", bigUpdate, 2, codec, true) - kv.broadcastNewValue("local", mediumUpdate, 3, codec, true) + kv.broadcastNewValue("non-local", smallUpdate, 1, codec, false, time.Now()) + kv.broadcastNewValue("non-local", bigUpdate, 2, codec, false, time.Now()) + kv.broadcastNewValue("local", smallUpdate, 1, codec, true, time.Now()) + kv.broadcastNewValue("local", bigUpdate, 2, codec, true, time.Now()) + kv.broadcastNewValue("local", mediumUpdate, 3, codec, true, time.Now()) err := testutil.GatherAndCompare(reg, bytes.NewBufferString(` # HELP memberlist_client_messages_in_broadcast_queue Number of user messages in the broadcast queue @@ -1786,3 +1803,119 @@ func marshalState(t *testing.T, kvps ...*KeyValuePair) []byte { return buf.Bytes() } + +func TestNotificationDelay(t *testing.T) { + cfg := KVConfig{} + // We're going to trigger sends manually, so effectively disable the automatic send interval. + const hundredYears = 100 * 365 * 24 * time.Hour + cfg.NotifyInterval = hundredYears + kv := NewKV(cfg, log.NewNopLogger(), &dnsProviderMock{}, prometheus.NewPedanticRegistry()) + + watchChan := make(chan string, 16) + + // Add ourselves as a watcher. + kv.watchersMu.Lock() + kv.watchers["foo_123"] = append(kv.watchers["foo_123"], watchChan) + kv.watchers["foo_124"] = append(kv.watchers["foo_124"], watchChan) + kv.watchersMu.Unlock() + + defer func() { + kv.watchersMu.Lock() + removeWatcherChannel("foo_123", watchChan, kv.watchers) + removeWatcherChannel("foo_124", watchChan, kv.watchers) + kv.watchersMu.Unlock() + }() + + verifyNotifs := func(expected map[string]int, comment string) { + observed := make(map[string]int, len(expected)) + for kk := range expected { + observed[kk] = 0 + } + loop: + for { + select { + case k := <-watchChan: + observed[k]++ + default: + break loop + } + } + require.Equal(t, expected, observed, comment) + } + + drainChan := func() { + for { + select { + case <-watchChan: + default: + return + } + } + } + + kv.notifyWatchers("foo_123") + kv.sendKeyNotifications() + verifyNotifs(map[string]int{"foo_123": 1}, "1 change 1 notification") + + // Test coalescing of updates. + drainChan() + verifyNotifs(map[string]int{"foo_123": 0}, "chan drained") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.notifyWatchers("foo_123") + verifyNotifs(map[string]int{"foo_123": 0}, "no flush -> no watcher notification") + kv.sendKeyNotifications() + verifyNotifs(map[string]int{"foo_123": 1}, "flush should coalesce updates") + + // multiple buffered updates + drainChan() + verifyNotifs(map[string]int{"foo_123": 0}, "chan drained") + kv.notifyWatchers("foo_123") + kv.sendKeyNotifications() + kv.notifyWatchers("foo_123") + kv.sendKeyNotifications() + verifyNotifs(map[string]int{"foo_123": 2}, "two buffered updates") + + // multiple keys + drainChan() + kv.notifyWatchers("foo_123") + kv.notifyWatchers("foo_124") + kv.sendKeyNotifications() + verifyNotifs(map[string]int{"foo_123": 1, "foo_124": 1}, "2 changes 2 notifications") + kv.sendKeyNotifications() + verifyNotifs(map[string]int{"foo_123": 0, "foo_124": 0}, "no new notifications") + + // sendKeyNotifications can be called repeatedly without new updates. + kv.sendKeyNotifications() + kv.sendKeyNotifications() + kv.sendKeyNotifications() + kv.sendKeyNotifications() + + // Finally, exercise the monitor method. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tick := make(chan time.Time) + go kv.monitorKeyNotifications(ctx, tick) + kv.notifyWatchers("foo_123") + tick <- time.Now() + + require.Eventually(t, func() bool { + select { + case k := <-watchChan: + if k != "foo_123" { + panic(fmt.Sprintf("unexpected key: %s", k)) + } + return true + default: // nothing yet. + return false + } + }, 20*time.Second, 100*time.Millisecond) +} diff --git a/kv/memberlist/status.gohtml b/kv/memberlist/status.gohtml index 6f845b6e0..becf3652c 100644 --- a/kv/memberlist/status.gohtml +++ b/kv/memberlist/status.gohtml @@ -22,6 +22,8 @@ Key Codec Version + Deleted + Update Time Actions @@ -32,6 +34,8 @@ {{ $k }} {{ $v.CodecID }} {{ $v.Version }} + {{ $v.Deleted }} + {{ $v.UpdateTime }} json | json-pretty @@ -149,4 +153,4 @@

Message history buffer is disabled, refer to the configuration to enable it in order to troubleshoot the message history.

{{ end }} - \ No newline at end of file + diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index 751ad1163..241d25b71 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -19,7 +19,6 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "go.uber.org/atomic" dstls "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" @@ -52,7 +51,13 @@ type TCPTransportConfig struct { // Timeout for writing packet data. Zero = no timeout. PacketWriteTimeout time.Duration `yaml:"packet_write_timeout" category:"advanced"` - // Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on + // Maximum number of concurrent writes to other nodes. + MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"` + + // Timeout for acquiring one of the concurrent write slots. + AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"` + + // Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-" category:"advanced"` // Where to put custom metrics. nil = don't register. @@ -73,12 +78,19 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") + f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") + f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f) } +type writeRequest struct { + b []byte + addr string +} + // TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream // operations ("packet" and "stream" are terms used by memberlist). // It uses a new TCP connections for each operation. There is no connection reuse. @@ -91,7 +103,11 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config - shutdown atomic.Int32 + shutdownMu sync.RWMutex + shutdown bool + writeCh chan writeRequest // this channel is protected by shutdownMu + + writeWG sync.WaitGroup advertiseMu sync.RWMutex advertiseAddr string @@ -107,6 +123,7 @@ type TCPTransport struct { sentPackets prometheus.Counter sentPacketsBytes prometheus.Counter sentPacketsErrors prometheus.Counter + droppedPackets prometheus.Counter unknownConnections prometheus.Counter } @@ -119,11 +136,21 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr // Build out the new transport. var ok bool + concurrentWrites := config.MaxConcurrentWrites + if concurrentWrites <= 0 { + concurrentWrites = 1 + } t := TCPTransport{ cfg: config, logger: log.With(logger, "component", "memberlist TCPTransport"), packetCh: make(chan *memberlist.Packet), connCh: make(chan net.Conn), + writeCh: make(chan writeRequest), + } + + for i := 0; i < concurrentWrites; i++ { + t.writeWG.Add(1) + go t.writeWorker() } var err error @@ -205,7 +232,10 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) { for { conn, err := tcpLn.Accept() if err != nil { - if s := t.shutdown.Load(); s == 1 { + t.shutdownMu.RLock() + isShuttingDown := t.shutdown + t.shutdownMu.RUnlock() + if isShuttingDown { break } @@ -424,29 +454,50 @@ func (t *TCPTransport) getAdvertisedAddr() string { // WriteTo is a packet-oriented interface that fires off the given // payload to the given address. func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { - t.sentPackets.Inc() - t.sentPacketsBytes.Add(float64(len(b))) - - err := t.writeTo(b, addr) - if err != nil { - t.sentPacketsErrors.Inc() - - logLevel := level.Warn(t.logger) - if strings.Contains(err.Error(), "connection refused") { - // The connection refused is a common error that could happen during normal operations when a node - // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. - logLevel = t.debugLog() - } - logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) + t.shutdownMu.RLock() + defer t.shutdownMu.RUnlock() // Unlock at the end to protect the chan + if t.shutdown { + return time.Time{}, errors.New("transport is shutting down") + } + // Send the packet to the write workers + // If this blocks for too long (as configured), abort and log an error. + select { + case <-time.After(t.cfg.AcquireWriterTimeout): + // Dropped packets are not an issue, the memberlist protocol will retry later. + level.Debug(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr) + t.droppedPackets.Inc() // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. return time.Now(), nil + case t.writeCh <- writeRequest{b: b, addr: addr}: + // OK } return time.Now(), nil } +func (t *TCPTransport) writeWorker() { + defer t.writeWG.Done() + for req := range t.writeCh { + b, addr := req.b, req.addr + t.sentPackets.Inc() + t.sentPacketsBytes.Add(float64(len(b))) + err := t.writeTo(b, addr) + if err != nil { + t.sentPacketsErrors.Inc() + + logLevel := level.Warn(t.logger) + if strings.Contains(err.Error(), "connection refused") { + // The connection refused is a common error that could happen during normal operations when a node + // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. + logLevel = t.debugLog() + } + logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) + } + } +} + func (t *TCPTransport) writeTo(b []byte, addr string) error { // Open connection, write packet header and data, data hash, close. Simple. c, err := t.getConnection(addr, t.cfg.PacketDialTimeout) @@ -559,17 +610,31 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn { // Shutdown is called when memberlist is shutting down; this gives the // transport a chance to clean up any listeners. +// This will avoid log spam about errors when we shut down. func (t *TCPTransport) Shutdown() error { + t.shutdownMu.Lock() // This will avoid log spam about errors when we shut down. - t.shutdown.Store(1) + if t.shutdown { + t.shutdownMu.Unlock() + return nil // already shut down + } + + // Set the shutdown flag and close the write channel. + t.shutdown = true + close(t.writeCh) + t.shutdownMu.Unlock() // Rip through all the connections and shut them down. for _, conn := range t.tcpListeners { _ = conn.Close() } + // Wait until all write workers have finished. + t.writeWG.Wait() + // Block until all the listener threads have died. t.wg.Wait() + return nil } @@ -618,6 +683,13 @@ func (t *TCPTransport) registerMetrics(registerer prometheus.Registerer) { Help: "Number of errors when receiving memberlist packets", }) + t.droppedPackets = promauto.With(registerer).NewCounter(prometheus.CounterOpts{ + Namespace: t.cfg.MetricsNamespace, + Subsystem: subsystem, + Name: "packets_dropped_total", + Help: "Number of dropped memberlist packets. These packets were not sent due to timeout waiting for a writer.", + }) + t.sentPackets = promauto.With(registerer).NewCounter(prometheus.CounterOpts{ Namespace: t.cfg.MetricsNamespace, Subsystem: subsystem, diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 310e11ecb..f80bd80bc 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -1,7 +1,11 @@ package memberlist import ( + "net" + "strings" + "sync" "testing" + "time" "github.com/go-kit/log" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/dskit/concurrency" + "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" ) @@ -39,7 +44,7 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T cfg := TCPTransportConfig{} flagext.DefaultValues(&cfg) - cfg.BindAddrs = []string{"127.0.0.1"} + cfg.BindAddrs = getLocalhostAddrs() cfg.BindPort = 0 if testData.setup != nil { testData.setup(t, &cfg) @@ -51,6 +56,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T _, err = transport.WriteTo([]byte("test"), testData.remoteAddr) require.NoError(t, err) + require.NoError(t, transport.Shutdown()) + if testData.expectedLogs != "" { assert.Contains(t, logs.String(), testData.expectedLogs) } @@ -61,6 +68,90 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T } } +type timeoutReader struct{} + +func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) { + time.Sleep(1 * time.Second) + return nil, nil +} + +func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { + writeCt := 50 + + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = getLocalhostAddrs() + cfg.MaxConcurrentWrites = writeCt + cfg.PacketDialTimeout = 500 * time.Millisecond + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + // Configure TLS only for writes. The dialing should timeout (because of the timeoutReader) + transport.cfg.TLSEnabled = true + transport.cfg.TLS = tls.ClientConfig{ + Reader: &timeoutReader{}, + CertPath: "fake", + KeyPath: "fake", + CAPath: "fake", + } + + timeStart := time.Now() + + for i := 0; i < writeCt; i++ { + _, err = transport.WriteTo([]byte("test"), listener.Addr().String()) + require.NoError(t, err) + } + + require.NoError(t, transport.Shutdown()) + + gotErrorCt := strings.Count(logs.String(), "context deadline exceeded") + assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt) + assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)") + assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") +} + +func TestTCPTransportWriterAcquireTimeout(t *testing.T) { + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = getLocalhostAddrs() + cfg.MaxConcurrentWrites = 1 + cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + writeCt := 100 + var reqWg sync.WaitGroup + for i := 0; i < writeCt; i++ { + reqWg.Add(1) + go func() { + defer reqWg.Done() + transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck + }() + } + reqWg.Wait() + + require.NoError(t, transport.Shutdown()) + gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message") + assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt) + assert.NotZero(t, gotErrorCt, "expected errors, got none") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string diff --git a/ring/ring.go b/ring/ring.go index c8db7da50..d47eb8fe2 100644 --- a/ring/ring.go +++ b/ring/ring.go @@ -215,13 +215,13 @@ type Ring struct { // Number of registered instances per zone. instancesCountPerZone map[string]int - // Nubmber of registered instances with tokens per zone. + // Number of registered instances with tokens per zone. instancesWithTokensCountPerZone map[string]int // Number of registered instances are writable and have tokens. writableInstancesWithTokensCount int - // Nubmber of registered instances with tokens per zone that are writable. + // Number of registered instances with tokens per zone that are writable. writableInstancesWithTokensCountPerZone map[string]int // Cache of shuffle-sharded subrings per identifier. Invalidated when topology changes. diff --git a/server/server_test.go b/server/server_test.go index f84abaa51..d5390e7aa 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -132,6 +132,8 @@ func TestTCPv4Network(t *testing.T) { func TestDefaultAddresses(t *testing.T) { var cfg Config cfg.RegisterFlags(flag.NewFlagSet("", flag.ExitOnError)) + cfg.GRPCListenAddress = "localhost" + cfg.HTTPListenAddress = "localhost" cfg.HTTPListenPort = 9090 cfg.MetricsNamespace = "testing_addresses" @@ -949,6 +951,8 @@ func TestGrpcOverProxyProtocol(t *testing.T) { cfg.RegisterFlags(flag.NewFlagSet("", flag.ExitOnError)) cfg.ProxyProtocolEnabled = true // Set this to 0 to have it choose a random port + cfg.HTTPListenAddress = "localhost" + cfg.GRPCListenAddress = "localhost" cfg.HTTPListenPort = 0 fakeSourceIP := "1.2.3.4" diff --git a/spanlogger/spanlogger.go b/spanlogger/spanlogger.go index 8daad995c..f32bce6f6 100644 --- a/spanlogger/spanlogger.go +++ b/spanlogger/spanlogger.go @@ -1,8 +1,13 @@ +// Provenance-includes-location: https://github.com/go-kit/log/blob/main/value.go +// Provenance-includes-license: MIT +// Provenance-includes-copyright: Go kit + package spanlogger import ( "context" "runtime" + "strconv" "strings" "go.uber.org/atomic" // Really just need sync/atomic but there is a lint rule preventing it. @@ -163,9 +168,6 @@ func (s *SpanLogger) getLogger() log.Logger { logger = log.With(logger, "trace_id", traceID) } - // Replace the default valuer for the 'caller' attribute with one that gets the caller of the methods in this file. - logger = log.With(logger, "caller", spanLoggerAwareCaller()) - // If the value has been set by another goroutine, fetch that other value and discard the one we made. if !s.logger.CompareAndSwap(nil, &logger) { pLogger := s.logger.Load() @@ -188,46 +190,64 @@ func (s *SpanLogger) SetSpanAndLogTag(key string, value interface{}) { s.logger.Store(&wrappedLogger) } -// spanLoggerAwareCaller is like log.Caller, but ensures that the caller information is -// that of the caller to SpanLogger, not SpanLogger itself. -func spanLoggerAwareCaller() log.Valuer { - valuer := atomic.NewPointer[log.Valuer](nil) - +// Caller is like github.com/go-kit/log's Caller, but ensures that the caller information is +// that of the caller to SpanLogger (if SpanLogger is being used), not SpanLogger itself. +// +// defaultStackDepth should be the number of stack frames to skip by default, as would be +// passed to github.com/go-kit/log's Caller method. +func Caller(defaultStackDepth int) log.Valuer { return func() interface{} { - // If we've already determined the correct stack depth, use it. - existingValuer := valuer.Load() - if existingValuer != nil { - return (*existingValuer)() - } - - // We haven't been called before, determine the correct stack depth to - // skip the configured logger's internals and the SpanLogger's internals too. - // - // Note that we can't do this in spanLoggerAwareCaller() directly because we - // need to do this when invoked by the configured logger - otherwise we cannot - // measure the stack depth of the logger's internals. - - stackDepth := 3 // log.DefaultCaller uses a stack depth of 3, so start searching for the correct stack depth there. + stackDepth := defaultStackDepth + 1 // +1 to account for this method. + seenSpanLogger := false + pc := make([]uintptr, 1) for { - _, file, _, ok := runtime.Caller(stackDepth) + function, file, line, ok := caller(stackDepth, pc) if !ok { // We've run out of possible stack frames. Give up. - valuer.Store(&unknownCaller) - return unknownCaller() + return "" } - if strings.HasSuffix(file, "spanlogger/spanlogger.go") { - stackValuer := log.Caller(stackDepth + 2) // Add one to skip the stack frame for the SpanLogger method, and another to skip the stack frame for the valuer which we'll invoke below. - valuer.Store(&stackValuer) - return stackValuer() + // If we're in a SpanLogger method, we need to continue searching. + // + // Matching on the exact function name like this does mean this will break if we rename or refactor SpanLogger, but + // the tests should catch this. In the worst case scenario, we'll log incorrect caller information, which isn't the + // end of the world. + if function == "github.com/grafana/dskit/spanlogger.(*SpanLogger).Log" || function == "github.com/grafana/dskit/spanlogger.(*SpanLogger).DebugLog" { + seenSpanLogger = true + stackDepth++ + continue } - stackDepth++ + // We need to check for go-kit/log stack frames like this because using log.With, log.WithPrefix or log.WithSuffix + // (including the various level methods like level.Debug, level.Info etc.) to wrap a SpanLogger introduce an + // additional context.Log stack frame that calls into the SpanLogger. This is because the use of SpanLogger + // as the logger means the optimisation to avoid creating a new logger in + // https://github.com/go-kit/log/blob/c7bf81493e581feca11e11a7672b14be3591ca43/log.go#L141-L146 used by those methods + // can't be used, and so the SpanLogger is wrapped in a new logger. + if seenSpanLogger && function == "github.com/go-kit/log.(*context).Log" { + stackDepth++ + continue + } + + return formatCallerInfoForLog(file, line) } } } -var unknownCaller log.Valuer = func() interface{} { - return "" +// caller is like runtime.Caller, but modified to allow reuse of the uintptr slice and return the function name. +func caller(stackDepth int, pc []uintptr) (function string, file string, line int, ok bool) { + n := runtime.Callers(stackDepth+1, pc) + if n < 1 { + return "", "", 0, false + } + + frame, _ := runtime.CallersFrames(pc).Next() + return frame.Function, frame.File, frame.Line, frame.PC != 0 +} + +// This is based on github.com/go-kit/log's Caller, but modified for use by Caller above. +func formatCallerInfoForLog(file string, line int) string { + idx := strings.LastIndexByte(file, '/') + return file[idx+1:] + ":" + strconv.Itoa(line) } diff --git a/spanlogger/spanlogger_test.go b/spanlogger/spanlogger_test.go index 0e9f7e1d1..fa22a15c8 100644 --- a/spanlogger/spanlogger_test.go +++ b/spanlogger/spanlogger_test.go @@ -7,6 +7,7 @@ import ( "io" "path/filepath" "runtime" + "slices" "strings" "testing" @@ -45,9 +46,6 @@ func TestSpanLogger_CustomLogger(t *testing.T) { } resolver := fakeResolver{} - _, thisFile, thisLineNumber, ok := runtime.Caller(0) - require.True(t, ok) - span, ctx := New(context.Background(), logger, "test", resolver) _ = span.Log("msg", "original spanlogger") @@ -58,9 +56,9 @@ func TestSpanLogger_CustomLogger(t *testing.T) { _ = span.Log("msg", "fallback spanlogger") expect := [][]interface{}{ - {"method", "test", "caller", toCallerInfo(thisFile, thisLineNumber+4), "msg", "original spanlogger"}, - {"caller", toCallerInfo(thisFile, thisLineNumber+7), "msg", "restored spanlogger"}, - {"caller", toCallerInfo(thisFile, thisLineNumber+10), "msg", "fallback spanlogger"}, + {"method", "test", "msg", "original spanlogger"}, + {"msg", "restored spanlogger"}, + {"msg", "fallback spanlogger"}, } require.Equal(t, expect, logged) } @@ -88,9 +86,6 @@ func TestSpanLogger_SetSpanAndLogTag(t *testing.T) { return nil } - _, thisFile, thisLineNumber, ok := runtime.Caller(0) - require.True(t, ok) - spanLogger, _ := New(context.Background(), logger, "the_method", fakeResolver{}) require.NoError(t, spanLogger.Log("msg", "this is the first message")) @@ -110,18 +105,15 @@ func TestSpanLogger_SetSpanAndLogTag(t *testing.T) { expectedLogMessages := [][]interface{}{ { "method", "the_method", - "caller", toCallerInfo(thisFile, thisLineNumber+4), "msg", "this is the first message", }, { "method", "the_method", - "caller", toCallerInfo(thisFile, thisLineNumber+7), "id", "123", "msg", "this is the second message", }, { "method", "the_method", - "caller", toCallerInfo(thisFile, thisLineNumber+10), "id", "123", "more context", "abc", "msg", "this is the third message", @@ -206,7 +198,7 @@ func BenchmarkSpanLoggerWithRealLogger(b *testing.B) { b.Run(name, func(b *testing.B) { buf := bytes.NewBuffer(nil) logger := dskit_log.NewGoKitWithWriter("logfmt", buf) - logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", log.Caller(5)) + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", Caller(5)) if debugEnabled { logger = level.NewFilter(logger, level.AllowAll()) @@ -222,7 +214,7 @@ func BenchmarkSpanLoggerWithRealLogger(b *testing.B) { resolver := fakeResolver{} sl, _ := New(context.Background(), logger, "test", resolver, "bar") - b.Run("log", func(b *testing.B) { + b.Run("Log", func(b *testing.B) { buf.Reset() b.ResetTimer() @@ -231,7 +223,7 @@ func BenchmarkSpanLoggerWithRealLogger(b *testing.B) { } }) - b.Run("level.debug", func(b *testing.B) { + b.Run("level.Debug", func(b *testing.B) { buf.Reset() b.ResetTimer() @@ -240,7 +232,7 @@ func BenchmarkSpanLoggerWithRealLogger(b *testing.B) { } }) - b.Run("debuglog", func(b *testing.B) { + b.Run("DebugLog", func(b *testing.B) { buf.Reset() b.ResetTimer() @@ -250,7 +242,31 @@ func BenchmarkSpanLoggerWithRealLogger(b *testing.B) { }) }) } +} + +func BenchmarkSpanLoggerAwareCaller(b *testing.B) { + runBenchmark := func(b *testing.B, caller log.Valuer) { + buf := bytes.NewBuffer(nil) + logger := dskit_log.NewGoKitWithWriter("logfmt", buf) + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", caller) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = logger.Log("msg", "foo", "more", "data") + } + + } + + const defaultStackDepth = 5 + + b.Run("with go-kit's Caller", func(b *testing.B) { + runBenchmark(b, log.Caller(defaultStackDepth)) + }) + + b.Run("with dskit's spanlogger.Caller", func(b *testing.B) { + runBenchmark(b, Caller(defaultStackDepth)) + }) } // Logger which does nothing and implements the DebugEnabled interface used by SpanLogger. @@ -267,12 +283,12 @@ type loggerWithDebugEnabled struct { func (l loggerWithDebugEnabled) DebugEnabled() bool { return l.debugEnabled } -func TestSpanLogger_CallerInfo(t *testing.T) { +func TestSpanLoggerAwareCaller(t *testing.T) { testCases := map[string]func(w io.Writer) log.Logger{ // This is based on Mimir's default logging configuration: https://github.com/grafana/mimir/blob/50d1c27b4ad82b265ff5a865345bec2d726f64ef/pkg/util/log/log.go#L45-L46 "default logger": func(w io.Writer) log.Logger { logger := dskit_log.NewGoKitWithWriter("logfmt", w) - logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", log.Caller(5)) + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", Caller(5)) logger = level.NewFilter(logger, level.AllowAll()) return logger }, @@ -280,11 +296,19 @@ func TestSpanLogger_CallerInfo(t *testing.T) { // This is based on Mimir's logging configuration with rate-limiting enabled: https://github.com/grafana/mimir/blob/50d1c27b4ad82b265ff5a865345bec2d726f64ef/pkg/util/log/log.go#L42-L43 "rate-limited logger": func(w io.Writer) log.Logger { logger := dskit_log.NewGoKitWithWriter("logfmt", w) - logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", log.Caller(6)) + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", Caller(6)) logger = dskit_log.NewRateLimitedLogger(logger, 1000, 1000, nil) logger = level.NewFilter(logger, level.AllowAll()) return logger }, + + "default logger that has been wrapped with further information": func(w io.Writer) log.Logger { + logger := dskit_log.NewGoKitWithWriter("logfmt", w) + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", Caller(5)) + logger = level.NewFilter(logger, level.AllowAll()) + logger = log.With(logger, "user", "user-1") + return logger + }, } resolver := fakeResolver{} @@ -307,12 +331,15 @@ func TestSpanLogger_CallerInfo(t *testing.T) { return buf, spanLogger, span.(*jaeger.Span) } - requireSpanHasTwoLogLinesWithoutCaller := func(t *testing.T, span *jaeger.Span) { + requireSpanHasTwoLogLinesWithoutCaller := func(t *testing.T, span *jaeger.Span, extraFields ...otlog.Field) { logs := span.Logs() require.Len(t, logs, 2) - require.Equal(t, []otlog.Field{otlog.String("msg", "this is a test")}, logs[0].Fields) - require.Equal(t, []otlog.Field{otlog.String("msg", "this is another test")}, logs[1].Fields) + expectedFields := append(slices.Clone(extraFields), otlog.String("msg", "this is a test")) + require.Equal(t, expectedFields, logs[0].Fields) + + expectedFields = append(slices.Clone(extraFields), otlog.String("msg", "this is another test")) + require.Equal(t, expectedFields, logs[1].Fields) } for name, loggerFactory := range testCases { @@ -326,6 +353,7 @@ func TestSpanLogger_CallerInfo(t *testing.T) { logged := logs.String() require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeFirstLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) logs.Reset() _, _, lineNumberTwoLinesBeforeSecondLogCall, ok := runtime.Caller(0) @@ -334,6 +362,7 @@ func TestSpanLogger_CallerInfo(t *testing.T) { logged = logs.String() require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeSecondLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) requireSpanHasTwoLogLinesWithoutCaller(t, span) }) @@ -346,6 +375,7 @@ func TestSpanLogger_CallerInfo(t *testing.T) { logged := logs.String() require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) logs.Reset() _, _, lineNumberTwoLinesBeforeSecondLogCall, ok := runtime.Caller(0) @@ -354,9 +384,33 @@ func TestSpanLogger_CallerInfo(t *testing.T) { logged = logs.String() require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeSecondLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) requireSpanHasTwoLogLinesWithoutCaller(t, span) }) + + t.Run("logging with SpanLogger wrapped in a level", func(t *testing.T) { + logs, spanLogger, span := setupTest(t, loggerFactory) + + _, thisFile, lineNumberTwoLinesBeforeFirstLogCall, ok := runtime.Caller(0) + require.True(t, ok) + _ = level.Info(spanLogger).Log("msg", "this is a test") + + logged := logs.String() + require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeFirstLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) + + logs.Reset() + _, _, lineNumberTwoLinesBeforeSecondLogCall, ok := runtime.Caller(0) + require.True(t, ok) + _ = level.Info(spanLogger).Log("msg", "this is another test") + + logged = logs.String() + require.Contains(t, logged, "caller="+toCallerInfo(thisFile, lineNumberTwoLinesBeforeSecondLogCall+2)) + require.Equalf(t, 1, strings.Count(logged, "caller="), "expected to only have one caller field, but got: %v", logged) + + requireSpanHasTwoLogLinesWithoutCaller(t, span, otlog.String("level", "info")) + }) }) } }