From 760fb8b09105859010be6a9c48460ca47beab899 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Wed, 7 Aug 2024 16:13:20 +0200 Subject: [PATCH] Fallback on all Redis errors --- httprateredis.go | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/httprateredis.go b/httprateredis.go index a9ca6f0..ebf321d 100644 --- a/httprateredis.go +++ b/httprateredis.go @@ -2,9 +2,7 @@ package httprateredis import ( "context" - "errors" "fmt" - "net" "os" "path/filepath" "strconv" @@ -107,13 +105,8 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i return c.fallbackCounter.IncrementBy(key, currentWindow, amount) } defer func() { - if err != nil { - // On redis network error, fallback to local in-memory counter. - var netErr net.Error - if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) { - c.fallback() - err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) - } + if c.shouldFallback(err) { + err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) } }() } @@ -147,13 +140,8 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) return c.fallbackCounter.Get(key, currentWindow, previousWindow) } defer func() { - if err != nil { - // On redis network error, fallback to local in-memory counter. - var netErr net.Error - if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) { - c.fallback() - curr, prev, err = c.fallbackCounter.Get(key, currentWindow, previousWindow) - } + if c.shouldFallback(err) { + curr, prev, err = c.fallbackCounter.Get(key, currentWindow, previousWindow) } }() } @@ -189,25 +177,30 @@ func (c *redisCounter) IsFallbackActivated() bool { return c.fallbackActivated.Load() } -func (c *redisCounter) fallback() { - // Activate the in-memory counter fallback, unless activated by some other goroutine. - fallbackAlreadyActivated := c.fallbackActivated.Swap(true) - if fallbackAlreadyActivated { - return +func (c *redisCounter) shouldFallback(err error) bool { + if err == nil { + return false } - go c.reconnect() + // Activate the local in-memory counter fallback, unless activated by some other goroutine. + alreadyActivated := c.fallbackActivated.Swap(true) + if !alreadyActivated { + go c.reconnect() + } + + return true } func (c *redisCounter) reconnect() { // Try to re-connect to redis every 200ms. for { + time.Sleep(200 * time.Millisecond) + err := c.client.Ping(context.Background()).Err() if err == nil { c.fallbackActivated.Store(false) return } - time.Sleep(200 * time.Millisecond) } }