You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.8 KiB
95 lines
3.8 KiB
import jwt
|
|
from fastapi import Depends, HTTPException, Request
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlalchemy import select
|
|
from database import AsyncSessionLocal
|
|
from models import User, UserRole, Role, RolePermission, Permission
|
|
from config import settings
|
|
|
|
security = HTTPBearer(auto_error=False) # HTTP Bearer 令牌认证方案
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
|
) -> dict:
|
|
"""获取当前登录用户及其角色、权限信息。
|
|
|
|
优先从 request.state.user 读取(由 RBAC 中间件预填充),
|
|
否则通过 JWT 令牌解析用户身份并从数据库查询角色权限。
|
|
返回包含 id、username、display_name、role、permissions、data_scope 的字典。
|
|
"""
|
|
if hasattr(request.state, "user") and request.state.user:
|
|
return request.state.user
|
|
|
|
if credentials:
|
|
try:
|
|
payload = jwt.decode(
|
|
credentials.credentials,
|
|
settings.JWT_SECRET,
|
|
algorithms=[settings.JWT_ALGORITHM],
|
|
)
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise HTTPException(401, "令牌无效")
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(401, "用户不存在")
|
|
|
|
# 查询用户关联的所有角色
|
|
ur_result = await db.execute(
|
|
select(Role).join(UserRole).where(UserRole.user_id == user.id)
|
|
)
|
|
roles = ur_result.scalars().all()
|
|
|
|
# 收集所有权限编码和数据权限范围
|
|
permissions = []
|
|
data_scopes = []
|
|
for role in roles:
|
|
data_scopes.append(role.data_scope)
|
|
rp_result = await db.execute(
|
|
select(Permission.code)
|
|
.join(RolePermission)
|
|
.where(RolePermission.role_id == role.id)
|
|
)
|
|
perms = rp_result.scalars().all()
|
|
permissions.extend(perms)
|
|
|
|
return {
|
|
"id": str(user.id),
|
|
"username": user.username,
|
|
"display_name": user.display_name,
|
|
"department_id": str(user.department_id) if user.department_id else None,
|
|
"role": roles[0].code if roles else "employee",
|
|
"permissions": list(set(permissions)),
|
|
"data_scope": "all" if "all" in data_scopes else (
|
|
"subordinate_only" if "subordinate_only" in data_scopes else "self_only"
|
|
),
|
|
}
|
|
except jwt.PyJWTError:
|
|
raise HTTPException(401, "令牌无效或已过期")
|
|
|
|
raise HTTPException(401, "未提供认证令牌")
|
|
|
|
|
|
def require_permission(perm_code: str):
|
|
"""权限检查依赖注入工厂,根据权限编码校验当前用户是否拥有该权限。"""
|
|
async def checker(user: dict = Depends(get_current_user)) -> dict:
|
|
if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []):
|
|
raise HTTPException(403, f"缺少权限: {perm_code}")
|
|
return user
|
|
return checker
|
|
|
|
|
|
async def get_db():
|
|
"""FastAPI 依赖注入函数,提供异步数据库会话,自动提交或回滚。"""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|