You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

198 lines
6.5 KiB

import uuid
import json
import asyncio
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import NotificationTemplate, AuditLog
from schemas import NotificationTemplateCreate, NotificationTemplateOut
from config import settings
from dependencies import get_current_user
router = APIRouter(prefix="/api/notification", tags=["notification"])
class WebSocketManager:
def __init__(self):
self.connections: dict[str, list[WebSocket]] = {}
async def connect(self, user_id: str, ws: WebSocket):
await ws.accept()
if user_id not in self.connections:
self.connections[user_id] = []
self.connections[user_id].append(ws)
def disconnect(self, user_id: str, ws: WebSocket):
if user_id in self.connections:
self.connections[user_id].remove(ws)
if not self.connections[user_id]:
del self.connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
connections = self.connections.get(user_id, [])
dead = []
for ws in connections:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
for ws in dead:
self.disconnect(user_id, ws)
async def broadcast(self, message: dict):
for user_id in list(self.connections.keys()):
await self.send_to_user(user_id, message)
@property
def active_count(self) -> int:
return sum(len(v) for v in self.connections.values())
ws_manager = WebSocketManager()
@router.websocket("/ws/{user_id}")
async def notification_websocket(ws: WebSocket, user_id: str):
await ws_manager.connect(user_id, ws)
try:
while True:
data = await ws.receive_text()
try:
msg = json.loads(data)
if msg.get("type") == "ping":
await ws.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()})
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
ws_manager.disconnect(user_id, ws)
@router.post("/send")
async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
user_id = payload.get("user_id", "")
target_all = payload.get("target_all", False)
title = payload.get("title", "系统通知")
body = payload.get("message", "")
notify_type = payload.get("type", "info")
msg = {
"type": notify_type,
"title": title,
"message": body,
"ts": datetime.utcnow().isoformat(),
}
if target_all:
await ws_manager.broadcast(msg)
elif user_id:
await ws_manager.send_to_user(user_id, msg)
if payload.get("push_to_wecom"):
await _push_to_wecom(title, body, user_id)
audit = AuditLog(
operator_id=uuid.UUID(user["id"]),
action="notification.send",
resource="notification",
detail={"title": title, "target": user_id if user_id else "broadcast"},
ip_address=request.client.host if request.client else None,
)
db.add(audit)
await db.flush()
return {"code": 200, "message": "已发送"}
@router.get("/templates", response_model=list[NotificationTemplateOut])
async def list_templates(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc())
)
return result.scalars().all()
@router.post("/templates", response_model=NotificationTemplateOut)
async def create_template(
req: NotificationTemplateCreate,
request: Request,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
):
existing = await db.execute(
select(NotificationTemplate).where(NotificationTemplate.code == req.code)
)
if existing.scalar_one_or_none():
raise HTTPException(400, "模板编码已存在")
template = NotificationTemplate(
name=req.name,
code=req.code,
channel=req.channel,
title_template=req.title_template,
body_template=req.body_template,
variables=req.variables,
)
db.add(template)
await db.flush()
return template
@router.get("/templates/{template_id}", response_model=NotificationTemplateOut)
async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none()
if not template:
raise HTTPException(404, "模板不存在")
return template
@router.delete("/templates/{template_id}")
async def delete_template(
template_id: uuid.UUID, request: Request,
db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user),
):
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none()
if not template:
raise HTTPException(404, "模板不存在")
if template.is_system:
raise HTTPException(400, "系统模板不可删除")
await db.delete(template)
await db.flush()
return {"code": 200, "message": "已删除"}
@router.get("/ws/stats")
async def ws_stats():
return {"code": 200, "data": {"active_connections": ws_manager.active_count}}
async def _push_to_wecom(title: str, body: str, user_id: str):
if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET:
return
try:
import httpx
async with httpx.AsyncClient() as client:
token_resp = await client.get(
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
params={"corpid": settings.WECOM_CORP_ID, "corpsecret": settings.WECOM_APP_SECRET},
)
token_data = token_resp.json()
access_token = token_data.get("access_token", "")
if access_token and user_id:
await client.post(
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}",
json={
"touser": user_id,
"msgtype": "text",
"agentid": 0,
"text": {"content": f"{title}\n{body}"},
},
)
except Exception:
pass