You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

102 lines
4.1 KiB

import jwt
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from database import AsyncSessionLocal
from models import User, UserRole, Role, RolePermission, Permission
from config import settings
security = HTTPBearer(auto_error=False) # HTTP Bearer 令牌认证方案
async def get_current_user(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
) -> dict:
"""获取当前登录用户及其角色、权限信息。
优先从 request.state.user 读取(由 RBAC 中间件预填充),
否则通过 JWT 令牌解析用户身份并从数据库查询角色权限。
返回包含 id、username、display_name、role、permissions、data_scope 的字典。
"""
if hasattr(request.state, "user") and request.state.user:
return request.state.user
if credentials:
try:
payload = jwt.decode(
credentials.credentials,
settings.JWT_SECRET,
algorithms=[settings.JWT_ALGORITHM],
)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(401, "令牌无效")
async with AsyncSessionLocal() as db:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(401, "用户不存在")
# 查询用户关联的所有角色
ur_result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user.id)
)
roles = ur_result.scalars().all()
# 按类型分离:平台角色(管理后台)vs 岗位(企业AI)
platform_roles = []
position_roles = []
all_permissions = []
data_scopes = []
for role in roles:
if role.role_type == "platform":
platform_roles.append(role.code)
else:
position_roles.append(role.code)
data_scopes.append(role.data_scope)
rp_result = await db.execute(
select(Permission.code)
.join(RolePermission)
.where(RolePermission.role_id == role.id)
)
perms = rp_result.scalars().all()
all_permissions.extend(perms)
return {
"id": str(user.id),
"username": user.username,
"display_name": user.display_name,
"department_id": str(user.department_id) if user.department_id else None,
"platform_roles": list(set(platform_roles)),
"positions": list(set(position_roles)),
"permissions": list(set(all_permissions)),
"data_scope": "all" if "all" in data_scopes else (
"subordinate_only" if "subordinate_only" in data_scopes else "self_only"
),
}
except jwt.PyJWTError:
raise HTTPException(401, "令牌无效或已过期")
raise HTTPException(401, "未提供认证令牌")
def require_permission(perm_code: str):
"""权限检查依赖注入工厂,根据权限编码校验当前用户是否拥有该权限。"""
async def checker(user: dict = Depends(get_current_user)) -> dict:
if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []):
raise HTTPException(403, f"缺少权限: {perm_code}")
return user
return checker
async def get_db():
"""FastAPI 依赖注入函数,提供异步数据库会话,自动提交或回滚。"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise