diff --git a/src/ai/backend/manager/api/ratelimit.py b/src/ai/backend/manager/api/ratelimit.py index f6bebc471e0..a18285f5ac0 100644 --- a/src/ai/backend/manager/api/ratelimit.py +++ b/src/ai/backend/manager/api/ratelimit.py @@ -50,15 +50,15 @@ if id_type == "ip" then local rate_limit = tonumber(ARGV[3]) - local score_threshold = rate_limit * 0.8 + local suspicious_ips_maxsize = tonumber(ARGV[4]) + local suspicious_ips_threshold_ratio = tonumber(ARGV[5]) - -- Add IP to suspicious_ips only if count is greater than score_threshold - if rolling_count >= score_threshold then + -- Add the IP address to "suspicious_ips" only if rolling_count is greater than the threshold + if rolling_count >= rate_limit * suspicious_ips_threshold_ratio then redis.call('ZADD', 'suspicious_ips', rolling_count, id_value) - local max_size = 1000 local current_size = redis.call('ZCARD', 'suspicious_ips') - if current_size > max_size then + if current_size > suspicious_ips_maxsize then redis.call('ZREMRANGEBYRANK', 'suspicious_ips', 0, 0) end end @@ -103,23 +103,34 @@ async def rlim_middleware( return response else: root_ctx: RootContext = app["_root.context"] - rate_limit = root_ctx.shared_config["anonymous_ratelimit"] + anonymous_ratelimiter = root_ctx.shared_config["anonymous_ratelimiter"] ip_address = get_client_ip(request) - if not ip_address or rate_limit is None: + if not ip_address or anonymous_ratelimiter is None: # No checks for rate limiting. response = await handler(request) # Arbitrary number for indicating no rate limiting. response.headers["X-RateLimit-Limit"] = "1000" response.headers["X-RateLimit-Remaining"] = "1000" else: + rate_limit, suspicious_ips_maxsize, suspicious_ips_threshold_ratio = ( + anonymous_ratelimiter["rlimit"], + anonymous_ratelimiter["suspicious_ips_maxsize"], + anonymous_ratelimiter["suspicious_ips_threshold_ratio"], + ) ret = await redis_helper.execute_script( rr, "ratelimit", _rlim_script, ["ip", ip_address], - [str(now), str(_rlim_window), str(rate_limit)], + [ + str(now), + str(_rlim_window), + str(rate_limit), + str(suspicious_ips_maxsize), + str(suspicious_ips_threshold_ratio), + ], ) if ret is None: remaining = rate_limit diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 2b47a8463e2..83f41738889 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -456,7 +456,12 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]: }, ).allow_extra("*"), t.Key("roundrobin_states", default=None): t.Null | tx.RoundRobinStatesJSONString, - t.Key("anonymous_ratelimit", default=None): t.Null | t.ToInt, + t.Key("anonymous_ratelimiter", default=None): t.Null + | t.Dict({ + t.Key("rlimit"): t.ToInt(), + t.Key("suspicious_ips_maxsize", default=1000): t.Null | t.ToInt(), + t.Key("suspicious_ips_threshold_ratio", default=0.8): t.Null | t.ToFloat(), + }), }).allow_extra("*") _volume_defaults: dict[str, Any] = {