import time import asyncio from collections import defaultdict from fastapi import Request, HTTPException from config import settings class RateLimiter: MAX_KEYS = 10000 def __init__(self): self._buckets: dict[str, list[float]] = defaultdict(list) self._lock = asyncio.Lock() self._last_cleanup = time.time() async def _cleanup(self): now = time.time() if now - self._last_cleanup < 60: return self._last_cleanup = now expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120] for k in expired_keys: del self._buckets[k] async def check(self, key: str) -> bool: now = time.time() limit = settings.RATE_LIMIT_PER_MINUTE window = 60.0 async with self._lock: await self._cleanup() bucket = self._buckets[key] bucket = [t for t in bucket if now - t < window] self._buckets[key] = bucket if len(bucket) >= limit: return False bucket.append(now) if len(self._buckets) > self.MAX_KEYS: oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2] for k in oldest_keys: del self._buckets[k] return True async def remaining(self, key: str) -> int: now = time.time() async with self._lock: bucket = [t for t in self._buckets.get(key, []) if now - t < 60] return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket)) rate_limiter = RateLimiter() async def rate_limit_middleware(request: Request, call_next): path = request.url.path if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]: return await call_next(request) client_ip = request.client.host if request.client else "unknown" key = f"ratelimit:{client_ip}" if not await rate_limiter.check(key): raise HTTPException(429, "请求过于频繁,请稍后再试") response = await call_next(request) remaining = await rate_limiter.remaining(key) response.headers["X-RateLimit-Remaining"] = str(remaining) return response