You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
3.1 KiB
85 lines
3.1 KiB
import jwt
|
|
from fastapi import Depends, HTTPException, Request
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlalchemy import select
|
|
from database import AsyncSessionLocal
|
|
from models import User, UserRole, Role, RolePermission, Permission
|
|
from config import settings
|
|
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
|
) -> dict:
|
|
if hasattr(request.state, "user") and request.state.user:
|
|
return request.state.user
|
|
|
|
if credentials:
|
|
try:
|
|
payload = jwt.decode(
|
|
credentials.credentials,
|
|
settings.JWT_SECRET,
|
|
algorithms=[settings.JWT_ALGORITHM],
|
|
)
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise HTTPException(401, "令牌无效")
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(401, "用户不存在")
|
|
|
|
ur_result = await db.execute(
|
|
select(Role).join(UserRole).where(UserRole.user_id == user.id)
|
|
)
|
|
roles = ur_result.scalars().all()
|
|
|
|
permissions = []
|
|
data_scopes = []
|
|
for role in roles:
|
|
data_scopes.append(role.data_scope)
|
|
rp_result = await db.execute(
|
|
select(Permission.code)
|
|
.join(RolePermission)
|
|
.where(RolePermission.role_id == role.id)
|
|
)
|
|
perms = rp_result.scalars().all()
|
|
permissions.extend(perms)
|
|
|
|
return {
|
|
"id": str(user.id),
|
|
"username": user.username,
|
|
"display_name": user.display_name,
|
|
"department_id": str(user.department_id) if user.department_id else None,
|
|
"role": roles[0].code if roles else "employee",
|
|
"permissions": list(set(permissions)),
|
|
"data_scope": "all" if "all" in data_scopes else (
|
|
"subordinate_only" if "subordinate_only" in data_scopes else "self_only"
|
|
),
|
|
}
|
|
except jwt.PyJWTError:
|
|
raise HTTPException(401, "令牌无效或已过期")
|
|
|
|
raise HTTPException(401, "未提供认证令牌")
|
|
|
|
|
|
def require_permission(perm_code: str):
|
|
async def checker(user: dict = Depends(get_current_user)) -> dict:
|
|
if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []):
|
|
raise HTTPException(403, f"缺少权限: {perm_code}")
|
|
return user
|
|
return checker
|
|
|
|
|
|
async def get_db():
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|