You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

269 lines
9.8 KiB

import uuid
import json
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import FlowDefinition, FlowExecution, User
from schemas import FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut, FlowNode, FlowEdge
from modules.flow_engine.engine import FlowEngine
from agentscope.message import Msg
router = APIRouter(prefix="/api/flow", tags=["flow"])
@router.get("/definitions", response_model=list[FlowDefinitionOut])
async def list_flows(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(FlowDefinition).order_by(FlowDefinition.updated_at.desc())
)
flows = result.scalars().all()
return [FlowDefinitionOut(
id=f.id, name=f.name, description=f.description,
version=f.version, status=f.status,
definition_json=f.definition_json,
published_to_wecom=f.published_to_wecom,
created_at=f.created_at, updated_at=f.updated_at,
) for f in flows]
@router.get("/definitions/{flow_id}", response_model=FlowDefinitionOut)
async def get_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
return FlowDefinitionOut(
id=flow.id, name=flow.name, description=flow.description,
version=flow.version, status=flow.status,
definition_json=flow.definition_json,
published_to_wecom=flow.published_to_wecom,
created_at=flow.created_at, updated_at=flow.updated_at,
)
@router.post("/definitions", response_model=FlowDefinitionOut)
async def create_flow(req: FlowDefinitionCreate, request: Request, db: AsyncSession = Depends(get_db)):
user_ctx = request.state.user
definition_json = {
"nodes": [n.model_dump() for n in req.nodes],
"edges": [e.model_dump() for e in req.edges],
"trigger": req.trigger,
}
flow = FlowDefinition(
name=req.name,
description=req.description,
definition_json=definition_json,
creator_id=uuid.UUID(user_ctx["id"]),
)
db.add(flow)
await db.flush()
return FlowDefinitionOut(
id=flow.id, name=flow.name, description=flow.description,
version=flow.version, status=flow.status,
definition_json=flow.definition_json,
published_to_wecom=flow.published_to_wecom,
created_at=flow.created_at, updated_at=flow.updated_at,
)
@router.put("/definitions/{flow_id}", response_model=FlowDefinitionOut)
async def update_flow(
flow_id: uuid.UUID, req: FlowDefinitionUpdate,
request: Request, db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
if req.name is not None:
flow.name = req.name
if req.description is not None:
flow.description = req.description
if req.nodes is not None and req.edges is not None:
flow.definition_json = {
"nodes": [n.model_dump() for n in req.nodes],
"edges": [e.model_dump() for e in req.edges],
"trigger": req.trigger or flow.definition_json.get("trigger", {}),
}
flow.version += 1
return FlowDefinitionOut(
id=flow.id, name=flow.name, description=flow.description,
version=flow.version, status=flow.status,
definition_json=flow.definition_json,
published_to_wecom=flow.published_to_wecom,
created_at=flow.created_at, updated_at=flow.updated_at,
)
@router.delete("/definitions/{flow_id}")
async def delete_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
await db.delete(flow)
return {"code": 200, "message": "已删除"}
@router.post("/definitions/{flow_id}/publish")
async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
nodes = flow.definition_json.get("nodes", [])
edges = flow.definition_json.get("edges", [])
if not nodes:
raise HTTPException(400, "流定义中没有节点")
flow.status = "published"
flow.published_to_wecom = True
return {"code": 200, "message": "流已上架到企微", "data": {"status": "published"}}
@router.post("/definitions/{flow_id}/unpublish")
async def unpublish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
flow.status = "draft"
flow.published_to_wecom = False
return {"code": 200, "message": "流已下架"}
@router.post("/definitions/{flow_id}/execute")
async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
user_ctx = request.state.user
input_text = payload.get("input", payload.get("message", ""))
engine = FlowEngine(flow.definition_json)
input_msg = Msg(name="user", content=input_text, role="user")
context = {
"user_id": user_ctx["id"],
"username": user_ctx["username"],
"trigger_data": payload.get("trigger", {}),
"_node_results": {},
}
try:
result_msg = await engine.execute(input_msg, context)
output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg)
execution = FlowExecution(
flow_id=flow.id,
trigger_type=payload.get("trigger_type", "manual"),
trigger_user_id=uuid.UUID(user_ctx["id"]),
input_data={"input": input_text},
output_data={"output": output_text},
status="completed",
finished_at=datetime.utcnow(),
)
db.add(execution)
return {
"code": 200,
"data": {
"output": output_text,
"node_results": context.get("_node_results", {}),
"execution_id": str(execution.id),
},
}
except Exception as e:
execution = FlowExecution(
flow_id=flow.id,
trigger_type="manual",
trigger_user_id=uuid.UUID(user_ctx["id"]),
input_data={"input": input_text},
status="failed",
finished_at=datetime.utcnow(),
)
db.add(execution)
raise HTTPException(500, f"流执行失败: {str(e)}")
@router.post("/definitions/{flow_id}/test")
async def test_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id))
flow = result.scalar_one_or_none()
if not flow:
raise HTTPException(404, "流定义不存在")
nodes = flow.definition_json.get("nodes", [])
edges = flow.definition_json.get("edges", [])
validation = {
"valid": True,
"node_count": len(nodes),
"edge_count": len(edges),
"node_types": list(set(n.get("type", "unknown") for n in nodes)),
"issues": [],
}
node_ids = {n["id"] for n in nodes}
for edge in edges:
source = edge.get("source") or edge.get("from")
target = edge.get("target") or edge.get("to")
if source and source not in node_ids:
validation["issues"].append(f"边源节点 {source} 不存在")
if target and target not in node_ids:
validation["issues"].append(f"边目标节点 {target} 不存在")
if validation["issues"]:
validation["valid"] = False
has_trigger = any(n.get("type") == "trigger" for n in nodes)
if not has_trigger:
validation["issues"].append("流缺少触发节点")
return {"code": 200, "data": validation}
@router.get("/market", response_model=list[FlowDefinitionOut])
async def flow_market(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(FlowDefinition)
.where(FlowDefinition.status == "published")
.order_by(FlowDefinition.updated_at.desc())
)
flows = result.scalars().all()
return [FlowDefinitionOut(
id=f.id, name=f.name, description=f.description,
version=f.version, status=f.status,
definition_json=f.definition_json,
published_to_wecom=f.published_to_wecom,
created_at=f.created_at, updated_at=f.updated_at,
) for f in flows]
@router.get("/executions")
async def list_executions(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(FlowExecution).order_by(FlowExecution.started_at.desc()).limit(100)
)
executions = result.scalars().all()
return {
"code": 200,
"data": [{
"id": str(e.id),
"flow_id": str(e.flow_id),
"trigger_type": e.trigger_type,
"status": e.status,
"started_at": str(e.started_at),
"finished_at": str(e.finished_at) if e.finished_at else None,
} for e in executions],
}