from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import FlowDefinition, FlowVersion from modules.flow_engine.engine import FlowEngine from agentscope.message import Msg from websocket_manager import ws_manager router = APIRouter(prefix="/api/chat", tags=["chat"]) @router.websocket("/ws") async def chat_websocket(websocket: WebSocket): user_id = websocket.query_params.get("user_id", "anonymous") await ws_manager.connect(websocket, user_id) try: while True: data = await websocket.receive_text() await ws_manager.send_to_user(user_id, {"type": "echo", "message": data}) except WebSocketDisconnect: ws_manager.disconnect(websocket, user_id) @router.post("/message/{flow_id}") async def chat_message( flow_id: str, request: Request, payload: dict, db: AsyncSession = Depends(get_db), ): try: import uuid as _uuid fid = _uuid.UUID(flow_id) except ValueError: raise HTTPException(400, "无效的流ID") flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid)) flow = flow_result.scalar_one_or_none() if not flow or flow.status != "published": raise HTTPException(404, "流不存在或未发布") definition = flow.definition_json published_version_id = getattr(flow, 'published_version_id', None) if published_version_id: ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id)) published = ver_result.scalar_one_or_none() if published and published.definition_json: import json definition = json.loads(json.dumps(published.definition_json)) user_ctx = request.state.user input_text = payload.get("message", payload.get("query", "")) if not input_text: raise HTTPException(400, "请输入消息内容") engine = FlowEngine(definition) input_msg = Msg(name="user", content=input_text, role="user") context = { "user_id": user_ctx.get("id", "web_user"), "username": user_ctx.get("username", "网页访客"), "trigger_data": {"channel": "web_chat"}, "_node_results": {}, "session_id": payload.get("session_id", str(uuid.uuid4())), } try: result_msg = await engine.execute(input_msg, context) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) return { "code": 200, "data": { "reply": output_text, "node_results": context.get("_node_results", {}), }, } except Exception as e: raise HTTPException(500, f"流执行失败: {str(e)}") @router.get("/flows") async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): result = await db.execute( select(FlowDefinition).where(FlowDefinition.status == "published") ) flows = result.scalars().all() return { "code": 200, "data": [ { "id": str(f.id), "name": f.name, "description": f.description, "published_to_web": f.published_to_web, "published_to_wecom": f.published_to_wecom, } for f in flows ], }