You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.4 KiB
100 lines
3.4 KiB
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import FlowDefinition, FlowVersion
|
|
from modules.flow_engine.engine import FlowEngine
|
|
from agentscope.message import Msg
|
|
from websocket_manager import ws_manager
|
|
|
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def chat_websocket(websocket: WebSocket):
|
|
user_id = websocket.query_params.get("user_id", "anonymous")
|
|
await ws_manager.connect(websocket, user_id)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
await ws_manager.send_to_user(user_id, {"type": "echo", "message": data})
|
|
except WebSocketDisconnect:
|
|
ws_manager.disconnect(websocket, user_id)
|
|
|
|
|
|
@router.post("/message/{flow_id}")
|
|
async def chat_message(
|
|
flow_id: str,
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
try:
|
|
import uuid as _uuid
|
|
fid = _uuid.UUID(flow_id)
|
|
except ValueError:
|
|
raise HTTPException(400, "无效的流ID")
|
|
|
|
flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid))
|
|
flow = flow_result.scalar_one_or_none()
|
|
if not flow or flow.status != "published":
|
|
raise HTTPException(404, "流不存在或未发布")
|
|
|
|
definition = flow.definition_json
|
|
published_version_id = getattr(flow, 'published_version_id', None)
|
|
if published_version_id:
|
|
ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id))
|
|
published = ver_result.scalar_one_or_none()
|
|
if published and published.definition_json:
|
|
import json
|
|
definition = json.loads(json.dumps(published.definition_json))
|
|
|
|
user_ctx = request.state.user
|
|
input_text = payload.get("message", payload.get("query", ""))
|
|
if not input_text:
|
|
raise HTTPException(400, "请输入消息内容")
|
|
|
|
engine = FlowEngine(definition)
|
|
input_msg = Msg(name="user", content=input_text, role="user")
|
|
context = {
|
|
"user_id": user_ctx.get("id", "web_user"),
|
|
"username": user_ctx.get("username", "网页访客"),
|
|
"trigger_data": {"channel": "web_chat"},
|
|
"_node_results": {},
|
|
"session_id": payload.get("session_id", str(uuid.uuid4())),
|
|
}
|
|
|
|
try:
|
|
result_msg = await engine.execute(input_msg, context)
|
|
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"reply": output_text,
|
|
"node_results": context.get("_node_results", {}),
|
|
},
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(500, f"流执行失败: {str(e)}")
|
|
|
|
|
|
@router.get("/flows")
|
|
async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(FlowDefinition).where(FlowDefinition.status == "published")
|
|
)
|
|
flows = result.scalars().all()
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(f.id),
|
|
"name": f.name,
|
|
"description": f.description,
|
|
"published_to_web": f.published_to_web,
|
|
"published_to_wecom": f.published_to_wecom,
|
|
}
|
|
for f in flows
|
|
],
|
|
}
|