You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

128 lines
4.7 KiB

"""速率限制中间件模块。
提供基于令牌桶算法的 HTTP 请求速率限制功能。
采用内存中的滑动窗口机制,限制每个 IP 地址在指定时间窗口内的请求数量。
"""
import time
import asyncio
from collections import defaultdict
from fastapi import Request, HTTPException
from config import settings
class RateLimiter:
"""内存速率限制器类,使用滑动窗口算法限制请求频率。
为每个唯一键(通常是 IP 地址)维护一个时间戳列表,
在每次请求时清理过期时间戳并检查是否超过限制。
Attributes:
MAX_KEYS: 最大缓存的键数量,防止内存无限增长。
_buckets: 滑动窗口桶,存储每个键的请求时间戳列表。
_lock: 异步锁,保证并发安全。
_last_cleanup: 上次清理缓存的时间戳。
"""
MAX_KEYS = 10000
def __init__(self):
"""初始化速率限制器实例。"""
self._buckets: dict[str, list[float]] = defaultdict(list) # 滑动窗口桶:{key: [timestamp, ...]}
self._lock = asyncio.Lock() # 异步锁,保证并发安全
self._last_cleanup = time.time() # 上次清理缓存的时间戳
async def _cleanup(self):
"""清理过期和空闲的键,释放内存空间。
仅在距上次清理超过 60 秒时执行实际清理操作。
删除空桶或最后一个请求超过 120 秒的桶。
"""
now = time.time()
if now - self._last_cleanup < 60:
return
self._last_cleanup = now
expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120]
for k in expired_keys:
del self._buckets[k]
async def check(self, key: str) -> bool:
"""检查指定键是否允许新的请求。
使用滑动窗口算法,清理窗口外的时间戳后检查是否超过限制。
如果超过限制则拒绝请求,否则记录当前时间戳并允许通过。
Args:
key: 速率限制键(通常为 IP 地址)。
Returns:
bool: 允许请求返回 True,拒绝请求返回 False。
"""
now = time.time()
limit = settings.RATE_LIMIT_PER_MINUTE # 每分钟请求限制数
window = 60.0 # 时间窗口(秒)
async with self._lock:
await self._cleanup()
bucket = self._buckets[key]
bucket = [t for t in bucket if now - t < window] # 过滤窗口内的时间戳
self._buckets[key] = bucket
if len(bucket) >= limit:
return False # 超过限制,拒绝请求
bucket.append(now) # 记录当前请求时间戳
# 当缓存键数量超过上限时,淘汰最旧的键
if len(self._buckets) > self.MAX_KEYS:
oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2]
for k in oldest_keys:
del self._buckets[k]
return True
async def remaining(self, key: str) -> int:
"""获取指定键剩余的请求次数。
Args:
key: 速率限制键。
Returns:
int: 当前时间窗口内剩余的请求次数。
"""
now = time.time()
async with self._lock:
bucket = [t for t in self._buckets.get(key, []) if now - t < 60]
return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket))
rate_limiter = RateLimiter() # 全局速率限制器单例实例
async def rate_limit_middleware(request: Request, call_next):
"""速率限制中间件。
对每个 HTTP 请求进行速率限制检查:
1. 跳过公开路径(健康检查、登录等)
2. 基于客户端 IP 地址进行速率限制
3. 在响应头中添加剩余请求次数信息
Args:
request: 当前 HTTP 请求对象。
call_next: 下一个中间件或路由处理函数。
Returns:
Response: 如果未超限则返回后续处理结果,否则返回 429 错误响应。
"""
path = request.url.path
if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]:
return await call_next(request)
client_ip = request.client.host if request.client else "unknown" # 客户端 IP 地址
key = f"ratelimit:{client_ip}" # 速率限制键
if not await rate_limiter.check(key):
raise HTTPException(429, "请求过于频繁,请稍后再试")
response = await call_next(request)
remaining = await rate_limiter.remaining(key)
response.headers["X-RateLimit-Remaining"] = str(remaining) # 响应头中添加剩余请求次数
return response