Skip to content

Commit

Permalink
Merge pull request #14 from vearne/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
vearne authored Sep 4, 2024
2 parents 0f3a0c8 + ce91e7d commit fb3be44
Show file tree
Hide file tree
Showing 25 changed files with 630 additions and 9,006 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ go get github.com/vearne/ratelimit
```
### Usage
#### 1. create redis.Client
with "github.com/go-redis/redis"
with "github.com/redis/go-redis"
Supports both redis master-slave mode and cluster mode
```
client := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -132,15 +132,17 @@ package main
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
"github.com/vearne/ratelimit/counter"
"github.com/vearne/ratelimit/tokenbucket"
slog "github.com/vearne/simplelog"
"sync"
"time"
)
func consume(r ratelimit.Limiter, group *sync.WaitGroup,
c *ratelimit.Counter, targetCount int) {
c *counter.Counter, targetCount int) {
defer group.Done()
var ok bool
for {
Expand Down Expand Up @@ -168,7 +170,7 @@ func main() {
DB: 0, // use default DB
})
limiter, err := ratelimit.NewTokenBucketRateLimiter(
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
context.Background(),
client,
"key:token",
Expand All @@ -184,7 +186,7 @@ func main() {
var wg sync.WaitGroup
total := 50
counter := ratelimit.NewCounter()
counter := counter.NewCounter()
start := time.Now()
for i := 0; i < 10; i++ {
wg.Add(1)
Expand All @@ -197,7 +199,7 @@ func main() {
```

### Dependency
[go-redis/redis](https://github.com/go-redis/redis)
[redis/go-redis](https://github.com/redis/go-redis)

### Thanks
The development of the module was inspired by the Reference 1.
Expand Down
14 changes: 8 additions & 6 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ go get github.com/vearne/ratelimit
```
## 用法
### 1. 创建 redis.Client
依赖 "github.com/go-redis/redis"
依赖 "github.com/redis/go-redis"
同时支持redis 主从模式和cluster模式
```
client := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -136,15 +136,17 @@ package main
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
"github.com/vearne/ratelimit/counter"
"github.com/vearne/ratelimit/tokenbucket"
slog "github.com/vearne/simplelog"
"sync"
"time"
)
func consume(r ratelimit.Limiter, group *sync.WaitGroup,
c *ratelimit.Counter, targetCount int) {
c *counter.Counter, targetCount int) {
defer group.Done()
var ok bool
for {
Expand Down Expand Up @@ -172,7 +174,7 @@ func main() {
DB: 0, // use default DB
})
limiter, err := ratelimit.NewTokenBucketRateLimiter(
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
context.Background(),
client,
"key:token",
Expand All @@ -188,7 +190,7 @@ func main() {
var wg sync.WaitGroup
total := 50
counter := ratelimit.NewCounter()
counter := counter.NewCounter()
start := time.Now()
for i := 0; i < 10; i++ {
wg.Add(1)
Expand All @@ -200,7 +202,7 @@ func main() {
}
```
### 依赖
[go-redis/redis](https://github.com/go-redis/redis)
[redis/go-redis](https://github.com/redis/go-redis)

### 致谢
模块的开发受到了资料1的启发,在此表示感谢
Expand Down
28 changes: 12 additions & 16 deletions alg.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ return increment
key ->
token_count -> {token_count}
updateTime -> {lastUpdateTime}* 1000000 + {microsecond}
*/
*/

const TokenBucketScript = `
local bucket = KEYS[1]
Expand Down Expand Up @@ -85,14 +85,12 @@ end
return count
`


/*
key Type: string
// updateTime
key -> {lastUpdateTime}* 1000000 + {microsecond}
key Type: string
*/
// updateTime
key -> {lastUpdateTime}* 1000000 + {microsecond}
*/
const LeakyBucketScript = `
local bucket = KEYS[1]
local interval = tonumber(ARGV[1])
Expand All @@ -117,15 +115,13 @@ end
return count
`


var (
algMap map[int]string
AlgMap map[int]string
)


func init(){
algMap = make(map[int]string)
algMap[CounterAlg] = counterScript
algMap[TokenBucketAlg] = TokenBucketScript
algMap[LeakyBucketAlg] = LeakyBucketScript
}
func init() {
AlgMap = make(map[int]string)
AlgMap[CounterAlg] = counterScript
AlgMap[TokenBucketAlg] = TokenBucketScript
AlgMap[LeakyBucketAlg] = LeakyBucketScript
}
2 changes: 1 addition & 1 deletion counter.go → counter/counter.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimit
package counter

import "sync"

Expand Down
49 changes: 33 additions & 16 deletions counter_limiter.go → counter/counter_limiter.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package ratelimit
package counter

import (
"context"
"crypto/sha1"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
"github.com/vearne/ratelimit"
slog "github.com/vearne/simplelog"
"golang.org/x/sync/singleflight"
"golang.org/x/time/rate"
Expand All @@ -14,7 +15,7 @@ import (

//nolint:govet
type CounterLimiter struct {
BaseRateLimiter
ratelimit.BaseRateLimiter
duration time.Duration
throughput int
batchSize int
Expand All @@ -29,9 +30,11 @@ type CounterLimiter struct {
antiDDoSLimiter *rate.Limiter
}

type Option func(*CounterLimiter)

func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string, duration time.Duration,
throughput int,
batchSize int) (Limiter, error) {
batchSize int, opts ...Option) (ratelimit.Limiter, error) {

_, err := client.Ping(ctx).Result()
if err != nil {
Expand All @@ -50,23 +53,35 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
return nil, errors.New("batchSize must greater than 0")
}

script := algMap[CounterAlg]
script := ratelimit.AlgMap[ratelimit.CounterAlg]
scriptSHA1 := fmt.Sprintf("%x", sha1.Sum([]byte(script)))

r := CounterLimiter{
BaseRateLimiter: BaseRateLimiter{redisClient: client, scriptSHA1: scriptSHA1, key: key},
BaseRateLimiter: ratelimit.BaseRateLimiter{RedisClient: client, ScriptSHA1: scriptSHA1, Key: key},
duration: duration,
throughput: throughput,
batchSize: batchSize,
N: 0,
AntiDDoS: true,
}
r.interval = duration / time.Duration(throughput)
r.Interval = duration / time.Duration(throughput)

if !r.redisClient.ScriptExists(ctx, r.scriptSHA1).Val()[0] {
r.redisClient.ScriptLoad(ctx, script).Val()
// Loop through each option
for _, opt := range opts {
// Call the option giving the instantiated
opt(&r)
}

values, err := r.RedisClient.ScriptExists(ctx, r.ScriptSHA1).Result()
if err != nil {
return nil, err
}
if !values[0] {
_, err = r.RedisClient.ScriptLoad(ctx, script).Result()
if err != nil {
return nil, err
}
}
// 2x throughput
throughputPerSec := int(float64(throughput) / float64(duration/time.Second))
r.antiDDoSLimiter = rate.NewLimiter(rate.Limit(throughputPerSec*2), throughputPerSec*2)
Expand All @@ -75,8 +90,10 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
}

// just for test
func (r *CounterLimiter) WithAntiDDos(antiDDoS bool) {
r.AntiDDoS = antiDDoS
func WithAntiDDos(antiDDoS bool) Option {
return func(r *CounterLimiter) {
r.AntiDDoS = antiDDoS
}
}

func (r *CounterLimiter) tryTakeFromLocal() bool {
Expand All @@ -101,7 +118,7 @@ func (r *CounterLimiter) Wait(ctx context.Context) (err error) {
}

deadline, ok := ctx.Deadline()
minWaitTime := r.interval
minWaitTime := r.Interval

slog.Debug("minWaitTime:%v", minWaitTime)
if ok {
Expand Down Expand Up @@ -143,11 +160,11 @@ func (r *CounterLimiter) Take(ctx context.Context) (bool, error) {
}

// 2. try to get from redis
_, err, _ := r.g.Do(r.key, func() (interface{}, error) {
x, err := r.redisClient.EvalSha(
_, err, _ := r.g.Do(r.Key, func() (interface{}, error) {
x, err := r.RedisClient.EvalSha(
ctx,
r.scriptSHA1,
[]string{r.key},
r.ScriptSHA1,
[]string{r.Key},
int(r.duration/time.Microsecond),
r.throughput,
r.batchSize,
Expand Down
107 changes: 107 additions & 0 deletions counter/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package counter

import (
"context"
"fmt"
"github.com/go-redis/redismock/v9"
"github.com/stretchr/testify/assert"
"log"
"testing"
"time"
)

const (
key = "key:count"
hashVal = "bdbede5669d5e48d6e6c2967aeed2f72f03868ac"
)

func MyMatch(expected, actual []interface{}) error {
expectedStr := fmt.Sprintf("%v", expected)
actualStr := fmt.Sprintf("%v", actual)
if expectedStr == actualStr {
return nil
}
log.Printf("expectedStr:%v, actualStr:%v", expectedStr, actualStr)
return fmt.Errorf("not equal, expectedStr:%s, actualStr:%s", expectedStr, actualStr)
}

func TestTakeFail(t *testing.T) {
db, mock := redismock.NewClientMock()

mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

ok, err := limiter.Take(context.Background())
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}
if !ok {
assert.Equal(t, ok, false)
}
}

func TestTakeSuccess(t *testing.T) {
db, mock := redismock.NewClientMock()

mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(1))

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

ok, err := limiter.Take(context.Background())
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}
if !ok {
assert.Equal(t, ok, true)
}
}

func TestContextTimeOut(t *testing.T) {
db, mock := redismock.NewClientMock()
mock = mock.CustomMatch(MyMatch)
mock.ExpectPing().SetVal("PONG")

mock.ExpectScriptExists(hashVal).SetVal([]bool{true, true})
for i := 0; i < 1000; i++ {
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))
}

limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
3,
2,
WithAntiDDos(false))
if err != nil {
t.Errorf("unexpected error, %v", err)
return
}

waitCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err = limiter.Wait(waitCtx)
assert.Contains(t, err.Error(), "timeout")
}
Loading

0 comments on commit fb3be44

Please sign in to comment.