You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

85 lines
3.1 KiB

import jwt
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from database import AsyncSessionLocal
from models import User, UserRole, Role, RolePermission, Permission
from config import settings
security = HTTPBearer(auto_error=False)
async def get_current_user(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security),
) -> dict:
if hasattr(request.state, "user") and request.state.user:
return request.state.user
if credentials:
try:
payload = jwt.decode(
credentials.credentials,
settings.JWT_SECRET,
algorithms=[settings.JWT_ALGORITHM],
)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(401, "令牌无效")
async with AsyncSessionLocal() as db:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(401, "用户不存在")
ur_result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user.id)
)
roles = ur_result.scalars().all()
permissions = []
data_scopes = []
for role in roles:
data_scopes.append(role.data_scope)
rp_result = await db.execute(
select(Permission.code)
.join(RolePermission)
.where(RolePermission.role_id == role.id)
)
perms = rp_result.scalars().all()
permissions.extend(perms)
return {
"id": str(user.id),
"username": user.username,
"display_name": user.display_name,
"department_id": str(user.department_id) if user.department_id else None,
"role": roles[0].code if roles else "employee",
"permissions": list(set(permissions)),
"data_scope": "all" if "all" in data_scopes else (
"subordinate_only" if "subordinate_only" in data_scopes else "self_only"
),
}
except jwt.PyJWTError:
raise HTTPException(401, "令牌无效或已过期")
raise HTTPException(401, "未提供认证令牌")
def require_permission(perm_code: str):
async def checker(user: dict = Depends(get_current_user)) -> dict:
if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []):
raise HTTPException(403, f"缺少权限: {perm_code}")
return user
return checker
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise