You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
6.5 KiB
198 lines
6.5 KiB
import uuid
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import NotificationTemplate, AuditLog
|
|
from schemas import NotificationTemplateCreate, NotificationTemplateOut
|
|
from config import settings
|
|
from dependencies import get_current_user
|
|
|
|
router = APIRouter(prefix="/api/notification", tags=["notification"])
|
|
|
|
|
|
class WebSocketManager:
|
|
def __init__(self):
|
|
self.connections: dict[str, list[WebSocket]] = {}
|
|
|
|
async def connect(self, user_id: str, ws: WebSocket):
|
|
await ws.accept()
|
|
if user_id not in self.connections:
|
|
self.connections[user_id] = []
|
|
self.connections[user_id].append(ws)
|
|
|
|
def disconnect(self, user_id: str, ws: WebSocket):
|
|
if user_id in self.connections:
|
|
self.connections[user_id].remove(ws)
|
|
if not self.connections[user_id]:
|
|
del self.connections[user_id]
|
|
|
|
async def send_to_user(self, user_id: str, message: dict):
|
|
connections = self.connections.get(user_id, [])
|
|
dead = []
|
|
for ws in connections:
|
|
try:
|
|
await ws.send_json(message)
|
|
except Exception:
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
self.disconnect(user_id, ws)
|
|
|
|
async def broadcast(self, message: dict):
|
|
for user_id in list(self.connections.keys()):
|
|
await self.send_to_user(user_id, message)
|
|
|
|
@property
|
|
def active_count(self) -> int:
|
|
return sum(len(v) for v in self.connections.values())
|
|
|
|
|
|
ws_manager = WebSocketManager()
|
|
|
|
|
|
@router.websocket("/ws/{user_id}")
|
|
async def notification_websocket(ws: WebSocket, user_id: str):
|
|
await ws_manager.connect(user_id, ws)
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
try:
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "ping":
|
|
await ws.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()})
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except WebSocketDisconnect:
|
|
ws_manager.disconnect(user_id, ws)
|
|
|
|
|
|
@router.post("/send", dependencies=[Depends(get_current_user)])
|
|
async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db)):
|
|
user_id = payload.get("user_id", "")
|
|
target_all = payload.get("target_all", False)
|
|
title = payload.get("title", "系统通知")
|
|
body = payload.get("message", "")
|
|
notify_type = payload.get("type", "info")
|
|
|
|
msg = {
|
|
"type": notify_type,
|
|
"title": title,
|
|
"message": body,
|
|
"ts": datetime.utcnow().isoformat(),
|
|
}
|
|
|
|
if target_all:
|
|
await ws_manager.broadcast(msg)
|
|
elif user_id:
|
|
await ws_manager.send_to_user(user_id, msg)
|
|
|
|
if payload.get("push_to_wecom"):
|
|
await _push_to_wecom(title, body, user_id)
|
|
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(request.state.user["id"]),
|
|
action="notification.send",
|
|
resource="notification",
|
|
detail={"title": title, "target": user_id if user_id else "broadcast"},
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
await db.flush()
|
|
|
|
return {"code": 200, "message": "已发送"}
|
|
|
|
|
|
@router.get("/templates", response_model=list[NotificationTemplateOut])
|
|
async def list_templates(request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.post("/templates", response_model=NotificationTemplateOut)
|
|
async def create_template(
|
|
req: NotificationTemplateCreate,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
existing = await db.execute(
|
|
select(NotificationTemplate).where(NotificationTemplate.code == req.code)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
raise HTTPException(400, "模板编码已存在")
|
|
|
|
template = NotificationTemplate(
|
|
name=req.name,
|
|
code=req.code,
|
|
channel=req.channel,
|
|
title_template=req.title_template,
|
|
body_template=req.body_template,
|
|
variables=req.variables,
|
|
)
|
|
db.add(template)
|
|
await db.flush()
|
|
return template
|
|
|
|
|
|
@router.get("/templates/{template_id}", response_model=NotificationTemplateOut)
|
|
async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
|
|
template = result.scalar_one_or_none()
|
|
if not template:
|
|
raise HTTPException(404, "模板不存在")
|
|
return template
|
|
|
|
|
|
@router.delete("/templates/{template_id}")
|
|
async def delete_template(
|
|
template_id: uuid.UUID, request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
|
|
template = result.scalar_one_or_none()
|
|
if not template:
|
|
raise HTTPException(404, "模板不存在")
|
|
if template.is_system:
|
|
raise HTTPException(400, "系统模板不可删除")
|
|
await db.delete(template)
|
|
await db.flush()
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/ws/stats")
|
|
async def ws_stats():
|
|
return {"code": 200, "data": {"active_connections": ws_manager.active_count}}
|
|
|
|
|
|
async def _push_to_wecom(title: str, body: str, user_id: str):
|
|
if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET:
|
|
return
|
|
|
|
try:
|
|
import httpx
|
|
async with httpx.AsyncClient() as client:
|
|
token_resp = await client.get(
|
|
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
|
params={"corpid": settings.WECOM_CORP_ID, "corpsecret": settings.WECOM_APP_SECRET},
|
|
)
|
|
token_data = token_resp.json()
|
|
access_token = token_data.get("access_token", "")
|
|
|
|
if access_token and user_id:
|
|
await client.post(
|
|
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}",
|
|
json={
|
|
"touser": user_id,
|
|
"msgtype": "text",
|
|
"agentid": 0,
|
|
"text": {"content": f"{title}\n{body}"},
|
|
},
|
|
)
|
|
except Exception:
|
|
pass
|