You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
325 lines
10 KiB
325 lines
10 KiB
"""通知模块路由。
|
|
|
|
提供实时通知推送功能,支持 WebSocket 连接、消息广播、定向发送。
|
|
支持通知模板管理和企业微信推送集成。
|
|
"""
|
|
import uuid
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import NotificationTemplate, AuditLog
|
|
from schemas import NotificationTemplateCreate, NotificationTemplateOut
|
|
from config import settings
|
|
from dependencies import get_current_user
|
|
|
|
router = APIRouter(prefix="/api/notification", tags=["notification"])
|
|
|
|
|
|
class WebSocketManager:
|
|
"""WebSocket 连接管理器类,管理所有用户的 WebSocket 连接。
|
|
|
|
支持按用户 ID 管理多个连接,提供定向发送和广播功能。
|
|
|
|
Attributes:
|
|
connections: 用户 ID 到 WebSocket 连接列表的映射字典。
|
|
"""
|
|
def __init__(self):
|
|
"""初始化 WebSocket 管理器实例。"""
|
|
self.connections: dict[str, list[WebSocket]] = {}
|
|
|
|
async def connect(self, user_id: str, ws: WebSocket):
|
|
"""接受并注册新的 WebSocket 连接。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
ws: WebSocket 连接对象。
|
|
"""
|
|
await ws.accept()
|
|
if user_id not in self.connections:
|
|
self.connections[user_id] = []
|
|
self.connections[user_id].append(ws)
|
|
|
|
def disconnect(self, user_id: str, ws: WebSocket):
|
|
"""断开并移除指定的 WebSocket 连接。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
ws: 要移除的 WebSocket 连接对象。
|
|
"""
|
|
if user_id in self.connections:
|
|
self.connections[user_id].remove(ws)
|
|
if not self.connections[user_id]:
|
|
del self.connections[user_id]
|
|
|
|
async def send_to_user(self, user_id: str, message: dict):
|
|
"""向指定用户的所有连接发送消息。
|
|
|
|
自动清理失效的连接。
|
|
|
|
Args:
|
|
user_id: 目标用户唯一标识。
|
|
message: 要发送的消息字典。
|
|
"""
|
|
connections = self.connections.get(user_id, [])
|
|
dead = []
|
|
for ws in connections:
|
|
try:
|
|
await ws.send_json(message)
|
|
except Exception:
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
self.disconnect(user_id, ws)
|
|
|
|
async def broadcast(self, message: dict):
|
|
"""向所有在线用户广播消息。
|
|
|
|
Args:
|
|
message: 要广播的消息字典。
|
|
"""
|
|
for user_id in list(self.connections.keys()):
|
|
await self.send_to_user(user_id, message)
|
|
|
|
@property
|
|
def active_count(self) -> int:
|
|
"""获取当前活跃的 WebSocket 连接总数。
|
|
|
|
Returns:
|
|
int: 活跃连接数量。
|
|
"""
|
|
return sum(len(v) for v in self.connections.values())
|
|
|
|
|
|
ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例实例
|
|
|
|
|
|
@router.websocket("/ws/{user_id}")
|
|
async def notification_websocket(ws: WebSocket, user_id: str):
|
|
"""通知 WebSocket 连接处理器。
|
|
|
|
接受客户端的 WebSocket 连接,处理心跳 ping/pong 消息。
|
|
|
|
Args:
|
|
ws: WebSocket 连接对象。
|
|
user_id: 用户唯一标识,从路径参数获取。
|
|
"""
|
|
await ws_manager.connect(user_id, ws)
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
try:
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "ping":
|
|
await ws.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()})
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except WebSocketDisconnect:
|
|
ws_manager.disconnect(user_id, ws)
|
|
|
|
|
|
@router.post("/send")
|
|
async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
|
|
"""发送实时通知,支持定向发送和广播。
|
|
|
|
支持推送到企业微信,并记录审计日志。
|
|
|
|
Args:
|
|
payload: 请求体,包含 user_id、target_all、title、message、type、push_to_wecom 等字段。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 操作结果响应。
|
|
"""
|
|
user_id = payload.get("user_id", "") # 目标用户 ID
|
|
target_all = payload.get("target_all", False) # 是否广播给所有用户
|
|
title = payload.get("title", "系统通知") # 通知标题
|
|
body = payload.get("message", "") # 通知内容
|
|
notify_type = payload.get("type", "info") # 通知类型
|
|
|
|
msg = {
|
|
"type": notify_type,
|
|
"title": title,
|
|
"message": body,
|
|
"ts": datetime.utcnow().isoformat(),
|
|
}
|
|
|
|
if target_all:
|
|
await ws_manager.broadcast(msg) # 广播给所有在线用户
|
|
elif user_id:
|
|
await ws_manager.send_to_user(user_id, msg) # 定向发送给指定用户
|
|
|
|
# 推送到企业微信
|
|
if payload.get("push_to_wecom"):
|
|
await _push_to_wecom(title, body, user_id)
|
|
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(user["id"]),
|
|
action="notification.send",
|
|
resource="notification",
|
|
detail={"title": title, "target": user_id if user_id else "broadcast"},
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
await db.flush()
|
|
|
|
return {"code": 200, "message": "已发送"}
|
|
|
|
|
|
@router.get("/templates", response_model=list[NotificationTemplateOut])
|
|
async def list_templates(request: Request, db: AsyncSession = Depends(get_db)):
|
|
"""列出所有通知模板。
|
|
|
|
Args:
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
list[NotificationTemplateOut]: 通知模板列表。
|
|
"""
|
|
result = await db.execute(
|
|
select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.post("/templates", response_model=NotificationTemplateOut)
|
|
async def create_template(
|
|
req: NotificationTemplateCreate,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""创建新的通知模板。
|
|
|
|
Args:
|
|
req: 通知模板创建请求体。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
NotificationTemplateOut: 创建后的通知模板响应。
|
|
|
|
Raises:
|
|
HTTPException: 模板编码已存在时抛出异常。
|
|
"""
|
|
existing = await db.execute(
|
|
select(NotificationTemplate).where(NotificationTemplate.code == req.code)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
raise HTTPException(400, "模板编码已存在")
|
|
|
|
template = NotificationTemplate(
|
|
name=req.name,
|
|
code=req.code,
|
|
channel=req.channel,
|
|
title_template=req.title_template,
|
|
body_template=req.body_template,
|
|
variables=req.variables,
|
|
)
|
|
db.add(template)
|
|
await db.flush()
|
|
return template
|
|
|
|
|
|
@router.get("/templates/{template_id}", response_model=NotificationTemplateOut)
|
|
async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
"""获取指定通知模板的详细信息。
|
|
|
|
Args:
|
|
template_id: 通知模板唯一标识 ID。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
NotificationTemplateOut: 通知模板详细信息。
|
|
|
|
Raises:
|
|
HTTPException: 模板不存在时抛出异常。
|
|
"""
|
|
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
|
|
template = result.scalar_one_or_none()
|
|
if not template:
|
|
raise HTTPException(404, "模板不存在")
|
|
return template
|
|
|
|
|
|
@router.delete("/templates/{template_id}")
|
|
async def delete_template(
|
|
template_id: uuid.UUID, request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""删除指定的通知模板。
|
|
|
|
Args:
|
|
template_id: 通知模板唯一标识 ID。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 操作结果响应。
|
|
|
|
Raises:
|
|
HTTPException: 模板不存在或为系统模板时抛出异常。
|
|
"""
|
|
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
|
|
template = result.scalar_one_or_none()
|
|
if not template:
|
|
raise HTTPException(404, "模板不存在")
|
|
if template.is_system:
|
|
raise HTTPException(400, "系统模板不可删除")
|
|
await db.delete(template)
|
|
await db.flush()
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/ws/stats")
|
|
async def ws_stats():
|
|
"""获取 WebSocket 连接统计信息。
|
|
|
|
Returns:
|
|
dict: 包含活跃连接数的响应数据。
|
|
"""
|
|
return {"code": 200, "data": {"active_connections": ws_manager.active_count}}
|
|
|
|
|
|
async def _push_to_wecom(title: str, body: str, user_id: str):
|
|
"""将通知推送到企业微信。
|
|
|
|
Args:
|
|
title: 通知标题。
|
|
body: 通知内容。
|
|
user_id: 目标企业微信用户 ID。
|
|
"""
|
|
if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET:
|
|
return
|
|
|
|
try:
|
|
import httpx
|
|
async with httpx.AsyncClient() as client:
|
|
token_resp = await client.get(
|
|
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
|
params={"corpid": settings.WECOM_CORP_ID, "corpsecret": settings.WECOM_APP_SECRET},
|
|
)
|
|
token_data = token_resp.json()
|
|
access_token = token_data.get("access_token", "")
|
|
|
|
if access_token and user_id:
|
|
await client.post(
|
|
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}",
|
|
json={
|
|
"touser": user_id,
|
|
"msgtype": "text",
|
|
"agentid": 0,
|
|
"text": {"content": f"{title}\n{body}"},
|
|
},
|
|
)
|
|
except Exception:
|
|
pass
|
|
|