Skip to content

Commit

Permalink
Implement OnError and OnFallbackChange callbacks (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek authored Aug 8, 2024
1 parent 8812af7 commit 61daf03
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 8 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type Config struct {
ClientName string `toml:"client_name"` // default: os.Args[0]
PrefixKey string `toml:"prefix_key"` // default: "httprate"

// OnError lets you subscribe to all runtime Redis errors. Useful for logging/debugging.
OnError func(err error)

// Disable the use of the local in-memory fallback mechanism. When enabled,
// the system will return HTTP 428 for all requests when Redis is down.
FallbackDisabled bool `toml:"fallback_disabled"` // default: false
Expand All @@ -22,6 +25,9 @@ type Config struct {
// the system will use the local counter unless it is explicitly disabled.
FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 100ms

// OnFallbackChange lets subscribe to local in-memory fallback changes.
OnFallbackChange func(activated bool)

// Client if supplied will be used and the below fields will be ignored.
//
// NOTE: It's recommended to set short dial/read/write timeouts and disable
Expand Down
17 changes: 16 additions & 1 deletion httprateredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
}

rc := &redisCounter{
prefixKey: cfg.PrefixKey,
prefixKey: cfg.PrefixKey,
onError: func(err error) {},
onFallback: func(activated bool) {},
}
if cfg.OnError != nil {
rc.onError = cfg.OnError
}
if !cfg.FallbackDisabled {
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
if cfg.OnFallbackChange != nil {
rc.onFallback = cfg.OnFallbackChange
}
}

if cfg.Client == nil {
Expand Down Expand Up @@ -89,6 +97,8 @@ type redisCounter struct {
prefixKey string
fallbackActivated atomic.Bool
fallbackCounter httprate.LimitCounter
onError func(err error)
onFallback func(activated bool)
}

var _ httprate.LimitCounter = (*redisCounter)(nil)
Expand Down Expand Up @@ -190,10 +200,12 @@ func (c *redisCounter) shouldFallback(err error) bool {
if err == nil {
return false
}
c.onError(err)

// Activate the local in-memory counter fallback, unless activated by some other goroutine.
alreadyActivated := c.fallbackActivated.Swap(true)
if !alreadyActivated {
c.onFallback(true)
go c.reconnect()
}

Expand All @@ -208,6 +220,9 @@ func (c *redisCounter) reconnect() {
err := c.client.Ping(context.Background()).Err()
if err == nil {
c.fallbackActivated.Store(false)
if c.onFallback != nil {
c.onFallback(false)
}
return
}
}
Expand Down
38 changes: 31 additions & 7 deletions local_fallback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ func TestLocalFallback(t *testing.T) {
redis, err := miniredis.Run()
redisPort, _ := strconv.Atoi(redis.Port())

var onErrorCalled bool
var onFallbackCalled bool

limitCounter, err := httprateredis.NewRedisLimitCounter(&httprateredis.Config{
Host: redis.Host(),
Port: uint16(redisPort),
MaxIdle: 0,
MaxActive: 1,
ClientName: "httprateredis_test",
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
FallbackTimeout: 200 * time.Millisecond,
Host: redis.Host(),
Port: uint16(redisPort),
MaxIdle: 0,
MaxActive: 1,
ClientName: "httprateredis_test",
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
FallbackTimeout: 200 * time.Millisecond,
OnError: func(err error) { onErrorCalled = true },
OnFallbackChange: func(fallbackActivated bool) { onFallbackCalled = true },
})
if err != nil {
t.Fatalf("redis not available: %v", err)
Expand All @@ -37,6 +42,12 @@ func TestLocalFallback(t *testing.T) {
if limitCounter.IsFallbackActivated() {
t.Error("fallback should not be activated at the beginning")
}
if onErrorCalled {
t.Error("onError() should not be called at the beginning")
}
if onFallbackCalled {
t.Error("onFallback() should not be called before we simulate redis failure")
}

err = limitCounter.IncrementBy("key:fallback", currentWindow, 1)
if err != nil {
Expand All @@ -51,6 +62,12 @@ func TestLocalFallback(t *testing.T) {
if limitCounter.IsFallbackActivated() {
t.Error("fallback should not be activated before we simulate redis failure")
}
if onErrorCalled {
t.Error("onError() should not be called before we simulate redis failure")
}
if onFallbackCalled {
t.Error("onFallback() should not be called before we simulate redis failure")
}

redis.Close()

Expand All @@ -67,4 +84,11 @@ func TestLocalFallback(t *testing.T) {
if !limitCounter.IsFallbackActivated() {
t.Error("fallback should be activated after we simulate redis failure")
}
if !onErrorCalled {
t.Error("onError() should be called after we simulate redis failure")
}
if !onFallbackCalled {
t.Error("onFallback() should be called after we simulate redis failure")
}

}

0 comments on commit 61daf03

Please sign in to comment.