You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

99 lines
3.4 KiB

from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import FlowDefinition, FlowVersion
from modules.flow_engine.engine import FlowEngine
from agentscope.message import Msg
from websocket_manager import ws_manager
router = APIRouter(prefix="/api/chat", tags=["chat"])
@router.websocket("/ws")
async def chat_websocket(websocket: WebSocket):
user_id = websocket.query_params.get("user_id", "anonymous")
await ws_manager.connect(websocket, user_id)
try:
while True:
data = await websocket.receive_text()
await ws_manager.send_to_user(user_id, {"type": "echo", "message": data})
except WebSocketDisconnect:
ws_manager.disconnect(websocket, user_id)
@router.post("/message/{flow_id}")
async def chat_message(
flow_id: str,
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
):
try:
import uuid as _uuid
fid = _uuid.UUID(flow_id)
except ValueError:
raise HTTPException(400, "无效的流ID")
flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid))
flow = flow_result.scalar_one_or_none()
if not flow or flow.status != "published":
raise HTTPException(404, "流不存在或未发布")
definition = flow.definition_json
if flow.published_version_id:
ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == flow.published_version_id))
published = ver_result.scalar_one_or_none()
if published and published.definition_json:
import json
definition = json.loads(json.dumps(published.definition_json))
user_ctx = request.state.user
input_text = payload.get("message", payload.get("query", ""))
if not input_text:
raise HTTPException(400, "请输入消息内容")
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=input_text, role="user")
context = {
"user_id": user_ctx.get("id", "web_user"),
"username": user_ctx.get("username", "网页访客"),
"trigger_data": {"channel": "web_chat"},
"_node_results": {},
"session_id": payload.get("session_id", str(uuid.uuid4())),
}
try:
result_msg = await engine.execute(input_msg, context)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
return {
"code": 200,
"data": {
"reply": output_text,
"node_results": context.get("_node_results", {}),
},
}
except Exception as e:
raise HTTPException(500, f"流执行失败: {str(e)}")
@router.get("/flows")
async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(FlowDefinition).where(FlowDefinition.status == "published")
)
flows = result.scalars().all()
return {
"code": 200,
"data": [
{
"id": str(f.id),
"name": f.name,
"description": f.description,
"published_to_web": f.published_to_web,
"published_to_wecom": f.published_to_wecom,
}
for f in flows
],
}