Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): V2 API client ratelimit improvements #519

Merged
merged 12 commits into from
Aug 1, 2024
3 changes: 3 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ func (c *Client) retryHTTPCheck(ctx context.Context, resp *http.Response, err er
if ctx.Err() != nil {
return false, ctx.Err()
}
if err != nil {
return true, err
}

if resp != nil {
if resp.StatusCode == http.StatusTooManyRequests ||
Expand Down
43 changes: 32 additions & 11 deletions client/v2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func NewClientWithConfig(config *Config) (*Client, error) {
},
}
client.http = &retryablehttp.Client{
Backoff: retryablehttp.DefaultBackoff,
Backoff: client.retryHTTPBackoff,
CheckRetry: client.retryHTTPCheck,
ErrorHandler: retryablehttp.PassthroughErrorHandler,
HTTPClient: config.HTTPClient,
Expand Down Expand Up @@ -195,22 +195,43 @@ func (c *Client) newRequest(
return req, err
}

// retryHTTPCheck is a retryablehttp.CheckRetry function that will
// retry on a 429 or any 5xx status code
func (c *Client) retryHTTPCheck(
ctx context.Context,
r *http.Response,
_ error,
err error,
) (bool, error) {
if r == nil || ctx.Err() != nil {
if ctx.Err() != nil {
return false, ctx.Err()
}
if err != nil {
return true, err
}
if r != nil {
if r.StatusCode == http.StatusTooManyRequests || r.StatusCode >= 500 {
return true, nil
}
}
return false, nil
}

switch r.StatusCode {
case http.StatusTooManyRequests:
// TODO: use new retry header timestamps to determine when to retry
return true, nil
case http.StatusBadGateway, http.StatusGatewayTimeout, http.StatusInternalServerError:
return true, nil
default:
return false, nil
// retryHTTPBackoff is a retryablehttp.Backoff function that will
// use a linear backoff for all status codes except 429, which will
// attempt to use the rate limit headers to determine the backoff time
func (c *Client) retryHTTPBackoff(
min, max time.Duration,
attemptNum int,
r *http.Response,
) time.Duration {
if r != nil && r.StatusCode == http.StatusTooManyRequests {
return rateLimitBackoff(min, max, r)
}

// if we've not been rate limited, use a linear backoff
// but increase the minimum and maximum backoff times
// and hand it off to retryablehttp.LinearJitterBackoff
min = 500 * time.Millisecond
max = 950 * time.Millisecond
return retryablehttp.LinearJitterBackoff(min, max, attemptNum, r)
}
109 changes: 109 additions & 0 deletions client/v2/limits.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package v2

import (
"errors"
"fmt"
"math/rand"
"net/http"
"time"

"github.com/dunglas/httpsfv"
)

const (
// HeaderRateLimit is the (draft07) recommended header from the IETF
// on rate limiting.
//
// The value of the header is expected to be a HTTP Structured Field Value (SFV)
// dictionary with the keys "limit", "remaining", and "reset".
//
// Where "limit" is the maximum number of requests allowed in the window,
// "remaining" is the number of requests remaining in the window,
// and "reset" is the number of seconds until the limit resets.
HeaderRateLimit = "Ratelimit"

// HeaderRetryAfter is the RFC7231 header used to indicate when a client should
// retry requests in UTC time.
HeaderRetryAfter = "Retry-After"
)

// rateLimitBackoff calculates the backoff time for a rate limited request
// based on the possible response headers.
// The function will first try to get the reset time from the rate limit header.
//
// If the rate limit header is not present, or the reset time is in the past,
// the function will return a random backoff time between min and max.
func rateLimitBackoff(min, max time.Duration, r *http.Response) time.Duration {
// calculate some jitter for a little extra fuzziness to avoid thundering herds
jitter := time.Duration(rand.Float64() * float64(max-min))

var reset time.Duration
if v := r.Header.Get(HeaderRateLimit); v != "" {
// we currently only care about the reset time
_, _, resetSeconds, err := parseRateLimitHeader(v)
if err == nil {
reset = time.Duration(resetSeconds) * time.Second
}
}
// if we didn't get a reset value from the ratelimit header
// try the retry-after header
if reset == 0 {
if v := r.Header.Get(HeaderRetryAfter); v != "" {
retryTime, err := time.Parse(time.RFC3339, v)
if err == nil {
reset = time.Until(retryTime)
}
}
}

// only update min if the time to wait is longer
if reset > min {
min = reset
}
return min + jitter
}

// parseRateLimitHeader parses the rate limit header into its constituent parts.
//
// The header is expected to be in the format "limit=X, remaining=Y, reset=Z".
// Where:
// - X is the maximum number of requests allowed in the window
// - Y is the number of requests remaining in the window
// - Z is the number of seconds until the limit resets
func parseRateLimitHeader(h string) (limit, remaining, reset int64, err error) {
vals, err := httpsfv.UnmarshalDictionary([]string{h})
if err != nil {
err = errors.New("invalid ratelimit header")
return
}

limit, err = valueFromSFVDictionary[int64](vals, "limit")
if err != nil {
err = fmt.Errorf("could not get \"limit\" from header: %w", err)
return
}
remaining, err = valueFromSFVDictionary[int64](vals, "remaining")
if err != nil {
err = fmt.Errorf("could not get \"remaining\" from header: %w", err)
return
}
reset, err = valueFromSFVDictionary[int64](vals, "reset")
if err != nil {
err = fmt.Errorf("could not get \"reset\" from header: %w", err)
return
}
return
}

func valueFromSFVDictionary[T any](d *httpsfv.Dictionary, key string) (T, error) {
var zero T
k, ok := d.Get(key)
if !ok {
return zero, errors.New("key not found")
}
v, ok := k.(httpsfv.Item).Value.(T)
if !ok {
return zero, fmt.Errorf("value is not a %T", zero)
}
return v, nil
}
208 changes: 208 additions & 0 deletions client/v2/limits_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package v2

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestClient_rateLimitBackoff(t *testing.T) {
t.Parallel()

now := time.Now().UTC()
// init min and max to zero to remove jitter
min, max := time.Duration(0), time.Duration(0)

tests := []struct {
name string
headerName string
headerValue string
expectedValue time.Duration
}{
{
name: "no header",
expectedValue: min,
},
{
name: "invalid ratelimit header",
headerName: HeaderRateLimit,
headerValue: "foobar",
expectedValue: min,
},
{
name: "invalid retry-after header",
headerName: HeaderRetryAfter,
headerValue: "three hours from now",
expectedValue: min,
},
{
name: "valid ratelimit header",
headerName: HeaderRateLimit,
headerValue: "limit=100, remaining=50, reset=60",
expectedValue: 60 * time.Second,
},
{
name: "valid retry-after header",
headerName: HeaderRetryAfter,
headerValue: now.Add(2 * time.Minute).UTC().Format(time.RFC3339),
expectedValue: 2 * time.Minute,
},
{
name: "negative reset value in ratelimit header",
headerName: HeaderRateLimit,
headerValue: "limit=100, remaining=-1, reset=-10",
expectedValue: min,
},
{
name: "retry-after in the past",
headerName: HeaderRetryAfter,
headerValue: now.Add(-2 * time.Minute).UTC().Format(time.RFC3339),
expectedValue: min,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
w.Header().Add(tc.headerName, tc.headerValue)
w.WriteHeader(http.StatusTooManyRequests)

r := rateLimitBackoff(min, max, w.Result())
assert.WithinDuration(t,
now.Add(tc.expectedValue),
now.Add(r),
time.Second,
)
})
}

t.Run("ratelimit header takes precedence", func(t *testing.T) {
w := httptest.NewRecorder()
w.Header().Add(HeaderRateLimit, "limit=100, remaining=50, reset=60")
w.Header().Add(HeaderRetryAfter, now.Add(2*time.Minute).UTC().Format(time.RFC3339))
w.WriteHeader(http.StatusTooManyRequests)

r := rateLimitBackoff(min, max, w.Result())
assert.WithinDuration(t,
now.Add(60*time.Second),
now.Add(r),
time.Second,
)
})

t.Run("reset value is fuzzed with jitter", func(t *testing.T) {
w := httptest.NewRecorder()
w.Header().Add(HeaderRateLimit, "limit=100, remaining=50, reset=60")
w.WriteHeader(http.StatusTooManyRequests)

min = 100 * time.Millisecond
max = 500 * time.Millisecond
r := rateLimitBackoff(min, max, w.Result())

if assert.Greater(t, r, 60*time.Second, "expected backoff to be 60sec+") {
assert.WithinRange(t,
time.Now().Add(r-60*time.Second),
time.Now().Add(min),
time.Now().Add(max),
"jitter not applied correctly",
)
}
})
}

func TestClient_parseRateLimitHeader(t *testing.T) {
t.Parallel()

type expect struct {
limit int64
remaining int64
reset int64
}
tests := []struct {
name string
header string
expect expect
expectErr bool
}{
{
name: "empty",
expectErr: true,
},
{
name: "invalid",
header: "foobar",
expectErr: true,
},
{
name: "valid",
header: "limit=100, remaining=50, reset=60",
expect: expect{
limit: 100,
remaining: 50,
reset: 60,
},
},
{
name: "valid, no spacing",
header: "limit=250,remaining=199,reset=120",
expect: expect{
limit: 250,
remaining: 199,
reset: 120,
},
},
{
name: "mixed up member order",
header: "remaining=50, limit=100, reset=60",
expect: expect{
limit: 100,
remaining: 50,
reset: 60,
},
},
{
name: "additional key, otherwise valid",
header: "limit=100, remaining=50, reset=120, foo=bar",
expect: expect{
limit: 100,
remaining: 50,
reset: 120,
},
},
{
name: "missing member",
header: "limit=100, remaining=50",
expectErr: true,
},
{
name: "wrong type value of member",
header: "limit=100, remaining=50, reset=now",
expectErr: true,
},
{
name: "additional key, otherwise valid",
header: "limit=100, remaining=50, reset=120, foo=bar",
expect: expect{
limit: 100,
remaining: 50,
reset: 120,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
limit, remaining, reset, err := parseRateLimitHeader(tc.header)
if tc.expectErr {
require.Error(t, err, "expected an error")
return
}
require.NoError(t, err, "expected no error")
assert.Equal(t, tc.expect.limit, limit, "limit doesn't match")
assert.Equal(t, tc.expect.remaining, remaining, "remaining doesn't match")
assert.Equal(t, tc.expect.reset, reset, "reset doesn't match")
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21
toolchain go1.21.7

require (
github.com/dunglas/httpsfv v1.0.2
github.com/google/go-querystring v1.1.0
github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320
github.com/hashicorp/go-retryablehttp v0.7.7
Expand Down
Loading