You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
140 lines
4.9 KiB
140 lines
4.9 KiB
"""对话模块路由。
|
|
|
|
提供基于流程的聊天功能,支持 WebSocket 实时通信和 HTTP 消息发送。
|
|
可以执行已发布的 AI 流程并将结果返回给客户端。
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import FlowDefinition, FlowVersion
|
|
from modules.flow_engine.engine import FlowEngine
|
|
from agentscope.message import Msg
|
|
from websocket_manager import ws_manager
|
|
|
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def chat_websocket(websocket: WebSocket):
|
|
"""WebSocket 聊天连接处理器。
|
|
|
|
接受客户端的 WebSocket 连接并将其注册到 WebSocket 管理器中。
|
|
持续接收消息并回显给发送者,断开时自动清理连接。
|
|
|
|
Args:
|
|
websocket: WebSocket 连接对象。
|
|
"""
|
|
user_id = websocket.query_params.get("user_id", "anonymous") # 从查询参数获取用户 ID
|
|
await ws_manager.connect(websocket, user_id)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
await ws_manager.send_to_user(user_id, {"type": "echo", "message": data})
|
|
except WebSocketDisconnect:
|
|
ws_manager.disconnect(websocket, user_id)
|
|
|
|
|
|
@router.post("/message/{flow_id}")
|
|
async def chat_message(
|
|
flow_id: str,
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""向指定的已发布流程发送消息并获取 AI 回复。
|
|
|
|
加载流程定义后使用 FlowEngine 执行,将用户消息作为输入,
|
|
返回流程执行结果。
|
|
|
|
Args:
|
|
flow_id: 流程定义的唯一标识 ID。
|
|
request: HTTP 请求对象,用于获取当前用户信息。
|
|
payload: 请求体,包含 message 或 query 字段作为输入文本。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 包含 AI 回复和节点执行结果的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 流程不存在、未发布或执行失败时抛出异常。
|
|
"""
|
|
try:
|
|
import uuid as _uuid
|
|
fid = _uuid.UUID(flow_id)
|
|
except ValueError:
|
|
raise HTTPException(400, "无效的流ID")
|
|
|
|
flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid))
|
|
flow = flow_result.scalar_one_or_none()
|
|
if not flow or flow.status != "published":
|
|
raise HTTPException(404, "流不存在或未发布")
|
|
|
|
definition = flow.definition_json # 流程定义 JSON
|
|
# 如果有已发布版本,优先使用版本的定义
|
|
published_version_id = getattr(flow, 'published_version_id', None)
|
|
if published_version_id:
|
|
ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id))
|
|
published = ver_result.scalar_one_or_none()
|
|
if published and published.definition_json:
|
|
import json
|
|
definition = json.loads(json.dumps(published.definition_json))
|
|
|
|
user_ctx = request.state.user
|
|
input_text = payload.get("message", payload.get("query", "")) # 用户输入文本
|
|
if not input_text:
|
|
raise HTTPException(400, "请输入消息内容")
|
|
|
|
engine = FlowEngine(definition)
|
|
input_msg = Msg(name="user", content=input_text, role="user")
|
|
context = {
|
|
"user_id": user_ctx.get("id", "web_user"),
|
|
"username": user_ctx.get("username", "网页访客"),
|
|
"trigger_data": {"channel": "web_chat"}, # 触发渠道为网页聊天
|
|
"_node_results": {}, # 存储各节点的执行结果
|
|
"session_id": payload.get("session_id", str(_uuid.uuid4())),
|
|
}
|
|
|
|
try:
|
|
result_msg = await engine.execute(input_msg, context)
|
|
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"reply": output_text,
|
|
"node_results": context.get("_node_results", {}),
|
|
},
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(500, f"流执行失败: {str(e)}")
|
|
|
|
|
|
@router.get("/flows")
|
|
async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)):
|
|
"""列出所有已发布的流程,供聊天界面选择使用。
|
|
|
|
Args:
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 包含已发布流程列表的响应数据。
|
|
"""
|
|
result = await db.execute(
|
|
select(FlowDefinition).where(FlowDefinition.status == "published")
|
|
)
|
|
flows = result.scalars().all()
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(f.id),
|
|
"name": f.name,
|
|
"description": f.description,
|
|
"published_to_web": f.published_to_web,
|
|
"published_to_wecom": f.published_to_wecom,
|
|
}
|
|
for f in flows
|
|
],
|
|
}
|
|
|