From 7fbf1b82ea7ed44915b1db6e2c764ffc986eadc6 Mon Sep 17 00:00:00 2001 From: gene-redpanda <123959009+gene-redpanda@users.noreply.github.com> Date: Mon, 8 Jul 2024 19:35:10 -0500 Subject: [PATCH] fix: cleaned up and fixed various bugs --- Makefile | 1 - redpanda/cloud/ratelimiter.go | 118 ++++++++++------ redpanda/cloud/ratelimiter_test.go | 219 +++++++++++++++++++++-------- 3 files changed, 239 insertions(+), 99 deletions(-) diff --git a/Makefile b/Makefile index b433fe6d..c81aa447 100644 --- a/Makefile +++ b/Makefile @@ -69,7 +69,6 @@ unit_tests: @DEBUG=true \ REDPANDA_CLIENT_ID="dummy_id" \ REDPANDA_CLIENT_SECRET="dummy_secret" \ - TF_ACC=false \ RUN_CLUSTER_TESTS=false \ $(GOCMD) test -short ./... diff --git a/redpanda/cloud/ratelimiter.go b/redpanda/cloud/ratelimiter.go index 53155d86..310a0f43 100644 --- a/redpanda/cloud/ratelimiter.go +++ b/redpanda/cloud/ratelimiter.go @@ -14,8 +14,8 @@ import ( ) const ( - limitPeriod = 60.0 - burstPeriod = 10.0 + limitPeriod = 60.0 // api resets rate limit at 60s + burstPeriod = 10.0 // allows bursts of up to 50 requests to reduce hammering of the api ) type rateLimiter struct { @@ -28,67 +28,103 @@ func newRateLimiter(limit int) *rateLimiter { } } +// parseRateLimit expects a header as defined in https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers func parseRateLimit(header string) (limit, remaining int, reset time.Duration, err error) { - for _, part := range strings.Split(header, ",") { - kv := strings.SplitN(strings.TrimSpace(part), "=", 2) - if len(kv) != 2 { + parts := strings.Split(header, ",") + if len(parts) != 3 { + return 0, 0, 0, fmt.Errorf("invalid rate limit header: %s", header) // we expect limit, remaining and reset + } + + var limitSet, remainingSet, resetSet bool + for _, part := range parts { + kv := strings.Split(part, "=") + if len(kv) != 2 { // invalid header contents skip example "limit=1=remaining" continue } - switch kv[0] { + + key, value := strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1]) + intValue, err := strconv.Atoi(value) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid %s value: %v", key, err) + } + + switch key { case "limit": - limit, err = strconv.Atoi(kv[1]) - if err != nil { - return 0, 0, 0, fmt.Errorf("invalid limit value: %v", err) - } + limit = intValue + limitSet = true case "remaining": - remaining, err = strconv.Atoi(kv[1]) - if err != nil { - return 0, 0, 0, fmt.Errorf("invalid remaining value: %v", err) - } + remaining = intValue + remainingSet = true case "reset": - seconds, err := strconv.Atoi(kv[1]) - if err != nil { - return 0, 0, 0, fmt.Errorf("invalid reset value: %v", err) - } - reset = time.Duration(seconds) * time.Second + reset = time.Duration(intValue) * time.Second + resetSet = true } } + + if !limitSet || !remainingSet || !resetSet { + return 0, 0, 0, fmt.Errorf("missing required rate limit information: %s", header) + } + return limit, remaining, reset, nil } +// Limiter is a grpc.UnaryClientInterceptor that updates the rate limiter based on the rate limit headers returned by the server +// malformed or otherwise incorrect headers are discarded with errors logged but non-halting +// messages without rate limits are considered valid func (r *rateLimiter) Limiter(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { var header metadata.MD if err := invoker(ctx, method, req, reply, cc, append(opts, grpc.Header(&header))...); err != nil { return err } - if rateLimitHeader := header.Get("ratelimit"); len(rateLimitHeader) > 0 { - tflog.Debug(ctx, "parsing rate limit headers") - limit, remaining, reset, err := parseRateLimit(rateLimitHeader[0]) - if err != nil { - return fmt.Errorf("failed to parse rate limit header: %v", err) - } + rateLimitHeader := header.Get("ratelimit") + if len(rateLimitHeader) == 0 { + return nil // no rate limit headers + } - tflog.Debug(ctx, "setting limit and burst") - r.limiter.SetLimit(rate.Limit(limit / limitPeriod)) - r.limiter.SetBurst(limit / burstPeriod) - if remaining == 0 && reset > 0 { - tflog.Warn(ctx, "rate limit exceeded", map[string]any{ - "limit": limit, - "remaining": remaining, - "reset": reset, - }) - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(reset + 1*time.Second): - } - } - tflog.Debug(ctx, "rate limit updated", map[string]any{ + tflog.Debug(ctx, "parsing rate limit headers", map[string]any{ + "header": rateLimitHeader[0], + }) + limit, remaining, reset, err := parseRateLimit(rateLimitHeader[0]) + if err != nil { + // if the parser returns an error we log it but otherwise treat it the same as not having a ratelimit header + tflog.Warn(ctx, "failed to parse rate limit header", map[string]any{ + "error": err.Error(), + "header": rateLimitHeader[0], + }) + // lint:ignore nilerr logging and not returning error as it is not fail worthy + return nil + } + + newLimit := rate.Limit(limit / limitPeriod) + newBurst := limit / burstPeriod + + if r.limiter.Limit() != newLimit || r.limiter.Burst() != newBurst { + tflog.Debug(ctx, "updating rate limiter", map[string]any{ + "new_limit": newLimit, + "new_burst": newBurst, + }) + r.limiter.SetLimit(newLimit) + r.limiter.SetBurst(newBurst) + } + + if remaining == 0 && reset > 0 { + tflog.Warn(ctx, "rate limit exceeded", map[string]any{ "limit": limit, "remaining": remaining, "reset": reset, }) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(reset + 1*time.Second): + } } + + tflog.Debug(ctx, "rate limit updated", map[string]any{ + "limit": limit, + "remaining": remaining, + "reset": reset, + }) return r.limiter.Wait(ctx) } diff --git a/redpanda/cloud/ratelimiter_test.go b/redpanda/cloud/ratelimiter_test.go index 75ae6e28..c91a9c85 100644 --- a/redpanda/cloud/ratelimiter_test.go +++ b/redpanda/cloud/ratelimiter_test.go @@ -3,6 +3,7 @@ package cloud import ( "context" "errors" + "fmt" "math" "testing" "time" @@ -18,85 +19,126 @@ func clampLimit(limit, period float64) rate.Limit { func TestRateLimiter_Limiter(t *testing.T) { tests := []struct { - name string - initialLimit int - headerLimit string - expectedLimit rate.Limit - invokerError error - expectedError error - consecutiveCalls int - expectedMinDuration time.Duration + name string + initialLimit int + headerLimit string + remainingValues []int + expectedLimit rate.Limit + expectedBurst int + invokerError error + expectedError error }{ { - name: "Normal update", - initialLimit: 200, - headerLimit: "limit=200,remaining=75,reset=30", - expectedLimit: clampLimit(200.0, limitPeriod), - invokerError: nil, - expectedError: nil, - consecutiveCalls: 10, - expectedMinDuration: time.Second * 7, + name: "Normal update", + initialLimit: 200, + headerLimit: "limit=200,remaining=%d,reset=30", + remainingValues: []int{75, 74, 73, 72, 71}, + expectedLimit: clampLimit(200.0, limitPeriod), + expectedBurst: int(200 / burstPeriod), + invokerError: nil, + expectedError: nil, }, { - name: "Rate limit exceeded", - initialLimit: 100, - headerLimit: "limit=10,remaining=0,reset=30", - expectedLimit: clampLimit(10.0, limitPeriod), - invokerError: nil, - expectedError: nil, - consecutiveCalls: 5, - expectedMinDuration: time.Second * 29, + name: "Rate limit exceeded", + initialLimit: 100, + headerLimit: "limit=100,remaining=%d,reset=15", + remainingValues: []int{5, 4, 3, 2, 1, 0}, + expectedLimit: clampLimit(100.0, limitPeriod), + expectedBurst: int(100 / burstPeriod), + invokerError: nil, + expectedError: nil, }, { - name: "Invalid header", - initialLimit: 100, - headerLimit: "invalid=header", - expectedLimit: clampLimit(100.0, limitPeriod), - invokerError: nil, - expectedError: errors.New("failed to parse rate limit header: incomplete rate limit header: missing required fields"), - consecutiveCalls: 1, - expectedMinDuration: 0, + name: "Invalid header", + initialLimit: 100, + headerLimit: "invalid=header=asdrf", + remainingValues: []int{0}, + expectedLimit: 1, + expectedBurst: 10, + invokerError: nil, + expectedError: nil, }, { - name: "Invoker error", - initialLimit: 100, - headerLimit: "", - expectedLimit: clampLimit(100.0, limitPeriod), - invokerError: errors.New("invoker error"), - expectedError: errors.New("invoker error"), - consecutiveCalls: 1, - expectedMinDuration: 0, + name: "Missing header element", + initialLimit: 100, + headerLimit: "limit=100,remaining=%d", + remainingValues: []int{5, 4, 3, 2, 1, 0}, + expectedLimit: 1, + expectedBurst: 10, + invokerError: nil, + expectedError: nil, + }, + { + name: "Invoker error", // validates that the invoker mock is working + initialLimit: 100, + headerLimit: "", + remainingValues: []int{0}, + expectedLimit: clampLimit(100.0, limitPeriod), + expectedBurst: int(100 / burstPeriod), + invokerError: errors.New("invoker error"), + expectedError: errors.New("invoker error"), + }, + { + name: "No rate limit headers", + initialLimit: 100, + headerLimit: "", + remainingValues: []int{0}, + expectedLimit: 1, + expectedBurst: 10, + invokerError: nil, + expectedError: nil, + }, + { + name: "Malformed header", + initialLimit: 100, + headerLimit: "limit=monkey,remaining=%d,reset=soon", + remainingValues: []int{0}, + expectedLimit: 1, + expectedBurst: 10, + invokerError: nil, + expectedError: nil, + }, + { + name: "Malformed header missing contents", + initialLimit: 100, + headerLimit: "monkey=1,remaining=%d,reset=soon", + remainingValues: []int{0}, + expectedLimit: 1, + expectedBurst: 10, + invokerError: nil, + expectedError: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rl := newRateLimiter(tt.initialLimit) - - mockInvoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + var lastErr error + for _, remaining := range tt.remainingValues { + var headerLimit string if tt.headerLimit != "" { - header := metadata.MD{ - "ratelimit": []string{tt.headerLimit}, - } - for _, opt := range opts { - if headerOpt, ok := opt.(grpc.HeaderCallOption); ok { - *headerOpt.HeaderAddr = header - break + headerLimit = fmt.Sprintf(tt.headerLimit, remaining) + } + mockInvoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + if tt.headerLimit != "" { + header := metadata.MD{ + "ratelimit": []string{headerLimit}, + } + for _, opt := range opts { + if headerOpt, ok := opt.(grpc.HeaderCallOption); ok { + *headerOpt.HeaderAddr = header + break + } } } + return tt.invokerError } - return tt.invokerError - } - start := time.Now() - var lastErr error - for i := 0; i < tt.consecutiveCalls; i++ { lastErr = rl.Limiter(context.Background(), "/test.Method", nil, nil, nil, mockInvoker) if lastErr != nil { break } } - duration := time.Since(start) if (lastErr != nil) != (tt.expectedError != nil) { t.Errorf("Expected error %v, got %v", tt.expectedError, lastErr) @@ -109,8 +151,71 @@ func TestRateLimiter_Limiter(t *testing.T) { t.Errorf("Expected limit to be %v, got %v", tt.expectedLimit, rl.limiter.Limit()) } - if duration < tt.expectedMinDuration { - t.Errorf("Expected minimum duration of %v, but got %v", tt.expectedMinDuration, duration) + if rl.limiter.Burst() != tt.expectedBurst { + t.Errorf("Expected burst to be %v, got %v", tt.expectedBurst, rl.limiter.Burst()) + } + }) + } +} + +func TestParseRateLimit(t *testing.T) { + tests := []struct { + name string + header string + expectedLimit int + expectedRem int + expectedReset time.Duration + expectError bool + }{ + { + name: "Valid header", + header: "limit=200,remaining=75,reset=30", + expectedLimit: 200, + expectedRem: 75, + expectedReset: 30 * time.Second, + expectError: false, + }, + { + name: "Missing field", + header: "limit=200,remaining=75", + expectError: true, + }, + { + name: "Invalid value", + header: "limit=monkey,remaining=75,reset=30", + expectError: true, + }, + { + name: "Extra field", + header: "limit=200,remaining=75,reset=30,extra=10", + expectedLimit: 200, + expectedRem: 75, + expectedReset: 30 * time.Second, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + limit, remaining, reset, err := parseRateLimit(tt.header) + + if tt.expectError { + if err == nil { + t.Errorf("Expected an error, but got none") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if limit != tt.expectedLimit { + t.Errorf("Expected limit %d, got %d", tt.expectedLimit, limit) + } + if remaining != tt.expectedRem { + t.Errorf("Expected remaining %d, got %d", tt.expectedRem, remaining) + } + if reset != tt.expectedReset { + t.Errorf("Expected reset %v, got %v", tt.expectedReset, reset) + } } }) }