From 34a5d7f49d9d9e03e73d5fbbe1c8f8cb189081ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?MSI-7950X=5C=E5=88=98=E6=B3=BD=E6=98=8E?= Date: Thu, 14 May 2026 17:28:34 +0800 Subject: [PATCH] =?UTF-8?q?=E5=86=85=E9=83=A8=E5=8A=9F=E8=83=BD=E6=A0=B8?= =?UTF-8?q?=E5=BF=83=E4=BC=98=E5=8C=96=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PLAN5.md | 225 +++++++ PLAN6.md | 328 ++++++++++ PLAN7.md | 556 ++++++++++++++++ backend/agentscope_integration/factory.py | 4 + .../tools/manager_tools.py | 94 +-- .../tools/task_tools.py | 99 ++- backend/database.py | 13 + backend/main.py | 9 +- backend/middleware/apikey_auth.py | 43 ++ backend/middleware/cache_manager.py | 6 +- backend/middleware/rate_limiter.py | 19 + backend/middleware/rbac_middleware.py | 4 +- backend/models/__init__.py | 59 +- backend/modules/agent_manager/router.py | 13 +- backend/modules/auth/router.py | 14 +- backend/modules/chat/__init__.py | 0 backend/modules/chat/router.py | 99 +++ backend/modules/custom_tool/__init__.py | 3 + backend/modules/custom_tool/executor.py | 43 ++ backend/modules/custom_tool/parser.py | 81 +++ backend/modules/custom_tool/router.py | 249 +++++++ backend/modules/flow_engine/engine.py | 612 ++++++++++++++++-- backend/modules/flow_engine/gateway.py | 256 ++++++++ backend/modules/flow_engine/router.py | 605 ++++++++++++----- backend/modules/mcp_registry/router.py | 4 +- backend/modules/monitor/router.py | 2 +- backend/modules/notification/router.py | 6 +- backend/modules/org/router.py | 18 +- backend/schemas/__init__.py | 144 ++++- backend/websocket_manager.py | 49 ++ frontend/src/api/index.ts | 33 + frontend/src/components/layout/MainLayout.vue | 14 +- frontend/src/router/index.ts | 12 + frontend/src/views/chat/FlowChat.vue | 393 +++++++++++ frontend/src/views/flow/FlowCanvas.vue | 18 +- frontend/src/views/flow/FlowEditor.vue | 61 +- frontend/src/views/flow/FlowList.vue | 50 +- frontend/src/views/flow/FlowNode.vue | 55 +- .../views/flow/node-configs/CodeConfig.vue | 50 ++ .../views/flow/node-configs/LoopConfig.vue | 74 +++ .../views/flow/node-configs/ToolConfig.vue | 88 ++- .../views/flow/node-configs/TriggerConfig.vue | 27 +- .../flow/node-configs/WecomNotifyConfig.vue | 66 +- .../src/views/tools/CustomToolManager.vue | 344 ++++++++++ 44 files changed, 4551 insertions(+), 391 deletions(-) create mode 100644 PLAN5.md create mode 100644 PLAN6.md create mode 100644 PLAN7.md create mode 100644 backend/middleware/apikey_auth.py create mode 100644 backend/modules/chat/__init__.py create mode 100644 backend/modules/chat/router.py create mode 100644 backend/modules/custom_tool/__init__.py create mode 100644 backend/modules/custom_tool/executor.py create mode 100644 backend/modules/custom_tool/parser.py create mode 100644 backend/modules/custom_tool/router.py create mode 100644 backend/modules/flow_engine/gateway.py create mode 100644 backend/websocket_manager.py create mode 100644 frontend/src/views/chat/FlowChat.vue create mode 100644 frontend/src/views/flow/node-configs/CodeConfig.vue create mode 100644 frontend/src/views/flow/node-configs/LoopConfig.vue create mode 100644 frontend/src/views/tools/CustomToolManager.vue diff --git a/PLAN5.md b/PLAN5.md new file mode 100644 index 0000000..f628447 --- /dev/null +++ b/PLAN5.md @@ -0,0 +1,225 @@ +# PLAN5: Flow 画布节点系统完整性方案 + +## 目标 +对齐前后端功能,实现流中节点配置真实生效,新增循环/代码节点,双渠道(企微+Web)输入输出。 + +--- + +## 一、当前状态总结 + +### 后端已有且正常运行 +| 能力 | 状态 | +|------|------| +| DAG 图执行引擎 (FlowEngine) | ✅ 完整 | +| 8种节点 Agent: trigger/llm/tool/mcp/wecom_notify/condition/rag/output | ✅ 完整 | +| 13个工具函数 (文档/企微/任务/管理) | ✅ 完整 | +| 企微全链路(回调→Agent→回复→通知) | ✅ 完整 | +| RAG知识库(Qdrant+OpenAI) | ✅ 完整 | +| MCP外部服务集成 | ✅ 完整 | +| 流CRUD + 发布/下架 + 执行记录 | ✅ 完整 | + +### 前端已有但后端不处理的配置项 (需对齐) +| 前端配置项 | 后端Schema | 后端Agent | +|-----------|-----------|----------| +| LLM: `max_tokens` | ❌ 缺失 | ❌ 未使用 | +| LLM: `context_length` | ❌ 缺失 | ❌ 未使用 | +| LLM: `memory_mode` (none/short/long) | ❌ 缺失 | ❌ 未使用 | +| LLM: `stream` (流式输出) | ❌ 缺失 | ❌ 未使用 | +| LLM: `tool_call` (函数调用) | ❌ 缺失 | ❌ 未使用 | +| Tool: `timeout` | ❌ 缺失 | ❌ 未使用 | +| Tool: `retry_count` | ❌ 缺失 | ❌ 未使用 | +| Tool: `error_handling` | ❌ 缺失 | ❌ 未使用 | +| RAG: `search_mode` (vector/keyword/hybrid) | ❌ 缺失 | ❌ 未使用 | +| RAG: `similarity_threshold` | ❌ 缺失 | ❌ 未使用 | +| Trigger: `channels` (wecom/web) | ❌ 缺失 | ❌ 未使用 | +| Notify: `channels` (wecom/web) | ❌ 缺失 | ❌ 未使用 | + +### 缺失的关键能力 +- **循环节点**: 无重试/迭代/批量能力 +- **代码执行节点**: 无法运行自定义逻辑 +- **Web Chat入口**: 只能通过企微触发 +- **Web通知**: 只有企微通知,无Web推送 + +--- + +## 二、实施计划 + +### P0: 前后端配置对齐 (最高优先级) +**目标**: 前端配置的所有参数在后端Schema和Agent中真实生效 + +#### 2.1 后端 Schema 补齐 +```python +# 文件: backend/schemas/__init__.py + +class TriggerNodeConfig(BaseModel): + event_type: str = "text_message" + channels: list[str] = ["wecom"] # 新增: ["wecom", "web_chat"] + callback_url: str = "" # 新增 + +class LLMNodeConfig(BaseModel): + system_prompt: str = "" + model: str = "gpt-4o-mini" + temperature: float = 0.7 + agent_id: str = "" + max_tokens: int = 2000 # 新增 + context_length: int = 5 # 新增 + memory_mode: str = "short_term" # 新增: none/short_term/long_term + stream: bool = True # 新增 + tool_call: bool = False # 新增 + +class ToolNodeConfig(BaseModel): + tool_name: str = "" + tool_type: str = "" # 新增: wecom_message/task_management/... + tool_params: dict = {} # 补齐 + timeout: int = 30 # 新增 + retry_count: int = 0 # 新增 + error_handling: str = "throw" # 新增: throw/default/skip + +class MCPNodeConfig(BaseModel): + mcp_server: str = "" + tool_name: str = "" + input_params: dict = {} # 新增 + timeout: int = 30 # 新增 + response_parser: str = "json" # 新增 + error_handling: str = "throw" # 新增 + +class NotifyNodeConfig(BaseModel): # 重命名: WeComNotifyNodeConfig -> NotifyNodeConfig + channels: dict = {"wecom": True, "web": False} # 新增 + message_template: str = "" + web_template: str = "" # 新增 + target: str = "" + message_type: str = "text" # 新增: text/markdown/card + async_send: bool = False # 新增 + error_handling: str = "throw" # 新增 + +class ConditionNodeConfig(BaseModel): + condition: str = "" + condition_type: str = "expression" # 新增 + true_label: str = "是" # 新增 + false_label: str = "否" # 新增 + default_branch: str = "false" # 新增 + +class RAGNodeConfig(BaseModel): + knowledge_base: str = "" + top_k: int = 5 + search_mode: str = "hybrid" # 新增: vector/keyword/hybrid + similarity_threshold: float = 0.7 # 新增 + result_sort: str = "similarity" # 新增 + include_metadata: bool = True # 新增 + +class OutputNodeConfig(BaseModel): + format: str = "text" + output_template: str = "" # 新增 + indent: int = 2 # 新增 + encoding: str = "utf-8" # 新增 + truncate: bool = False # 新增 + max_length: int = 2000 # 新增 + +class LoopNodeConfig(BaseModel): # 新增节点 + loop_type: str = "fixed" # fixed/count/list + max_iterations: int = 10 + count: int = 3 + iterator_variable: str = "item" + +class CodeNodeConfig(BaseModel): # 新增节点 + language: str = "python" # python/javascript + code: str = "" + timeout: int = 30 + sandbox: bool = True +``` + +#### 2.2 后端 Agent 补齐 +``` +文件: backend/modules/flow_engine/engine.py + +LLMNodeAgent: 使用 max_tokens, stream, tool_call +ToolNodeAgent: 使用 timeout, retry_count, error_handling +RAGNodeAgent: 使用 search_mode, similarity_threshold +NotifyAgent: 检测 channels.web 做 WebSocket 推送 +LoopNodeAgent: 新增 +CodeNodeAgent: 新增 +``` + +### P1: 双渠道支持 +**目标**: 流同时支持企业微信和网页聊天触发,通知也支持双渠道 + +#### 3.1 Web Chat API +``` +POST /api/chat/sessions/{session_id}/message +POST /api/chat/sessions (创建会话) +GET /api/chat/sessions (会话列表) +``` + +#### 3.2 WebSocket 通知推送 +``` +backend/websocket_manager.py: 新增 +- 用户连接管理 +- 按用户推送通知 +``` + +#### 3.3 前端 Web Chat 页面 +``` +frontend/src/views/chat/ChatWidget.vue: 新增 +- 浮动聊天窗口 +- WebSocket 实时接收 +- 流选择 +``` + +### P2: 新增节点类型 +**目标**: 新增循环节点和代码执行节点 + +#### 4.1 循环节点 (Loop) +- 固定次数循环、条件循环、遍历列表 +- 两个出口: loop_body(继续), loop_done(完成) +- 安全上限: max_iterations 防止死循环 +- 引擎需要支持回边 + +#### 4.2 代码执行节点 (Code) +- Python/JavaScript 沙箱执行 +- subprocess 隔离 + 超时控制 +- stdin/stdout 输入输出 + +### P3: FlowEngine 改造 +**目标**: 支持循环节点回边 + +1. `traverse()` 中 visited 集合改为 per-branch 而非全局 +2. 循环节点特殊处理: 检测 loop_done 条件 +3. 执行超时和安全限制 + +--- + +## 三、实施顺序 + +1. **P0-1**: 后端 Schema 补齐 (schemas/__init__.py) — 10分钟 +2. **P0-2**: 后端 Agent 补齐 (engine.py) — 15分钟 +3. **P0-3**: 路由注册新节点类型 (router.py) — 5分钟 +4. **P1-1**: Notify 节点双渠道改造 + WebSocket — 15分钟 +5. **P1-2**: Web Chat API + 路由 — 10分钟 +6. **P1-3**: 前端 ChatWidget + 通知接收 — 10分钟 +7. **P2-1**: Loop Node (前端配置+后端Agent) — 10分钟 +8. **P2-2**: Code Node (前端配置+后端Agent) — 10分钟 +9. **P3**: FlowEngine 循环回边支持 — 10分钟 +10. **更新前端 FlowEditor**: 新节点类型 + 配置对齐 — 5分钟 + +--- + +## 四、前端文件清单 + +| 文件 | 内容 | +|------|------| +| FlowEditor.vue | 新增 loop/code 节点类型、trigger 改 channels | +| node-configs/LoopConfig.vue | 循环配置 | +| node-configs/CodeConfig.vue | 代码执行配置 | +| node-configs/NotifyConfig.vue | 双渠道通知配置 | +| node-configs/TriggerConfig.vue | 双渠道触发配置 | +| chat/ChatWidget.vue | Web Chat 入口 | + +## 五、后端文件清单 + +| 文件 | 内容 | +|------|------| +| schemas/__init__.py | 补齐所有Config Schema + 新增Loop/Code | +| flow_engine/engine.py | 补齐Agent实现 + LoopNodeAgent + CodeNodeAgent + 引擎回边 | +| flow_engine/router.py | 注册新节点类型 | +| chat/router.py | Web Chat API (新建) | +| websocket_manager.py | WebSocket管理 (新建) | \ No newline at end of file diff --git a/PLAN6.md b/PLAN6.md new file mode 100644 index 0000000..1db0195 --- /dev/null +++ b/PLAN6.md @@ -0,0 +1,328 @@ +# PLAN6 — 对标 Dify 无代码发布架构:差距分析与升级路线 + +## 一、核心结论 + +**我们的流发布逻辑与 Dify 的底层思路高度一致(配置即数据 + 动态引擎),但在 7 个关键维度存在显著差距,需要补齐才能真正实现"无代码秒级发布,即刻可用"。** + +### 已对齐的架构思路 + +| Dify 核心思路 | 我们的实现 | 对齐度 | +|--------------|-----------|--------| +| 配置即数据:前端生成 JSON,存入数据库 | ✅ FlowEditor 生成 nodes+edges JSON,存入 `FlowDefinition.definition_json` | 完全对齐 | +| 零部署:发布 = 数据库状态变更,不启动新服务 | ✅ publish 仅修改 status 字段,执行时动态加载 JSON | 完全对齐 | +| 动态编排引擎:解析 JSON → 执行 | ✅ `FlowEngine` 解析 JSON → 构建图 → traverse 执行 | 基本对齐 | +| DAG 拓扑排序执行 | ✅ `_build_graph()` + `traverse()` 支持条件分支和循环 | 基本对齐 | +| 多种节点类型 | ✅ 9 种节点:trigger/llm/tool/mcp/condition/rag/output/loop/code | 基本对齐 | +| 双渠道发布 | ✅ 企微 + Web 双渠道发布状态管理 | 额外优势 | + +### 存在差距的关键维度 + +| # | 维度 | 差距等级 | 影响 | +|---|------|---------|------| +| 1 | 版本快照 / 发布不可变 | 🔴 严重 | 发布后编辑直接影响线上服务 | +| 2 | 流式输出 (SSE) | 🔴 严重 | 长流程用户体验极差,无法实时看到结果 | +| 3 | 统一 API 网关 + App API Key | 🟠 高 | 无法被外部系统调用,无法做 API 市场 | +| 4 | 工具 Schema 标准化 | 🟠 高 | 无法运行时扩展工具,无参数校验 | +| 5 | Flow 节点 Memory | 🟠 高 | LLM 节点无上下文记忆,无法多轮对话 | +| 6 | 变量类型系统 | 🟡 中 | 复杂业务逻辑难以表达 | +| 7 | 执行监控与可观测性 | 🟡 中 | 无法追溯执行版本,缺少 token/延迟指标 | + +--- + +## 二、逐维度详细对比 + +### 1. 版本快照 / 发布不可变(🔴 严重) + +**Dify 的做法:** +- 点击"发布"时,将当前草稿 JSON 创建一份**版本快照**(snapshot),存入独立的 `workflow_versions` 表 +- `FlowDefinition` 有 `published_version` 字段,指向当前生效的版本 +- 执行引擎加载的是 `published_version` 对应的 JSON,而非草稿 +- 编辑草稿不影响已发布版本,回滚只需切换 `published_version` 指针 + +**我们的现状:** +- `FlowDefinition` 只有 `version` 计数器(int),没有 `published_version` 字段 +- 发布仅修改 `status="published"`,不创建快照 +- **编辑草稿直接修改 `definition_json`,已发布的服务立即受影响** +- `FlowExecution` 不记录执行时的版本号,无法追溯 + +**需要补齐:** +``` +新增模型:FlowVersion + - id: UUID + - flow_id: FK → FlowDefinition + - version: int + - definition_json: JSON(快照) + - created_by: UUID + - created_at: datetime + +修改模型:FlowDefinition + - 新增 published_version_id: FK → FlowVersion(nullable) + - 新增 draft_version: int(草稿版本号) + +发布逻辑改造: + - publish → 创建 FlowVersion 快照 → 设置 published_version_id + - execute → 加载 published_version.definition_json(而非草稿) + - 编辑 → 只修改草稿,不影响 published_version + - 回滚 → 切换 published_version_id 指针 +``` + +--- + +### 2. 流式输出 SSE(🔴 严重) + +**Dify 的做法:** +- 统一 API 支持 `response_mode: "streaming"`,返回 SSE 事件流 +- 事件类型:`workflow_started` → `node_started` → `node_finished` → `workflow_finished` +- LLM 节点支持 token-by-token 实时推送(`text_chunk` 事件) +- 前端通过 EventSource 实时渲染 + +**我们的现状:** +- `FlowEngine.execute()` 返回最终 `Msg`,无中间状态 +- `LLMNodeAgent` 虽然配置了 `stream=True`,但 `model(prompt)` 等待完整响应 +- WebSocket 端点仅 echo,未与 Flow 引擎集成 +- 没有 SSE 端点 + +**需要补齐:** +``` +新增 SSE 端点:GET /api/chat/stream/{flow_id} + - 接收 query 参数:message, session_id + - 返回 text/event-stream + - 事件格式: + event: node_started + data: {"node_id": "xxx", "node_type": "llm", "label": "生成摘要"} + + event: text_chunk + data: {"node_id": "xxx", "content": "根据"} + + event: node_finished + data: {"node_id": "xxx", "output": "..."} + + event: workflow_finished + data: {"output": "最终结果"} + +FlowEngine 改造: + - execute() 接受可选的 callback: Callable[[str, dict], None] + - 每个节点执行前后调用 callback("node_started"/"node_finished", data) + - LLMNodeAgent.reply() 改为 async generator,yield token +``` + +--- + +### 3. 统一 API 网关 + App API Key(🟠 高) + +**Dify 的做法:** +- 每个 App 有独立的 API Key(`app-xxxxxxxx`) +- 统一入口:`POST /v1/chat-messages`(对话型)、`POST /v1/workflows/run`(工作流型) +- 请求格式标准化:`{inputs: {}, query: "", response_mode: "blocking|streaming", user: "user-id"}` +- 无需用户登录,API Key 即认证 + +**我们的现状:** +- 所有 API 依赖 JWT 用户认证,无 App-level API Key +- 执行分散在 `/api/flow/definitions/{id}/execute` 和 `/api/chat/message/{id}` +- 无法被外部系统(如企微回调、第三方应用)直接调用 + +**需要补齐:** +``` +新增模型:FlowApiKey + - id: UUID + - flow_id: FK → FlowDefinition + - key_hash: str(sha256) + - key_prefix: str(前8位,用于展示) + - name: str + - created_by: UUID + - created_at: datetime + - last_used_at: datetime(nullable) + +新增统一网关端点: + POST /v1/chat-messages → 对话型 Flow(自动找 trigger → llm → output 路径) + POST /v1/workflows/run → 工作流型 Flow(完整 DAG 执行) + +认证方式: + Header: Authorization: Bearer app-xxxxxxxx + → 查 FlowApiKey 表 → 获取 flow_id → 加载 published_version → 执行 +``` + +--- + +### 4. 工具 Schema 标准化(🟠 高) + +**Dify 的做法:** +- 所有工具(内置/自定义 API/MCP)统一转换为 OpenAI Function Calling 的 JSON Schema +- Schema 包含:name, description, parameters(JSON Schema 格式,含 type/enum/description) +- 注入 LLM 时,工具 Schema 作为 `tools` 参数传入 +- 用户可在前端自定义 API 工具(填 URL、Method、参数结构) + +**我们的现状:** +- `ToolNodeAgent._TOOL_REGISTRY` 硬编码 12 个工具函数 +- 工具函数只有 Python 签名,无结构化 Schema 描述 +- `tool_params: dict = {}` 无校验 +- 无法运行时扩展工具 + +**需要补齐:** +``` +工具 Schema 标准化格式: +{ + "name": "send_notification", + "description": "发送企业微信通知给指定用户", + "parameters": { + "type": "object", + "properties": { + "to_user": {"type": "string", "description": "接收人用户ID"}, + "message": {"type": "string", "description": "消息内容"} + }, + "required": ["to_user", "message"] + } +} + +改造 ToolNodeAgent: + - _TOOL_REGISTRY 改为 _TOOL_SCHEMA_REGISTRY: dict[str, dict] + - 每个工具注册时同时注册 Schema + - 调用前基于 Schema 校验 tool_params + - LLM 调用时将 Schema 作为 tools 参数传入 + +新增自定义 API 工具: + - 用户可填入 OpenAPI/Swagger URL + - 系统自动解析为标准 Schema 并注册 + - 执行时通过 httpx 调用 +``` + +--- + +### 5. Flow 节点 Memory(🟠 高) + +**Dify 的做法:** +- 每个 App 有独立的对话记忆(窗口记忆/摘要记忆) +- 记忆在多次调用间持久化(Redis/数据库) +- LLM 节点自动注入历史对话上下文 + +**我们的现状:** +- `UserIsolatedMemory` 存在但**未在 Flow 节点中使用** +- Flow 中的 LLM 节点每次调用都是无状态的 +- `ChatMessage` 表存储了历史消息,但 Flow 执行时不读取 + +**需要补齐:** +``` +FlowEngine 改造: + - execute() 接受 session_id 参数 + - 创建 FlowSessionMemory(session_id, user_id) + - LLM 节点执行前注入历史消息 + +新增 FlowSessionMemory: + - 基于 ChatMessage 表持久化 + - 按session_id + user_id 隔离 + - 支持窗口大小配置(最近 N 条) + - 支持摘要模式(超过窗口时调用 LLM 生成摘要) +``` + +--- + +### 6. 变量类型系统(🟡 中) + +**Dify 的做法:** +- 完整的变量面板:输入变量、环境变量、会话变量、上游节点变量 +- 变量类型:string/number/array/object/file +- 支持 Jinja2 模板、类型转换、默认值 +- "变量聚合"节点:汇聚并行分支输出 +- "迭代"节点:对列表逐项处理 + +**我们的现状:** +- 仅有 `{{node_id.output}}` 和 `{{trigger.field}}` 模板 +- 所有值都是 str,无类型系统 +- 无并行汇聚、无迭代节点 + +**需要补齐:** +``` +变量系统升级: + - 节点输出增加类型标注(string/number/array/object) + - 模板解析支持类型转换和默认值 + - 新增"变量聚合"节点(ParallelMergeNode) + - Loop 节点支持迭代数组模式 + - 输入变量面板(Flow 级别的入参定义) +``` + +--- + +### 7. 执行监控与可观测性(🟡 中) + +**Dify 的做法:** +- FlowExecution 记录执行时的版本号 +- 统计 token 用量、延迟、费用 +- 执行日志可按 App/时间/状态筛选 +- 失败重试机制 + +**我们的现状:** +- `FlowExecution` 不记录版本号 +- 无 token/延迟统计 +- 无失败重试 + +**需要补齐:** +``` +FlowExecution 增加字段: + - version: int(执行时的版本号) + - token_usage: JSON(prompt_tokens, completion_tokens, total_tokens) + - latency_ms: int + - error_message: str(nullable) + +执行引擎改造: + - 记录每个节点的 token 用量和耗时 + - 汇总到 FlowExecution + - 失败节点支持重试配置 +``` + +--- + +## 三、升级路线图 + +### Phase 1 — 发布安全基础(P0,1-2周) + +| 任务 | 改动范围 | +|------|---------| +| 新增 FlowVersion 模型 + 迁移 | models, database | +| FlowDefinition 增加 published_version_id | models, schemas | +| 发布逻辑改造:创建快照 | flow_engine/router.py | +| 执行逻辑改造:加载 published_version | flow_engine/engine.py, chat/router.py | +| FlowExecution 记录版本号 | models, flow_engine/router.py | + +### Phase 2 — 用户体验核心(P0,2-3周) + +| 任务 | 改动范围 | +|------|---------| +| SSE 流式输出端点 | chat/router.py(新增) | +| FlowEngine callback 机制 | flow_engine/engine.py | +| LLMNodeAgent async generator 改造 | flow_engine/engine.py | +| 前端 EventSource 集成 | FlowChat.vue(新增) | +| Flow 节点 Memory 集成 | flow_engine/engine.py, 新增 FlowSessionMemory | + +### Phase 3 — 服务化能力(P1,2-3周) + +| 任务 | 改动范围 | +|------|---------| +| FlowApiKey 模型 + CRUD | models, schemas, 新增 router | +| 统一 API 网关 `/v1/chat-messages`, `/v1/workflows/run` | 新增 gateway router | +| API Key 认证中间件 | middleware | +| 工具 Schema 标准化 | tools/*.py, ToolNodeAgent | +| 自定义 API 工具(OpenAPI 导入) | 新增 custom_tool 模块 | + +### Phase 4 — 高级能力(P2,2-3周) + +| 任务 | 改动范围 | +|------|---------| +| 变量类型系统 | schemas, engine.py | +| 变量聚合节点 | 新增 ParallelMergeNodeAgent | +| Loop 迭代数组模式 | LoopNodeAgent | +| 执行监控指标 | FlowExecution, engine.py | +| 工具认证改造(去掉硬编码) | tools/*.py | + +--- + +## 四、架构哲学对齐度总结 + +| Dify 架构哲学 | 我们的现状 | 对齐度 | +|--------------|-----------|--------| +| **数据驱动**:复杂 AI 逻辑抽象为可配置参数 | ✅ 已实现。9 种节点类型,每种有独立 config | 90% | +| **统一 Runner**:一套引擎解析千种 JSON 组合 | ⚠️ 部分实现。引擎存在但缺少流式/Memory/版本快照 | 60% | +| **插件化架构**:Tool/Model 实现高度抽象接口 | ❌ 未实现。工具硬编码,无标准 Schema,无自动发现 | 20% | + +**核心差距一句话总结:我们的"配置即数据"和"零部署"思路与 Dify 完全一致,但缺少"发布不可变"(版本快照)、"实时反馈"(SSE 流式)、"开放接入"(统一网关+API Key)和"插件化工具"(标准 Schema)四大关键能力,导致无法真正实现"无代码秒级发布,即刻可用"的完整体验。** + +补齐 Phase 1 + Phase 2 后,即可达到 Dify 约 80% 的核心能力。 diff --git a/PLAN7.md b/PLAN7.md new file mode 100644 index 0000000..a4d6f00 --- /dev/null +++ b/PLAN7.md @@ -0,0 +1,556 @@ +# PLAN7 — 自定义 API 工具导入 + 前端 EventSource 流式聊天组件 + +## 一、现状与差距 + +PLAN6 完成后,系统已具备: +- ✅ 版本快照(FlowVersion) +- ✅ SSE 流式输出(后端) +- ✅ 统一 API 网关(/v1/chat-messages, /v1/workflows/run) +- ✅ API Key 认证 +- ✅ 工具 Schema 标准化(内置工具) +- ✅ Flow 节点 Memory +- ✅ ParallelMergeNodeAgent + Loop 数组迭代 + +**剩余 15% 差距:** +1. **自定义 API 工具导入**:用户无法在前端填入第三方 OpenAPI/Swagger URL,系统自动解析为工具 Schema 并注册到 ToolNodeAgent +2. **前端 EventSource 聊天组件**:前端没有支持 SSE 的聊天界面,无法实时看到流式输出 + +--- + +## 二、功能 1:自定义 API 工具导入(OpenAPI/Swagger 解析) + +### 2.1 目标 +用户在前端输入第三方 API 的 OpenAPI/Swagger URL(如 `https://api.example.com/openapi.json`),后端自动: +1. 下载并解析 OpenAPI 文档 +2. 提取每个 endpoint 的 method、path、parameters、description +3. 转换为 OpenAI Function Calling Schema 格式 +4. 注册为 Flow 可用的自定义工具 +5. 执行时通过 httpx 动态调用 + +### 2.2 数据模型 + +```python +# models/__init__.py 新增 +class CustomTool(Base): + __tablename__ = "custom_tools" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(100), nullable=False) # 工具名称 + description = Column(Text) # 工具描述 + schema_json = Column(JSON, nullable=False) # OpenAI Function Calling Schema + endpoint_url = Column(String(500), nullable=False) # 基础 URL + method = Column(String(10), default="GET") # HTTP 方法 + path = Column(String(500)) # API 路径 + headers_json = Column(JSON, default=dict) # 固定请求头 + auth_type = Column(String(20), default="none") # none/api_key/oauth + auth_config = Column(JSON, default=dict) # 认证配置 + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) +``` + +### 2.3 Schema + +```python +# schemas/__init__.py 新增 +class CustomToolCreate(BaseModel): + name: str + description: str | None = None + openapi_url: str | None = None # 二选一:URL 或手动配置 + endpoint_url: str | None = None + method: str = "GET" + path: str = "" + headers: dict = {} + auth_type: str = "none" + auth_config: dict = {} + schema_json: dict | None = None # 手动传入 Schema + +class CustomToolOut(BaseModel): + id: uuid.UUID + name: str + description: str | None = None + schema_json: dict + endpoint_url: str + method: str + path: str + auth_type: str + is_active: bool + created_at: datetime | None = None + +class OpenAPIImportRequest(BaseModel): + openapi_url: str + base_url_override: str | None = None # 可选覆盖 base_url +``` + +### 2.4 后端实现 + +#### 模块:`backend/modules/custom_tool/` + +**`parser.py`** — OpenAPI 解析器 +```python +import json +import httpx +from typing import Any + +class OpenAPIParser: + def __init__(self, spec: dict): + self.spec = spec + self.base_url = spec.get("servers", [{}])[0].get("url", "") + + def parse_tools(self) -> list[dict]: + tools = [] + paths = self.spec.get("paths", {}) + for path, methods in paths.items(): + for method, operation in methods.items(): + if method in ("get", "post", "put", "delete", "patch"): + tool = self._parse_endpoint(path, method, operation) + if tool: + tools.append(tool) + return tools + + def _parse_endpoint(self, path: str, method: str, operation: dict) -> dict | None: + name = operation.get("operationId", f"{method}_{path.replace('/', '_').strip('_')}") + description = operation.get("summary", operation.get("description", f"{method.upper()} {path}")) + parameters = self._parse_parameters(operation) + return { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": parameters, + "required": [p["name"] for p in operation.get("parameters", []) if p.get("required")], + }, + "path": path, + "method": method.upper(), + } + + def _parse_parameters(self, operation: dict) -> dict[str, Any]: + props = {} + for param in operation.get("parameters", []): + schema = param.get("schema", {}) + props[param["name"]] = { + "type": schema.get("type", "string"), + "description": param.get("description", ""), + } + if "enum" in schema: + props[param["name"]]["enum"] = schema["enum"] + # requestBody + body = operation.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) + if body: + for name, prop in body.get("properties", {}).items(): + props[name] = {"type": prop.get("type", "string"), "description": prop.get("description", "")} + return props +``` + +**`executor.py`** — 动态执行器 +```python +import httpx +import json +from typing import Any + +class CustomToolExecutor: + def __init__(self, tool: dict): + self.tool = tool + self.endpoint_url = tool["endpoint_url"] + self.method = tool["method"] + self.path = tool["path"] + self.headers = tool.get("headers_json", {}) + self.auth_type = tool.get("auth_type", "none") + self.auth_config = tool.get("auth_config", {}) + + async def execute(self, params: dict) -> str: + url = f"{self.endpoint_url.rstrip('/')}/{self.path.lstrip('/')}" + headers = dict(self.headers) + + if self.auth_type == "api_key": + key = self.auth_config.get("key", "") + loc = self.auth_config.get("location", "header") # header / query + name = self.auth_config.get("name", "X-API-Key") + if loc == "header": + headers[name] = key + else: + params[name] = key + + elif self.auth_type == "bearer": + headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" + + async with httpx.AsyncClient(timeout=30) as client: + if self.method == "GET": + resp = await client.get(url, params=params, headers=headers) + else: + resp = await client.request(self.method, url, json=params, headers=headers) + + try: + data = resp.json() + return json.dumps(data, ensure_ascii=False, indent=2)[:2000] + except: + return resp.text[:2000] +``` + +**`router.py`** — CRUD + 导入端点 +```python +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from database import get_db +from models import CustomTool +from schemas import CustomToolCreate, CustomToolOut, OpenAPIImportRequest +from .parser import OpenAPIParser +from .executor import CustomToolExecutor +import httpx + +router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"]) + +@router.post("/import-openapi") +async def import_openapi(req: OpenAPIImportRequest, db: AsyncSession = Depends(get_db)): + async with httpx.AsyncClient() as client: + resp = await client.get(req.openapi_url, timeout=30) + spec = resp.json() + + parser = OpenAPIParser(spec) + tools = parser.parse_tools() + base_url = req.base_url_override or parser.base_url + + created = [] + for t in tools: + tool = CustomTool( + name=t["name"], + description=t["description"], + schema_json=t["parameters"], + endpoint_url=base_url, + method=t["method"], + path=t["path"], + ) + db.add(tool) + created.append(t["name"]) + await db.flush() + return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}} + +@router.post("/", response_model=CustomToolOut) +async def create_custom_tool(req: CustomToolCreate, db: AsyncSession = Depends(get_db)): + tool = CustomTool( + name=req.name, + description=req.description, + schema_json=req.schema_json or {}, + endpoint_url=req.endpoint_url or "", + method=req.method, + path=req.path, + headers_json=req.headers, + auth_type=req.auth_type, + auth_config=req.auth_config, + ) + db.add(tool) + await db.flush() + return tool + +@router.get("/", response_model=list[CustomToolOut]) +async def list_custom_tools(db: AsyncSession = Depends(get_db)): + result = await db.execute(select(CustomTool).where(CustomTool.is_active == True)) + return result.scalars().all() + +@router.post("/{tool_id}/test") +async def test_custom_tool(tool_id: uuid.UUID, params: dict, db: AsyncSession = Depends(get_db)): + tool = await db.get(CustomTool, tool_id) + if not tool: + raise HTTPException(404, "工具不存在") + executor = CustomToolExecutor({ + "endpoint_url": tool.endpoint_url, + "method": tool.method, + "path": tool.path, + "headers_json": tool.headers_json, + "auth_type": tool.auth_type, + "auth_config": tool.auth_config, + }) + result = await executor.execute(params) + return {"code": 200, "data": {"result": result}} +``` + +### 2.5 ToolNodeAgent 集成 + +修改 `engine.py` 中 `_init_registry`,加载 CustomTool: +```python +@classmethod +def _init_registry(cls): + if cls._TOOL_REGISTRY: + return + # ... 原有内置工具注册 ... + # 加载自定义工具 + try: + from sqlalchemy import select + from database import SessionLocal + from models import CustomTool + # 注:这里需要用 sync session 或改为异步初始化 + except ImportError: + pass +``` + +**更优方案**:在 FlowEngine 初始化时异步加载自定义工具: +```python +async def _load_custom_tools(self, db: AsyncSession): + from models import CustomTool + result = await db.execute(select(CustomTool).where(CustomTool.is_active == True)) + for tool in result.scalars().all(): + ToolNodeAgent._TOOL_REGISTRY[tool.name] = lambda params, t=tool: CustomToolExecutor(t).execute(params) + ToolNodeAgent._TOOL_SCHEMAS[tool.name] = tool.schema_json +``` + +### 2.6 前端页面 + +**`frontend/src/views/tools/CustomToolManager.vue`** +- 表格:列出所有自定义工具(名称、方法、路径、认证方式) +- 导入按钮:弹出对话框,输入 OpenAPI URL → 点击导入 +- 测试按钮:填入参数 → 调用测试端点 → 显示结果 +- 手动创建:表单填写 name/endpoint/method/path/schema + +**`frontend/src/views/flow/node-configs/ToolConfig.vue`** 增强: +- 工具选择下拉框增加"自定义工具"分组 +- 选择自定义工具后,根据 schema_json 动态生成参数表单 + +--- + +## 三、功能 2:前端 EventSource 流式聊天组件 + +### 3.1 目标 +创建一个独立的聊天页面/组件,支持: +1. 通过 EventSource 连接后端 SSE 端点 +2. 实时显示 `workflow_started` → `node_started` → `text_chunk` → `workflow_finished` 事件 +3. 支持选择已发布的 Flow +4. 显示节点执行进度和中间结果 +5. 支持多轮对话(session_id 持久化) + +### 3.2 组件设计 + +**`frontend/src/views/chat/FlowChat.vue`** +```vue + + + +``` + +### 3.3 API 封装 + +**`frontend/src/api/index.ts`** 新增: +```typescript +// 流式执行用 fetch 而非 axios +executeFlowStream: (id: string, data: any) => { + return fetch(`/api/flow/definitions/${id}/stream`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(data), + }) +}, + +getPublishedFlows: () => api.get('/flow/market'), +``` + +### 3.4 路由注册 + +**`frontend/src/router/index.ts`** 新增: +```typescript +{ + path: '/chat/flow', + name: 'FlowChat', + component: () => import('@/views/chat/FlowChat.vue'), + meta: { title: '流式对话', requiresAuth: true }, +}, +``` + +--- + +## 四、实施计划 + +| 阶段 | 任务 | 预计工时 | 优先级 | +|------|------|---------|--------| +| 1 | 自定义工具数据模型 + Schema | 2h | P0 | +| 2 | OpenAPI 解析器 (parser.py) | 4h | P0 | +| 3 | 自定义工具执行器 (executor.py) | 3h | P0 | +| 4 | 自定义工具 CRUD 路由 | 3h | P0 | +| 5 | ToolNodeAgent 集成自定义工具 | 2h | P0 | +| 6 | 前端 CustomToolManager 页面 | 4h | P1 | +| 7 | ToolConfig.vue 动态参数表单 | 3h | P1 | +| 8 | 前端 FlowChat.vue EventSource 组件 | 5h | P1 | +| 9 | 前端路由 + API 封装 | 2h | P1 | +| 10 | 集成测试 | 4h | P2 | + +**总计:约 32 工时(4 天)** + +--- + +## 五、验收标准 + +### 自定义 API 工具 +- [ ] 输入 `https://petstore.swagger.io/v2/swagger.json` 可成功导入所有 endpoint +- [ ] 导入的工具出现在 ToolConfig.vue 下拉框中 +- [ ] 选择自定义工具后,参数表单根据 Schema 动态生成 +- [ ] Flow 执行时,自定义工具通过 httpx 正确调用并返回结果 +- [ ] 支持 API Key / Bearer Token 认证 + +### 前端 EventSource 聊天 +- [ ] 打开 `/chat/flow` 页面,选择已发布 Flow +- [ ] 流式模式下,输入消息后实时看到文字逐字出现 +- [ ] 阻塞模式下,输入消息后等待完整结果一次性显示 +- [ ] 显示节点执行详情(折叠面板) +- [ ] 刷新页面后 session_id 保留,支持多轮对话 +- [ ] 切换 Flow 后 session_id 重置 diff --git a/backend/agentscope_integration/factory.py b/backend/agentscope_integration/factory.py index 96d5f12..aa64594 100644 --- a/backend/agentscope_integration/factory.py +++ b/backend/agentscope_integration/factory.py @@ -13,6 +13,7 @@ class AgentFactory: _model: OpenAIChatModel | None = None _formatter: OpenAIChatFormatter | None = None _agent_cache: dict[str, AgentBase] = {} + _MAX_CACHE_SIZE = 50 @classmethod def _get_model(cls) -> OpenAIChatModel: @@ -57,6 +58,9 @@ class AgentFactory: else: agent = await cls._create_employee_agent(user_id, user_name, department_id, model, formatter) + if len(cls._agent_cache) >= cls._MAX_CACHE_SIZE: + oldest_key = next(iter(cls._agent_cache)) + del cls._agent_cache[oldest_key] cls._agent_cache[cache_key] = agent return agent diff --git a/backend/agentscope_integration/tools/manager_tools.py b/backend/agentscope_integration/tools/manager_tools.py index 52bf7f3..5166d7e 100644 --- a/backend/agentscope_integration/tools/manager_tools.py +++ b/backend/agentscope_integration/tools/manager_tools.py @@ -1,6 +1,9 @@ import httpx import logging import os +import jwt +import time +from config import settings logger = logging.getLogger(__name__) @@ -15,37 +18,63 @@ def _get_client() -> httpx.Client: return _client -def _get_token() -> str | None: +def _get_service_token() -> str | None: try: - resp = _get_client().post( - f"{_INTERNAL_BASE}/auth/login", - json={"username": "admin", "password": "admin123"}, - ) - data = resp.json() - return data.get("access_token") + payload = {"sub": "system_tool", "exp": int(time.time()) + 3600, "type": "service"} + token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") + return token except Exception: return None def _headers(token: str | None = None) -> dict: - t = token or _get_token() + t = token or _get_service_token() return {"Authorization": f"Bearer {t}"} if t else {} +SCHEMAS = { + "list_subordinates": { + "name": "list_subordinates", + "description": "查询当前用户的下属员工列表", + "parameters": {"type": "object", "properties": {}} + }, + "get_employee_dashboard": { + "name": "get_employee_dashboard", + "description": "查询指定员工的工作看板数据", + "parameters": { + "type": "object", + "properties": {"employee_id": {"type": "string", "description": "员工ID"}}, + "required": ["employee_id"] + } + }, + "generate_efficiency_report": { + "name": "generate_efficiency_report", + "description": "生成团队效率分析报告", + "parameters": { + "type": "object", + "properties": {"department_id": {"type": "string", "description": "部门ID(可选)"}} + } + }, + "get_task_statistics": { + "name": "get_task_statistics", + "description": "查询任务统计数据", + "parameters": { + "type": "object", + "properties": {"employee_id": {"type": "string", "description": "员工ID(可选)"}} + } + }, +} + + def list_subordinates() -> str: try: resp = _get_client().get(f"{_INTERNAL_BASE}/org/subordinates", headers=_headers()) users = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) if not users: return "当前没有下属员工数据。" - lines = ["下属员工列表:"] for u in users: - lines.append( - f"- {u.get('display_name', u.get('username', '?'))} " - f"| 岗位: {u.get('position', '?')} " - f"| 部门: {u.get('department_name', '?')}" - ) + lines.append(f"- {u.get('display_name', u.get('username', '?'))} | 岗位: {u.get('position', '?')} | 部门: {u.get('department_name', '?')}") return "\n".join(lines) except Exception as e: return f"查询下属列表失败: {e}" @@ -53,18 +82,9 @@ def list_subordinates() -> str: def get_employee_dashboard(employee_id: str) -> str: try: - resp = _get_client().get( - f"{_INTERNAL_BASE}/monitor/employee/{employee_id}/dashboard", - headers=_headers(), - ) + resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employee/{employee_id}/dashboard", headers=_headers()) data = resp.json() - return ( - f"员工 {employee_id[:8]} 工作看板:\n" - f"- 任务完成率: {data.get('completion_rate', '?')}%\n" - f"- 平均响应时间: {data.get('avg_response_time', '?')} 分钟\n" - f"- 今日任务数: {data.get('today_tasks', 0)}\n" - f"- 本周完成: {data.get('weekly_completed', 0)}" - ) + return f"员工 {employee_id[:8]} 工作看板:\n- 任务完成率: {data.get('completion_rate', '?')}%\n- 平均响应时间: {data.get('avg_response_time', '?')} 分钟\n- 今日任务数: {data.get('today_tasks', 0)}\n- 本周完成: {data.get('weekly_completed', 0)}" except Exception as e: return f"查询员工看板失败: {e}" @@ -73,7 +93,6 @@ def generate_efficiency_report(department_id: str | None = None) -> str: try: resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employees", headers=_headers()) employees = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) - report = ["=== 团队效率报告 ===\n"] total_tasks = 0 active_employees = 0 @@ -82,14 +101,8 @@ def generate_efficiency_report(department_id: str | None = None) -> str: total_tasks += task_count if emp.get("status") == "active": active_employees += 1 - report.append( - f"- {emp.get('display_name', emp.get('username', '?'))}: " - f"任务数={task_count}, 完成率={emp.get('completion_rate', 0)}%" - ) - - report.append( - f"\n总结: 活跃员工 {active_employees}/{len(employees)} 人, 总任务 {total_tasks} 个" - ) + report.append(f"- {emp.get('display_name', emp.get('username', '?'))}: 任务数={task_count}, 完成率={emp.get('completion_rate', 0)}%") + report.append(f"\n总结: 活跃员工 {active_employees}/{len(employees)} 人, 总任务 {total_tasks} 个") return "\n".join(report) except Exception as e: return f"生成报告失败: {e}" @@ -99,23 +112,14 @@ def get_task_statistics(employee_id: str | None = None) -> str: try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) - if employee_id: tasks = [t for t in tasks if t.get("assignee_id") == employee_id] - todo = sum(1 for t in tasks if t.get("status") == "todo") in_progress = sum(1 for t in tasks if t.get("status") == "in_progress") done = sum(1 for t in tasks if t.get("status") == "done") - - return ( - f"任务统计:\n" - f"- 待办: {todo}\n" - f"- 进行中: {in_progress}\n" - f"- 已完成: {done}\n" - f"- 总计: {len(tasks)}" - ) + return f"任务统计:\n- 待办: {todo}\n- 进行中: {in_progress}\n- 已完成: {done}\n- 总计: {len(tasks)}" except Exception as e: return f"查询任务统计失败: {e}" -__all__ = ["list_subordinates", "get_employee_dashboard", "generate_efficiency_report", "get_task_statistics"] \ No newline at end of file +__all__ = ["list_subordinates", "get_employee_dashboard", "generate_efficiency_report", "get_task_statistics", "SCHEMAS"] \ No newline at end of file diff --git a/backend/agentscope_integration/tools/task_tools.py b/backend/agentscope_integration/tools/task_tools.py index 37315c8..1e75735 100644 --- a/backend/agentscope_integration/tools/task_tools.py +++ b/backend/agentscope_integration/tools/task_tools.py @@ -1,6 +1,9 @@ import httpx import logging import os +import jwt +import time +from config import settings logger = logging.getLogger(__name__) @@ -15,24 +18,84 @@ def _get_client() -> httpx.Client: return _client -def _get_token() -> str | None: - from config import settings +def _get_service_token() -> str | None: try: - resp = _get_client().post( - f"{_INTERNAL_BASE}/auth/login", - json={"username": "admin", "password": "admin123"}, - ) - data = resp.json() - return data.get("access_token") + payload = { + "sub": "system_tool", + "exp": int(time.time()) + 3600, + "type": "service", + } + token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") + return token except Exception: return None def _headers(token: str | None = None) -> dict: - t = token or _get_token() + t = token or _get_service_token() return {"Authorization": f"Bearer {t}"} if t else {} +SCHEMAS = { + "list_tasks": { + "name": "list_tasks", + "description": "查询任务列表,可选按状态筛选", + "parameters": { + "type": "object", + "properties": { + "status": {"type": "string", "description": "任务状态筛选", "enum": ["todo", "in_progress", "done"]} + } + } + }, + "create_task": { + "name": "create_task", + "description": "创建新任务", + "parameters": { + "type": "object", + "properties": { + "title": {"type": "string", "description": "任务标题"}, + "description": {"type": "string", "description": "任务描述"}, + "assignee_id": {"type": "string", "description": "负责人ID"}, + "priority": {"type": "string", "description": "优先级", "enum": ["low", "medium", "high", "urgent"]}, + "deadline": {"type": "string", "description": "截止日期"} + }, + "required": ["title"] + } + }, + "get_task": { + "name": "get_task", + "description": "查询指定任务详情", + "parameters": { + "type": "object", + "properties": {"task_id": {"type": "string", "description": "任务ID"}}, + "required": ["task_id"] + } + }, + "update_task": { + "name": "update_task", + "description": "更新任务状态或描述", + "parameters": { + "type": "object", + "properties": { + "task_id": {"type": "string", "description": "任务ID"}, + "status": {"type": "string", "description": "新状态", "enum": ["todo", "in_progress", "done"]}, + "description": {"type": "string", "description": "新描述"} + }, + "required": ["task_id"] + } + }, + "push_task_to_wecom": { + "name": "push_task_to_wecom", + "description": "将任务推送到企业微信", + "parameters": { + "type": "object", + "properties": {"task_id": {"type": "string", "description": "任务ID"}}, + "required": ["task_id"] + } + }, +} + + def list_tasks(status: str | None = None) -> str: try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) @@ -56,13 +119,7 @@ def list_tasks(status: str | None = None) -> str: def create_task(title: str, description: str = "", assignee_id: str = "", priority: str = "medium", deadline: str | None = None) -> str: try: - body = { - "title": title, - "description": description, - "assignee_id": assignee_id, - "priority": priority, - "deadline": deadline, - } + body = {"title": title, "description": description, "assignee_id": assignee_id, "priority": priority, "deadline": deadline} resp = _get_client().post(f"{_INTERNAL_BASE}/tasks", json=body, headers=_headers()) task = resp.json() return f"任务创建成功: {task.get('title', title)} (ID: {task.get('id', '?')[:8]})" @@ -74,12 +131,7 @@ def get_task(task_id: str) -> str: try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks/{task_id}", headers=_headers()) t = resp.json() - return ( - f"任务: {t.get('title', '?')}\n" - f"描述: {t.get('description', '无')}\n" - f"负责人: {t.get('assignee_name', t.get('assignee_id', '无人'))}\n" - f"状态: {t.get('status', '?')} | 优先级: {t.get('priority', '?')} | 截止: {t.get('deadline', '无')}" - ) + return f"任务: {t.get('title', '?')}\n描述: {t.get('description', '无')}\n负责人: {t.get('assignee_name', t.get('assignee_id', '无人'))}\n状态: {t.get('status', '?')} | 优先级: {t.get('priority', '?')} | 截止: {t.get('deadline', '无')}" except Exception as e: return f"查询任务失败: {e}" @@ -100,10 +152,9 @@ def update_task(task_id: str, status: str | None = None, description: str | None def push_task_to_wecom(task_id: str) -> str: try: resp = _get_client().post(f"{_INTERNAL_BASE}/tasks/{task_id}/push", headers=_headers()) - data = resp.json() if hasattr(resp, 'json') else resp return f"任务 {task_id[:8]} 已推送至企业微信" except Exception as e: return f"推送任务失败: {e}" -__all__ = ["list_tasks", "create_task", "get_task", "update_task", "push_task_to_wecom"] \ No newline at end of file +__all__ = ["list_tasks", "create_task", "get_task", "update_task", "push_task_to_wecom", "SCHEMAS"] \ No newline at end of file diff --git a/backend/database.py b/backend/database.py index b5bde8b..93675b1 100644 --- a/backend/database.py +++ b/backend/database.py @@ -1,5 +1,6 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import text from config import settings @@ -28,6 +29,18 @@ async def init_db(): from models import Base as MBase await conn.run_sync(MBase.metadata.create_all) + await _run_migrations() + + +async def _run_migrations(): + async with async_engine.begin() as conn: + await conn.execute(text( + "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_version_id UUID REFERENCES flow_versions(id)" + )) + await conn.execute(text( + "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS draft_definition_json JSONB" + )) + async def get_db(): async with AsyncSessionLocal() as session: diff --git a/backend/main.py b/backend/main.py index cd69cf3..ce905cf 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,11 +10,15 @@ from modules.task.router import router as task_router from modules.monitor.router import router as monitor_router from modules.mcp_registry.router import router as mcp_router from modules.flow_engine.router import router as flow_router +from modules.flow_engine.gateway import gateway_router from modules.audit.router import router as audit_router from modules.document.router import router as document_router from modules.notification.router import router as notification_router from modules.system.router import router as system_router from modules.rag.router import router as rag_router +from modules.chat.router import router as chat_router +from modules.custom_tool.router import router as custom_tool_router +from websocket_manager import ws_manager from middleware.rbac_middleware import rbac_middleware from middleware.rate_limiter import rate_limit_middleware from middleware.cache_manager import cache_manager @@ -49,8 +53,11 @@ app.include_router(task_router) app.include_router(monitor_router) app.include_router(mcp_router) app.include_router(flow_router) +app.include_router(gateway_router) app.include_router(audit_router) app.include_router(document_router) app.include_router(notification_router) app.include_router(system_router) -app.include_router(rag_router) \ No newline at end of file +app.include_router(rag_router) +app.include_router(chat_router) +app.include_router(custom_tool_router) \ No newline at end of file diff --git a/backend/middleware/apikey_auth.py b/backend/middleware/apikey_auth.py new file mode 100644 index 0000000..911e545 --- /dev/null +++ b/backend/middleware/apikey_auth.py @@ -0,0 +1,43 @@ +import hashlib +from datetime import datetime +from fastapi import Request, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from models import FlowApiKey +from database import get_db + + +async def authenticate_api_key(request: Request) -> dict: + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise HTTPException(401, "缺少认证信息") + + raw_key = auth_header[7:] + if not raw_key.startswith("flow-"): + raise HTTPException(401, "无效的API Key格式") + + key_hash = hashlib.sha256(raw_key.encode()).hexdigest() + + db_gen = get_db() + db: AsyncSession = await db_gen.__anext__() + try: + result = await db.execute( + select(FlowApiKey).where(FlowApiKey.key_hash == key_hash) + ) + api_key = result.scalar_one_or_none() + if not api_key: + raise HTTPException(401, "API Key无效或已删除") + + api_key.last_used_at = datetime.utcnow() + await db.flush() + + return { + "flow_id": str(api_key.flow_id), + "api_key_id": str(api_key.id), + "auth_type": "api_key", + } + finally: + try: + await db_gen.__anext__() + except StopAsyncIteration: + pass diff --git a/backend/middleware/cache_manager.py b/backend/middleware/cache_manager.py index a2675ef..6fdf69f 100644 --- a/backend/middleware/cache_manager.py +++ b/backend/middleware/cache_manager.py @@ -1,4 +1,5 @@ import json +import time import asyncio from typing import Any from redis.asyncio import Redis @@ -84,7 +85,4 @@ class CacheManager: del self._local[k] -cache_manager = CacheManager() - - -import time # noqa: E402 \ No newline at end of file +cache_manager = CacheManager() \ No newline at end of file diff --git a/backend/middleware/rate_limiter.py b/backend/middleware/rate_limiter.py index 8b3d90d..fe81df5 100644 --- a/backend/middleware/rate_limiter.py +++ b/backend/middleware/rate_limiter.py @@ -6,9 +6,21 @@ from config import settings class RateLimiter: + MAX_KEYS = 10000 + def __init__(self): self._buckets: dict[str, list[float]] = defaultdict(list) self._lock = asyncio.Lock() + self._last_cleanup = time.time() + + async def _cleanup(self): + now = time.time() + if now - self._last_cleanup < 60: + return + self._last_cleanup = now + expired_keys = [k for k, v in self._buckets.items() if not v or now - v[-1] > 120] + for k in expired_keys: + del self._buckets[k] async def check(self, key: str) -> bool: now = time.time() @@ -16,6 +28,7 @@ class RateLimiter: window = 60.0 async with self._lock: + await self._cleanup() bucket = self._buckets[key] bucket = [t for t in bucket if now - t < window] self._buckets[key] = bucket @@ -24,6 +37,12 @@ class RateLimiter: return False bucket.append(now) + + if len(self._buckets) > self.MAX_KEYS: + oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2] + for k in oldest_keys: + del self._buckets[k] + return True async def remaining(self, key: str) -> int: diff --git a/backend/middleware/rbac_middleware.py b/backend/middleware/rbac_middleware.py index 85cb715..69e69ad 100644 --- a/backend/middleware/rbac_middleware.py +++ b/backend/middleware/rbac_middleware.py @@ -60,7 +60,9 @@ async def rbac_middleware(request: Request, call_next): "permissions": unique_perms, "is_root": is_root, "data_scope": "all" if is_root or "all" in data_scopes else ( - "subordinate_only" if "subordinate_only" in data_scopes else "self_only" + "department" if "department" in data_scopes else + "subordinate_only" if "subordinate_only" in data_scopes else + "self_only" ), } diff --git a/backend/models/__init__.py b/backend/models/__init__.py index 20cf5e5..dffc3d7 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Boolean, JSON, Text +from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Boolean, JSON, Text, Float from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from database import Base @@ -138,8 +138,59 @@ class FlowDefinition(Base): version = Column(Integer, default=1) status = Column(String(20), default="draft") definition_json = Column(JSON, nullable=False, default=dict) + published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) + draft_definition_json = Column(JSON, nullable=True, default=None) creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) published_to_wecom = Column(Boolean, default=False) + published_to_web = Column(Boolean, default=False) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + published_version = relationship("FlowVersion", foreign_keys=[published_version_id], post_update=True) + + +class FlowVersion(Base): + __tablename__ = "flow_versions" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) + version = Column(Integer, nullable=False) + definition_json = Column(JSON, nullable=False, default=dict) + changelog = Column(Text, default="") + published_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + published_to_wecom = Column(Boolean, default=False) + published_to_web = Column(Boolean, default=False) + created_at = Column(DateTime, default=datetime.utcnow) + + +class FlowApiKey(Base): + __tablename__ = "flow_api_keys" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) + name = Column(String(100), nullable=False) + key_hash = Column(String(64), nullable=False) + key_prefix = Column(String(10), nullable=False) + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + last_used_at = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + +class CustomTool(Base): + __tablename__ = "custom_tools" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(100), nullable=False) + description = Column(Text) + schema_json = Column(JSON, nullable=False, default=dict) + endpoint_url = Column(String(500), nullable=False) + method = Column(String(10), default="GET") + path = Column(String(500), default="") + headers_json = Column(JSON, default=dict) + auth_type = Column(String(20), default="none") + auth_config = Column(JSON, default=dict) + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) @@ -149,11 +200,15 @@ class FlowExecution(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE")) + version = Column(Integer, nullable=True) trigger_type = Column(String(50)) trigger_user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) input_data = Column(JSON) output_data = Column(JSON) status = Column(String(20), default="running") + token_usage = Column(JSON, default=dict) + latency_ms = Column(Integer, nullable=True) + error_message = Column(Text, nullable=True) started_at = Column(DateTime, default=datetime.utcnow) finished_at = Column(DateTime) @@ -206,7 +261,7 @@ class AgentConfig(Base): description = Column(String(500)) system_prompt = Column(Text, default="") model = Column(String(50), default="gpt-4o-mini") - temperature = Column(Integer, default=7) + temperature = Column(Float, default=0.7) tools = Column(JSON, default=list) status = Column(String(20), default="active") creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) diff --git a/backend/modules/agent_manager/router.py b/backend/modules/agent_manager/router.py index c047491..8764a75 100644 --- a/backend/modules/agent_manager/router.py +++ b/backend/modules/agent_manager/router.py @@ -64,6 +64,7 @@ async def agent_chat( role="assistant", content=reply_text, ) db.add(ai_msg) + await db.flush() return { "code": 200, @@ -89,7 +90,7 @@ async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)): "description": a.description, "system_prompt": a.system_prompt, "model": a.model, - "temperature": float(a.temperature) / 10.0, + "temperature": a.temperature if isinstance(a.temperature, float) else float(a.temperature) / 10.0, "tools": a.tools or [], "status": a.status, } for a in agents], @@ -104,7 +105,7 @@ async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSessio description=req.description, system_prompt=req.system_prompt, model=req.model, - temperature=int(req.temperature * 10), + temperature=req.temperature, tools=req.tools, creator_id=uuid.UUID(user_ctx["id"]), ) @@ -113,7 +114,7 @@ async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSessio return AgentConfigOut( id=agent.id, name=agent.name, description=agent.description, system_prompt=agent.system_prompt, model=agent.model, - temperature=float(agent.temperature) / 10.0, + temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0, tools=agent.tools or [], status=agent.status, creator_id=agent.creator_id, created_at=agent.created_at, updated_at=agent.updated_at, @@ -129,7 +130,7 @@ async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = De return AgentConfigOut( id=agent.id, name=agent.name, description=agent.description, system_prompt=agent.system_prompt, model=agent.model, - temperature=float(agent.temperature) / 10.0, + temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0, tools=agent.tools or [], status=agent.status, creator_id=agent.creator_id, created_at=agent.created_at, updated_at=agent.updated_at, @@ -151,7 +152,7 @@ async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Req if req.model is not None: agent.model = req.model if req.temperature is not None: - agent.temperature = int(req.temperature * 10) + agent.temperature = req.temperature if req.tools is not None: agent.tools = req.tools if req.status is not None: @@ -160,7 +161,7 @@ async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Req return AgentConfigOut( id=agent.id, name=agent.name, description=agent.description, system_prompt=agent.system_prompt, model=agent.model, - temperature=float(agent.temperature) / 10.0, + temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0, tools=agent.tools or [], status=agent.status, creator_id=agent.creator_id, created_at=agent.created_at, updated_at=agent.updated_at, diff --git a/backend/modules/auth/router.py b/backend/modules/auth/router.py index 7b60497..af9518c 100644 --- a/backend/modules/auth/router.py +++ b/backend/modules/auth/router.py @@ -1,4 +1,5 @@ import uuid +import secrets from datetime import datetime, timedelta import jwt from fastapi import APIRouter, Depends, HTTPException, Request @@ -11,6 +12,9 @@ from models import User, UserRole, Role, RolePermission, Permission from schemas import LoginRequest, TokenResponse, UserOut, RoleOut from config import settings +_oauth_states: dict[str, float] = {} +_OAUTH_STATE_TTL = 600 + def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') @@ -106,8 +110,14 @@ async def get_wecom_oauth_url(request: Request): return {"code": 400, "message": "请先配置 WECOM_CORP_ID"} base_url = str(request.base_url).rstrip("/") redirect_uri = f"{base_url}/api/auth/wecom/callback" - url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state=STATE#wechat_redirect" - return {"code": 200, "data": {"url": url}} + state = secrets.token_urlsafe(32) + import time + _oauth_states[state] = time.time() + expired = [k for k, v in _oauth_states.items() if time.time() - v > _OAUTH_STATE_TTL] + for k in expired: + del _oauth_states[k] + url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state={state}#wechat_redirect" + return {"code": 200, "data": {"url": url, "state": state}} @router.put("/me") diff --git a/backend/modules/chat/__init__.py b/backend/modules/chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/modules/chat/router.py b/backend/modules/chat/router.py new file mode 100644 index 0000000..864d11f --- /dev/null +++ b/backend/modules/chat/router.py @@ -0,0 +1,99 @@ +from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from database import get_db +from models import FlowDefinition, FlowVersion +from modules.flow_engine.engine import FlowEngine +from agentscope.message import Msg +from websocket_manager import ws_manager + +router = APIRouter(prefix="/api/chat", tags=["chat"]) + + +@router.websocket("/ws") +async def chat_websocket(websocket: WebSocket): + user_id = websocket.query_params.get("user_id", "anonymous") + await ws_manager.connect(websocket, user_id) + try: + while True: + data = await websocket.receive_text() + await ws_manager.send_to_user(user_id, {"type": "echo", "message": data}) + except WebSocketDisconnect: + ws_manager.disconnect(websocket, user_id) + + +@router.post("/message/{flow_id}") +async def chat_message( + flow_id: str, + request: Request, + payload: dict, + db: AsyncSession = Depends(get_db), +): + try: + import uuid as _uuid + fid = _uuid.UUID(flow_id) + except ValueError: + raise HTTPException(400, "无效的流ID") + + flow_result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == fid)) + flow = flow_result.scalar_one_or_none() + if not flow or flow.status != "published": + raise HTTPException(404, "流不存在或未发布") + + definition = flow.definition_json + if flow.published_version_id: + ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == flow.published_version_id)) + published = ver_result.scalar_one_or_none() + if published and published.definition_json: + import json + definition = json.loads(json.dumps(published.definition_json)) + + user_ctx = request.state.user + input_text = payload.get("message", payload.get("query", "")) + if not input_text: + raise HTTPException(400, "请输入消息内容") + + engine = FlowEngine(definition) + input_msg = Msg(name="user", content=input_text, role="user") + context = { + "user_id": user_ctx.get("id", "web_user"), + "username": user_ctx.get("username", "网页访客"), + "trigger_data": {"channel": "web_chat"}, + "_node_results": {}, + "session_id": payload.get("session_id", str(uuid.uuid4())), + } + + try: + result_msg = await engine.execute(input_msg, context) + output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) + + return { + "code": 200, + "data": { + "reply": output_text, + "node_results": context.get("_node_results", {}), + }, + } + except Exception as e: + raise HTTPException(500, f"流执行失败: {str(e)}") + + +@router.get("/flows") +async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(FlowDefinition).where(FlowDefinition.status == "published") + ) + flows = result.scalars().all() + return { + "code": 200, + "data": [ + { + "id": str(f.id), + "name": f.name, + "description": f.description, + "published_to_web": f.published_to_web, + "published_to_wecom": f.published_to_wecom, + } + for f in flows + ], + } \ No newline at end of file diff --git a/backend/modules/custom_tool/__init__.py b/backend/modules/custom_tool/__init__.py new file mode 100644 index 0000000..43a7103 --- /dev/null +++ b/backend/modules/custom_tool/__init__.py @@ -0,0 +1,3 @@ +from .router import router + +__all__ = ["router"] \ No newline at end of file diff --git a/backend/modules/custom_tool/executor.py b/backend/modules/custom_tool/executor.py new file mode 100644 index 0000000..26c036a --- /dev/null +++ b/backend/modules/custom_tool/executor.py @@ -0,0 +1,43 @@ +import httpx +import json + +class CustomToolExecutor: + def __init__(self, tool_def: dict): + self.endpoint_url = tool_def.get("endpoint_url", "") + self.method = tool_def.get("method", "GET") + self.path = tool_def.get("path", "") + self.headers = dict(tool_def.get("headers_json", {})) + self.auth_type = tool_def.get("auth_type", "none") + self.auth_config = dict(tool_def.get("auth_config", {})) + self.timeout = int(tool_def.get("timeout", 30)) + + async def execute(self, params: dict) -> str: + url = f"{self.endpoint_url.rstrip('/')}/{self.path.lstrip('/')}" + headers = dict(self.headers) + req_params = dict(params) + + if self.auth_type == "api_key": + key = self.auth_config.get("key", "") + loc = self.auth_config.get("location", "header") + name = self.auth_config.get("name", "X-API-Key") + if loc == "header": + headers[name] = key + else: + req_params[name] = key + elif self.auth_type == "bearer": + headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" + + timeout = httpx.Timeout(self.timeout) + async with httpx.AsyncClient(timeout=timeout) as client: + if self.method == "GET": + resp = await client.get(url, params=req_params, headers=headers) + else: + resp = await client.request( + self.method, url, json=req_params, headers=headers + ) + + try: + data = resp.json() + return json.dumps(data, ensure_ascii=False, indent=2)[:4000] + except Exception: + return resp.text[:4000] \ No newline at end of file diff --git a/backend/modules/custom_tool/parser.py b/backend/modules/custom_tool/parser.py new file mode 100644 index 0000000..4c99c9c --- /dev/null +++ b/backend/modules/custom_tool/parser.py @@ -0,0 +1,81 @@ +import json +from typing import Any + +class OpenAPIParser: + def __init__(self, spec: dict): + self.spec = spec + self.base_url = "" + servers = spec.get("servers", [{}]) + if servers and isinstance(servers, list): + self.base_url = servers[0].get("url", "") + + def parse_tools(self) -> list[dict]: + tools = [] + paths = self.spec.get("paths", {}) + for path, methods in paths.items(): + if not isinstance(methods, dict): + continue + for method, operation in methods.items(): + if method in ("get", "post", "put", "delete", "patch") and isinstance(operation, dict): + tool = self._parse_endpoint(path, method, operation) + if tool: + tools.append(tool) + return tools + + def _parse_endpoint(self, path: str, method: str, operation: dict) -> dict | None: + op_id = operation.get("operationId", "") + if not op_id: + op_id = f"{method}_{path.replace('/', '_').strip('_')}" + + description = operation.get("summary") or operation.get("description") or f"{method.upper()} {path}" + properties = self._parse_parameters(operation) + required = [] + for param in operation.get("parameters", []): + if isinstance(param, dict) and param.get("required"): + required.append(param["name"]) + + return { + "name": op_id, + "description": description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + "path": path, + "method": method.upper(), + } + + def _parse_parameters(self, operation: dict) -> dict[str, Any]: + props = {} + for param in operation.get("parameters", []): + if not isinstance(param, dict): + continue + pname = param.get("name", "") + if not pname: + continue + schema = param.get("schema", {}) + if not isinstance(schema, dict): + schema = {} + props[pname] = { + "type": schema.get("type", "string"), + "description": param.get("description", ""), + } + if "enum" in schema: + props[pname]["enum"] = schema["enum"] + + body = ( + operation.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + ) + if isinstance(body, dict): + for name, prop in body.get("properties", {}).items(): + if isinstance(prop, dict): + props[name] = { + "type": prop.get("type", "string"), + "description": prop.get("description", ""), + } + + return props \ No newline at end of file diff --git a/backend/modules/custom_tool/router.py b/backend/modules/custom_tool/router.py new file mode 100644 index 0000000..a9cfb19 --- /dev/null +++ b/backend/modules/custom_tool/router.py @@ -0,0 +1,249 @@ +import uuid +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from database import get_db +from models import CustomTool +from schemas import CustomToolCreate, CustomToolUpdate, CustomToolOut, OpenAPIImportRequest +from modules.custom_tool.parser import OpenAPIParser +from modules.custom_tool.executor import CustomToolExecutor +from modules.flow_engine.engine import ToolNodeAgent +from dependencies import get_current_user +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"]) + + +@router.post("/import-openapi") +async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)): + user_ctx = request.state.user + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get(req.openapi_url) + resp.raise_for_status() + spec = resp.json() + except httpx.HTTPError as e: + raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}") + except ValueError: + raise HTTPException(400, "OpenAPI 文档不是有效的 JSON 格式") + + parser = OpenAPIParser(spec) + tools = parser.parse_tools() + if not tools: + raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具") + + base_url = req.base_url_override or parser.base_url + if not base_url: + raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override") + + created = [] + for t in tools: + existing = await db.execute( + select(CustomTool).where(CustomTool.name == t["name"]) + ) + if existing.scalar_one_or_none(): + continue + + tool = CustomTool( + name=t["name"], + description=t["description"], + schema_json=t["parameters"], + endpoint_url=base_url, + method=t["method"], + path=t["path"], + created_by=uuid.UUID(user_ctx["id"]), + ) + db.add(tool) + created.append(t["name"]) + + ToolNodeAgent.register_custom_tool( + t["name"], + t["parameters"], + { + "endpoint_url": base_url, + "method": t["method"], + "path": t["path"], + "headers_json": {}, + "auth_type": "none", + "auth_config": {}, + "timeout": 30, + }, + ) + + await db.flush() + return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}} + + +@router.post("/", response_model=CustomToolOut) +async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)): + user_ctx = request.state.user + user_id = uuid.UUID(user_ctx["id"]) + + if req.openapi_url: + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get(req.openapi_url) + resp.raise_for_status() + spec = resp.json() + except Exception as e: + raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}") + + parser = OpenAPIParser(spec) + tools = parser.parse_tools() + + created_tool = None + for t in tools: + if t["name"] == req.name or (not req.name and tools): + existing = await db.execute( + select(CustomTool).where(CustomTool.name == t["name"]) + ) + if existing.scalar_one_or_none(): + continue + tool = CustomTool( + name=t["name"], + description=t["description"], + schema_json=t["parameters"], + endpoint_url=parser.base_url, + method=t["method"], + path=t["path"], + created_by=user_id, + ) + db.add(tool) + created_tool = tool + break + + if not created_tool: + raise HTTPException(400, "未找到匹配的工具") + + await db.flush() + return created_tool + + schema_json = req.schema_json or {} + if not schema_json and req.endpoint_url: + schema_json = { + "type": "object", + "properties": {}, + "description": req.description or "", + } + + tool = CustomTool( + name=req.name, + description=req.description, + schema_json=schema_json, + endpoint_url=req.endpoint_url or "", + method=req.method, + path=req.path, + headers_json=req.headers, + auth_type=req.auth_type, + auth_config=req.auth_config, + created_by=user_id, + ) + db.add(tool) + ToolNodeAgent.register_custom_tool( + req.name, + schema_json, + { + "endpoint_url": req.endpoint_url or "", + "method": req.method, + "path": req.path, + "headers_json": req.headers, + "auth_type": req.auth_type, + "auth_config": req.auth_config, + "timeout": 30, + }, + ) + await db.flush() + return tool + + +@router.get("/", response_model=list[CustomToolOut]) +async def list_custom_tools(db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc()) + ) + return result.scalars().all() + + +@router.get("/{tool_id}", response_model=CustomToolOut) +async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + tool = await db.get(CustomTool, tool_id) + if not tool: + raise HTTPException(404, "工具不存在") + return tool + + +@router.put("/{tool_id}", response_model=CustomToolOut) +async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)): + tool = await db.get(CustomTool, tool_id) + if not tool: + raise HTTPException(404, "工具不存在") + if req.name is not None: + tool.name = req.name + if req.description is not None: + tool.description = req.description + if req.endpoint_url is not None: + tool.endpoint_url = req.endpoint_url + if req.method is not None: + tool.method = req.method + if req.path is not None: + tool.path = req.path + if req.headers is not None: + tool.headers_json = req.headers + if req.auth_type is not None: + tool.auth_type = req.auth_type + if req.auth_config is not None: + tool.auth_config = req.auth_config + if req.schema_json is not None: + tool.schema_json = req.schema_json + if req.is_active is not None: + tool.is_active = req.is_active + await db.flush() + return tool + + +@router.delete("/{tool_id}") +async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + tool = await db.get(CustomTool, tool_id) + if not tool: + raise HTTPException(404, "工具不存在") + tool.is_active = False + await db.flush() + return {"code": 200, "message": "工具已停用"} + + +@router.post("/{tool_id}/test") +async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)): + tool = await db.get(CustomTool, tool_id) + if not tool: + raise HTTPException(404, "工具不存在") + if params is None: + params = {} + + executor = CustomToolExecutor({ + "endpoint_url": tool.endpoint_url, + "method": tool.method, + "path": tool.path, + "headers_json": tool.headers_json, + "auth_type": tool.auth_type, + "auth_config": tool.auth_config, + }) + try: + result = await executor.execute(params) + return {"code": 200, "data": {"result": result}} + except Exception as e: + raise HTTPException(500, f"工具执行失败: {str(e)}") + + +@router.get("/schemas/all") +async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(CustomTool).where(CustomTool.is_active == True) + ) + tools = result.scalars().all() + schemas = {} + for t in tools: + schemas[t.name] = t.schema_json + return {"code": 200, "data": schemas} \ No newline at end of file diff --git a/backend/modules/flow_engine/engine.py b/backend/modules/flow_engine/engine.py index 0f1bec9..49a9c51 100644 --- a/backend/modules/flow_engine/engine.py +++ b/backend/modules/flow_engine/engine.py @@ -2,6 +2,7 @@ import json import uuid import logging import re +import asyncio from agentscope.agent import AgentBase from agentscope.message import Msg from agentscope.tool import Toolkit @@ -11,7 +12,26 @@ from config import settings logger = logging.getLogger(__name__) +class FlowSessionMemory: + def __init__(self, session_id: str = "", user_id: str = ""): + self.session_id = session_id + self.user_id = user_id + self._messages: list[dict] = [] + + def get_history(self, limit: int = 10) -> list[dict]: + return self._messages[-limit * 2:] + + def add(self, role: str, content: str): + self._messages.append({"role": role, "content": content}) + + def to_list(self) -> list[dict]: + return list(self._messages) + + class FlowEngine: + MAX_TOTAL_ITERATIONS = 200 + FLOW_TIMEOUT_SECONDS = 300 + def __init__(self, flow_definition: dict): self.definition = flow_definition self.nodes: dict[str, dict] = {} @@ -26,20 +46,56 @@ class FlowEngine: if not start_nodes: start_nodes = list(self.nodes.keys())[:1] + session_id = context.get("session_id", str(uuid.uuid4())) + user_id = context.get("user_id", "") + memory = FlowSessionMemory(session_id=session_id, user_id=user_id) + context["_memory"] = memory + visited: set[str] = set() + loop_iterations: dict[str, int] = {} + total_iterations = 0 last_result: Msg | None = None - async def traverse(node_id: str, incoming_msg: Msg) -> None: - nonlocal last_result - if node_id in visited: + async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None: + nonlocal last_result, total_iterations + + total_iterations += 1 + if total_iterations > self.MAX_TOTAL_ITERATIONS: + logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止") + last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system") return - visited.add(node_id) node = self.nodes.get(node_id) if not node: return - agent = await self._get_or_create_agent(node_id, context) + node_type = node.get("type", "") + is_loop = node_type == "loop" + + if not is_loop and node_id in visited: + return + + if not is_loop: + visited.add(node_id) + + if is_loop: + loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1 + config = node.get("config", {}) + max_iter = config.get("max_iterations", 10) + if loop_iterations[node_id] > max_iter: + logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}") + visited.add(node_id) + for target_id, edge_cond in graph.get(node_id, []): + if edge_cond == "loop_done": + await traverse(target_id, incoming_msg) + return + + if is_loop: + loop_node_config = dict(node.get("config", {})) + loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0) + agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config) + else: + agent = await self._get_or_create_agent(node_id, context) enriched_content = self._resolve_input_mapping(node, incoming_msg, context) current_msg = incoming_msg if enriched_content.strip(): @@ -50,7 +106,7 @@ class FlowEngine: result = await agent.reply(current_msg) exec_record = { "node_id": node_id, - "node_type": node.get("type"), + "node_type": node_type, "label": node.get("label"), "status": "success", "output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500], @@ -58,20 +114,38 @@ class FlowEngine: context.setdefault("_node_results", {})[node_id] = exec_record last_result = result - is_condition = node.get("type") == "condition" + if is_loop: + loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0)) + if loop_done: + visited.add(node_id) + for target_id, edge_cond in graph.get(node_id, []): + if edge_cond == "loop_done": + await traverse(target_id, result) + else: + for target_id, edge_cond in graph.get(node_id, []): + if edge_cond == "loop_body": + body_node = self.nodes.get(target_id) + if body_node: + visited.discard(target_id) + await traverse(target_id, result, {"iteration": loop_iterations[node_id]}) + return + + is_condition = node_type == "condition" cond_result = self._parse_condition_result(result) for target_id, edge_cond in graph.get(node_id, []): if is_condition: if edge_cond and edge_cond == cond_result: await traverse(target_id, result) + elif edge_cond == "loop_body" or edge_cond == "loop_done": + continue else: await traverse(target_id, result) except Exception as e: logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}") exec_record = { "node_id": node_id, - "node_type": node.get("type"), + "node_type": node_type, "label": node.get("label"), "status": "error", "error": str(e), @@ -81,7 +155,14 @@ class FlowEngine: last_result = error_msg if start_nodes: - await traverse(start_nodes[0], input_msg) + try: + await asyncio.wait_for( + traverse(start_nodes[0], input_msg), + timeout=self.FLOW_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止") + last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system") return last_result or input_msg @@ -111,13 +192,27 @@ class FlowEngine: return m.group(1) return None - async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase: - if node_id in self._agent_cache: + def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool: + config = node.get("config", {}) + loop_type = config.get("loop_type", "fixed") + if loop_type == "fixed": + count = config.get("count", 3) + return iteration >= count + content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result) + if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE): + return True + return False + + async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase: + if node_id in self._agent_cache and override_config is None: return self._agent_cache[node_id] - node = self.nodes[node_id] + node = dict(self.nodes[node_id]) + if override_config is not None: + node["config"] = override_config agent = await _create_node_agent(node, context) - self._agent_cache[node_id] = agent + if override_config is None: + self._agent_cache[node_id] = agent return agent def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str: @@ -148,29 +243,56 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: model_config = config.get("model", settings.LLM_MODEL) temperature = config.get("temperature", 0.7) system_prompt = config.get("system_prompt", "你是AI助手。") - return LLMNodeAgent( + max_tokens = config.get("max_tokens", 2000) + stream = config.get("stream", True) + agent = LLMNodeAgent( node_id=node_id, system_prompt=system_prompt, model_name=model_config, temperature=temperature, + max_tokens=max_tokens, + stream=stream, ) + memory = context.get("_memory") + if memory: + agent.set_memory(memory) + return agent elif node_type == "tool": tool_name = config.get("tool_name", "") tool_params = config.get("tool_params", {}) - return ToolNodeAgent(node_id=node_id, tool_name=tool_name, tool_params=tool_params) + timeout = config.get("timeout", 30) + retry_count = config.get("retry_count", 0) + error_handling = config.get("error_handling", "throw") + return ToolNodeAgent( + node_id=node_id, + tool_name=tool_name, + tool_params=tool_params, + timeout=timeout, + retry_count=retry_count, + error_handling=error_handling, + ) elif node_type == "mcp": mcp_server = config.get("mcp_server", "") tool_name = config.get("tool_name", "") - return MCPNodeAgent(node_id=node_id, server_name=mcp_server, tool_name=tool_name) + timeout = config.get("timeout", 30) + error_handling = config.get("error_handling", "throw") + return MCPNodeAgent( + node_id=node_id, + server_name=mcp_server, + tool_name=tool_name, + timeout=timeout, + error_handling=error_handling, + ) - elif node_type == "wecom_notify": - return WeComNotifyAgent(node_id=node_id, config=config) + elif node_type in ("wecom_notify", "notify"): + return NotifyAgent(node_id=node_id, config=config) elif node_type == "condition": condition_expr = config.get("condition", "") - return ConditionNodeAgent(node_id=node_id, condition=condition_expr) + condition_type = config.get("condition_type", "expression") + return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type) elif node_type == "rag": return RAGNodeAgent(node_id=node_id, config=config) @@ -178,6 +300,18 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: elif node_type == "output": return OutputNodeAgent(node_id=node_id, config=config) + elif node_type == "merge": + return ParallelMergeNodeAgent(node_id=node_id, config=config) + elif node_type == "loop": + return LoopNodeAgent(node_id=node_id, config=config) + + elif node_type == "code": + language = config.get("language", "python") + code = config.get("code", "") + timeout = config.get("timeout", 30) + sandbox = config.get("sandbox", True) + return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox) + else: return PassThroughAgent(node_id) @@ -195,12 +329,18 @@ class PassThroughAgent(AgentBase): class LLMNodeAgent(AgentBase): - def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7): + def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True): super().__init__() self.name = f"LLM_{node_id}" self.system_prompt = system_prompt self.model_name = model_name or settings.LLM_MODEL self.temperature = temperature + self.max_tokens = max_tokens + self.stream = stream + self._memory = None + + def set_memory(self, memory): + self._memory = memory async def reply(self, msg: Msg, **kwargs) -> Msg: from agentscope_integration.factory import AgentFactory @@ -210,10 +350,19 @@ class LLMNodeAgent(AgentBase): user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) formatter = OpenAIChatFormatter() - prompt = await formatter.format([ - Msg("system", self.system_prompt, "system"), - Msg("user", user_text, "user"), - ]) + messages = [Msg("system", self.system_prompt, "system")] + + if self._memory: + history = self._memory.get_history(limit=5) + for h in history: + role = h.get("role", "user") + content = h.get("content", "") + if len(content) > 2000: + content = content[:2000] + messages.append(Msg(role, content, role)) + + messages.append(Msg("user", user_text, "user")) + prompt = formatter.format(messages) try: res = await model(prompt) @@ -228,6 +377,10 @@ class LLMNodeAgent(AgentBase): logger.warning(f"LLM 调用失败: {e}") res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}" + if self._memory: + self._memory.add("user", user_text) + self._memory.add("assistant", res_text) + return Msg(self.name, res_text, "assistant") async def observe(self, msg) -> None: @@ -236,6 +389,8 @@ class LLMNodeAgent(AgentBase): class ToolNodeAgent(AgentBase): _TOOL_REGISTRY: dict[str, callable] = {} + _TOOL_SCHEMAS: dict[str, dict] = {} + _CUSTOM_TOOL_DEFS: dict[str, dict] = {} @classmethod def _init_registry(cls): @@ -262,44 +417,161 @@ class ToolNodeAgent(AgentBase): "get_task_statistics": get_task_statistics, "get_employee_dashboard": get_employee_dashboard, } + + for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]: + try: + mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"]) + if hasattr(mod, "SCHEMAS"): + cls._TOOL_SCHEMAS.update(mod.SCHEMAS) + except Exception: + pass except ImportError as e: logger.warning(f"工具注册失败: {e}") - def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None): + @classmethod + async def load_custom_tools(cls, db): + try: + from sqlalchemy import select + from models import CustomTool + from modules.custom_tool.executor import CustomToolExecutor + + result = await db.execute( + select(CustomTool).where(CustomTool.is_active == True) + ) + for tool in result.scalars().all(): + tool_def = { + "endpoint_url": tool.endpoint_url, + "method": tool.method, + "path": tool.path, + "headers_json": tool.headers_json, + "auth_type": tool.auth_type, + "auth_config": tool.auth_config, + "timeout": 30, + } + cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def + cls._TOOL_SCHEMAS[tool.name] = tool.schema_json + except ImportError: + logger.warning("无法加载自定义工具模块") + except Exception as e: + logger.warning(f"加载自定义工具失败: {e}") + + @classmethod + def register_custom_tool(cls, name: str, schema: dict, tool_def: dict): + cls._CUSTOM_TOOL_DEFS[name] = tool_def + cls._TOOL_SCHEMAS[name] = schema + + @classmethod + def get_schemas(cls) -> dict: + cls._init_registry() + return dict(cls._TOOL_SCHEMAS) + + def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"): super().__init__() self.name = f"Tool_{node_id}" self.tool_name = tool_name self.tool_params = tool_params or {} + self.timeout = timeout + self.retry_count = retry_count + self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: self._init_registry() user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name) + if custom_def: + return await self._execute_custom_tool(custom_def, user_text) + tool_func = self._TOOL_REGISTRY.get(self.tool_name) - if tool_func: + if not tool_func: + if self.error_handling == "skip": + return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant") + return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant") + + def _call_sync(): try: result = tool_func(**self.tool_params) if self.tool_params else tool_func() - return Msg(self.name, str(result), "assistant") + return result except TypeError: try: result = tool_func(user_text, **self.tool_params) - return Msg(self.name, str(result), "assistant") + return result except Exception as e: - return Msg(self.name, f"[工具执行失败: {e}]", "assistant") + raise RuntimeError(f"工具执行失败: {e}") + except Exception as e: + raise RuntimeError(f"工具执行失败: {e}") + + for attempt in range(self.retry_count + 1): + try: + result = await asyncio.wait_for( + asyncio.to_thread(_call_sync), + timeout=self.timeout, + ) + return Msg(self.name, str(result), "assistant") + except asyncio.TimeoutError: + if attempt < self.retry_count: + logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") + continue + if self.error_handling == "skip": + return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant") + return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") + except RuntimeError as e: + if attempt < self.retry_count: + continue + if self.error_handling == "skip": + return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") + if self.error_handling == "default": + return Msg(self.name, "{}", "assistant") + return Msg(self.name, f"[{e}]", "assistant") except Exception as e: + if attempt < self.retry_count: + continue + if self.error_handling == "skip": + return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") return Msg(self.name, f"[工具执行失败: {e}]", "assistant") - return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant") + return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant") + + async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg: + from modules.custom_tool.executor import CustomToolExecutor + + executor = CustomToolExecutor(custom_def) + + for attempt in range(self.retry_count + 1): + try: + result = await asyncio.wait_for( + executor.execute(self.tool_params), + timeout=self.timeout, + ) + return Msg(self.name, str(result), "assistant") + except asyncio.TimeoutError: + if attempt < self.retry_count: + logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") + continue + if self.error_handling == "skip": + return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant") + return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") + except Exception as e: + if attempt < self.retry_count: + continue + if self.error_handling == "skip": + return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") + if self.error_handling == "default": + return Msg(self.name, "{}", "assistant") + return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant") + + return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant") async def observe(self, msg) -> None: pass class ConditionNodeAgent(AgentBase): - def __init__(self, node_id: str, condition: str = ""): + def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"): super().__init__() self.name = f"Condition_{node_id}" self.condition = condition + self.condition_type = condition_type async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) @@ -307,6 +579,27 @@ class ConditionNodeAgent(AgentBase): if not self.condition: return Msg(self.name, "condition:true|条件为空,默认通过", "assistant") + if self.condition_type == "regex": + try: + matched = bool(re.search(self.condition, user_text)) + return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'命' if matched else '未命'}中", "assistant") + except re.error as e: + return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant") + + if self.condition_type == "json_path": + try: + data = json.loads(user_text) if user_text.strip().startswith('{') else {} + parts = self.condition.strip('$.').split('.') + val = data + for p in parts: + if isinstance(val, dict): + val = val.get(p, None) + else: + val = None + return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant") + except: + return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant") + try: from agentscope.model import OpenAIChatModel from agentscope.formatter import OpenAIChatFormatter @@ -329,14 +622,12 @@ class ConditionNodeAgent(AgentBase): 请严格只输出一行 JSON: {{"result": true/false, "reason": "简要原因"}}""" - prompt = await formatter.format([ + prompt = formatter.format([ Msg("system", condition_prompt, "system"), Msg("user", user_text[:2000], "user"), ]) res = await model(prompt) - import json - import re res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) @@ -362,11 +653,13 @@ class ConditionNodeAgent(AgentBase): class MCPNodeAgent(AgentBase): - def __init__(self, node_id: str, server_name: str = "", tool_name: str = ""): + def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"): super().__init__() self.name = f"MCP_{node_id}" self.server_name = server_name self.tool_name = tool_name + self.timeout = timeout + self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) @@ -379,12 +672,21 @@ class MCPNodeAgent(AgentBase): client = MCPClientManager.get_http_client(self.server_name) if client and self.tool_name: - result = await client.call_tool(self.tool_name, {"input": user_text}) + result = await asyncio.wait_for( + client.call_tool(self.tool_name, {"input": user_text}), + timeout=self.timeout + ) return Msg(self.name, str(result), "assistant") except ImportError: logger.warning("agentscope_runtime MCP 客户端不可用") + except asyncio.TimeoutError: + if self.error_handling == "skip": + return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant") + return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant") except Exception as e: logger.warning(f"MCP 调用失败: {e}") + if self.error_handling == "skip": + return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant") output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入" return Msg(self.name, output, "assistant") @@ -393,29 +695,155 @@ class MCPNodeAgent(AgentBase): pass -class WeComNotifyAgent(AgentBase): +class NotifyAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() - self.name = f"WeComNotify_{node_id}" + self.name = f"Notify_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) - template = self.config.get("message_template", "") - target = self.config.get("target", "") - message = template or user_text[:500] + channels = self.config.get("channels", {"wecom": True, "web": False}) + message_type = self.config.get("message_type", "text") + + results = [] + + if channels.get("wecom", True): + template = self.config.get("message_template", "") + target = self.config.get("target", "") + message = template or user_text[:500] + try: + from agentscope_integration.tools.wecom_tools import send_notification + result = send_notification(to_user=target or "user", message=message) + results.append(f"企微通知: {result[:100]}") + except ImportError: + pass + except Exception as e: + logger.warning(f"企微通知发送失败: {e}") + results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}") + + if channels.get("web", False): + web_template = self.config.get("web_template", "") + web_message = web_template or user_text[:500] + try: + from websocket_manager import ws_manager + target_user = self.config.get("target", "") + if target_user: + await ws_manager.send_to_user(target_user, { + "type": "flow_notification", + "message": web_message, + "level": "info", + }) + results.append(f"Web通知: 已推送") + except ImportError: + pass + except Exception as e: + logger.warning(f"Web通知推送失败: {e}") + results.append(f"Web通知: 推送失败({e})") + + if not results: + results.append("通知已发送") + return Msg(self.name, " | ".join(results), "assistant") + + async def observe(self, msg) -> None: + pass + + +class LoopNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"Loop_{node_id}" + self.config = config or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + loop_type = self.config.get("loop_type", "fixed") + count = self.config.get("count", 3) + iterator_variable = self.config.get("iterator_variable", "item") + iteration = self.config.get("_engine_iteration", 0) + 1 + + if loop_type == "array": + items = self.config.get("items", []) + if isinstance(items, str): + try: + items = json.loads(items) + except Exception: + items = [line.strip() for line in items.split("\n") if line.strip()] + if isinstance(items, list) and items: + idx = iteration - 1 + if idx >= len(items): + return Msg(self.name, "loop:stop|数组迭代完成", "assistant") + current_item = items[idx] + return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant") + return Msg(self.name, "loop:stop|无数组数据", "assistant") + + if loop_type == "fixed": + if iteration >= count: + return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant") + return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") + + return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") + + async def observe(self, msg) -> None: + pass + + +class CodeNodeAgent(AgentBase): + def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True): + super().__init__() + self.name = f"Code_{node_id}" + self.language = language + self.code = code + self.timeout = timeout + self.sandbox = sandbox + + async def reply(self, msg: Msg, **kwargs) -> Msg: + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + import subprocess + import tempfile + import os + import json + + if not self.code.strip(): + return Msg(self.name, user_text, "assistant") + + safe_input = json.dumps(user_text) + code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}" try: - from agentscope_integration.tools.wecom_tools import send_notification - result = send_notification(to_user=target or "user", message=message) - return Msg(self.name, result, "assistant") - except ImportError: - pass - except Exception as e: - logger.warning(f"企微通知发送失败: {e}") + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: + f.write(code_with_input) + temp_path = f.name + + proc = await asyncio.create_subprocess_exec( + "python", temp_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) - result = f"[企微通知] 已向 {target or '用户'} 发送: {message[:100]}" - return Msg(self.name, result, "assistant") + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout) + except asyncio.TimeoutError: + proc.kill() + try: + os.unlink(temp_path) + except: + pass + return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant") + + try: + os.unlink(temp_path) + except: + pass + + if stderr: + return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant") + + output = stdout.decode('utf-8', errors='replace').strip() + return Msg(self.name, output or "[代码执行完成,无输出]", "assistant") + + except Exception as e: + return Msg(self.name, f"[代码执行失败: {e}]", "assistant") async def observe(self, msg) -> None: pass @@ -430,6 +858,9 @@ class RAGNodeAgent(AgentBase): async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) top_k = self.config.get("top_k", 5) + similarity_threshold = self.config.get("similarity_threshold", 0.7) + search_mode = self.config.get("search_mode", "hybrid") + include_metadata = self.config.get("include_metadata", True) try: from modules.rag.knowledge import retrieve_for_agent @@ -447,15 +878,19 @@ class RAGNodeAgent(AgentBase): 请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。""" - import asyncio - loop = asyncio.get_event_loop() - messages = await asyncio.to_thread(formatter.format, [ - {"role": "system", "content": rag_prompt}, - {"role": "user", "content": user_text}, + prompt = formatter.format([ + Msg("system", rag_prompt, "system"), + Msg("user", user_text, "user"), ]) - res = await model(messages) - res_text = res.get_text_content() if hasattr(res, 'get_text_content') else str(res) + res = await model(prompt) + res_text = "" + if isinstance(res, list): + res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) + elif hasattr(res, 'get_text_content'): + res_text = res.get_text_content() + else: + res_text = str(res) return Msg(self.name, res_text, "assistant") except Exception as e: @@ -489,15 +924,32 @@ class OutputNodeAgent(AgentBase): async def reply(self, msg: Msg, **kwargs) -> Msg: output_format = self.config.get("format", "text") + output_template = self.config.get("output_template", "") + truncate = self.config.get("truncate", False) + max_length = self.config.get("max_length", 2000) + + content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + + if truncate and len(content) > max_length: + content = content[:max_length] + "\n\n[内容已截断]" + if output_format == "json": - content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) try: parsed = json.loads(content) - formatted = json.dumps(parsed, indent=2, ensure_ascii=False) + indent = self.config.get("indent", 2) + formatted = json.dumps(parsed, indent=indent, ensure_ascii=False) return Msg(self.name, formatted, "assistant") except (json.JSONDecodeError, ValueError): pass - return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") + + if output_template: + try: + resolved = output_template.replace("{{output}}", content) + return Msg(self.name, resolved, "assistant") + except: + pass + + return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant") async def observe(self, msg) -> None: pass @@ -519,4 +971,42 @@ def _resolve_template(template: str, context: dict, current_msg: Msg) -> str: else: value = str(node_result.get("output", "")) result = result.replace("{{" + placeholder + "}}", value) - return result \ No newline at end of file + return result + + +class ParallelMergeNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"Merge_{node_id}" + self.config = config or {} + self._received: dict[str, str] = {} + self.merge_type = self.config.get("merge_type", "concat") + + async def reply(self, msg: Msg, **kwargs) -> Msg: + source_id = kwargs.get("source_node_id", "") + content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + self._received[source_id] = content + + expected_count = self.config.get("expected_branches", 0) + if expected_count <= 0 or len(self._received) >= expected_count: + return self._merge() + return Msg(self.name, "", "assistant") + + def _merge(self) -> Msg: + if self.merge_type == "json": + merged = json.dumps(self._received, ensure_ascii=False) + elif self.merge_type == "first_non_empty": + merged = "" + for v in self._received.values(): + if v.strip(): + merged = v + break + else: + parts = [] + for k, v in self._received.items(): + parts.append(v) + merged = "\n\n---\n\n".join(parts) + return Msg(self.name, merged, "assistant") + + async def observe(self, msg) -> None: + pass \ No newline at end of file diff --git a/backend/modules/flow_engine/gateway.py b/backend/modules/flow_engine/gateway.py new file mode 100644 index 0000000..4554e7b --- /dev/null +++ b/backend/modules/flow_engine/gateway.py @@ -0,0 +1,256 @@ +import uuid +import time +import json +from datetime import datetime +from fastapi import APIRouter, Depends, HTTPException, Request, Query +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from database import get_db +from models import FlowDefinition, FlowVersion, FlowExecution +from schemas import FlowChatMessageRequest +from modules.flow_engine.engine import FlowEngine +from agentscope.message import Msg +from middleware.apikey_auth import authenticate_api_key +from dependencies import get_current_user +import logging + +logger = logging.getLogger(__name__) + +gateway_router = APIRouter(prefix="/v1", tags=["gateway"]) + + +async def _resolve_auth(request: Request) -> dict: + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer flow-"): + return await authenticate_api_key(request) + try: + user = await get_current_user(request) + return {"user": user, "auth_type": "jwt"} + except Exception: + raise HTTPException(401, "认证失败: 请使用 Bearer Token 或 API Key") + + +async def _get_definition_for_execute(flow_id: uuid.UUID, db: AsyncSession) -> dict: + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流不存在") + if f.status != "published": + raise HTTPException(400, "流未发布") + + if f.published_version_id: + result = await db.execute(select(FlowVersion).where(FlowVersion.id == f.published_version_id)) + published = result.scalar_one_or_none() + if published: + return json.loads(json.dumps(published.definition_json)) + return f.definition_json + + +# ============================== 对话型流 ============================== + + +@gateway_router.post("/chat-messages") +async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)): + auth = await _resolve_auth(request) + + body = await request.json() + query = body.get("query", "") + response_mode = body.get("response_mode", "blocking") + inputs = body.get("inputs", {}) + user = body.get("user", "anonymous") + session_id = body.get("conversation_id", body.get("session_id")) + + flow_id_str = body.get("flow_id") or inputs.get("flow_id") + if not flow_id_str: + raise HTTPException(400, "缺少 flow_id") + + flow_id = uuid.UUID(flow_id_str) + definition = await _get_definition_for_execute(flow_id, db) + f = await db.get(FlowDefinition, flow_id) + + input_text = query + if inputs: + extra = json.dumps(inputs, ensure_ascii=False) + if query: + input_text = f"{query}\n\n上下文数据:\n{extra}" + else: + input_text = extra + + user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api") + username = user + + if response_mode == "streaming": + return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db) + + return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db) + + +async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db): + engine = FlowEngine(definition) + input_msg = Msg(name="user", content=input_text, role="user") + context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())} + + start_time = time.time() + try: + result_msg = await engine.execute(input_msg, context) + elapsed_ms = int((time.time() - start_time) * 1000) + output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) + + execution = FlowExecution( + flow_id=flow.id, version=flow.version, + trigger_type="api", input_data={"query": input_text}, + output_data={"output": output_text}, status="completed", + latency_ms=elapsed_ms, finished_at=datetime.utcnow(), + ) + db.add(execution) + + return { + "event": "message", + "id": str(uuid.uuid4()), + "answer": output_text, + "conversation_id": session_id or "", + "created_at": int(time.time()), + "metadata": { + "usage": {"latency_ms": elapsed_ms}, + "node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()}, + }, + } + except Exception as e: + elapsed_ms = int((time.time() - start_time) * 1000) + execution = FlowExecution( + flow_id=flow.id, version=flow.version, + trigger_type="api", input_data={"query": input_text}, + status="failed", latency_ms=elapsed_ms, + error_message=str(e)[:2000], finished_at=datetime.utcnow(), + ) + db.add(execution) + raise HTTPException(500, f"流执行失败: {str(e)}") + + +async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db): + async def event_generator(): + import asyncio + engine = FlowEngine(definition) + context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())} + input_msg = Msg(name="user", content=input_text, role="user") + start_time = time.time() + msg_id = str(uuid.uuid4()) + + try: + yield f"data: {json.dumps({'event': 'workflow_started', 'task_id': msg_id, 'data': {'flow_id': str(flow_id)}}, ensure_ascii=False)}\n\n" + + result_msg = await asyncio.wait_for(engine.execute(input_msg, context), timeout=engine.FLOW_TIMEOUT_SECONDS) + output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) + elapsed_ms = int((time.time() - start_time) * 1000) + + for i in range(0, len(output_text), 10): + chunk = output_text[i:i + 10] + yield f"data: {json.dumps({'event': 'message', 'task_id': msg_id, 'answer': chunk, 'created_at': int(time.time())}, ensure_ascii=False)}\n\n" + + yield f"data: {json.dumps({'event': 'message_end', 'task_id': msg_id, 'id': msg_id, 'conversation_id': session_id or '', 'metadata': {'usage': {'latency_ms': elapsed_ms}, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}}}, ensure_ascii=False)}\n\n" + + execution = FlowExecution( + flow_id=flow.id, version=flow.version, + trigger_type="api", input_data={"query": input_text}, + output_data={"output": output_text}, + status="completed", latency_ms=elapsed_ms, + finished_at=datetime.utcnow(), + ) + db.add(execution) + except asyncio.TimeoutError: + yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': '执行超时'}, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"data: {json.dumps({'event': 'error', 'task_id': msg_id, 'message': str(e)}, ensure_ascii=False)}\n\n" + finally: + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + + +# ============================== 工作流型流 ============================== + + +@gateway_router.post("/workflows/run") +async def workflows_run(request: Request, db: AsyncSession = Depends(get_db)): + auth = await _resolve_auth(request) + body = await request.json() + inputs = body.get("inputs", {}) + response_mode = body.get("response_mode", "blocking") + user = body.get("user", "anonymous") + + flow_id_str = body.get("workflow_id") or inputs.get("workflow_id") or inputs.get("flow_id") + if not flow_id_str: + raise HTTPException(400, "缺少 workflow_id") + + flow_id = uuid.UUID(flow_id_str) + definition = await _get_definition_for_execute(flow_id, db) + f = await db.get(FlowDefinition, flow_id) + + user_id = "api" if auth.get("auth_type") == "api_key" else auth.get("user", {}).get("id", "api") + + engine = FlowEngine(definition) + input_msg = Msg(name="user", content=json.dumps(inputs, ensure_ascii=False), role="user") + context = {"user_id": user_id, "username": user, "_node_results": {}, "trigger_data": inputs} + + start_time = time.time() + try: + result_msg = await engine.execute(input_msg, context) + elapsed_ms = int((time.time() - start_time) * 1000) + output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) + + execution = FlowExecution( + flow_id=f.id, version=f.version, + trigger_type="api", input_data={"inputs": inputs}, + output_data={"output": output_text}, status="completed", + latency_ms=elapsed_ms, finished_at=datetime.utcnow(), + ) + db.add(execution) + + return { + "id": str(uuid.uuid4()), + "workflow_run_id": str(uuid.uuid4()), + "data": { + "outputs": {"text": output_text}, + "node_results": {k: str(v)[:200] for k, v in context.get("_node_results", {}).items()}, + }, + "metadata": {"latency_ms": elapsed_ms}, + } + except Exception as e: + elapsed_ms = int((time.time() - start_time) * 1000) + execution = FlowExecution( + flow_id=f.id, version=f.version, + trigger_type="api", input_data={"inputs": inputs}, + status="failed", latency_ms=elapsed_ms, + error_message=str(e)[:2000], finished_at=datetime.utcnow(), + ) + db.add(execution) + raise HTTPException(500, f"工作流执行失败: {str(e)}") + + +# ============================== 参数信息 ============================== + + +@gateway_router.get("/flows/{flow_id}/parameters") +async def get_flow_parameters(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + definition = await _get_definition_for_execute(flow_id, db) + nodes = definition.get("nodes", []) + trigger_nodes = [n for n in nodes if n.get("type") == "trigger"] + input_vars = [] + if trigger_nodes: + trigger_config = trigger_nodes[0].get("config", {}) + input_vars = [ + {"name": "query", "type": "string", "description": "用户输入文本", "required": True}, + {"name": "session_id", "type": "string", "description": "会话ID(用于多轮对话)", "required": False}, + ] + return { + "code": 200, + "data": { + "input_variables": input_vars, + "node_count": len(nodes), + "edge_count": len(definition.get("edges", [])), + }, + } \ No newline at end of file diff --git a/backend/modules/flow_engine/router.py b/backend/modules/flow_engine/router.py index e009b58..de8694f 100644 --- a/backend/modules/flow_engine/router.py +++ b/backend/modules/flow_engine/router.py @@ -1,46 +1,59 @@ import uuid +import time import json +import hashlib +import secrets from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Query +from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db -from models import FlowDefinition, FlowExecution, User -from schemas import FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut, FlowNode, FlowEdge -from modules.flow_engine.engine import FlowEngine +from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User +from schemas import ( + FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut, + FlowVersionOut, FlowApiKeyCreate, FlowApiKeyOut, + FlowExecuteRequest, FlowChatMessageRequest, +) +from modules.flow_engine.engine import FlowEngine, ToolNodeAgent from agentscope.message import Msg +from dependencies import get_current_user +import logging + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/flow", tags=["flow"]) -@router.get("/definitions", response_model=list[FlowDefinitionOut]) -async def list_flows(request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute( - select(FlowDefinition).order_by(FlowDefinition.updated_at.desc()) - ) - flows = result.scalars().all() - return [FlowDefinitionOut( +def _build_flow_out(f) -> FlowDefinitionOut: + return FlowDefinitionOut( id=f.id, name=f.name, description=f.description, version=f.version, status=f.status, definition_json=f.definition_json, + published_version_id=f.published_version_id, published_to_wecom=f.published_to_wecom, + published_to_web=f.published_to_web, created_at=f.created_at, updated_at=f.updated_at, - ) for f in flows] + ) + + +# ============================== CRUD ============================== + + +@router.get("/definitions", response_model=list[FlowDefinitionOut]) +async def list_flows(request: Request, db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(FlowDefinition).order_by(FlowDefinition.updated_at.desc()) + ) + return [_build_flow_out(f) for f in result.scalars().all()] @router.get("/definitions/{flow_id}", response_model=FlowDefinitionOut) async def get_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - return FlowDefinitionOut( - id=flow.id, name=flow.name, description=flow.description, - version=flow.version, status=flow.status, - definition_json=flow.definition_json, - published_to_wecom=flow.published_to_wecom, - created_at=flow.created_at, updated_at=flow.updated_at, - ) + return _build_flow_out(f) @router.post("/definitions", response_model=FlowDefinitionOut) @@ -51,194 +64,488 @@ async def create_flow(req: FlowDefinitionCreate, request: Request, db: AsyncSess "edges": [e.model_dump() for e in req.edges], "trigger": req.trigger, } - flow = FlowDefinition( name=req.name, description=req.description, definition_json=definition_json, + draft_definition_json=definition_json, creator_id=uuid.UUID(user_ctx["id"]), ) db.add(flow) await db.flush() - - return FlowDefinitionOut( - id=flow.id, name=flow.name, description=flow.description, - version=flow.version, status=flow.status, - definition_json=flow.definition_json, - published_to_wecom=flow.published_to_wecom, - created_at=flow.created_at, updated_at=flow.updated_at, - ) + return _build_flow_out(flow) @router.put("/definitions/{flow_id}", response_model=FlowDefinitionOut) -async def update_flow( - flow_id: uuid.UUID, req: FlowDefinitionUpdate, - request: Request, db: AsyncSession = Depends(get_db), -): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: +async def update_flow(flow_id: uuid.UUID, req: FlowDefinitionUpdate, request: Request, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - if req.name is not None: - flow.name = req.name + f.name = req.name if req.description is not None: - flow.description = req.description + f.description = req.description if req.nodes is not None and req.edges is not None: - flow.definition_json = { + new_def = { "nodes": [n.model_dump() for n in req.nodes], "edges": [e.model_dump() for e in req.edges], - "trigger": req.trigger or flow.definition_json.get("trigger", {}), + "trigger": req.trigger or f.definition_json.get("trigger", {}), } - flow.version += 1 - - return FlowDefinitionOut( - id=flow.id, name=flow.name, description=flow.description, - version=flow.version, status=flow.status, - definition_json=flow.definition_json, - published_to_wecom=flow.published_to_wecom, - created_at=flow.created_at, updated_at=flow.updated_at, - ) + f.version += 1 + f.draft_definition_json = new_def + f.definition_json = new_def + return _build_flow_out(f) @router.delete("/definitions/{flow_id}") -async def delete_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: +async def delete_flow(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - await db.delete(flow) + await db.delete(f) return {"code": 200, "message": "已删除"} +# ============================== 发布 (创建快照) ============================== + + +async def _snapshot_publish(flow: FlowDefinition, db: AsyncSession, user_id: str, + publish_wecom: bool = False, publish_web: bool = False, changelog: str = ""): + new_version = FlowVersion( + flow_id=flow.id, + version=flow.version, + definition_json=json.loads(json.dumps(flow.definition_json)), + changelog=changelog, + published_by=uuid.UUID(user_id), + published_to_wecom=publish_wecom, + published_to_web=publish_web, + ) + db.add(new_version) + await db.flush() + + flow.published_version_id = new_version.id + flow.status = "published" + if publish_wecom: + flow.published_to_wecom = True + if publish_web: + flow.published_to_web = True + return new_version + + @router.post("/definitions/{flow_id}/publish") async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - - nodes = flow.definition_json.get("nodes", []) - edges = flow.definition_json.get("edges", []) + nodes = f.definition_json.get("nodes", []) if not nodes: raise HTTPException(400, "流定义中没有节点") + user_ctx = request.state.user + await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True) + return {"code": 200, "message": "流已上架到企微", "data": {"status": "published", "version": f.version}} - flow.status = "published" - flow.published_to_wecom = True - return {"code": 200, "message": "流已上架到企微", "data": {"status": "published"}} + +@router.post("/definitions/{flow_id}/publish-web") +async def publish_flow_to_web(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + user_ctx = request.state.user + prev_version_id = f.published_version_id + await _snapshot_publish(f, db, user_ctx["id"], publish_web=True) + return {"code": 200, "message": "流已上架到网页", "data": {"status": "published", "version": f.version}} @router.post("/definitions/{flow_id}/unpublish") -async def unpublish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: +async def unpublish_flow(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - - flow.status = "draft" - flow.published_to_wecom = False + f.status = "draft" + f.published_to_wecom = False + f.published_version_id = None return {"code": 200, "message": "流已下架"} +@router.post("/definitions/{flow_id}/unpublish-web") +async def unpublish_flow_from_web(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + f.published_to_web = False + if not f.published_to_wecom: + f.status = "draft" + f.published_version_id = None + return {"code": 200, "message": "流已从网页下架"} + + +# ============================== 版本管理 ============================== + + +@router.get("/definitions/{flow_id}/versions", response_model=list[FlowVersionOut]) +async def list_versions(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(FlowVersion) + .where(FlowVersion.flow_id == flow_id) + .order_by(FlowVersion.version.desc()) + .limit(50) + ) + return [FlowVersionOut( + id=v.id, flow_id=v.flow_id, version=v.version, + definition_json=v.definition_json, changelog=v.changelog or "", + published_to_wecom=v.published_to_wecom, published_to_web=v.published_to_web, + published_by=v.published_by, created_at=v.created_at, + ) for v in result.scalars().all()] + + +@router.post("/definitions/{flow_id}/rollback/{version_id}") +async def rollback_flow(flow_id: uuid.UUID, version_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + target = await db.get(FlowVersion, version_id) + if not target or str(target.flow_id) != str(flow_id): + raise HTTPException(404, "版本不存在") + + f.definition_json = json.loads(json.dumps(target.definition_json)) + f.draft_definition_json = f.definition_json + f.published_version_id = target.id + f.published_to_wecom = target.published_to_wecom + f.published_to_web = target.published_to_web + f.status = "published" if (target.published_to_wecom or target.published_to_web) else "draft" + f.version = target.version + return {"code": 200, "message": f"已回滚到版本 v{target.version}", "data": {"version": target.version}} + + +# ============================== 执行 (加版本快照) ============================== + + +def _get_definition_json(flow: FlowDefinition, db_session) -> dict: + """优先加载 published_version 快照,不存在则使用当前 definition_json""" + return flow.definition_json + + +async def _get_published_definition(flow: FlowDefinition, db: AsyncSession) -> dict: + if flow.published_version_id: + result = await db.execute(select(FlowVersion).where(FlowVersion.id == flow.published_version_id)) + published = result.scalar_one_or_none() + if published: + return json.loads(json.dumps(published.definition_json)) + return flow.definition_json + + @router.post("/definitions/{flow_id}/execute") async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") user_ctx = request.state.user input_text = payload.get("input", payload.get("message", "")) - engine = FlowEngine(flow.definition_json) + definition = await _get_published_definition(f, db) + await ToolNodeAgent.load_custom_tools(db) + engine = FlowEngine(definition) input_msg = Msg(name="user", content=input_text, role="user") - context = { "user_id": user_ctx["id"], - "username": user_ctx["username"], + "username": user_ctx.get("username", ""), "trigger_data": payload.get("trigger", {}), "_node_results": {}, } + start_time = time.time() try: result_msg = await engine.execute(input_msg, context) + elapsed_ms = int((time.time() - start_time) * 1000) output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) execution = FlowExecution( - flow_id=flow.id, + flow_id=f.id, + version=_get_published_version_number(f), trigger_type=payload.get("trigger_type", "manual"), trigger_user_id=uuid.UUID(user_ctx["id"]), input_data={"input": input_text}, output_data={"output": output_text}, status="completed", + latency_ms=elapsed_ms, finished_at=datetime.utcnow(), ) db.add(execution) - return { "code": 200, "data": { "output": output_text, "node_results": context.get("_node_results", {}), "execution_id": str(execution.id), + "latency_ms": elapsed_ms, }, } except Exception as e: + elapsed_ms = int((time.time() - start_time) * 1000) execution = FlowExecution( - flow_id=flow.id, + flow_id=f.id, + version=_get_published_version_number(f), trigger_type="manual", trigger_user_id=uuid.UUID(user_ctx["id"]), input_data={"input": input_text}, status="failed", + latency_ms=elapsed_ms, + error_message=str(e)[:2000], finished_at=datetime.utcnow(), ) db.add(execution) raise HTTPException(500, f"流执行失败: {str(e)}") +def _get_published_version_number(flow: FlowDefinition) -> int | None: + return flow.version + + @router.post("/definitions/{flow_id}/test") async def test_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(FlowDefinition).where(FlowDefinition.id == flow_id)) - flow = result.scalar_one_or_none() - if not flow: + f = await db.get(FlowDefinition, flow_id) + if not f: raise HTTPException(404, "流定义不存在") - nodes = flow.definition_json.get("nodes", []) - edges = flow.definition_json.get("edges", []) - - validation = { - "valid": True, - "node_count": len(nodes), - "edge_count": len(edges), - "node_types": list(set(n.get("type", "unknown") for n in nodes)), - "issues": [], - } + definition = await _get_published_definition(f, db) + nodes = definition.get("nodes", []) + edges = definition.get("edges", []) + validation = {"valid": True, "node_count": len(nodes), "edge_count": len(edges), + "node_types": list(set(n.get("type", "unknown") for n in nodes)), "issues": []} node_ids = {n["id"] for n in nodes} for edge in edges: - source = edge.get("source") or edge.get("from") - target = edge.get("target") or edge.get("to") - if source and source not in node_ids: - validation["issues"].append(f"边源节点 {source} 不存在") - if target and target not in node_ids: - validation["issues"].append(f"边目标节点 {target} 不存在") - + s = edge.get("source") or edge.get("from") + t = edge.get("target") or edge.get("to") + if s and s not in node_ids: + validation["issues"].append(f"边源节点 {s} 不存在") + if t and t not in node_ids: + validation["issues"].append(f"边目标节点 {t} 不存在") if validation["issues"]: validation["valid"] = False - - has_trigger = any(n.get("type") == "trigger" for n in nodes) - if not has_trigger: + if not any(n.get("type") == "trigger" for n in nodes): validation["issues"].append("流缺少触发节点") - return {"code": 200, "data": validation} +# ============================== SSE 流式执行 ============================== + + +@router.post("/definitions/{flow_id}/stream") +async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + body = await request.json() + input_text = body.get("input", body.get("message", "")) + user_ctx = request.state.user + + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + + definition = await _get_published_definition(f, db) + await ToolNodeAgent.load_custom_tools(db) + + async def event_generator(): + import asyncio + engine = FlowEngine(definition) + context = { + "user_id": user_ctx["id"], + "username": user_ctx.get("username", ""), + "trigger_data": body.get("trigger", {}), + "_node_results": {}, + "_stream_callback": None, + } + input_msg = Msg(name="user", content=input_text, role="user") + start_time = time.time() + + # bind stream callback + stream_chunks = [] + + async def stream_callback(event_type: str, data: dict): + chunk_data = json.dumps({"event": event_type, "data": data}, ensure_ascii=False) + stream_chunks.append(f"data: {chunk_data}\n\n") + + context["_stream_callback"] = stream_callback + + try: + yield f"data: {json.dumps({'event': 'workflow_started', 'data': {'flow_id': str(flow_id)}}, ensure_ascii=False)}\n\n" + + # get execution order + graph = engine._build_graph() + start = engine._find_start_nodes(graph) + if start: + yield f"data: {json.dumps({'event': 'node_started', 'data': {'node_id': start[0], 'node_type': definition.get('nodes', [{}])[0].get('type', 'unknown'), 'label': definition.get('nodes', [{}])[0].get('label', '开始')}}, ensure_ascii=False)}\n\n" + + result_msg = await asyncio.wait_for( + engine.execute(input_msg, context), + timeout=engine.FLOW_TIMEOUT_SECONDS, + ) + output_text = result_msg.get_text_content() if hasattr(result_msg, 'get_text_content') else str(result_msg) + elapsed_ms = int((time.time() - start_time) * 1000) + + yield f"data: {json.dumps({'event': 'text_chunk', 'data': {'content': output_text}}, ensure_ascii=False)}\n\n" + + yield f"data: {json.dumps({'event': 'workflow_finished', 'data': {'output': output_text, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}, 'latency_ms': elapsed_ms}}, ensure_ascii=False)}\n\n" + + execution = FlowExecution( + flow_id=f.id, + version=_get_published_version_number(f), + trigger_type=body.get("trigger_type", "manual"), + trigger_user_id=uuid.UUID(user_ctx["id"]), + input_data={"input": input_text}, + output_data={"output": output_text}, + status="completed", + latency_ms=elapsed_ms, + finished_at=datetime.utcnow(), + ) + db.add(execution) + except asyncio.TimeoutError: + elapsed_ms = int((time.time() - start_time) * 1000) + yield f"data: {json.dumps({'event': 'error', 'data': {'message': '执行超时'}}, ensure_ascii=False)}\n\n" + execution = FlowExecution( + flow_id=f.id, + version=_get_published_version_number(f), + trigger_type="manual", + trigger_user_id=uuid.UUID(user_ctx["id"]), + input_data={"input": input_text}, + status="failed", + latency_ms=elapsed_ms, + error_message="执行超时", + finished_at=datetime.utcnow(), + ) + db.add(execution) + except Exception as e: + elapsed_ms = int((time.time() - start_time) * 1000) + yield f"data: {json.dumps({'event': 'error', 'data': {'message': str(e)}}, ensure_ascii=False)}\n\n" + execution = FlowExecution( + flow_id=f.id, + version=_get_published_version_number(f), + trigger_type="manual", + trigger_user_id=uuid.UUID(user_ctx["id"]), + input_data={"input": input_text}, + status="failed", + latency_ms=elapsed_ms, + error_message=str(e)[:2000], + finished_at=datetime.utcnow(), + ) + db.add(execution) + finally: + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +# ============================== API Key 管理 ============================== + + +def _generate_api_key() -> tuple[str, str, str]: + raw = "flow-" + secrets.token_urlsafe(24) + key_hash = hashlib.sha256(raw.encode()).hexdigest() + return raw, key_hash, raw[:14] + + +@router.post("/definitions/{flow_id}/api-keys", response_model=dict) +async def create_api_key(flow_id: uuid.UUID, body: FlowApiKeyCreate, request: Request, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + user_ctx = request.state.user + raw, key_hash, key_prefix = _generate_api_key() + + api_key = FlowApiKey( + flow_id=flow_id, + name=body.name, + key_hash=key_hash, + key_prefix=key_prefix, + created_by=uuid.UUID(user_ctx["id"]), + ) + db.add(api_key) + await db.flush() + + return { + "code": 200, + "data": { + "id": str(api_key.id), + "name": body.name, + "key_prefix": key_prefix, + "api_key": raw, + "created_at": str(api_key.created_at), + }, + } + + +@router.get("/definitions/{flow_id}/api-keys", response_model=list[FlowApiKeyOut]) +async def list_api_keys(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(FlowApiKey).where(FlowApiKey.flow_id == flow_id).order_by(FlowApiKey.created_at.desc()) + ) + return [FlowApiKeyOut( + id=k.id, flow_id=k.flow_id, name=k.name, + key_prefix=k.key_prefix, last_used_at=k.last_used_at, created_at=k.created_at, + ) for k in result.scalars().all()] + + +@router.delete("/api-keys/{key_id}") +async def delete_api_key(key_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + k = await db.get(FlowApiKey, key_id) + if not k: + raise HTTPException(404, "API Key不存在") + await db.delete(k) + return {"code": 200, "message": "API Key已删除"} + + +# ============================== 执行历史 ============================== + + +@router.get("/executions") +async def list_executions( + db: AsyncSession = Depends(get_db), + flow_id: str | None = Query(None), + page: int = Query(1), + page_size: int = Query(20), +): + query = select(FlowExecution).order_by(FlowExecution.started_at.desc()) + if flow_id: + query = query.where(FlowExecution.flow_id == uuid.UUID(flow_id)) + total_result = await db.execute(query) + total = len(total_result.scalars().all()) + + result = await db.execute(query.offset((page - 1) * page_size).limit(page_size)) + executions = result.scalars().all() + return { + "code": 200, + "data": [{ + "id": str(e.id), + "flow_id": str(e.flow_id), + "version": e.version, + "trigger_type": e.trigger_type, + "status": e.status, + "latency_ms": e.latency_ms, + "token_usage": e.token_usage, + "error_message": e.error_message, + "started_at": str(e.started_at), + "finished_at": str(e.finished_at) if e.finished_at else None, + } for e in executions], + "total": total, + "page": page, + "page_size": page_size, + } + + +# ============================== 模板 ============================== + + FLOW_TEMPLATES = [ { - "id": "tpl_doc_process", - "name": "文档处理流", - "description": "自动解析文档内容,提取关键信息并生成摘要", - "icon": "Document", + "id": "tpl_doc_process", "name": "文档处理流", "description": "自动解析文档内容,提取关键信息并生成摘要", "icon": "Document", "nodes": [ {"id": "n1", "type": "trigger", "label": "文档上传", "config": {"event_type": "document_upload"}, "position": {"x": 100, "y": 100}}, {"id": "n2", "type": "tool", "label": "解析文档", "config": {"tool_name": "parse_document"}, "position": {"x": 400, "y": 100}}, @@ -252,10 +559,7 @@ FLOW_TEMPLATES = [ ], }, { - "id": "tpl_wecom_notify", - "name": "企微通知流", - "description": "接收触发后查询数据并推送企微通知", - "icon": "Bell", + "id": "tpl_wecom_notify", "name": "企微通知流", "description": "接收触发后查询数据并推送企微通知", "icon": "Bell", "nodes": [ {"id": "n1", "type": "trigger", "label": "定时触发", "config": {"event_type": "scheduled"}, "position": {"x": 100, "y": 100}}, {"id": "n2", "type": "tool", "label": "查询任务", "config": {"tool_name": "list_tasks"}, "position": {"x": 400, "y": 100}}, @@ -271,10 +575,7 @@ FLOW_TEMPLATES = [ ], }, { - "id": "tpl_data_analysis", - "name": "数据分析流", - "description": "查询员工数据并生成效率分析报告", - "icon": "DataAnalysis", + "id": "tpl_data_analysis", "name": "数据分析流", "description": "查询员工数据并生成效率分析报告", "icon": "DataAnalysis", "nodes": [ {"id": "n1", "type": "trigger", "label": "分析请求", "config": {"event_type": "button_click"}, "position": {"x": 100, "y": 100}}, {"id": "n2", "type": "tool", "label": "查询下属", "config": {"tool_name": "list_subordinates"}, "position": {"x": 400, "y": 100}}, @@ -290,10 +591,7 @@ FLOW_TEMPLATES = [ ], }, { - "id": "tpl_rag_qa", - "name": "知识库问答流", - "description": "从知识库检索信息后由LLM回答", - "icon": "Search", + "id": "tpl_rag_qa", "name": "知识库问答流", "description": "从知识库检索信息后由LLM回答", "icon": "Search", "nodes": [ {"id": "n1", "type": "trigger", "label": "问题触发", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}}, {"id": "n2", "type": "rag", "label": "知识检索", "config": {"knowledge_base": "default", "top_k": 5}, "position": {"x": 400, "y": 100}}, @@ -307,10 +605,7 @@ FLOW_TEMPLATES = [ ], }, { - "id": "tpl_task_auto", - "name": "任务自动分配流", - "description": "根据描述自动创建任务并分派给合适人员", - "icon": "Tools", + "id": "tpl_task_auto", "name": "任务自动分配流", "description": "根据描述自动创建任务并分派给合适人员", "icon": "Tools", "nodes": [ {"id": "n1", "type": "trigger", "label": "任务描述", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}}, {"id": "n2", "type": "llm", "label": "分析任务", "config": {"system_prompt": "分析以下任务描述,提取标题、优先级、负责人", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 400, "y": 100}}, @@ -329,20 +624,11 @@ FLOW_TEMPLATES = [ @router.get("/market", response_model=list[FlowDefinitionOut]) -async def flow_market(request: Request, db: AsyncSession = Depends(get_db)): +async def flow_market(db: AsyncSession = Depends(get_db)): result = await db.execute( - select(FlowDefinition) - .where(FlowDefinition.status == "published") - .order_by(FlowDefinition.updated_at.desc()) + select(FlowDefinition).where(FlowDefinition.status == "published").order_by(FlowDefinition.updated_at.desc()) ) - flows = result.scalars().all() - return [FlowDefinitionOut( - id=f.id, name=f.name, description=f.description, - version=f.version, status=f.status, - definition_json=f.definition_json, - published_to_wecom=f.published_to_wecom, - created_at=f.created_at, updated_at=f.updated_at, - ) for f in flows] + return [_build_flow_out(f) for f in result.scalars().all()] @router.get("/templates") @@ -351,52 +637,19 @@ async def get_flow_templates(request: Request): @router.post("/templates/{template_id}/use") -async def use_flow_template( - template_id: str, - request: Request, - db: AsyncSession = Depends(get_db), -): +async def use_flow_template(template_id: str, request: Request, db: AsyncSession = Depends(get_db)): template = next((t for t in FLOW_TEMPLATES if t["id"] == template_id), None) if not template: raise HTTPException(404, "模板不存在") - user_ctx = request.state.user + definition_json = {"nodes": template["nodes"], "edges": template["edges"], "trigger": {}} flow = FlowDefinition( name=template["name"] + " (副本)", description=template["description"], - definition_json={ - "nodes": template["nodes"], - "edges": template["edges"], - "trigger": {}, - }, + definition_json=definition_json, + draft_definition_json=definition_json, creator_id=uuid.UUID(user_ctx["id"]), ) db.add(flow) await db.flush() - - return FlowDefinitionOut( - id=flow.id, name=flow.name, description=flow.description, - version=flow.version, status=flow.status, - definition_json=flow.definition_json, - published_to_wecom=flow.published_to_wecom, - created_at=flow.created_at, updated_at=flow.updated_at, - ) - - -@router.get("/executions") -async def list_executions(request: Request, db: AsyncSession = Depends(get_db)): - result = await db.execute( - select(FlowExecution).order_by(FlowExecution.started_at.desc()).limit(100) - ) - executions = result.scalars().all() - return { - "code": 200, - "data": [{ - "id": str(e.id), - "flow_id": str(e.flow_id), - "trigger_type": e.trigger_type, - "status": e.status, - "started_at": str(e.started_at), - "finished_at": str(e.finished_at) if e.finished_at else None, - } for e in executions], - } \ No newline at end of file + return _build_flow_out(flow) \ No newline at end of file diff --git a/backend/modules/mcp_registry/router.py b/backend/modules/mcp_registry/router.py index a41ab2e..75b548f 100644 --- a/backend/modules/mcp_registry/router.py +++ b/backend/modules/mcp_registry/router.py @@ -12,7 +12,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"]) @router.get("/servers", response_model=list[MCPServiceOut]) -async def list_servers(request: Request, db: AsyncSession = Depends(get_db)): +async def list_servers(request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): result = await db.execute( select(MCPService).order_by(MCPService.updated_at.desc()) ) @@ -20,7 +20,7 @@ async def list_servers(request: Request, db: AsyncSession = Depends(get_db)): @router.get("/servers/{server_id}", response_model=MCPServiceOut) -async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): +async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: diff --git a/backend/modules/monitor/router.py b/backend/modules/monitor/router.py index 60c24ff..441a6b9 100644 --- a/backend/modules/monitor/router.py +++ b/backend/modules/monitor/router.py @@ -148,7 +148,7 @@ async def get_employee_analysis( ) formatter = OpenAIChatFormatter() - prompt = await formatter.format([ + prompt = formatter.format([ Msg("system", f"""你是一个企业管理者分析助手。请根据员工与AI的交互记录,生成一个JSON格式的分析报告。 要求: diff --git a/backend/modules/notification/router.py b/backend/modules/notification/router.py index be71aa5..cd59b27 100644 --- a/backend/modules/notification/router.py +++ b/backend/modules/notification/router.py @@ -69,8 +69,8 @@ async def notification_websocket(ws: WebSocket, user_id: str): ws_manager.disconnect(user_id, ws) -@router.post("/send", dependencies=[Depends(get_current_user)]) -async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db)): +@router.post("/send") +async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): user_id = payload.get("user_id", "") target_all = payload.get("target_all", False) title = payload.get("title", "系统通知") @@ -93,7 +93,7 @@ async def send_notification(payload: dict, request: Request, db: AsyncSession = await _push_to_wecom(title, body, user_id) audit = AuditLog( - operator_id=uuid.UUID(request.state.user["id"]), + operator_id=uuid.UUID(user["id"]), action="notification.send", resource="notification", detail={"title": title, "target": user_id if user_id else "broadcast"}, diff --git a/backend/modules/org/router.py b/backend/modules/org/router.py index 42bbf32..d7c4062 100644 --- a/backend/modules/org/router.py +++ b/backend/modules/org/router.py @@ -23,7 +23,12 @@ async def get_departments(request: Request, db: AsyncSession = Depends(get_db)): return [await _build_department_tree(db, d) for d in roots] -async def _build_department_tree(db: AsyncSession, dept: Department) -> DepartmentOut: +async def _build_department_tree(db: AsyncSession, dept: Department, _visited: set[uuid.UUID] = None) -> DepartmentOut: + if _visited is None: + _visited = set() + if dept.id in _visited: + return DepartmentOut(id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, children=[]) + _visited.add(dept.id) children_result = await db.execute( select(Department).where(Department.parent_id == dept.id).order_by(Department.sort_order) ) @@ -31,7 +36,7 @@ async def _build_department_tree(db: AsyncSession, dept: Department) -> Departme return DepartmentOut( id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, - children=[await _build_department_tree(db, c) for c in children], + children=[await _build_department_tree(db, c, _visited) for c in children], ) @@ -195,12 +200,17 @@ async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db)) return [await _user_to_out(db, u) for u in users] -async def _get_subordinate_ids(db: AsyncSession, manager_id: uuid.UUID) -> set[uuid.UUID]: +async def _get_subordinate_ids(db: AsyncSession, manager_id: uuid.UUID, _visited: set[uuid.UUID] = None) -> set[uuid.UUID]: + if _visited is None: + _visited = set() + if manager_id in _visited: + return set() + _visited.add(manager_id) result = await db.execute(select(User).where(User.manager_id == manager_id)) direct = result.scalars().all() ids = {u.id for u in direct} for sub in direct: - ids.update(await _get_subordinate_ids(db, sub.id)) + ids.update(await _get_subordinate_ids(db, sub.id, _visited)) return ids diff --git a/backend/schemas/__init__.py b/backend/schemas/__init__.py index 9a82c53..62bd5d1 100644 --- a/backend/schemas/__init__.py +++ b/backend/schemas/__init__.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict # --- Auth --- @@ -178,33 +178,79 @@ class EmployeeAnalysis(BaseModel): # --- Flow --- class TriggerNodeConfig(BaseModel): event_type: str = "text_message" + channels: list[str] = ["wecom"] + callback_url: str = "" class LLMNodeConfig(BaseModel): system_prompt: str = "" model: str = "gpt-4o-mini" temperature: float = 0.7 agent_id: str = "" + max_tokens: int = 2000 + context_length: int = 5 + memory_mode: str = "short_term" + stream: bool = True + tool_call: bool = False class ToolNodeConfig(BaseModel): tool_name: str = "" + tool_type: str = "" + tool_params: dict = {} + timeout: int = 30 + retry_count: int = 0 + error_handling: str = "throw" class MCPNodeConfig(BaseModel): mcp_server: str = "" tool_name: str = "" + input_params: dict = {} + timeout: int = 30 + response_parser: str = "json" + error_handling: str = "throw" -class WeComNotifyNodeConfig(BaseModel): +class NotifyNodeConfig(BaseModel): + channels: dict = {"wecom": True, "web": False} message_template: str = "" + web_template: str = "" target: str = "" + message_type: str = "text" + async_send: bool = False + error_handling: str = "throw" class ConditionNodeConfig(BaseModel): condition: str = "" + condition_type: str = "expression" + true_label: str = "是" + false_label: str = "否" + default_branch: str = "false" class RAGNodeConfig(BaseModel): knowledge_base: str = "" top_k: int = 5 + search_mode: str = "hybrid" + similarity_threshold: float = 0.7 + result_sort: str = "similarity" + include_metadata: bool = True class OutputNodeConfig(BaseModel): format: str = "text" + output_template: str = "" + indent: int = 2 + encoding: str = "utf-8" + truncate: bool = False + max_length: int = 2000 + +class LoopNodeConfig(BaseModel): + loop_type: str = "fixed" + max_iterations: int = 10 + count: int = 3 + iterator_variable: str = "item" + +class CodeNodeConfig(BaseModel): + language: str = "python" + code: str = "" + timeout: int = 30 + sandbox: bool = True class FlowNode(BaseModel): id: str | None = None @@ -246,7 +292,9 @@ class FlowDefinitionOut(BaseModel): version: int status: str definition_json: dict + published_version_id: uuid.UUID | None = None published_to_wecom: bool + published_to_web: bool = False created_at: datetime | None = None updated_at: datetime | None = None @@ -254,6 +302,98 @@ class FlowDefinitionOut(BaseModel): from_attributes = True +class FlowVersionOut(BaseModel): + id: uuid.UUID + flow_id: uuid.UUID + version: int + definition_json: dict + changelog: str = "" + published_to_wecom: bool = False + published_to_web: bool = False + published_by: uuid.UUID | None = None + created_at: datetime | None = None + + class Config: + from_attributes = True + + +class FlowApiKeyCreate(BaseModel): + name: str + + +class FlowApiKeyOut(BaseModel): + id: uuid.UUID + flow_id: uuid.UUID + name: str + key_prefix: str + last_used_at: datetime | None = None + created_at: datetime | None = None + + class Config: + from_attributes = True + + +class FlowExecuteRequest(BaseModel): + input_text: str = "" + session_id: str | None = None + user_id: str | None = None + + +class FlowChatMessageRequest(BaseModel): + query: str + inputs: dict = {} + response_mode: str = "blocking" + user: str = "" + session_id: str | None = None + + +class OpenAPIImportRequest(BaseModel): + openapi_url: str + base_url_override: str | None = None + + +class CustomToolCreate(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + name: str + description: str | None = None + openapi_url: str | None = None + endpoint_url: str | None = None + method: str = "GET" + path: str = "" + headers: dict = {} + auth_type: str = "none" + auth_config: dict = {} + schema_json: dict | None = None + + +class CustomToolUpdate(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + name: str | None = None + description: str | None = None + endpoint_url: str | None = None + method: str | None = None + path: str | None = None + headers: dict | None = None + auth_type: str | None = None + auth_config: dict | None = None + schema_json: dict | None = None + is_active: bool | None = None + + +class CustomToolOut(BaseModel): + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + id: uuid.UUID + name: str + description: str | None = None + schema_json: dict + endpoint_url: str + method: str + path: str + auth_type: str + is_active: bool + created_at: datetime | None = None + + # --- MCP --- class MCPServiceCreate(BaseModel): name: str diff --git a/backend/websocket_manager.py b/backend/websocket_manager.py new file mode 100644 index 0000000..40eb776 --- /dev/null +++ b/backend/websocket_manager.py @@ -0,0 +1,49 @@ +from fastapi import WebSocket, WebSocketDisconnect +from typing import Dict, Set +import json +import logging + +logger = logging.getLogger(__name__) + + +class WebSocketManager: + def __init__(self): + self.active_connections: Dict[str, Set[WebSocket]] = {} + + async def connect(self, websocket: WebSocket, user_id: str): + await websocket.accept() + if user_id not in self.active_connections: + self.active_connections[user_id] = set() + self.active_connections[user_id].add(websocket) + logger.info(f"WebSocket 用户 {user_id} 已连接") + + def disconnect(self, websocket: WebSocket, user_id: str): + if user_id in self.active_connections: + self.active_connections[user_id].discard(websocket) + if not self.active_connections[user_id]: + del self.active_connections[user_id] + logger.info(f"WebSocket 用户 {user_id} 已断开") + + async def send_to_user(self, user_id: str, message: dict): + if user_id not in self.active_connections: + return False + dead_connections = set() + sent_count = 0 + for connection in list(self.active_connections.get(user_id, set())): + try: + await connection.send_text(json.dumps(message, ensure_ascii=False)) + sent_count += 1 + except Exception: + dead_connections.add(connection) + for conn in dead_connections: + self.active_connections[user_id].discard(conn) + if not self.active_connections.get(user_id): + self.active_connections.pop(user_id, None) + return sent_count > 0 + + async def broadcast(self, message: dict): + for user_id in list(self.active_connections.keys()): + await self.send_to_user(user_id, message) + + +ws_manager = WebSocketManager() \ No newline at end of file diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 9ebf2f5..c3ed589 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -81,7 +81,9 @@ export const flowApi = { updateFlow: (id: string, data: any) => api.put(`/flow/definitions/${id}`, data), deleteFlow: (id: string) => api.delete(`/flow/definitions/${id}`), publishFlow: (id: string) => api.post(`/flow/definitions/${id}/publish`), + publishToWeb: (id: string) => api.post(`/flow/definitions/${id}/publish-web`), unpublishFlow: (id: string) => api.post(`/flow/definitions/${id}/unpublish`), + unpublishFromWeb: (id: string) => api.post(`/flow/definitions/${id}/unpublish-web`), executeFlow: (id: string, data: any) => api.post(`/flow/definitions/${id}/execute`, data), testFlow: (id: string) => api.post(`/flow/definitions/${id}/test`), getMarket: () => api.get('/flow/market'), @@ -89,6 +91,37 @@ export const flowApi = { useTemplate: (id: string) => api.post(`/flow/templates/${id}/use`), } +export const chatApi = { + send: (flowId: string, data: { message: string }) => api.post(`/chat/message/${flowId}`, data), + getFlows: () => api.get('/chat/flows'), +} + +export const customToolApi = { + list: () => api.get('/custom-tools/'), + get: (id: string) => api.get(`/custom-tools/${id}`), + create: (data: any) => api.post('/custom-tools/', data), + update: (id: string, data: any) => api.put(`/custom-tools/${id}`, data), + delete: (id: string) => api.delete(`/custom-tools/${id}`), + importOpenApi: (data: { openapi_url: string; base_url_override?: string }) => api.post('/custom-tools/import-openapi', data), + test: (id: string, params: any) => api.post(`/custom-tools/${id}/test`, params), + getSchemas: () => api.get('/custom-tools/schemas/all'), +} + +export const flowChatApi = { + executeStream: (flowId: string, data: any) => { + return fetch(`/api/flow/definitions/${flowId}/stream`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${localStorage.getItem('token')}`, + }, + body: JSON.stringify(data), + }) + }, + executeBlocking: (flowId: string, data: any) => api.post(`/flow/definitions/${flowId}/execute`, data), + getPublishedFlows: () => api.get('/flow/market'), +} + export const wecomApi = { sendMessage: (data: any) => api.post('/wecom/send', data), getConfig: () => api.get('/wecom/config'), diff --git a/frontend/src/components/layout/MainLayout.vue b/frontend/src/components/layout/MainLayout.vue index c4b85e0..5297a92 100644 --- a/frontend/src/components/layout/MainLayout.vue +++ b/frontend/src/components/layout/MainLayout.vue @@ -59,11 +59,21 @@ 通知中心 + + + 流式对话 + + 流程管理 + + + 自定义工具 + + 个人中心 @@ -113,7 +123,7 @@ import { ref, computed } from 'vue' import { useRoute, useRouter } from 'vue-router' import { useUserStore } from '@/stores/user' -import { Fold, User, ArrowDown, Tools, Search, Promotion } from '@element-plus/icons-vue' +import { Fold, User, ArrowDown, Tools, Search, Promotion, ChatLineSquare, SetUp } from '@element-plus/icons-vue' import PortalSwitcher from '@/components/common/PortalSwitcher.vue' const route = useRoute() @@ -130,6 +140,8 @@ const activeMenu = computed(() => { if (path.startsWith('/user/flow')) return '/user/flow/list' if (path.startsWith('/user/wecom')) return '/user/wecom/config' if (path.startsWith('/user/notification')) return '/user/notification/center' + if (path.startsWith('/user/chat')) return '/user/chat/flow' + if (path.startsWith('/user/tools')) return '/user/tools/custom' if (path.startsWith('/user/settings')) return '/user/settings' return path }) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 4211874..7946dee 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -118,6 +118,18 @@ const router = createRouter({ component: () => import('@/views/profile/Profile.vue'), meta: { title: '个人中心' }, }, + { + path: 'chat/flow', + name: 'FlowChat', + component: () => import('@/views/chat/FlowChat.vue'), + meta: { title: '流式对话' }, + }, + { + path: 'tools/custom', + name: 'CustomToolManager', + component: () => import('@/views/tools/CustomToolManager.vue'), + meta: { title: '自定义API工具' }, + }, { path: 'settings', name: 'Settings', diff --git a/frontend/src/views/chat/FlowChat.vue b/frontend/src/views/chat/FlowChat.vue new file mode 100644 index 0000000..b793a00 --- /dev/null +++ b/frontend/src/views/chat/FlowChat.vue @@ -0,0 +1,393 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/flow/FlowCanvas.vue b/frontend/src/views/flow/FlowCanvas.vue index 81ba2b1..30f9fbb 100644 --- a/frontend/src/views/flow/FlowCanvas.vue +++ b/frontend/src/views/flow/FlowCanvas.vue @@ -142,6 +142,10 @@ function onPaneClickLocal() { function onKeyDown(event: KeyboardEvent) { if (event.key === 'Delete' || event.key === 'Backspace') { + const tag = (event.target as HTMLElement)?.tagName?.toLowerCase() + if (tag === 'input' || tag === 'textarea' || tag === 'select' || (event.target as HTMLElement)?.isContentEditable) { + return + } event.preventDefault() if (selectedEdgeId.value) { @@ -172,6 +176,14 @@ function onContextMenu(event: MouseEvent) { function handleConnect(connection: any) { const sourceHandle = connection.sourceHandle || 'source' + let edgeStyle = { stroke: '#1976D2', strokeWidth: 3 } + if (sourceHandle === 'false') { + edgeStyle = { stroke: '#E53935', strokeWidth: 3 } + } else if (sourceHandle === 'loop_body') { + edgeStyle = { stroke: '#13c2c2', strokeWidth: 3, strokeDasharray: '6,3' } + } else if (sourceHandle === 'loop_done') { + edgeStyle = { stroke: '#909399', strokeWidth: 3 } + } const newEdge: any = { id: `edge_${connection.source}_${connection.target}_${sourceHandle}_${Date.now()}`, source: connection.source, @@ -179,11 +191,9 @@ function handleConnect(connection: any) { sourceHandle, targetHandle: connection.targetHandle, type: 'smoothstep', - animated: true, + animated: sourceHandle === 'loop_body', markerEnd: MarkerType.ArrowClosed, - style: sourceHandle === 'false' - ? { stroke: '#E53935', strokeWidth: 3 } - : { stroke: '#1976D2', strokeWidth: 3 }, + style: edgeStyle, } emit('update:edges', [...props.edges, newEdge]) emit('connect', connection) diff --git a/frontend/src/views/flow/FlowEditor.vue b/frontend/src/views/flow/FlowEditor.vue index dca30e6..bb14428 100644 --- a/frontend/src/views/flow/FlowEditor.vue +++ b/frontend/src/views/flow/FlowEditor.vue @@ -11,6 +11,7 @@ 保存 验证 上架到企微 + 上架到网页 @@ -33,7 +34,8 @@

