You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

279 lines
9.5 KiB

"""认证模块路由。
提供用户登录、JWT 令牌生成、企业微信 OAuth 授权、个人信息查询/修改和密码修改等功能。
支持基于用户名密码和企业微信 OAuth 两种认证方式。
"""
import uuid
import secrets
from datetime import datetime, timedelta
import jwt
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
import bcrypt
from database import get_db
from models import User, UserRole, Role, RolePermission, Permission
from schemas import LoginRequest, TokenResponse, UserOut, RoleOut
from config import settings
# OAuth 状态存储,用于防止 CSRF 攻击
_oauth_states: dict[str, float] = {}
_OAUTH_STATE_TTL = 600 # OAuth 状态有效期(秒)
def hash_password(password: str) -> str:
"""对密码进行 bcrypt 哈希加密。
Args:
password: 明文密码字符串。
Returns:
str: bcrypt 加密后的哈希字符串。
"""
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
router = APIRouter(prefix="/api/auth", tags=["auth"]) # 认证路由前缀
async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> list[str]:
"""根据角色 ID 列表获取所有关联的权限代码。
Args:
db: 异步数据库会话。
role_ids: 角色 ID 列表。
Returns:
list[str]: 去重后的权限代码列表。
"""
result = await db.execute(
select(Permission.code)
.join(RolePermission)
.where(RolePermission.role_id.in_(role_ids))
)
return list(set(result.scalars().all()))
async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]:
"""获取用户的所有角色及其权限信息。
查询用户关联的所有角色,并为每个角色查询其关联的权限代码。
Args:
db: 异步数据库会话。
user_id: 用户唯一标识 ID。
Returns:
list[RoleOut]: 角色信息列表,每个角色包含名称、代码、描述和权限代码列表。
"""
result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user_id)
)
roles = result.scalars().all()
out = []
for role in roles:
# 查询该角色关联的所有权限代码
rp_result = await db.execute(
select(Permission.code)
.join(RolePermission)
.where(RolePermission.role_id == role.id)
)
perms = list(rp_result.scalars().all())
out.append(RoleOut(
id=role.id,
name=role.name,
code=role.code,
description=role.description,
is_system=role.is_system,
data_scope=role.data_scope,
permissions=perms,
))
return out
@router.post("/login", response_model=TokenResponse)
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
"""用户登录接口。
验证用户名和密码,验证通过后生成 JWT 令牌并返回用户信息。
Args:
req: 登录请求体,包含用户名和密码。
db: 异步数据库会话。
Returns:
TokenResponse: 包含访问令牌和用户信息的响应。
Raises:
HTTPException: 用户名或密码错误、账户被禁用时抛出异常。
"""
result = await db.execute(select(User).where(User.username == req.username))
user = result.scalar_one_or_none()
if not user or not bcrypt.checkpw(req.password.encode('utf-8'), user.password_hash.encode('utf-8')):
raise HTTPException(401, "用户名或密码错误")
if user.status != "active":
raise HTTPException(403, "账户已被禁用")
roles = await get_user_roles(db, user.id) # 获取用户角色信息
# 生成 JWT 令牌
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
token = jwt.encode(
{"sub": str(user.id), "username": user.username, "exp": expire},
settings.JWT_SECRET,
algorithm=settings.JWT_ALGORITHM,
)
return TokenResponse(
access_token=token,
user=UserOut(
id=user.id, username=user.username, display_name=user.display_name,
email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id,
department_id=user.department_id, position=user.position,
manager_id=user.manager_id, status=user.status,
roles=roles, created_at=user.created_at,
),
)
@router.get("/me", response_model=UserOut)
async def get_me(request: Request, db: AsyncSession = Depends(get_db)):
"""获取当前登录用户的详细信息。
Args:
request: HTTP 请求对象,包含当前用户上下文。
db: 异步数据库会话。
Returns:
UserOut: 当前用户信息,包含角色列表。
Raises:
HTTPException: 用户不存在时抛出异常。
"""
user_ctx = request.state.user # 从请求状态中获取当前用户上下文
result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(404, "用户不存在")
roles = await get_user_roles(db, user.id) # 获取用户角色信息
return UserOut(
id=user.id, username=user.username, display_name=user.display_name,
email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id,
department_id=user.department_id, position=user.position,
manager_id=user.manager_id, status=user.status,
roles=roles, created_at=user.created_at,
)
@router.get("/wecom/oauth-url")
async def get_wecom_oauth_url(request: Request):
"""获取企业微信 OAuth 授权 URL。
生成用于企业微信网页授权登录的 URL,包含随机 state 参数用于防 CSRF 攻击。
Args:
request: HTTP 请求对象,用于获取基础 URL 构建回调地址。
Returns:
dict: 包含 OAuth 授权 URL 和 state 参数的响应数据。
"""
corp_id = settings.WECOM_CORP_ID or ""
if not corp_id:
return {"code": 400, "message": "请先配置 WECOM_CORP_ID"}
base_url = str(request.base_url).rstrip("/")
redirect_uri = f"{base_url}/api/auth/wecom/callback" # OAuth 回调地址
state = secrets.token_urlsafe(32) # 生成随机 state 用于防 CSRF
import time
_oauth_states[state] = time.time() # 存储 state 及其创建时间
# 清理过期的 state
expired = [k for k, v in _oauth_states.items() if time.time() - v > _OAUTH_STATE_TTL]
for k in expired:
del _oauth_states[k]
# 拼接企业微信 OAuth 授权 URL
url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state={state}#wechat_redirect"
return {"code": 200, "data": {"url": url, "state": state}}
@router.put("/me")
async def update_me(
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
):
"""更新当前用户的个人信息。
支持修改显示名称、邮箱和手机号。
Args:
request: HTTP 请求对象,包含当前用户上下文。
payload: 更新字段字典,可包含 display_name、email、phone。
db: 异步数据库会话。
Returns:
UserOut: 更新后的用户信息。
Raises:
HTTPException: 用户不存在时抛出异常。
"""
user_ctx = request.state.user
result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(404, "用户不存在")
if "display_name" in payload:
user.display_name = payload["display_name"]
if "email" in payload:
user.email = payload["email"]
if "phone" in payload:
user.phone = payload["phone"]
await db.commit()
roles = await get_user_roles(db, user.id)
return UserOut(
id=user.id, username=user.username, display_name=user.display_name,
email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id,
department_id=user.department_id, position=user.position,
manager_id=user.manager_id, status=user.status,
roles=roles, created_at=user.created_at,
)
@router.put("/password")
async def change_password(
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
):
"""修改当前用户的登录密码。
需要验证旧密码正确性,新密码至少 6 位。
Args:
request: HTTP 请求对象,包含当前用户上下文。
payload: 包含 old_password 和 new_password 的字典。
db: 异步数据库会话。
Returns:
dict: 修改成功的响应数据。
Raises:
HTTPException: 用户不存在、旧密码错误或新密码长度不足时抛出异常。
"""
user_ctx = request.state.user
result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(404, "用户不存在")
old_pw = payload.get("old_password", "")
new_pw = payload.get("new_password", "")
if not bcrypt.checkpw(old_pw.encode('utf-8'), user.password_hash.encode('utf-8')):
raise HTTPException(400, "当前密码错误")
if len(new_pw) < 6:
raise HTTPException(400, "新密码至少6位")
user.password_hash = hash_password(new_pw) # 更新为新密码哈希
await db.commit()
return {"code": 200, "message": "密码已修改"}