You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

140 lines
4.9 KiB

"""对话模块路由。
提供基于流程的聊天功能,支持 WebSocket 实时通信和 HTTP 消息发送。
可以执行已发布的 AI 流程并将结果返回给客户端。
"""
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import FlowDefinition, FlowVersion
from modules.flow_engine.engine import FlowEngine
from agentscope.message import Msg
from websocket_manager import ws_manager
router = APIRouter(prefix="/api/chat", tags=["chat"])
@router.websocket("/ws")
async def chat_websocket(websocket: WebSocket):
"""WebSocket 聊天连接处理器。
接受客户端的 WebSocket 连接并将其注册到 WebSocket 管理器中。
持续接收消息并回显给发送者,断开时自动清理连接。
Args:
websocket: WebSocket 连接对象。
"""
user_id = websocket.query_params.get("user_id", "anonymous") # 从查询参数获取用户 ID
await ws_manager.connect(websocket, user_id)
try:
while True:
data = await websocket.receive_text()
await ws_manager.send_to_user(user_id, {"type": "echo", "message": data})
except WebSocketDisconnect:
ws_manager.disconnect(websocket, user_id)
@router.post("/message/{flow_id}")
async def chat_message(
flow_id: str,
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
):
"""向指定的已发布流程发送消息并获取 AI 回复。
加载流程定义后使用 FlowEngine 执行,将用户消息作为输入,
返回流程执行结果。
Args:
flow_id: 流程定义的唯一标识 ID。
request: HTTP 请求对象,用于获取当前用户信息。
payload: 请求体,包含 message 或 query 字段作为输入文本。
db: 异步数据库会话。
Returns:
dict: 包含 AI 回复和节点执行结果的响应数据。
Raises:
HTTPException: 流程不存在、未发布或执行失败时抛出异常。
"""
try:
import uuid as _uuid
fid = _uuid.UUID(flow_id)
except ValueError:
raise HTTPException(400, "无效的流ID")
flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid))
flow = flow_result.scalar_one_or_none()
if not flow or flow.status != "published":
raise HTTPException(404, "流不存在或未发布")
definition = flow.definition_json # 流程定义 JSON
# 如果有已发布版本,优先使用版本的定义
published_version_id = getattr(flow, 'published_version_id', None)
if published_version_id:
ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id))
published = ver_result.scalar_one_or_none()
if published and published.definition_json:
import json
definition = json.loads(json.dumps(published.definition_json))
user_ctx = request.state.user
input_text = payload.get("message", payload.get("query", "")) # 用户输入文本
if not input_text:
raise HTTPException(400, "请输入消息内容")
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=input_text, role="user")
context = {
"user_id": user_ctx.get("id", "web_user"),
"username": user_ctx.get("username", "网页访客"),
"trigger_data": {"channel": "web_chat"}, # 触发渠道为网页聊天
"_node_results": {}, # 存储各节点的执行结果
"session_id": payload.get("session_id", str(_uuid.uuid4())),
}
try:
result_msg = await engine.execute(input_msg, context)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
return {
"code": 200,
"data": {
"reply": output_text,
"node_results": context.get("_node_results", {}),
},
}
except Exception as e:
raise HTTPException(500, f"流执行失败: {str(e)}")
@router.get("/flows")
async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)):
"""列出所有已发布的流程,供聊天界面选择使用。
Args:
request: HTTP 请求对象。
db: 异步数据库会话。
Returns:
dict: 包含已发布流程列表的响应数据。
"""
result = await db.execute(
select(FlowDefinition).where(FlowDefinition.status == "published")
)
flows = result.scalars().all()
return {
"code": 200,
"data": [
{
"id": str(f.id),
"name": f.name,
"description": f.description,
"published_to_web": f.published_to_web,
"published_to_wecom": f.published_to_wecom,
}
for f in flows
],
}