"""对话模块路由。 提供基于流程的聊天功能,支持 WebSocket 实时通信和 HTTP 消息发送。 可以执行已发布的 AI 流程并将结果返回给客户端。 """ from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import FlowDefinition, FlowVersion from modules.flow_engine.engine import FlowEngine from agentscope.message import Msg from websocket_manager import ws_manager router = APIRouter(prefix="/api/chat", tags=["chat"]) @router.websocket("/ws") async def chat_websocket(websocket: WebSocket): """WebSocket 聊天连接处理器。 接受客户端的 WebSocket 连接并将其注册到 WebSocket 管理器中。 持续接收消息并回显给发送者,断开时自动清理连接。 Args: websocket: WebSocket 连接对象。 """ user_id = websocket.query_params.get("user_id", "anonymous") # 从查询参数获取用户 ID await ws_manager.connect(websocket, user_id) try: while True: data = await websocket.receive_text() await ws_manager.send_to_user(user_id, {"type": "echo", "message": data}) except WebSocketDisconnect: ws_manager.disconnect(websocket, user_id) @router.post("/message/{flow_id}") async def chat_message( flow_id: str, request: Request, payload: dict, db: AsyncSession = Depends(get_db), ): """向指定的已发布流程发送消息并获取 AI 回复。 加载流程定义后使用 FlowEngine 执行,将用户消息作为输入, 返回流程执行结果。 Args: flow_id: 流程定义的唯一标识 ID。 request: HTTP 请求对象,用于获取当前用户信息。 payload: 请求体,包含 message 或 query 字段作为输入文本。 db: 异步数据库会话。 Returns: dict: 包含 AI 回复和节点执行结果的响应数据。 Raises: HTTPException: 流程不存在、未发布或执行失败时抛出异常。 """ try: import uuid as _uuid fid = _uuid.UUID(flow_id) except ValueError: raise HTTPException(400, "无效的流ID") flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid)) flow = flow_result.scalar_one_or_none() if not flow or flow.status != "published": raise HTTPException(404, "流不存在或未发布") definition = flow.definition_json # 流程定义 JSON # 如果有已发布版本,优先使用版本的定义 published_version_id = getattr(flow, 'published_version_id', None) if published_version_id: ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id)) published = ver_result.scalar_one_or_none() if published and published.definition_json: import json definition = json.loads(json.dumps(published.definition_json)) user_ctx = request.state.user input_text = payload.get("message", payload.get("query", "")) # 用户输入文本 if not input_text: raise HTTPException(400, "请输入消息内容") engine = FlowEngine(definition) input_msg = Msg(name="user", content=input_text, role="user") context = { "user_id": user_ctx.get("id", "web_user"), "username": user_ctx.get("username", "网页访客"), "trigger_data": {"channel": "web_chat"}, # 触发渠道为网页聊天 "_node_results": {}, # 存储各节点的执行结果 "session_id": payload.get("session_id", str(_uuid.uuid4())), } try: result_msg = await engine.execute(input_msg, context) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) return { "code": 200, "data": { "reply": output_text, "node_results": context.get("_node_results", {}), }, } except Exception as e: raise HTTPException(500, f"流执行失败: {str(e)}") @router.get("/flows") async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): """列出所有已发布的流程,供聊天界面选择使用。 Args: request: HTTP 请求对象。 db: 异步数据库会话。 Returns: dict: 包含已发布流程列表的响应数据。 """ result = await db.execute( select(FlowDefinition).where(FlowDefinition.status == "published") ) flows = result.scalars().all() return { "code": 200, "data": [ { "id": str(f.id), "name": f.name, "description": f.description, "published_to_web": f.published_to_web, "published_to_wecom": f.published_to_wecom, } for f in flows ], }