import uuid import time import json from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Request, Query from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import FlowDefinition, FlowVersion, FlowExecution from schemas import FlowChatMessageRequest from modules.flow_engine.engine import FlowEngine from agentscope.message import Msg from middleware.apikey_auth import authenticate_api_key from dependencies import get_current_user import logging logger = logging.getLogger(__name__) gateway_router = APIRouter(prefix="/v1", tags=["gateway"]) async def _resolve_auth(request: Request) -> dict: auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer flow-"): return await authenticate_api_key(request) try: user = await get_current_user(request) return {"user": user, "auth_type": "jwt"} except Exception: raise HTTPException(401, "认证失败: 请使用 Bearer Token 或 API Key") async def _get_definition_for_execute(flow_id: uuid.UUID, db: AsyncSession) -> dict: f = await db.get(FlowDefinition, flow_id) if not f: raise HTTPException(404, "流不存在") if f.status != "published": raise HTTPException(400, "流未发布") if f.published_version_id: result = await db.execute(select(FlowVersion).where(FlowVersion.id == f.published_version_id)) published = result.scalar_one_or_none() if published: return json.loads(json.dumps(published.definition_json)) return f.definition_json # ============================== 对话型流 ============================== @gateway_router.post("/chat-messages") async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)): auth = await _resolve_auth(request) body = await request.json() query = body.get("query", "") response_mode = body.get("response_mode", "blocking") inputs = body.get("inputs", {}) user = body.get("user", "anonymous") session_id = body.get("conversation_id", body.get("session_id")) flow_id_str = body.get("flow_id") or inputs.get("flow_id") if not flow_id_str: raise HTTPException(400, "缺少 flow_id") flow_id = uuid.UUID(flow_id_str) definition = await _get_definition_for_execute(flow_id, db) f = await db.get(FlowDefinition, flow_id) input_text = query if inputs: extra = json.dumps(inputs, ensure_ascii=False) if query: input_text = f"{query}\n\n上下文数据:\n{extra}" else: input_text = extra user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api") username = user if response_mode == "streaming": return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db) return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db) async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db): engine = FlowEngine(definition) input_msg = Msg(name="user", content=input_text, role="user") context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())} start_time = time.time() try: result_msg = await engine.execute(input_msg, context) elapsed_ms = int((time.time() - start_time) * 1000) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) execution = FlowExecution( flow_id=flow.id, version=flow.version, trigger_type="api", input_data={"query": input_text}, output_data={"output": output_text}, status="completed", latency_ms=elapsed_ms, finished_at=datetime.utcnow(), ) db.add(execution) return { "event": "message", "id": str(uuid.uuid4()), "answer": output_text, "conversation_id": session_id or "", "created_at": int(time.time()), "metadata": { "usage": {"latency_ms": elapsed_ms}, "node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()}, }, } except Exception as e: elapsed_ms = int((time.time() - start_time) * 1000) execution = FlowExecution( flow_id=flow.id, version=flow.version, trigger_type="api", input_data={"query": input_text}, status="failed", latency_ms=elapsed_ms, error_message=str(e)[:2000], finished_at=datetime.utcnow(), ) db.add(execution) raise HTTPException(500, f"流执行失败: {str(e)}") async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db): async def event_generator(): import asyncio engine = FlowEngine(definition) context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())} input_msg = Msg(name="user", content=input_text, role="user") start_time = time.time() msg_id = str(uuid.uuid4()) try: yield f"data: {json.dumps({'event': 'workflow_started', 'task_id': msg_id, 'data': {'flow_id': str(flow_id)}}, ensure_ascii=False)}\n\n" result_msg = await asyncio.wait_for(engine.execute(input_msg, context), timeout=engine.FLOW_TIMEOUT_SECONDS) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) elapsed_ms = int((time.time() - start_time) * 1000) for i in range(0, len(output_text), 10): chunk = output_text[i:i + 10] yield f"data: {json.dumps({'event': 'message', 'task_id': msg_id, 'answer': chunk, 'created_at': int(time.time())}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'event': 'message_end', 'task_id': msg_id, 'id': msg_id, 'conversation_id': session_id or '', 'metadata': {'usage': {'latency_ms': elapsed_ms}, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}}}, ensure_ascii=False)}\n\n" execution = FlowExecution( flow_id=flow.id, version=flow.version, trigger_type="api", input_data={"query": input_text}, output_data={"output": output_text}, status="completed", latency_ms=elapsed_ms, finished_at=datetime.utcnow(), ) db.add(execution) except asyncio.TimeoutError: yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': '执行超时'}, ensure_ascii=False)}\n\n" except Exception as e: yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': str(e)}, ensure_ascii=False)}\n\n" finally: yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) # ============================== 工作流型流 ============================== @gateway_router.post("/workflows/run") async def workflows_run(request: Request, db: AsyncSession = Depends(get_db)): auth = await _resolve_auth(request) body = await request.json() inputs = body.get("inputs", {}) response_mode = body.get("response_mode", "blocking") user = body.get("user", "anonymous") flow_id_str = body.get("workflow_id") or inputs.get("workflow_id") or inputs.get("flow_id") if not flow_id_str: raise HTTPException(400, "缺少 workflow_id") flow_id = uuid.UUID(flow_id_str) definition = await _get_definition_for_execute(flow_id, db) f = await db.get(FlowDefinition, flow_id) user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api") engine = FlowEngine(definition) input_msg = Msg(name="user", content=json.dumps(inputs, ensure_ascii=False), role="user") context = {"user_id": user_id, "username": user, "_node_results": {}, "trigger_data": inputs} start_time = time.time() try: result_msg = await engine.execute(input_msg, context) elapsed_ms = int((time.time() - start_time) * 1000) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) execution = FlowExecution( flow_id=f.id, version=f.version, trigger_type="api", input_data={"inputs": inputs}, output_data={"output": output_text}, status="completed", latency_ms=elapsed_ms, finished_at=datetime.utcnow(), ) db.add(execution) return { "id": str(uuid.uuid4()), "workflow_run_id": str(uuid.uuid4()), "data": { "outputs": {"text": output_text}, "node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()}, }, "metadata": {"latency_ms": elapsed_ms}, } except Exception as e: elapsed_ms = int((time.time() - start_time) * 1000) execution = FlowExecution( flow_id=f.id, version=f.version, trigger_type="api", input_data={"inputs": inputs}, status="failed", latency_ms=elapsed_ms, error_message=str(e)[:2000], finished_at=datetime.utcnow(), ) db.add(execution) raise HTTPException(500, f"工作流执行失败: {str(e)}") # ============================== 参数信息 ============================== @gateway_router.get("/flows/{flow_id}/parameters") async def get_flow_parameters(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): definition = await _get_definition_for_execute(flow_id, db) nodes = definition.get("nodes", []) trigger_nodes = [n for n in nodes if n.get("type") == "trigger"] input_vars = [] if trigger_nodes: trigger_config = trigger_nodes[0].get("config", {}) input_vars = [ {"name": "query", "type": "string", "description": "用户输入文本", "required": True}, {"name": "session_id", "type": "string", "description": "会话ID(用于多轮对话)", "required": False}, ] return { "code": 200, "data": { "input_variables": input_vars, "node_count": len(nodes), "edge_count": len(definition.get("edges", [])), }, }