import jwt from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy import select from database import AsyncSessionLocal from models import User, UserRole, Role, RolePermission, Permission from config import settings security = HTTPBearer(auto_error=False) async def get_current_user( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(security), ) -> dict: if hasattr(request.state, "user") and request.state.user: return request.state.user if credentials: try: payload = jwt.decode( credentials.credentials, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM], ) user_id = payload.get("sub") if not user_id: raise HTTPException(401, "令牌无效") async with AsyncSessionLocal() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException(401, "用户不存在") ur_result = await db.execute( select(Role).join(UserRole).where(UserRole.user_id == user.id) ) roles = ur_result.scalars().all() permissions = [] data_scopes = [] for role in roles: data_scopes.append(role.data_scope) rp_result = await db.execute( select(Permission.code) .join(RolePermission) .where(RolePermission.role_id == role.id) ) perms = rp_result.scalars().all() permissions.extend(perms) return { "id": str(user.id), "username": user.username, "display_name": user.display_name, "department_id": str(user.department_id) if user.department_id else None, "role": roles[0].code if roles else "employee", "permissions": list(set(permissions)), "data_scope": "all" if "all" in data_scopes else ( "subordinate_only" if "subordinate_only" in data_scopes else "self_only" ), } except jwt.PyJWTError: raise HTTPException(401, "令牌无效或已过期") raise HTTPException(401, "未提供认证令牌") def require_permission(perm_code: str): async def checker(user: dict = Depends(get_current_user)) -> dict: if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []): raise HTTPException(403, f"缺少权限: {perm_code}") return user return checker async def get_db(): async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise