"""速率限制中间件模块。 提供基于令牌桶算法的 HTTP 请求速率限制功能。 采用内存中的滑动窗口机制,限制每个 IP 地址在指定时间窗口内的请求数量。 """ import time import asyncio from collections import defaultdict from fastapi import Request, HTTPException from config import settings class RateLimiter: """内存速率限制器类,使用滑动窗口算法限制请求频率。 为每个唯一键(通常是 IP 地址)维护一个时间戳列表, 在每次请求时清理过期时间戳并检查是否超过限制。 Attributes: MAX_KEYS: 最大缓存的键数量,防止内存无限增长。 _buckets: 滑动窗口桶,存储每个键的请求时间戳列表。 _lock: 异步锁,保证并发安全。 _last_cleanup: 上次清理缓存的时间戳。 """ MAX_KEYS = 10000 def __init__(self): """初始化速率限制器实例。""" self._buckets: dict[str, list[float]] = defaultdict(list) # 滑动窗口桶:{key: [timestamp, ...]} self._lock = asyncio.Lock() # 异步锁,保证并发安全 self._last_cleanup = time.time() # 上次清理缓存的时间戳 async def _cleanup(self): """清理过期和空闲的键,释放内存空间。 仅在距上次清理超过 60 秒时执行实际清理操作。 删除空桶或最后一个请求超过 120 秒的桶。 """ now = time.time() if now - self._last_cleanup < 60: return self._last_cleanup = now expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120] for k in expired_keys: del self._buckets[k] async def check(self, key: str) -> bool: """检查指定键是否允许新的请求。 使用滑动窗口算法,清理窗口外的时间戳后检查是否超过限制。 如果超过限制则拒绝请求,否则记录当前时间戳并允许通过。 Args: key: 速率限制键(通常为 IP 地址)。 Returns: bool: 允许请求返回 True,拒绝请求返回 False。 """ now = time.time() limit = settings.RATE_LIMIT_PER_MINUTE # 每分钟请求限制数 window = 60.0 # 时间窗口(秒) async with self._lock: await self._cleanup() bucket = self._buckets[key] bucket = [t for t in bucket if now - t < window] # 过滤窗口内的时间戳 self._buckets[key] = bucket if len(bucket) >= limit: return False # 超过限制,拒绝请求 bucket.append(now) # 记录当前请求时间戳 # 当缓存键数量超过上限时,淘汰最旧的键 if len(self._buckets) > self.MAX_KEYS: oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2] for k in oldest_keys: del self._buckets[k] return True async def remaining(self, key: str) -> int: """获取指定键剩余的请求次数。 Args: key: 速率限制键。 Returns: int: 当前时间窗口内剩余的请求次数。 """ now = time.time() async with self._lock: bucket = [t for t in self._buckets.get(key, []) if now - t < 60] return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket)) rate_limiter = RateLimiter() # 全局速率限制器单例实例 async def rate_limit_middleware(request: Request, call_next): """速率限制中间件。 对每个 HTTP 请求进行速率限制检查: 1. 跳过公开路径(健康检查、登录等) 2. 基于客户端 IP 地址进行速率限制 3. 在响应头中添加剩余请求次数信息 Args: request: 当前 HTTP 请求对象。 call_next: 下一个中间件或路由处理函数。 Returns: Response: 如果未超限则返回后续处理结果,否则返回 429 错误响应。 """ path = request.url.path if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]: return await call_next(request) client_ip = request.client.host if request.client else "unknown" # 客户端 IP 地址 key = f"ratelimit:{client_ip}" # 速率限制键 if not await rate_limiter.check(key): raise HTTPException(429, "请求过于频繁,请稍后再试") response = await call_next(request) remaining = await rate_limiter.remaining(key) response.headers["X-RateLimit-Remaining"] = str(remaining) # 响应头中添加剩余请求次数 return response