Skip to content

Commit

Permalink
Fix passing custom redis client
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Aug 5, 2024
1 parent fe06926 commit 0d69b97
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
1 change: 1 addition & 0 deletions _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func main() {
})
})

// Rate-limit at 50 req/s per IP address.
r.Use(httprate.Limit(
50, time.Second,
httprate.WithKeyByIP(),
Expand Down
4 changes: 2 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ type Config struct {
// the system will use the local counter unless it is explicitly disabled.
FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 50ms

// Client if supplied will be used and below fields will be ignored.
// Client if supplied will be used and the below fields will be ignored.
//
// NOTE: It's recommended to set short Dial/Read/Write timeouts and disable
// NOTE: It's recommended to set short dial/read/write timeouts and disable
// retries on the client, so the local in-memory fallback can activate quickly.
Client *redis.Client `toml:"-"`
Host string `toml:"host"`
Expand Down
46 changes: 24 additions & 22 deletions httprateredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
cfg.PrefixKey = "httprate"
}
if cfg.FallbackTimeout == 0 {
// Activate local in-memory fallback fairly quickly, as this would slow down all requests.
cfg.FallbackTimeout = 50 * time.Millisecond
}

Expand All @@ -50,29 +51,30 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
}

var maxIdle, maxActive = cfg.MaxIdle, cfg.MaxActive
if maxIdle <= 0 {
maxIdle = 20
}
if maxActive <= 0 {
maxActive = 50
}
if cfg.Client == nil {
maxIdle, maxActive := cfg.MaxIdle, cfg.MaxActive
if maxIdle < 1 {
maxIdle = 20
}
if maxActive < 1 {
maxActive = 50
}

address := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
rc.client = redis.NewClient(&redis.Options{
Addr: address,
Password: cfg.Password,
DB: cfg.DBIndex,
PoolSize: maxActive,
MaxIdleConns: maxIdle,
ClientName: cfg.ClientName,
rc.client = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DBIndex,
ClientName: cfg.ClientName,

DialTimeout: cfg.FallbackTimeout,
ReadTimeout: cfg.FallbackTimeout,
WriteTimeout: cfg.FallbackTimeout,
MinIdleConns: 1,
MaxRetries: -1,
})
DialTimeout: cfg.FallbackTimeout,
ReadTimeout: cfg.FallbackTimeout,
WriteTimeout: cfg.FallbackTimeout,
PoolSize: maxActive,
MinIdleConns: 1,
MaxIdleConns: maxIdle,
MaxRetries: -1, // -1 disables retries
})
}

return rc, nil
}
Expand Down Expand Up @@ -109,7 +111,7 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i
var netErr net.Error
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
go c.fallback()
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) // = nil
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount)
}
}
}()
Expand Down

0 comments on commit 0d69b97

Please sign in to comment.