You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

256 lines
10 KiB

import uuid
import time
import json
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import FlowDefinition, FlowVersion, FlowExecution
from schemas import FlowChatMessageRequest
from modules.flow_engine.engine import FlowEngine
from agentscope.message import Msg
from middleware.apikey_auth import authenticate_api_key
from dependencies import get_current_user
import logging
logger = logging.getLogger(__name__)
gateway_router = APIRouter(prefix="/v1", tags=["gateway"])
async def _resolve_auth(request: Request) -> dict:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer flow-"):
return await authenticate_api_key(request)
try:
user = await get_current_user(request)
return {"user": user, "auth_type": "jwt"}
except Exception:
raise HTTPException(401, "认证失败: 请使用 Bearer Token 或 API Key")
async def _get_definition_for_execute(flow_id: uuid.UUID, db: AsyncSession) -> dict:
f = await db.get(FlowDefinition, flow_id)
if not f:
raise HTTPException(404, "流不存在")
if f.status != "published":
raise HTTPException(400, "流未发布")
if f.published_version_id:
result = await db.execute(select(FlowVersion).where(FlowVersion.id == f.published_version_id))
published = result.scalar_one_or_none()
if published:
return json.loads(json.dumps(published.definition_json))
return f.definition_json
# ============================== 对话型流 ==============================
@gateway_router.post("/chat-messages")
async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)):
auth = await _resolve_auth(request)
body = await request.json()
query = body.get("query", "")
response_mode = body.get("response_mode", "blocking")
inputs = body.get("inputs", {})
user = body.get("user", "anonymous")
session_id = body.get("conversation_id", body.get("session_id"))
flow_id_str = body.get("flow_id") or inputs.get("flow_id")
if not flow_id_str:
raise HTTPException(400, "缺少 flow_id")
flow_id = uuid.UUID(flow_id_str)
definition = await _get_definition_for_execute(flow_id, db)
f = await db.get(FlowDefinition, flow_id)
input_text = query
if inputs:
extra = json.dumps(inputs, ensure_ascii=False)
if query:
input_text = f"{query}\n\n上下文数据:\n{extra}"
else:
input_text = extra
user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api")
username = user
if response_mode == "streaming":
return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db)
return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db)
async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db):
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=input_text, role="user")
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
start_time = time.time()
try:
result_msg = await engine.execute(input_msg, context)
elapsed_ms = int((time.time() - start_time) * 1000)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
execution = FlowExecution(
flow_id=flow.id, version=flow.version,
trigger_type="api", input_data={"query": input_text},
output_data={"output": output_text}, status="completed",
latency_ms=elapsed_ms, finished_at=datetime.utcnow(),
)
db.add(execution)
return {
"event": "message",
"id": str(uuid.uuid4()),
"answer": output_text,
"conversation_id": session_id or "",
"created_at": int(time.time()),
"metadata": {
"usage": {"latency_ms": elapsed_ms},
"node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()},
},
}
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
execution = FlowExecution(
flow_id=flow.id, version=flow.version,
trigger_type="api", input_data={"query": input_text},
status="failed", latency_ms=elapsed_ms,
error_message=str(e)[:2000], finished_at=datetime.utcnow(),
)
db.add(execution)
raise HTTPException(500, f"流执行失败: {str(e)}")
async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db):
async def event_generator():
import asyncio
engine = FlowEngine(definition)
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
input_msg = Msg(name="user", content=input_text, role="user")
start_time = time.time()
msg_id = str(uuid.uuid4())
try:
yield f"data: {json.dumps({'event': 'workflow_started', 'task_id': msg_id, 'data': {'flow_id': str(flow_id)}}, ensure_ascii=False)}\n\n"
result_msg = await asyncio.wait_for(engine.execute(input_msg, context), timeout=engine.FLOW_TIMEOUT_SECONDS)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
elapsed_ms = int((time.time() - start_time) * 1000)
for i in range(0, len(output_text), 10):
chunk = output_text[i:i + 10]
yield f"data: {json.dumps({'event': 'message', 'task_id': msg_id, 'answer': chunk, 'created_at': int(time.time())}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'event': 'message_end', 'task_id': msg_id, 'id': msg_id, 'conversation_id': session_id or '', 'metadata': {'usage': {'latency_ms': elapsed_ms}, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}}}, ensure_ascii=False)}\n\n"
execution = FlowExecution(
flow_id=flow.id, version=flow.version,
trigger_type="api", input_data={"query": input_text},
output_data={"output": output_text},
status="completed", latency_ms=elapsed_ms,
finished_at=datetime.utcnow(),
)
db.add(execution)
except asyncio.TimeoutError:
yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': '执行超时'}, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': str(e)}, ensure_ascii=False)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# ============================== 工作流型流 ==============================
@gateway_router.post("/workflows/run")
async def workflows_run(request: Request, db: AsyncSession = Depends(get_db)):
auth = await _resolve_auth(request)
body = await request.json()
inputs = body.get("inputs", {})
response_mode = body.get("response_mode", "blocking")
user = body.get("user", "anonymous")
flow_id_str = body.get("workflow_id") or inputs.get("workflow_id") or inputs.get("flow_id")
if not flow_id_str:
raise HTTPException(400, "缺少 workflow_id")
flow_id = uuid.UUID(flow_id_str)
definition = await _get_definition_for_execute(flow_id, db)
f = await db.get(FlowDefinition, flow_id)
user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api")
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=json.dumps(inputs, ensure_ascii=False), role="user")
context = {"user_id": user_id, "username": user, "_node_results": {}, "trigger_data": inputs}
start_time = time.time()
try:
result_msg = await engine.execute(input_msg, context)
elapsed_ms = int((time.time() - start_time) * 1000)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
execution = FlowExecution(
flow_id=f.id, version=f.version,
trigger_type="api", input_data={"inputs": inputs},
output_data={"output": output_text}, status="completed",
latency_ms=elapsed_ms, finished_at=datetime.utcnow(),
)
db.add(execution)
return {
"id": str(uuid.uuid4()),
"workflow_run_id": str(uuid.uuid4()),
"data": {
"outputs": {"text": output_text},
"node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()},
},
"metadata": {"latency_ms": elapsed_ms},
}
except Exception as e:
elapsed_ms = int((time.time() - start_time) * 1000)
execution = FlowExecution(
flow_id=f.id, version=f.version,
trigger_type="api", input_data={"inputs": inputs},
status="failed", latency_ms=elapsed_ms,
error_message=str(e)[:2000], finished_at=datetime.utcnow(),
)
db.add(execution)
raise HTTPException(500, f"工作流执行失败: {str(e)}")
# ============================== 参数信息 ==============================
@gateway_router.get("/flows/{flow_id}/parameters")
async def get_flow_parameters(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
definition = await _get_definition_for_execute(flow_id, db)
nodes = definition.get("nodes", [])
trigger_nodes = [n for n in nodes if n.get("type") == "trigger"]
input_vars = []
if trigger_nodes:
trigger_config = trigger_nodes[0].get("config", {})
input_vars = [
{"name": "query", "type": "string", "description": "用户输入文本", "required": True},
{"name": "session_id", "type": "string", "description": "会话ID(用于多轮对话)", "required": False},
]
return {
"code": 200,
"data": {
"input_variables": input_vars,
"node_count": len(nodes),
"edge_count": len(definition.get("edges", [])),
},
}