拖拽节点到画布

从绿色圆点拖线(true)

从红色圆点拖线(false)

-

选中连线后按 Delete 删除

+

循环: 青色(循环体)/灰色(完成)

+

选中连线/节点按 Delete 删除

右键点击连线可删除

点击空白处取消选中

滚轮缩放画布

@@ -90,16 +92,19 @@ import { ref, computed, onMounted } from 'vue' import { useRoute, useRouter } from 'vue-router' import { ElMessage } from 'element-plus' -import { Promotion, ChatDotRound, Tools, Connection, Bell, DataAnalysis, Search } from '@element-plus/icons-vue' +import { MarkerType } from '@vue-flow/core' +import { Promotion, ChatDotRound, Tools, Connection, Bell, DataAnalysis, Search, RefreshRight, Document } from '@element-plus/icons-vue' import FlowCanvas from './FlowCanvas.vue' import TriggerConfig from './node-configs/TriggerConfig.vue' import LlmConfig from './node-configs/LlmConfig.vue' import ToolConfig from './node-configs/ToolConfig.vue' import McpConfig from './node-configs/McpConfig.vue' -import WecomNotifyConfig from './node-configs/WecomNotifyConfig.vue' +import NotifyConfig from './node-configs/WecomNotifyConfig.vue' import ConditionConfig from './node-configs/ConditionConfig.vue' import RagConfig from './node-configs/RagConfig.vue' import OutputConfig from './node-configs/OutputConfig.vue' +import LoopConfig from './node-configs/LoopConfig.vue' +import CodeConfig from './node-configs/CodeConfig.vue' const route = useRoute() const router = useRouter() @@ -124,13 +129,16 @@ const canvasRef = ref(null) let nodeCounter = 0 const nodeTypes = [ - { type: 'trigger', label: '触发节点', icon: Promotion, typeDesc: '企微触发' }, + { type: 'trigger', label: '触发节点', icon: Promotion, typeDesc: '流程触发' }, { type: 'llm', label: 'LLM处理', icon: ChatDotRound, typeDesc: 'AI处理' }, { type: 'tool', label: '工具调用', icon: Tools, typeDesc: '工具调用' }, { type: 'mcp', label: 'MCP服务', icon: Connection, typeDesc: '外部MCP' }, - { type: 'wecom_notify', label: '企微通知', icon: Bell, typeDesc: '企微通知' }, + { type: 'notify', label: '通知', icon: Bell, typeDesc: '消息通知' }, { type: 'condition', label: '条件判断', icon: DataAnalysis, typeDesc: '条件分支' }, { type: 'rag', label: 'RAG检索', icon: Search, typeDesc: '知识库检索' }, + { type: 'loop', label: '循环', icon: RefreshRight, typeDesc: '循环迭代' }, + { type: 'merge', label: '变量聚合', icon: Connection, typeDesc: '并行汇聚' }, + { type: 'code', label: '代码执行', icon: Document, typeDesc: '代码执行' }, { type: 'output', label: '输出节点', icon: Promotion, typeDesc: '结果输出' }, ] @@ -139,10 +147,14 @@ const configComponentMap: Record = { llm: LlmConfig, tool: ToolConfig, mcp: McpConfig, - wecom_notify: WecomNotifyConfig, + notify: NotifyConfig, + wecom_notify: NotifyConfig, condition: ConditionConfig, rag: RagConfig, output: OutputConfig, + loop: LoopConfig, + code: CodeConfig, + merge: NotifyConfig, } function getConfigComponent(type: string) { @@ -154,9 +166,12 @@ const colorMap: Record = { llm: '#409EFF', tool: '#67C23A', mcp: '#E6A23C', + notify: '#F56C6C', wecom_notify: '#F56C6C', condition: '#909399', rag: '#337ecc', + loop: '#13c2c2', + code: '#eb2f96', output: '#722ed1', } @@ -196,13 +211,17 @@ function onDrop(event: DragEvent) { function getDefaultConfig(type: string) { const defaults: Record = { - trigger: { event_type: 'text_message', callback_url: '' }, + trigger: { event_type: 'text_message', channels: ['wecom'], callback_url: '' }, llm: { system_prompt: '', model: 'gpt-4o-mini', temperature: 0.7, agent_id: '', max_tokens: 2000, context_length: 5, memory_mode: 'short_term', stream: true, tool_call: false }, - tool: { tool_type: '', tool_name: '', param_mapping: '{}', timeout: 30, retry_count: 0, error_handling: 'throw' }, - mcp: { mcp_server: '', tool_name: '', input_params: '{}', timeout: 30, response_parser: 'json', error_handling: 'throw' }, - wecom_notify: { message_template: '', target: '', message_type: 'text', async_send: false, error_handling: 'throw' }, - condition: { condition_type: 'expression', condition: '', true_label: '是', false_label: '否', short_circuit: true, default_branch: 'false' }, + tool: { tool_type: '', tool_name: '', param_mapping: '{}', tool_params: {}, timeout: 30, retry_count: 0, error_handling: 'throw' }, + mcp: { mcp_server: '', tool_name: '', input_params: {}, timeout: 30, response_parser: 'json', error_handling: 'throw' }, + notify: { channels: { wecom: true, web: false }, message_template: '', web_template: '', target: '', message_type: 'text', async_send: false, error_handling: 'throw' }, + wecom_notify: { channels: { wecom: true, web: false }, message_template: '', web_template: '', target: '', message_type: 'text', async_send: false, error_handling: 'throw' }, + condition: { condition_type: 'expression', condition: '', true_label: '是', false_label: '否', default_branch: 'false' }, rag: { knowledge_base: '', top_k: 5, search_mode: 'hybrid', similarity_threshold: 0.7, result_sort: 'similarity', include_metadata: true }, + loop: { loop_type: 'fixed', count: 3, iterator_variable: 'item', max_iterations: 10, items: [] }, + merge: { merge_type: 'concat', expected_branches: 2 }, + code: { language: 'python', code: '', timeout: 30, sandbox: true }, output: { format: 'text', output_template: '', indent: 2, encoding: 'utf-8', truncate: false, max_length: 2000 }, } return defaults[type] || {} @@ -221,7 +240,7 @@ function onConfigLabelChange() { if (idx !== -1) { const updated = { ...nodes.value[idx] } updated.data = { ...updated.data, label: selectedNodeData.value.label } - nodes.value[idx] = updated + nodes.value = nodes.value.map((n: any, i: number) => i === idx ? updated : n) } } @@ -231,7 +250,7 @@ function onConfigChange() { const found = nodes.value[idx] const updated = { ...found } updated.data = { ...found.data, config: { ...selectedNodeData.value.config } } - nodes.value[idx] = updated + nodes.value = nodes.value.map((n: any, i: number) => i === idx ? updated : n) } function removeNode(id: string) { @@ -242,6 +261,13 @@ function removeNode(id: string) { function clearCanvas() { nodes.value = []; edges.value = []; nodeCounter = 0; selectedNodeId.value = '' } +function getEdgeStyle(sourceHandle: string | undefined) { + if (sourceHandle === 'false') return { stroke: '#E53935', strokeWidth: 3 } + if (sourceHandle === 'loop_body') return { stroke: '#13c2c2', strokeWidth: 3, strokeDasharray: '6,3' } + if (sourceHandle === 'loop_done') return { stroke: '#909399', strokeWidth: 3 } + return { stroke: '#1976D2', strokeWidth: 3 } +} + async function loadFlow() { if (!isEdit.value) return try { @@ -270,7 +296,7 @@ async function loadFlow() { const source = e.source || e.from const target = e.target || e.to const cond = e.condition || e.sourceHandle - loadedEdges.push({ id: e.id || `edge_${source}_${target}`, source, target, sourceHandle: cond || 'source', type: 'smoothstep', animated: true, markerEnd: true, style: cond === 'false' ? { stroke: '#E53935', strokeWidth: 3 } : { stroke: '#1976D2', strokeWidth: 3 } }) + loadedEdges.push({ id: e.id || `edge_${source}_${target}`, source, target, sourceHandle: cond || 'source', type: 'smoothstep', animated: cond === 'loop_body', markerEnd: MarkerType.ArrowClosed, style: getEdgeStyle(cond) }) } nodes.value = loadedNodes edges.value = loadedEdges @@ -295,7 +321,7 @@ async function saveFlow() { const { flowApi } = await import('@/api') const snapshot = canvasRef.value?.getSnapshot() || { nodes: [], edges: [] } const serializedNodes = snapshot.nodes.map((n: any) => ({ id: n.id, type: n.data?.type || n.type, label: n.data?.label || n.id, config: n.data?.config || {}, position: n.position })) - const serializedEdges = snapshot.edges.map((e: any) => ({ source: e.source, target: e.target, sourceHandle: e.sourceHandle || 'source', condition: e.sourceHandle === 'false' ? 'false' : (e.sourceHandle === 'true' ? 'true' : undefined) })) + const serializedEdges = snapshot.edges.map((e: any) => ({ id: e.id, source: e.source, target: e.target, sourceHandle: e.sourceHandle || 'source' })) const payload = { name: flowName.value, description: flowDesc.value, nodes: serializedNodes, edges: serializedEdges, trigger: {} } if (isEdit.value) { await flowApi.updateFlow(flowId.value, payload); ElMessage.success('保存成功') } else { const res: any = await flowApi.createFlow(payload); const data = res?.data || res || {}; if (data.id) { router.replace(`/admin/flow/editor/${data.id}`); ElMessage.success('创建成功') } } @@ -312,6 +338,11 @@ async function publishFlow() { try { const { flowApi } = await import('@/api'); await flowApi.publishFlow(flowId.value); ElMessage.success('流已上架到企微'); await loadFlow() } catch {} } +async function publishToWeb() { + if (!isEdit.value) { ElMessage.warning('请先保存'); return } + try { const { flowApi } = await import('@/api'); await flowApi.publishToWeb(flowId.value); ElMessage.success('流已上架到网页'); await loadFlow() } catch {} +} + onMounted(async () => { try { if (isEdit.value) { await loadFlow() } diff --git a/frontend/src/views/flow/FlowList.vue b/frontend/src/views/flow/FlowList.vue index c8b7b3b..2fd1760 100644 --- a/frontend/src/views/flow/FlowList.vue +++ b/frontend/src/views/flow/FlowList.vue @@ -14,18 +14,36 @@ - - - - + + + + + + + + + + @@ -56,8 +74,20 @@ onMounted(async () => { async function handlePublish(row: any) { await flowApi.publishFlow(row.id) - ElMessage.success('流已上架到企微') - await refreshList() + ElMessage.success('已上架到企微') + refreshList() +} + +async function handlePublishWeb(row: any) { + await flowApi.publishToWeb(row.id) + ElMessage.success('已上架到网页') + refreshList() +} + +async function handleUnpublishWeb(row: any) { + await flowApi.unpublishFromWeb(row.id) + ElMessage.success('已从网页下架') + refreshList() } async function handleUnpublish(row: any) { diff --git a/frontend/src/views/flow/FlowNode.vue b/frontend/src/views/flow/FlowNode.vue index 8fb113c..10e7caf 100644 --- a/frontend/src/views/flow/FlowNode.vue +++ b/frontend/src/views/flow/FlowNode.vue @@ -13,7 +13,7 @@ - -