You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.2 KiB
72 lines
2.2 KiB
import time
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from fastapi import Request, HTTPException
|
|
from config import settings
|
|
|
|
|
|
class RateLimiter:
|
|
MAX_KEYS = 10000
|
|
|
|
def __init__(self):
|
|
self._buckets: dict[str, list[float]] = defaultdict(list)
|
|
self._lock = asyncio.Lock()
|
|
self._last_cleanup = time.time()
|
|
|
|
async def _cleanup(self):
|
|
now = time.time()
|
|
if now - self._last_cleanup < 60:
|
|
return
|
|
self._last_cleanup = now
|
|
expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120]
|
|
for k in expired_keys:
|
|
del self._buckets[k]
|
|
|
|
async def check(self, key: str) -> bool:
|
|
now = time.time()
|
|
limit = settings.RATE_LIMIT_PER_MINUTE
|
|
window = 60.0
|
|
|
|
async with self._lock:
|
|
await self._cleanup()
|
|
bucket = self._buckets[key]
|
|
bucket = [t for t in bucket if now - t < window]
|
|
self._buckets[key] = bucket
|
|
|
|
if len(bucket) >= limit:
|
|
return False
|
|
|
|
bucket.append(now)
|
|
|
|
if len(self._buckets) > self.MAX_KEYS:
|
|
oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2]
|
|
for k in oldest_keys:
|
|
del self._buckets[k]
|
|
|
|
return True
|
|
|
|
async def remaining(self, key: str) -> int:
|
|
now = time.time()
|
|
async with self._lock:
|
|
bucket = [t for t in self._buckets.get(key, []) if now - t < 60]
|
|
return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket))
|
|
|
|
|
|
rate_limiter = RateLimiter()
|
|
|
|
|
|
async def rate_limit_middleware(request: Request, call_next):
|
|
path = request.url.path
|
|
if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]:
|
|
return await call_next(request)
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
key = f"ratelimit:{client_ip}"
|
|
|
|
if not await rate_limiter.check(key):
|
|
raise HTTPException(429, "请求过于频繁,请稍后再试")
|
|
|
|
response = await call_next(request)
|
|
remaining = await rate_limiter.remaining(key)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
return response
|