import time import asyncio from collections import defaultdict from fastapi import Request, HTTPException from config import settings class RateLimiter: def __init__(self): self._buckets: dict[str, list[float]] = defaultdict(list) self._lock = asyncio.Lock() async def check(self, key: str) -> bool: now = time.time() limit = settings.RATE_LIMIT_PER_MINUTE window = 60.0 async with self._lock: bucket = self._buckets[key] bucket = [t for t in bucket if now - t < window] self._buckets[key] = bucket if len(bucket) >= limit: return False bucket.append(now) return True async def remaining(self, key: str) -> int: now = time.time() async with self._lock: bucket = [t for t in self._buckets.get(key, []) if now - t < 60] return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket)) rate_limiter = RateLimiter() async def rate_limit_middleware(request: Request, call_next): path = request.url.path if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]: return await call_next(request) client_ip = request.client.host if request.client else "unknown" key = f"ratelimit:{client_ip}" if not await rate_limiter.check(key): raise HTTPException(429, "请求过于频繁,请稍后再试") response = await call_next(request) remaining = await rate_limiter.remaining(key) response.headers["X-RateLimit-Remaining"] = str(remaining) return response