You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

325 lines
10 KiB

"""通知模块路由。
提供实时通知推送功能,支持 WebSocket 连接、消息广播、定向发送。
支持通知模板管理和企业微信推送集成。
"""
import uuid
import json
import asyncio
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import NotificationTemplate, AuditLog
from schemas import NotificationTemplateCreate, NotificationTemplateOut
from config import settings
from dependencies import get_current_user
router = APIRouter(prefix="/api/notification", tags=["notification"])
class WebSocketManager:
"""WebSocket 连接管理器类,管理所有用户的 WebSocket 连接。
支持按用户 ID 管理多个连接,提供定向发送和广播功能。
Attributes:
connections: 用户 ID 到 WebSocket 连接列表的映射字典。
"""
def __init__(self):
"""初始化 WebSocket 管理器实例。"""
self.connections: dict[str, list[WebSocket]] = {}
async def connect(self, user_id: str, ws: WebSocket):
"""接受并注册新的 WebSocket 连接。
Args:
user_id: 用户唯一标识。
ws: WebSocket 连接对象。
"""
await ws.accept()
if user_id not in self.connections:
self.connections[user_id] = []
self.connections[user_id].append(ws)
def disconnect(self, user_id: str, ws: WebSocket):
"""断开并移除指定的 WebSocket 连接。
Args:
user_id: 用户唯一标识。
ws: 要移除的 WebSocket 连接对象。
"""
if user_id in self.connections:
self.connections[user_id].remove(ws)
if not self.connections[user_id]:
del self.connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
"""向指定用户的所有连接发送消息。
自动清理失效的连接。
Args:
user_id: 目标用户唯一标识。
message: 要发送的消息字典。
"""
connections = self.connections.get(user_id, [])
dead = []
for ws in connections:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
for ws in dead:
self.disconnect(user_id, ws)
async def broadcast(self, message: dict):
"""向所有在线用户广播消息。
Args:
message: 要广播的消息字典。
"""
for user_id in list(self.connections.keys()):
await self.send_to_user(user_id, message)
@property
def active_count(self) -> int:
"""获取当前活跃的 WebSocket 连接总数。
Returns:
int: 活跃连接数量。
"""
return sum(len(v) for v in self.connections.values())
ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例实例
@router.websocket("/ws/{user_id}")
async def notification_websocket(ws: WebSocket, user_id: str):
"""通知 WebSocket 连接处理器。
接受客户端的 WebSocket 连接,处理心跳 ping/pong 消息。
Args:
ws: WebSocket 连接对象。
user_id: 用户唯一标识,从路径参数获取。
"""
await ws_manager.connect(user_id, ws)
try:
while True:
data = await ws.receive_text()
try:
msg = json.loads(data)
if msg.get("type") == "ping":
await ws.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()})
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
ws_manager.disconnect(user_id, ws)
@router.post("/send")
async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
"""发送实时通知,支持定向发送和广播。
支持推送到企业微信,并记录审计日志。
Args:
payload: 请求体,包含 user_id、target_all、title、message、type、push_to_wecom 等字段。
request: HTTP 请求对象。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
"""
user_id = payload.get("user_id", "") # 目标用户 ID
target_all = payload.get("target_all", False) # 是否广播给所有用户
title = payload.get("title", "系统通知") # 通知标题
body = payload.get("message", "") # 通知内容
notify_type = payload.get("type", "info") # 通知类型
msg = {
"type": notify_type,
"title": title,
"message": body,
"ts": datetime.utcnow().isoformat(),
}
if target_all:
await ws_manager.broadcast(msg) # 广播给所有在线用户
elif user_id:
await ws_manager.send_to_user(user_id, msg) # 定向发送给指定用户
# 推送到企业微信
if payload.get("push_to_wecom"):
await _push_to_wecom(title, body, user_id)
audit = AuditLog(
operator_id=uuid.UUID(user["id"]),
action="notification.send",
resource="notification",
detail={"title": title, "target": user_id if user_id else "broadcast"},
ip_address=request.client.host if request.client else None,
)
db.add(audit)
await db.flush()
return {"code": 200, "message": "已发送"}
@router.get("/templates", response_model=list[NotificationTemplateOut])
async def list_templates(request: Request, db: AsyncSession = Depends(get_db)):
"""列出所有通知模板。
Args:
request: HTTP 请求对象。
db: 异步数据库会话。
Returns:
list[NotificationTemplateOut]: 通知模板列表。
"""
result = await db.execute(
select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc())
)
return result.scalars().all()
@router.post("/templates", response_model=NotificationTemplateOut)
async def create_template(
req: NotificationTemplateCreate,
request: Request,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
):
"""创建新的通知模板。
Args:
req: 通知模板创建请求体。
request: HTTP 请求对象。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
NotificationTemplateOut: 创建后的通知模板响应。
Raises:
HTTPException: 模板编码已存在时抛出异常。
"""
existing = await db.execute(
select(NotificationTemplate).where(NotificationTemplate.code == req.code)
)
if existing.scalar_one_or_none():
raise HTTPException(400, "模板编码已存在")
template = NotificationTemplate(
name=req.name,
code=req.code,
channel=req.channel,
title_template=req.title_template,
body_template=req.body_template,
variables=req.variables,
)
db.add(template)
await db.flush()
return template
@router.get("/templates/{template_id}", response_model=NotificationTemplateOut)
async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""获取指定通知模板的详细信息。
Args:
template_id: 通知模板唯一标识 ID。
request: HTTP 请求对象。
db: 异步数据库会话。
Returns:
NotificationTemplateOut: 通知模板详细信息。
Raises:
HTTPException: 模板不存在时抛出异常。
"""
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none()
if not template:
raise HTTPException(404, "模板不存在")
return template
@router.delete("/templates/{template_id}")
async def delete_template(
template_id: uuid.UUID, request: Request,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
):
"""删除指定的通知模板。
Args:
template_id: 通知模板唯一标识 ID。
request: HTTP 请求对象。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 模板不存在或为系统模板时抛出异常。
"""
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none()
if not template:
raise HTTPException(404, "模板不存在")
if template.is_system:
raise HTTPException(400, "系统模板不可删除")
await db.delete(template)
await db.flush()
return {"code": 200, "message": "已删除"}
@router.get("/ws/stats")
async def ws_stats():
"""获取 WebSocket 连接统计信息。
Returns:
dict: 包含活跃连接数的响应数据。
"""
return {"code": 200, "data": {"active_connections": ws_manager.active_count}}
async def _push_to_wecom(title: str, body: str, user_id: str):
"""将通知推送到企业微信。
Args:
title: 通知标题。
body: 通知内容。
user_id: 目标企业微信用户 ID。
"""
if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET:
return
try:
import httpx
async with httpx.AsyncClient() as client:
token_resp = await client.get(
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
params={"corpid": settings.WECOM_CORP_ID, "corpsecret": settings.WECOM_APP_SECRET},
)
token_data = token_resp.json()
access_token = token_data.get("access_token", "")
if access_token and user_id:
await client.post(
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}",
json={
"touser": user_id,
"msgtype": "text",
"agentid": 0,
"text": {"content": f"{title}\n{body}"},
},
)
except Exception:
pass