You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.7 KiB
128 lines
4.7 KiB
"""速率限制中间件模块。
|
|
|
|
提供基于令牌桶算法的 HTTP 请求速率限制功能。
|
|
采用内存中的滑动窗口机制,限制每个 IP 地址在指定时间窗口内的请求数量。
|
|
"""
|
|
import time
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from fastapi import Request, HTTPException
|
|
from config import settings
|
|
|
|
|
|
class RateLimiter:
|
|
"""内存速率限制器类,使用滑动窗口算法限制请求频率。
|
|
|
|
为每个唯一键(通常是 IP 地址)维护一个时间戳列表,
|
|
在每次请求时清理过期时间戳并检查是否超过限制。
|
|
|
|
Attributes:
|
|
MAX_KEYS: 最大缓存的键数量,防止内存无限增长。
|
|
_buckets: 滑动窗口桶,存储每个键的请求时间戳列表。
|
|
_lock: 异步锁,保证并发安全。
|
|
_last_cleanup: 上次清理缓存的时间戳。
|
|
"""
|
|
MAX_KEYS = 10000
|
|
|
|
def __init__(self):
|
|
"""初始化速率限制器实例。"""
|
|
self._buckets: dict[str, list[float]] = defaultdict(list) # 滑动窗口桶:{key: [timestamp, ...]}
|
|
self._lock = asyncio.Lock() # 异步锁,保证并发安全
|
|
self._last_cleanup = time.time() # 上次清理缓存的时间戳
|
|
|
|
async def _cleanup(self):
|
|
"""清理过期和空闲的键,释放内存空间。
|
|
|
|
仅在距上次清理超过 60 秒时执行实际清理操作。
|
|
删除空桶或最后一个请求超过 120 秒的桶。
|
|
"""
|
|
now = time.time()
|
|
if now - self._last_cleanup < 60:
|
|
return
|
|
self._last_cleanup = now
|
|
expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120]
|
|
for k in expired_keys:
|
|
del self._buckets[k]
|
|
|
|
async def check(self, key: str) -> bool:
|
|
"""检查指定键是否允许新的请求。
|
|
|
|
使用滑动窗口算法,清理窗口外的时间戳后检查是否超过限制。
|
|
如果超过限制则拒绝请求,否则记录当前时间戳并允许通过。
|
|
|
|
Args:
|
|
key: 速率限制键(通常为 IP 地址)。
|
|
|
|
Returns:
|
|
bool: 允许请求返回 True,拒绝请求返回 False。
|
|
"""
|
|
now = time.time()
|
|
limit = settings.RATE_LIMIT_PER_MINUTE # 每分钟请求限制数
|
|
window = 60.0 # 时间窗口(秒)
|
|
|
|
async with self._lock:
|
|
await self._cleanup()
|
|
bucket = self._buckets[key]
|
|
bucket = [t for t in bucket if now - t < window] # 过滤窗口内的时间戳
|
|
self._buckets[key] = bucket
|
|
|
|
if len(bucket) >= limit:
|
|
return False # 超过限制,拒绝请求
|
|
|
|
bucket.append(now) # 记录当前请求时间戳
|
|
|
|
# 当缓存键数量超过上限时,淘汰最旧的键
|
|
if len(self._buckets) > self.MAX_KEYS:
|
|
oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2]
|
|
for k in oldest_keys:
|
|
del self._buckets[k]
|
|
|
|
return True
|
|
|
|
async def remaining(self, key: str) -> int:
|
|
"""获取指定键剩余的请求次数。
|
|
|
|
Args:
|
|
key: 速率限制键。
|
|
|
|
Returns:
|
|
int: 当前时间窗口内剩余的请求次数。
|
|
"""
|
|
now = time.time()
|
|
async with self._lock:
|
|
bucket = [t for t in self._buckets.get(key, []) if now - t < 60]
|
|
return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket))
|
|
|
|
|
|
rate_limiter = RateLimiter() # 全局速率限制器单例实例
|
|
|
|
|
|
async def rate_limit_middleware(request: Request, call_next):
|
|
"""速率限制中间件。
|
|
|
|
对每个 HTTP 请求进行速率限制检查:
|
|
1. 跳过公开路径(健康检查、登录等)
|
|
2. 基于客户端 IP 地址进行速率限制
|
|
3. 在响应头中添加剩余请求次数信息
|
|
|
|
Args:
|
|
request: 当前 HTTP 请求对象。
|
|
call_next: 下一个中间件或路由处理函数。
|
|
|
|
Returns:
|
|
Response: 如果未超限则返回后续处理结果,否则返回 429 错误响应。
|
|
"""
|
|
path = request.url.path
|
|
if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]:
|
|
return await call_next(request)
|
|
|
|
client_ip = request.client.host if request.client else "unknown" # 客户端 IP 地址
|
|
key = f"ratelimit:{client_ip}" # 速率限制键
|
|
|
|
if not await rate_limiter.check(key):
|
|
raise HTTPException(429, "请求过于频繁,请稍后再试")
|
|
|
|
response = await call_next(request)
|
|
remaining = await rate_limiter.remaining(key)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining) # 响应头中添加剩余请求次数
|
|
return response
|
|
|