import uuid import json import asyncio from datetime import datetime from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import NotificationTemplate, AuditLog from schemas import NotificationTemplateCreate, NotificationTemplateOut from config import settings from dependencies import get_current_user router = APIRouter(prefix="/api/notification", tags=["notification"]) class WebSocketManager: def __init__(self): self.connections: dict[str, list[WebSocket]] = {} async def connect(self, user_id: str, ws: WebSocket): await ws.accept() if user_id not in self.connections: self.connections[user_id] = [] self.connections[user_id].append(ws) def disconnect(self, user_id: str, ws: WebSocket): if user_id in self.connections: self.connections[user_id].remove(ws) if not self.connections[user_id]: del self.connections[user_id] async def send_to_user(self, user_id: str, message: dict): connections = self.connections.get(user_id, []) dead = [] for ws in connections: try: await ws.send_json(message) except Exception: dead.append(ws) for ws in dead: self.disconnect(user_id, ws) async def broadcast(self, message: dict): for user_id in list(self.connections.keys()): await self.send_to_user(user_id, message) @property def active_count(self) -> int: return sum(len(v) for v in self.connections.values()) ws_manager = WebSocketManager() @router.websocket("/ws/{user_id}") async def notification_websocket(ws: WebSocket, user_id: str): await ws_manager.connect(user_id, ws) try: while True: data = await ws.receive_text() try: msg = json.loads(data) if msg.get("type") == "ping": await ws.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()}) except json.JSONDecodeError: pass except WebSocketDisconnect: ws_manager.disconnect(user_id, ws) @router.post("/send", dependencies=[Depends(get_current_user)]) async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db)): user_id = payload.get("user_id", "") target_all = payload.get("target_all", False) title = payload.get("title", "系统通知") body = payload.get("message", "") notify_type = payload.get("type", "info") msg = { "type": notify_type, "title": title, "message": body, "ts": datetime.utcnow().isoformat(), } if target_all: await ws_manager.broadcast(msg) elif user_id: await ws_manager.send_to_user(user_id, msg) if payload.get("push_to_wecom"): await _push_to_wecom(title, body, user_id) audit = AuditLog( operator_id=uuid.UUID(request.state.user["id"]), action="notification.send", resource="notification", detail={"title": title, "target": user_id if user_id else "broadcast"}, ip_address=request.client.host if request.client else None, ) db.add(audit) await db.flush() return {"code": 200, "message": "已发送"} @router.get("/templates", response_model=list[NotificationTemplateOut]) async def list_templates(request: Request, db: AsyncSession = Depends(get_db)): result = await db.execute( select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc()) ) return result.scalars().all() @router.post("/templates", response_model=NotificationTemplateOut) async def create_template( req: NotificationTemplateCreate, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): existing = await db.execute( select(NotificationTemplate).where(NotificationTemplate.code == req.code) ) if existing.scalar_one_or_none(): raise HTTPException(400, "模板编码已存在") template = NotificationTemplate( name=req.name, code=req.code, channel=req.channel, title_template=req.title_template, body_template=req.body_template, variables=req.variables, ) db.add(template) await db.flush() return template @router.get("/templates/{template_id}", response_model=NotificationTemplateOut) async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) template = result.scalar_one_or_none() if not template: raise HTTPException(404, "模板不存在") return template @router.delete("/templates/{template_id}") async def delete_template( template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) template = result.scalar_one_or_none() if not template: raise HTTPException(404, "模板不存在") if template.is_system: raise HTTPException(400, "系统模板不可删除") await db.delete(template) await db.flush() return {"code": 200, "message": "已删除"} @router.get("/ws/stats") async def ws_stats(): return {"code": 200, "data": {"active_connections": ws_manager.active_count}} async def _push_to_wecom(title: str, body: str, user_id: str): if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET: return try: import httpx async with httpx.AsyncClient() as client: token_resp = await client.get( "https://qyapi.weixin.qq.com/cgi-bin/gettoken", params={"corpid": settings.WECOM_CORP_ID, "corpsecret": settings.WECOM_APP_SECRET}, ) token_data = token_resp.json() access_token = token_data.get("access_token", "") if access_token and user_id: await client.post( f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}", json={ "touser": user_id, "msgtype": "text", "agentid": 0, "text": {"content": f"{title}\n{body}"}, }, ) except Exception: pass