You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
10 KiB
256 lines
10 KiB
import uuid
|
|
import time
|
|
import json
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, Query
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import FlowDefinition, FlowVersion, FlowExecution
|
|
from schemas import FlowChatMessageRequest
|
|
from modules.flow_engine.engine import FlowEngine
|
|
from agentscope.message import Msg
|
|
from middleware.apikey_auth import authenticate_api_key
|
|
from dependencies import get_current_user
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
gateway_router = APIRouter(prefix="/v1", tags=["gateway"])
|
|
|
|
|
|
async def _resolve_auth(request: Request) -> dict:
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer flow-"):
|
|
return await authenticate_api_key(request)
|
|
try:
|
|
user = await get_current_user(request)
|
|
return {"user": user, "auth_type": "jwt"}
|
|
except Exception:
|
|
raise HTTPException(401, "认证失败: 请使用 Bearer Token 或 API Key")
|
|
|
|
|
|
async def _get_definition_for_execute(flow_id: uuid.UUID, db: AsyncSession) -> dict:
|
|
f = await db.get(FlowDefinition, flow_id)
|
|
if not f:
|
|
raise HTTPException(404, "流不存在")
|
|
if f.status != "published":
|
|
raise HTTPException(400, "流未发布")
|
|
|
|
if f.published_version_id:
|
|
result = await db.execute(select(FlowVersion).where(FlowVersion.id == f.published_version_id))
|
|
published = result.scalar_one_or_none()
|
|
if published:
|
|
return json.loads(json.dumps(published.definition_json))
|
|
return f.definition_json
|
|
|
|
|
|
# ============================== 对话型流 ==============================
|
|
|
|
|
|
@gateway_router.post("/chat-messages")
|
|
async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)):
|
|
auth = await _resolve_auth(request)
|
|
|
|
body = await request.json()
|
|
query = body.get("query", "")
|
|
response_mode = body.get("response_mode", "blocking")
|
|
inputs = body.get("inputs", {})
|
|
user = body.get("user", "anonymous")
|
|
session_id = body.get("conversation_id", body.get("session_id"))
|
|
|
|
flow_id_str = body.get("flow_id") or inputs.get("flow_id")
|
|
if not flow_id_str:
|
|
raise HTTPException(400, "缺少 flow_id")
|
|
|
|
flow_id = uuid.UUID(flow_id_str)
|
|
definition = await _get_definition_for_execute(flow_id, db)
|
|
f = await db.get(FlowDefinition, flow_id)
|
|
|
|
input_text = query
|
|
if inputs:
|
|
extra = json.dumps(inputs, ensure_ascii=False)
|
|
if query:
|
|
input_text = f"{query}\n\n上下文数据:\n{extra}"
|
|
else:
|
|
input_text = extra
|
|
|
|
user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api")
|
|
username = user
|
|
|
|
if response_mode == "streaming":
|
|
return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db)
|
|
|
|
return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db)
|
|
|
|
|
|
async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db):
|
|
engine = FlowEngine(definition)
|
|
input_msg = Msg(name="user", content=input_text, role="user")
|
|
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
|
|
|
|
start_time = time.time()
|
|
try:
|
|
result_msg = await engine.execute(input_msg, context)
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
|
|
|
|
execution = FlowExecution(
|
|
flow_id=flow.id, version=flow.version,
|
|
trigger_type="api", input_data={"query": input_text},
|
|
output_data={"output": output_text}, status="completed",
|
|
latency_ms=elapsed_ms, finished_at=datetime.utcnow(),
|
|
)
|
|
db.add(execution)
|
|
|
|
return {
|
|
"event": "message",
|
|
"id": str(uuid.uuid4()),
|
|
"answer": output_text,
|
|
"conversation_id": session_id or "",
|
|
"created_at": int(time.time()),
|
|
"metadata": {
|
|
"usage": {"latency_ms": elapsed_ms},
|
|
"node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()},
|
|
},
|
|
}
|
|
except Exception as e:
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
execution = FlowExecution(
|
|
flow_id=flow.id, version=flow.version,
|
|
trigger_type="api", input_data={"query": input_text},
|
|
status="failed", latency_ms=elapsed_ms,
|
|
error_message=str(e)[:2000], finished_at=datetime.utcnow(),
|
|
)
|
|
db.add(execution)
|
|
raise HTTPException(500, f"流执行失败: {str(e)}")
|
|
|
|
|
|
async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db):
|
|
async def event_generator():
|
|
import asyncio
|
|
engine = FlowEngine(definition)
|
|
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
|
|
input_msg = Msg(name="user", content=input_text, role="user")
|
|
start_time = time.time()
|
|
msg_id = str(uuid.uuid4())
|
|
|
|
try:
|
|
yield f"data: {json.dumps({'event': 'workflow_started', 'task_id': msg_id, 'data': {'flow_id': str(flow_id)}}, ensure_ascii=False)}\n\n"
|
|
|
|
result_msg = await asyncio.wait_for(engine.execute(input_msg, context), timeout=engine.FLOW_TIMEOUT_SECONDS)
|
|
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
|
|
for i in range(0, len(output_text), 10):
|
|
chunk = output_text[i:i + 10]
|
|
yield f"data: {json.dumps({'event': 'message', 'task_id': msg_id, 'answer': chunk, 'created_at': int(time.time())}, ensure_ascii=False)}\n\n"
|
|
|
|
yield f"data: {json.dumps({'event': 'message_end', 'task_id': msg_id, 'id': msg_id, 'conversation_id': session_id or '', 'metadata': {'usage': {'latency_ms': elapsed_ms}, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}}}, ensure_ascii=False)}\n\n"
|
|
|
|
execution = FlowExecution(
|
|
flow_id=flow.id, version=flow.version,
|
|
trigger_type="api", input_data={"query": input_text},
|
|
output_data={"output": output_text},
|
|
status="completed", latency_ms=elapsed_ms,
|
|
finished_at=datetime.utcnow(),
|
|
)
|
|
db.add(execution)
|
|
except asyncio.TimeoutError:
|
|
yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': '执行超时'}, ensure_ascii=False)}\n\n"
|
|
except Exception as e:
|
|
yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': str(e)}, ensure_ascii=False)}\n\n"
|
|
finally:
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
# ============================== 工作流型流 ==============================
|
|
|
|
|
|
@gateway_router.post("/workflows/run")
|
|
async def workflows_run(request: Request, db: AsyncSession = Depends(get_db)):
|
|
auth = await _resolve_auth(request)
|
|
body = await request.json()
|
|
inputs = body.get("inputs", {})
|
|
response_mode = body.get("response_mode", "blocking")
|
|
user = body.get("user", "anonymous")
|
|
|
|
flow_id_str = body.get("workflow_id") or inputs.get("workflow_id") or inputs.get("flow_id")
|
|
if not flow_id_str:
|
|
raise HTTPException(400, "缺少 workflow_id")
|
|
|
|
flow_id = uuid.UUID(flow_id_str)
|
|
definition = await _get_definition_for_execute(flow_id, db)
|
|
f = await db.get(FlowDefinition, flow_id)
|
|
|
|
user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api")
|
|
|
|
engine = FlowEngine(definition)
|
|
input_msg = Msg(name="user", content=json.dumps(inputs, ensure_ascii=False), role="user")
|
|
context = {"user_id": user_id, "username": user, "_node_results": {}, "trigger_data": inputs}
|
|
|
|
start_time = time.time()
|
|
try:
|
|
result_msg = await engine.execute(input_msg, context)
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
|
|
|
|
execution = FlowExecution(
|
|
flow_id=f.id, version=f.version,
|
|
trigger_type="api", input_data={"inputs": inputs},
|
|
output_data={"output": output_text}, status="completed",
|
|
latency_ms=elapsed_ms, finished_at=datetime.utcnow(),
|
|
)
|
|
db.add(execution)
|
|
|
|
return {
|
|
"id": str(uuid.uuid4()),
|
|
"workflow_run_id": str(uuid.uuid4()),
|
|
"data": {
|
|
"outputs": {"text": output_text},
|
|
"node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()},
|
|
},
|
|
"metadata": {"latency_ms": elapsed_ms},
|
|
}
|
|
except Exception as e:
|
|
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
execution = FlowExecution(
|
|
flow_id=f.id, version=f.version,
|
|
trigger_type="api", input_data={"inputs": inputs},
|
|
status="failed", latency_ms=elapsed_ms,
|
|
error_message=str(e)[:2000], finished_at=datetime.utcnow(),
|
|
)
|
|
db.add(execution)
|
|
raise HTTPException(500, f"工作流执行失败: {str(e)}")
|
|
|
|
|
|
# ============================== 参数信息 ==============================
|
|
|
|
|
|
@gateway_router.get("/flows/{flow_id}/parameters")
|
|
async def get_flow_parameters(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
definition = await _get_definition_for_execute(flow_id, db)
|
|
nodes = definition.get("nodes", [])
|
|
trigger_nodes = [n for n in nodes if n.get("type") == "trigger"]
|
|
input_vars = []
|
|
if trigger_nodes:
|
|
trigger_config = trigger_nodes[0].get("config", {})
|
|
input_vars = [
|
|
{"name": "query", "type": "string", "description": "用户输入文本", "required": True},
|
|
{"name": "session_id", "type": "string", "description": "会话ID(用于多轮对话)", "required": False},
|
|
]
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"input_variables": input_vars,
|
|
"node_count": len(nodes),
|
|
"edge_count": len(definition.get("edges", [])),
|
|
},
|
|
}
|