Browse Source

优化更新记忆模块

master
MSI-7950X\刘泽明 1 day ago
parent
commit
c1bbeb3721
  1. 893
      PLAN9.md
  2. 8
      backend/main.py
  3. 113
      backend/models/__init__.py
  4. 447
      backend/modules/flow_engine/engine.py
  5. 8
      backend/modules/flow_engine/gateway.py
  6. 249
      backend/modules/flow_engine/router.py
  7. 4
      backend/modules/memory/__init__.py
  8. 820
      backend/modules/memory/manager.py
  9. 2
      backend/modules/memory/router.py
  10. 3
      backend/modules/model_provider/__init__.py
  11. 181
      backend/modules/model_provider/router.py
  12. 2
      backend/schemas/__init__.py
  13. 13
      docker-compose.prod.yml
  14. 13
      docker-compose.yml
  15. 15
      frontend/src/api/index.ts
  16. 2
      frontend/src/components/layout/AdminLayout.vue
  17. 6
      frontend/src/router/index.ts
  18. 129
      frontend/src/views/flow/FlowEditor.vue
  19. 14
      frontend/src/views/flow/FlowNode.vue
  20. 127
      frontend/src/views/flow/node-configs/HttpRequestConfig.vue
  21. 80
      frontend/src/views/flow/node-configs/IterationConfig.vue
  22. 24
      frontend/src/views/flow/node-configs/MergeConfig.vue
  23. 120
      frontend/src/views/flow/node-configs/QuestionClassifierConfig.vue
  24. 81
      frontend/src/views/flow/node-configs/QuestionOptimiserConfig.vue
  25. 94
      frontend/src/views/flow/node-configs/TemplateTransformConfig.vue
  26. 123
      frontend/src/views/flow/node-configs/VariableAssignerConfig.vue
  27. 338
      frontend/src/views/model/ModelProviderManager.vue
  28. 6
      init-db/02-add-published-cols.sql
  29. 88
      init-db/03-memory-tables.sql
  30. 30
      init-db/04-model-provider.sql
  31. 7
      init-db/05-flow-mode.sql
  32. 7
      init-db/06-agent-model-ids.sql
  33. 27
      init-db/07-hybrid-search.sql
  34. 18
      init-db/08-flow-templates.sql
  35. 300
      scripts/migrate_memory_redis_to_pg.py

893
PLAN9.md

@ -0,0 +1,893 @@
# PLAN9 — 全面升级计划:记忆增强 · 流重构 · 模型管理 · 节点扩充
> 日期:2026-05-15
> 状态:规划中,待确认后实施
---
## 一、记忆管理升级:吸收 Agent-Memory 分层架构
### 1.1 现状分析
**我们当前方案(PLAN8 已实施)**:
- 纯 Redis 存储:List 存消息、Hash 存元数据、String 存摘要
- 透明中间件:在 `flow_engine/router.py``execute_flow` / `execute_flow_stream` 中注入记忆
- LLMNodeAgent.reply() 从 `context["_memory_context"]` 读取历史摘要 + 最近消息
- 7 天 TTL 自动过期,异步记录不阻塞响应
**Agent-Memory 项目(ali-agentscope-src/Agent-Memory)**:
- TypeScript 实现,以 OpenClaw 插件形式运行
- 四层记忆金字塔:L0(原始对话) → L1(结构化原子) → L2(场景块) → L3(用户画像)
- 双写机制:JSONL 持久化 + SQLite 向量检索
- 混合搜索:BM25 关键词 + Embedding 向量 + RRF 融合
- Context Offload:长任务上下文溢出时 Mermaid 符号化压缩
- Pipeline 调度:每 N 轮对话触发 L1 提取,延迟触发 L2/L3
### 1.2 对比评估
| 维度 | 我们方案 | Agent-Memory | 评价 |
|------|---------|-------------|------|
| **存储** | Redis(纯内存) | SQLite + JSONL(磁盘) | 我们快但不适合海量历史;他们慢但可持久化大量数据 |
| **记忆分层** | 无分层(平铺消息列表) | L0/L1/L2/L3 四层金字塔 | **最大差距**:我们没有结构化提取,全是原始消息 |
| **检索方式** | 全量最近 N 条 + 摘要 | BM25 + 向量 + RRF 混合搜索 | 我们只能按时间取最近消息,无法按语义检索 |
| **去重** | 无 | L1 提取后 batchDedup(store/update/merge/skip) | 我们会重复存储相同信息 |
| **用户画像** | 未实现 | L3 Persona 自动生成 + 增量更新 | 他们有完整的画像系统 |
| **上下文溢出** | 无处理 | Context Offload + Mermaid 压缩 | 我们长对话会直接截断,丢失信息 |
| **接入方式** | Python 原生集成 | TypeScript 插件,需 OpenClaw/Hermes 宿主 | **不能直接部署接入**,语言和架构不兼容 |
| **性能** | 毫秒级(Redis) | 秒级(SQLite + LLM 提取) | 我们快但浅;他们慢但深 |
### 1.3 结论
**不能直接部署接入**:Agent-Memory 是 TypeScript 项目,依赖 OpenClaw/Hermes 宿主环境,与我们的 Python/FastAPI 架构完全不兼容。
**应吸收其核心设计思想**,但存储底座需要重新选型——详见 1.4 节分析。
### 1.4 记忆存储底座选型:Redis vs MongoDB vs PostgreSQL
#### 1.4.1 当前 Redis 方案的数据结构与访问模式
```
当前 Redis 键结构:
mem:{uid}:{fid}:{sid}:messages → List(LPUSH 写入,LRANGE 读取最近 N 条)
mem:{uid}:{fid}:{sid}:meta → Hash(HSET 写入,HGETALL 读取)
mem:{uid}:{fid}:{sid}:summary → String(SETEX 写入,GET 读取)
mem:{uid}:sessions → Set(SADD 写入,SMEMBERS 读取)
访问模式:
写入:record_exchange() → pipeline 批量 LPUSH + HSET + SADD(每次对话 1 次)
读取:inject_memory() → LRANGE + GET(每次对话 1 次,延迟 < 1ms
删除:delete_session() → KEYS + DEL(低频,用户主动清除)
列表:list_user_sessions() → SMEMBERS + KEYS + HGETALL(低频,管理页面)
```
#### 1.4.2 三种数据库四维度对比
**维度一:内存占用与性能**
| 指标 | Redis | MongoDB | PostgreSQL |
|------|-------|---------|-----------|
| **写入延迟** | < 1ms纯内存 | 2-5msWAL + 内存映射 | 3-8msWAL + fsync |
| **读取延迟** | < 1ms | 1-3msWiredTiger 缓存命中 | 1-5msshared_buffers 命中 |
| **inject_memory 延迟** | ~1ms | ~5ms | ~8ms |
| **内存消耗(1万会话×40条消息)** | ~800MB(纯内存,无压缩) | ~200MB(磁盘 + 缓存热数据) | ~150MB(磁盘 + shared_buffers) |
| **内存消耗(100万会话)** | ~80GB(不可行,需分片) | ~20GB(缓存 + 磁盘冷数据淘汰) | ~15GB(磁盘为主,热数据缓存) |
| **大规模场景性能** | 受内存上限约束,超内存即 OOM | 缓存冷热分离,可支撑 TB 级 | 磁盘为主,可支撑 PB 级 |
| **向量检索性能** | 需 Redis Stack(RediSearch),10万向量 ~50ms | 原生向量索引(Atlas Vector Search),10万向量 ~30ms | pgvector 扩展,10万向量 ~40ms |
| **全文检索性能** | 需 Redis Stack(FTS 模块),非标准 | 原生文本索引,成熟 | tsvector + GIN 索引,成熟 |
**关键结论**:
- Redis 在小规模(< 10万会话下性能最优但内存成本线性增长
- MongoDB 和 PostgreSQL 在大规模场景下更优,冷数据自动落盘不占内存
- 向量检索三者性能接近,但 Redis 需额外安装 Stack 模块
**维度二:数据模型适配性**
| 指标 | Redis | MongoDB | PostgreSQL |
|------|-------|---------|-----------|
| **消息列表(L0)** | List(天然有序,但无结构化查询) | 嵌入文档数组或独立文档(灵活) | 关系表 + 时间索引(标准) |
| **结构化原子(L1)** | Hash + Sorted Set(需手动管理) | 文档(天然 JSON,灵活 schema) | JSONB 列(灵活 + 可索引) |
| **场景块(L2)** | Hash(需手动序列化) | 文档(嵌套结构自然表达) | JSONB 或独立表 |
| **用户画像(L3)** | String(纯文本,无查询能力) | 文档(结构化,可按字段查询) | JSONB 或独立表(可按字段查询) |
| **向量存储** | 需 Redis Stack(非标准部署) | 原生支持(Atlas Vector Search) | pgvector 扩展(成熟) |
| **数据迁移复杂度** | 基准(当前方案) | 中(需新建 MongoDB 连接 + 集合设计) | **低**(已有 PostgreSQL 连接池,复用 SQLAlchemy) |
| **Schema 演进** | 无 Schema(灵活但无约束) | 无 Schema(灵活,可加验证) | 有 Schema(需迁移 SQL,但类型安全) |
**关键结论**:
- MongoDB 对文档型记忆数据最自然(JSON 原生存储,无需 ORM 映射)
- PostgreSQL 迁移成本最低(项目已用 SQLAlchemy + asyncpg,复用现有连接池)
- Redis 对 L0 消息列表操作最自然(LPUSH/LRANGE),但对 L1/L2/L3 结构化数据支持弱
**维度三:扩展性与维护成本**
| 指标 | Redis | MongoDB | PostgreSQL |
|------|-------|---------|-----------|
| **水平扩展** | Redis Cluster(分片复杂,跨 slot 操作受限) | 原生分片(Shard Key 自动路由) | 读写分离 + 分区表(Citus 扩展) |
| **运维复杂度** | 低(单实例简单,Cluster 复杂) | 中(副本集 + 分片需监控) | **最低**(项目已有 PostgreSQL 运维) |
| **新增基础设施** | 无(已部署) | **需新增 MongoDB 服务** | **无需新增**(复用现有 PostgreSQL) |
| **备份恢复** | RDB/AOF(需配置持久化策略) | mongodump/oplog | pg_dump/WAL 归档(**已有**) |
| **监控工具** | redis-cli INFO | MongoDB Compass/Cloud Manager | pg_stat_statements(**已有**) |
| **长期维护成本** | 高(内存持续增长需扩容) | 中 | **低**(磁盘扩容便宜,运维体系成熟) |
**关键结论**:
- **PostgreSQL 运维成本最低**:项目已有 PostgreSQL 实例、连接池、备份策略,无需新增基础设施
- MongoDB 需要额外部署和维护一套数据库服务
- Redis 长期成本最高:内存是磁盘的 10-50 倍价格
**维度四:事务与一致性需求**
| 指标 | Redis | MongoDB | PostgreSQL |
|------|-------|---------|-----------|
| **事务支持** | Pipeline(非原子,MULTI 仅单 key 原子) | 4.0+ 多文档事务(副本集) | 完整 ACID 事务 |
| **记忆数据事务需求** | 低(单次写入可容忍部分失败) | 低 | 低 |
| **模型配置事务需求** | 不涉及 | 不涉及 | 高(需与现有业务表关联) |
| **数据一致性** | 最终一致(AOF 异步刷盘可能丢 1s 数据) | 强一致(副本集 w:majority) | 强一致(WAL 同步刷盘) |
| **记忆丢失风险** | AOF 异步刷盘时宕机可能丢失 | 副本集保证不丢 | WAL 保证不丢 |
**关键结论**:
- 记忆数据对事务要求低(丢失一条消息可接受),三种数据库均满足
- 但 PostgreSQL 的强一致性在模型配置等业务数据上更有优势
- Redis AOF 异步模式下有 1 秒数据丢失窗口,对"不失忆"要求有风险
#### 1.4.3 综合评估与选型决策
| 评估维度 | Redis | MongoDB | PostgreSQL | 权重 |
|---------|-------|---------|-----------|------|
| 性能(小规模) | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ | 20% |
| 性能(大规模) | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 25% |
| 数据模型适配 | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 20% |
| 迁移成本 | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | 15% |
| 运维成本 | ⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | 10% |
| 事务/一致性 | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | 10% |
| **加权总分** | **3.35** | **3.25** | **3.95** | — |
#### 1.4.4 选型结论:PostgreSQL + Redis 混合方案
**选择 PostgreSQL 作为记忆主存储,Redis 保留为热数据缓存层**。
理由:
1. **零新增基础设施**:项目已有 PostgreSQL(asyncpg + SQLAlchemy),无需部署新服务
2. **迁移成本最低**:复用现有连接池、ORM、迁移框架,新增表即可
3. **PGVector 原生支持**:向量检索无需额外模块,`CREATE EXTENSION vector` 即可
4. **JSONB 灵活 + 可索引**:L1/L2/L3 结构化数据用 JSONB 存储,支持 GIN 索引按字段查询
5. **强一致性保证不失忆**:WAL 同步刷盘,无 Redis AOF 的 1 秒丢失窗口
6. **大规模成本优势**:磁盘存储比内存便宜 10-50 倍,100 万会话仅需 ~15GB 磁盘
**Redis 保留角色**:
- 热数据缓存:最近 10 条消息缓存到 Redis(inject_memory 时直接读缓存,延迟 < 1ms
- 速率限制:已有 cache_manager 的限流功能
- 会话状态:当前活跃会话的临时状态
#### 1.4.5 PostgreSQL 记忆数据表设计
```sql
-- L0: 原始对话消息
CREATE TABLE memory_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id),
flow_id UUID NOT NULL REFERENCES flow_definitions(id),
session_id UUID NOT NULL,
role VARCHAR(20) NOT NULL, -- "user" / "assistant"
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_memory_messages_session ON memory_messages(user_id, flow_id, session_id, created_at DESC);
CREATE INDEX idx_memory_messages_user_flow ON memory_messages(user_id, flow_id, created_at DESC);
-- L1: 结构化记忆原子
CREATE TABLE memory_atoms (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id),
flow_id UUID REFERENCES flow_definitions(id), -- NULL 表示全局
atom_type VARCHAR(20) NOT NULL, -- "persona" / "episodic" / "instruction"
content TEXT NOT NULL,
priority SMALLINT DEFAULT 50, -- 0-100,越高越核心
source_session_id UUID, -- 来源会话
metadata JSONB DEFAULT '{}', -- 扩展元数据
embedding vector(1536), -- 向量(需 pgvector 扩展)
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_memory_atoms_user ON memory_atoms(user_id, atom_type, priority DESC);
CREATE INDEX idx_memory_atoms_embedding ON memory_atoms USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
-- L2: 场景块
CREATE TABLE memory_scenes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id),
flow_id UUID REFERENCES flow_definitions(id),
scene_name VARCHAR(200) NOT NULL,
summary TEXT NOT NULL,
heat INTEGER DEFAULT 0, -- 热度(访问频率)
content JSONB DEFAULT '{}', -- 场景完整内容
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_memory_scenes_user ON memory_scenes(user_id, flow_id, heat DESC);
-- L3: 用户画像
CREATE TABLE memory_personas (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL UNIQUE REFERENCES users(id), -- 每用户一条
content JSONB NOT NULL DEFAULT '{}', -- 结构化画像
raw_text TEXT DEFAULT '', -- 原始画像文本(注入 LLM 用)
version INTEGER DEFAULT 1,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 会话元数据
CREATE TABLE memory_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id),
flow_id UUID NOT NULL REFERENCES flow_definitions(id),
session_id UUID NOT NULL,
flow_name VARCHAR(200) DEFAULT '',
message_count INTEGER DEFAULT 0,
last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(user_id, flow_id, session_id)
);
CREATE INDEX idx_memory_sessions_user ON memory_sessions(user_id, last_active_at DESC);
```
#### 1.4.6 MemoryManager 改造方案
```python
class MemoryManager:
def __init__(self, db_session_factory, redis: Redis):
self.db_session_factory = db_session_factory # AsyncSessionLocal
self.redis = redis # 热数据缓存
async def inject_memory(self, user_id, flow_id, session_id, context):
# 1. 先查 Redis 缓存(最近 10 条消息)
cached = await self._get_cached_messages(user_id, flow_id, session_id)
if cached:
recent_messages = cached
else:
# 2. 缓存未命中,查 PostgreSQL
recent_messages = await self._query_recent_messages(user_id, flow_id, session_id)
# 3. 回填 Redis 缓存
await self._cache_messages(user_id, flow_id, session_id, recent_messages)
# 4. 查 L1 原子(PostgreSQL)
atoms = await self._query_relevant_atoms(user_id, flow_id, context.get("input", ""))
# 5. 查 L3 画像(PostgreSQL)
persona = await self._query_persona(user_id)
context["_memory_context"] = {
"recent_messages": recent_messages,
"atoms": atoms,
"persona": persona,
"session_id": session_id,
}
async def record_exchange(self, user_id, flow_id, session_id, user_msg, assistant_msg, flow_name=""):
ts = datetime.utcnow()
# 1. 写 PostgreSQL(主存储,保证持久化)
async with self.db_session_factory() as session:
session.add(MemoryMessage(
user_id=user_id, flow_id=flow_id, session_id=session_id,
role="user", content=user_msg, created_at=ts,
))
session.add(MemoryMessage(
user_id=user_id, flow_id=flow_id, session_id=session_id,
role="assistant", content=assistant_msg, created_at=ts,
))
await session.commit()
# 2. 更新 Redis 缓存(热数据)
await self._cache_append_message(user_id, flow_id, session_id, [
{"role": "user", "content": user_msg, "ts": ts.isoformat()},
{"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
])
# 3. 异步触发 L1 提取
asyncio.create_task(self._maybe_extract_atoms(user_id, flow_id, session_id))
```
#### 1.4.7 迁移步骤
```
Step 1: 创建 PostgreSQL 记忆表(init-db/03-memory-tables.sql)
Step 2: 安装 pgvector 扩展(init-db/03-memory-tables.sql 中 CREATE EXTENSION IF NOT EXISTS vector)
Step 3: 改造 MemoryManager,PostgreSQL 为主存储,Redis 为缓存
Step 4: 数据迁移脚本:Redis → PostgreSQL(一次性,读取 Redis 中现有记忆数据写入 PG)
Step 5: 验证全链路:inject_memory / record_exchange / delete_session
Step 6: 确认无误后,Redis 中的记忆键可设置短 TTL 自然过期
```
#### 1.4.8 迁移 SQL 文件
详见 `init-db/03-memory-tables.sql`(实施时创建),包含:
- `CREATE EXTENSION IF NOT EXISTS vector`
- 上述 5 张表的 CREATE TABLE + INDEX
- 幂等执行(IF NOT EXISTS)
---
## 二、全面升级流管理(对标 Dify)
### 2.1 现状问题
1. **节点类型不足**:当前 11 种节点(trigger/llm/tool/mcp/notify/condition/rag/loop/merge/code/output),Dify 有 15+ 种
2. **记忆不是默认行为**:虽然 PLAN8 实现了透明中间件,但缺乏 Dify 那样的"对话型 vs 工作流型"区分
3. **变量系统原始**:仅 `{{node_id.output}}` 模板,无类型、无作用域、无聚合
4. **版本管理不完善**:有 FlowVersion 快照但缺少完整的草稿/发布分离流程
### 2.2 流类型区分(Chatflow vs Workflow)
**Dify 的核心设计**:区分两种应用模式,记忆机制不同
| 模式 | 记忆 | 典型场景 | 对应我们的 |
|------|------|---------|-----------|
| **Chatflow(对话型)** | 自动维护对话历史,LLM 节点自动注入上下文 | 客服、助手、问答 | FlowChat.vue 使用的流 |
| **Workflow(工作流型)** | 无自动记忆,每次执行独立 | 数据处理、批量任务、API 调用 | API 网关调用的流 |
**改造方案**:
```python
# FlowDefinition 新增字段
flow_mode = Column(String(20), default="chatflow") # "chatflow" | "workflow"
# 记忆中间件逻辑调整
if f.flow_mode == "chatflow":
await mm.inject_memory(...) # 注入记忆
asyncio.create_task(mm.record_exchange(...)) # 记录对话
else:
pass # workflow 模式不注入记忆
```
前端 FlowEditor.vue 新增流类型选择(创建流时选择,创建后不可更改)。
### 2.3 记忆默认启用策略
**核心原则**:所有 Chatflow 类型的流,记忆管理是默认行为,不需要用户在流中添加节点。
```
用户输入 → router.py
├─ Chatflow 模式:
│ ├─ 执行前:inject_memory() → 检索历史 + 画像 → 注入 context
│ ├─ LLM 调用:LLMNodeAgent 自动读取 _memory_context
│ └─ 执行后:record_exchange() → 异步记录 + L1 提取
└─ Workflow 模式:
└─ 直接执行,无记忆注入
```
### 2.4 流创建流程增强
**当前**:创建流 → 编辑画布 → 发布
**升级后**:
```
创建流
├─ 选择流类型(Chatflow / Workflow)
├─ 选择模板(可选,从模板市场选择)
└─ 进入编辑器
编辑画布
├─ 拖拽节点
├─ 配置节点参数
└─ 连线
发布
├─ 拓扑完整性检查(无孤立节点、起始/结束节点存在)
├─ 必填参数校验
├─ 创建版本快照(不可变)
└─ 更新 published_version_id 指针
```
### 2.5 流市场 vs 流列表的关系梳理
**当前问题**:流市场的"已上架工作流列表"和一级菜单"流列表"内容重叠,用户困惑。
**梳理方案**:
| 页面 | 定位 | 数据来源 | 操作 |
|------|------|---------|------|
| **流列表**(管理端) | 我创建/我管理的所有流 | `GET /api/flow/definitions` | 编辑、删除、发布/下架、上架到市场 |
| **流市场**(用户端) | 已上架到市场的公开流 | `GET /api/flow/market` | 查看详情、安装/使用、评价 |
| **模板中心**(创建时) | 系统预置 + 社区贡献的模板 | `GET /api/flow/templates` | 一键使用模板创建新流 |
**关键区分**:
- 流列表 = 私有工作区,看到的是自己管理的流
- 流市场 = 公开商店,看到的是别人发布到市场的流
- 模板中心 = 快速起步,创建新流时的入口
---
## 三、智能体管理改造:OpenAI-API-Compatible 模型管理
### 3.1 现状问题
当前 `AgentConfig` 模型只有一个 `model` 字段(String),无法区分模型类型,无法管理 Embedding/Rerank 模型,无法支持多供应商。
```python
# 当前 AgentConfig
model = Column(String(50), default="gpt-4o-mini") # 只能存一个模型名
```
### 3.2 Dify 的模型供应商架构
Dify 采用三层架构:
```
ModelProvider(供应商抽象层)
├── 模型类型:LLM / Embedding / Rerank / TTS / Speech-to-Text
├── 供应商实例:OpenAI / Anthropic / ZhipuAI / Ollama / OpenAI-API-Compatible
└── 模型实例:gpt-4o / text-embedding-3-small / cohere-rerank
```
**OpenAI-API-Compatible 模式**:任何实现了 OpenAI API 格式的服务都能即插即用,只需配置 base_url + api_key + model_name。
### 3.3 改造方案
#### 数据模型
```python
# 新增:模型供应商表
class ModelProvider(Base):
__tablename__ = "model_providers"
id = Column(UUID, primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False) # "OpenAI" / "智谱AI" / "本地Ollama"
provider_type = Column(String(50), nullable=False) # "openai" / "zhipu" / "ollama" / "openai_compatible"
base_url = Column(String(500)) # API 端点
api_key = Column(Text) # 加密存储
extra_config = Column(JSON, default=dict) # 供应商特有配置
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
# 新增:模型实例表
class ModelInstance(Base):
__tablename__ = "model_instances"
id = Column(UUID, primary_key=True, default=uuid.uuid4)
provider_id = Column(UUID, ForeignKey("model_providers.id"))
model_name = Column(String(100), nullable=False) # "gpt-4o" / "embedding-3" / "bge-rerank"
model_type = Column(String(30), nullable=False) # "llm" / "embedding" / "rerank"
display_name = Column(String(200)) # "GPT-4o" / "文本嵌入v3"
capabilities = Column(JSON, default=dict) # {"vision": true, "function_calling": true, "max_tokens": 128000}
default_params = Column(JSON, default=dict) # {"temperature": 0.7, "top_p": 1.0}
is_default = Column(Boolean, default=False) # 是否为该类型的默认模型
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
```
#### 各模型类型的配置参数
| 模型类型 | 通用参数 | 特有参数 |
|---------|---------|---------|
| **LLM** | model, temperature, top_p, max_tokens, stream | vision(多模态)、function_calling、response_format |
| **Embedding** | model, dimensions | encoding_format, input_type |
| **Rerank** | model, top_n | query, documents, return_documents |
#### API 端点设计
```
# 供应商管理
POST /api/model-providers/ # 添加供应商
GET /api/model-providers/ # 列出供应商
PUT /api/model-providers/{id} # 更新供应商
DELETE /api/model-providers/{id} # 删除供应商
POST /api/model-providers/{id}/test # 测试连通性
# 模型实例管理
POST /api/model-providers/{id}/models/ # 添加模型
GET /api/model-instances/ # 列出所有模型(支持 ?type=llm 筛选)
PUT /api/model-instances/{id} # 更新模型
DELETE /api/model-instances/{id} # 删除模型
POST /api/model-instances/{id}/test # 测试模型调用
# 默认模型设置
PUT /api/model-defaults/ # 设置各类型默认模型
GET /api/model-defaults/ # 获取各类型默认模型
```
#### 前端页面
新增 `ModelProviderManager.vue`(管理端):
- 供应商列表(卡片式,显示名称、类型、状态、模型数量)
- 添加供应商表单(选择类型 → 填写 base_url/api_key → 自动检测可用模型)
- 模型实例列表(按类型 Tab 分组:LLM / Embedding / Rerank)
- 每个模型可设置默认参数、是否为默认模型
- 测试按钮:发送测试请求验证连通性
#### AgentConfig 关联调整
```python
# AgentConfig 改造
class AgentConfig(Base):
# ...现有字段保留...
model = Column(String(50), default="gpt-4o-mini") # 保留兼容
model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联模型实例
embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联嵌入模型
```
#### 引擎层适配
```python
# engine.py 中 LLMNodeAgent 改造
class LLMNodeAgent(AgentBase):
async def reply(self, msg, **kwargs):
context = kwargs.get("context", {})
config = self.config
# 优先使用 model_instance_id 获取模型配置
model_instance_id = config.get("model_instance_id")
if model_instance_id:
provider, model = await self._resolve_model(model_instance_id)
base_url = provider.base_url
api_key = provider.api_key
model_name = model.model_name
else:
# fallback 到旧逻辑
base_url = settings.OPENAI_BASE_URL
api_key = settings.OPENAI_API_KEY
model_name = config.get("model", "gpt-4o-mini")
```
### 3.4 实施优先级
| 步骤 | 内容 | 优先级 |
|------|------|--------|
| Step 1 | 创建 ModelProvider + ModelInstance 数据模型 + 迁移 SQL | **P0** |
| Step 2 | 实现供应商/模型 CRUD API | **P0** |
| Step 3 | 前端 ModelProviderManager.vue 页面 | **P0** |
| Step 4 | LLMNodeAgent 适配 model_instance_id | P1 |
| Step 5 | RAGNodeAgent 适配 embedding_model_id | P1 |
| Step 6 | AgentConfig 关联 model_instance_id | P1 |
| Step 7 | 供应商自动检测可用模型 | P2 |
---
## 四、流编辑器节点扩充(对标 Dify)
### 4.1 现有节点 vs Dify 节点对比
| 节点 | 我们有 | Dify 有 | 差距说明 |
|------|:------:|:-------:|---------|
| 触发/开始 | ✅ trigger | ✅ start | 我们多了企微触发,Dify 只有输入变量 |
| LLM | ✅ llm | ✅ llm | 基本对齐 |
| 工具调用 | ✅ tool | ✅ tool | 基本对齐 |
| MCP | ✅ mcp | ❌ | 我们独有 |
| 通知 | ✅ notify | ❌ | 我们独有(企微通知) |
| 条件分支 | ✅ condition | ✅ if-else | 基本对齐 |
| RAG检索 | ✅ rag | ✅ knowledge-retrieval | 基本对齐 |
| 循环 | ✅ loop | ✅ iteration | Dify 是数组迭代,我们是通用循环 |
| 变量聚合 | ✅ merge | ✅ variable-assigner | 基本对齐 |
| 代码执行 | ✅ code | ✅ code | 基本对齐 |
| 输出/结束 | ✅ output | ✅ end | 基本对齐 |
| **HTTP 请求** | ❌ | ✅ http-request | **缺失**:调用外部 API |
| **问题分类器** | ❌ | ✅ question-classifier | **缺失**:意图路由 |
| **模板转换** | ❌ | ✅ template-transform | **缺失**:Jinja2 格式化 |
| **变量赋值** | ❌ | ✅ variable-assigner | **缺失**:运行时变量操作 |
| **迭代** | ❌ | ✅ iteration | **缺失**:数组逐项处理(与 loop 不同) |
| **问题优化** | ❌ | ✅ question-optimiser | **缺失**:检索前 query 改写 |
### 4.2 新增节点方案
#### P0 — 必须新增(使用频率高)
**1. HTTP 请求节点(http_request)**
```
功能:发送 HTTP 请求(GET/POST/PUT/DELETE/PATCH)
配置参数:
- method: 请求方法
- url: 请求地址(支持变量模板)
- headers: 请求头(JSON)
- body: 请求体(支持变量模板)
- auth_type: 认证方式(none/api_key/bearer/basic/oauth2)
- auth_config: 认证配置
- timeout: 超时时间(秒)
- retry_count: 重试次数
输出:
- status_code: 状态码
- headers: 响应头
- body: 响应体(JSON 解析)
- raw: 原始文本
前端:HttpRequestConfig.vue
```
**2. 问题分类器节点(question_classifier)**
```
功能:基于 LLM 对用户输入进行意图分类,路由到不同分支
配置参数:
- model: 使用的 LLM(可选,默认用系统默认)
- categories: 分类列表 [{name, description}]
- instruction: 分类指令(补充说明)
输出:
- category: 分类结果
- confidence: 置信度
- 多出口:每个分类一个出口
前端:QuestionClassifierConfig.vue
引擎:调用 LLM 做分类,根据结果走不同分支(类似 condition 但基于 LLM)
```
**3. 变量赋值节点(variable_assigner)**
```
功能:在运行时对变量进行赋值/修改操作
配置参数:
- assignments: [{target_var, source_type, source_value}]
- source_type: "constant" / "upstream_output" / "template" / "expression"
输出:
- 变量值更新到 context 中
前端:VariableAssignerConfig.vue
```
#### P1 — 建议新增(提升体验)
**4. 模板转换节点(template_transform)**
```
功能:使用 Jinja2 模板语法对变量进行格式化/拼接/转换
配置参数:
- template: Jinja2 模板字符串
- output_type: 输出类型(string/json/array)
输出:
- rendered: 渲染后的文本
前端:TemplateTransformConfig.vue
引擎:使用 jinja2 库渲染模板
```
**5. 迭代节点(iteration)**
```
功能:对列表/数组逐项处理,内部可嵌套子工作流
与 loop 的区别:
- loop:固定次数或条件循环,处理同一逻辑
- iteration:遍历数组,每次迭代处理一个元素,输出结果数组
配置参数:
- input_array: 输入数组(变量引用)
- output_variable: 每次迭代的输出变量名
- max_iterations: 最大迭代次数
输出:
- output: 结果数组
前端:IterationConfig.vue
```
**6. 问题优化节点(question_optimiser)**
```
功能:在 RAG 检索前对用户 query 进行改写/扩展,提升检索召回率
配置参数:
- model: 使用的 LLM
- strategy: 优化策略(rewrite/expand/decompose)
- instruction: 自定义优化指令
输出:
- optimized_query: 优化后的查询
- original_query: 原始查询
前端:QuestionOptimiserConfig.vue
```
### 4.3 节点前端配置组件清单
| 新增节点 | 配置组件 | 复杂度 |
|---------|---------|--------|
| http_request | HttpRequestConfig.vue | 中 |
| question_classifier | QuestionClassifierConfig.vue | 中 |
| variable_assigner | VariableAssignerConfig.vue | 低 |
| template_transform | TemplateTransformConfig.vue | 低 |
| iteration | IterationConfig.vue | 中 |
| question_optimiser | QuestionOptimiserConfig.vue | 低 |
### 4.4 引擎层改造
```python
# engine.py _create_agent() 新增分支
elif node_type == "http_request":
return HttpRequestNodeAgent(...)
elif node_type == "question_classifier":
return QuestionClassifierNodeAgent(...)
elif node_type == "variable_assigner":
return VariableAssignerNodeAgent(...)
elif node_type == "template_transform":
return TemplateTransformNodeAgent(...)
elif node_type == "iteration":
return IterationNodeAgent(...)
elif node_type == "question_optimiser":
return QuestionOptimiserNodeAgent(...)
```
---
## 五、版本管理与模板系统完善
### 5.1 当前版本管理现状
- ✅ FlowVersion 模型已存在
- ✅ publish_flow / unpublish_flow / rollback_flow API 已实现
- ✅ 执行时加载 published_version 的 definition_json
- ❌ 缺少前端版本管理界面(版本列表、对比、回滚按钮)
- ❌ 缺少发布前的完整性校验
- ❌ 缺少草稿自动保存
### 5.2 完善方案
**前端版本管理**:
- FlowEditor.vue 工具栏增加"版本历史"按钮
- 版本历史弹窗:显示版本列表(版本号、发布时间、发布人、变更说明)
- 支持查看历史版本的定义 JSON
- 支持回滚到指定版本
- 发布时弹出"变更说明"输入框
**发布前校验**:
```python
async def _validate_before_publish(definition: dict) -> list[str]:
errors = []
nodes = definition.get("nodes", [])
edges = definition.get("edges", [])
# 1. 必须有起始节点
if not any(n.get("type") == "trigger" for n in nodes):
errors.append("缺少触发/起始节点")
# 2. Chatflow 必须有 LLM 节点
if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes):
errors.append("对话型流必须包含至少一个 LLM 节点")
# 3. 无孤立节点
connected_ids = set()
for e in edges:
connected_ids.add(e["source"])
connected_ids.add(e["target"])
for n in nodes:
if n["id"] not in connected_ids and len(nodes) > 1:
errors.append(f"节点 '{n.get('label', n['id'])}' 未连接")
return errors
```
**草稿自动保存**:
- 前端 FlowEditor.vue 每 30 秒自动保存草稿到 `draft_definition_json`
- 切换流时检测未保存草稿,提示用户
### 5.3 模板系统增强
**当前**:硬编码 2 个模板(文档处理流、企微通知流)
**升级后**:
- 模板存储到数据库(新增 `flow_templates` 表)
- 支持管理员创建模板(从已有流"另存为模板")
- 模板分类:客服对话 / 文档处理 / 数据分析 / 通知推送 / 自定义
- 创建流时显示模板选择页面(可选,也可从空白创建)
---
## 六、实施路线图
### Phase 1 — 基础架构升级(P0,2-3 周)
| 任务 | 涉及文件 | 依赖 |
|------|---------|------|
| 1.1 创建 PostgreSQL 记忆表 + pgvector 扩展 | init-db/03-memory-tables.sql, models/\_\_init\_\_.py | 无 |
| 1.2 改造 MemoryManager:PG 主存储 + Redis 缓存 | memory/manager.py | 1.1 |
| 1.3 Redis → PostgreSQL 数据迁移脚本 | scripts/migrate_memory_redis_to_pg.py | 1.2 |
| 1.4 创建 ModelProvider + ModelInstance 数据模型 | models/\_\_init\_\_.py, init-db/04-model-provider.sql | 无 |
| 1.5 实现供应商/模型 CRUD API | modules/model_provider/router.py | 1.4 |
| 1.6 前端 ModelProviderManager.vue | frontend/src/views/model/ | 1.5 |
| 1.7 FlowDefinition 新增 flow_mode 字段 | models/\_\_init\_\_.py, init-db/ | 无 |
| 1.8 记忆中间件按 flow_mode 区分 | flow_engine/router.py | 1.7 |
| 1.9 新增 HTTP 请求节点 | engine.py, HttpRequestConfig.vue | 无 |
| 1.10 新增问题分类器节点 | engine.py, QuestionClassifierConfig.vue | 无 |
| 1.11 新增变量赋值节点 | engine.py, VariableAssignerConfig.vue | 无 |
### Phase 2 — 记忆分层 + 模型适配(P1,2-3 周)
| 任务 | 涉及文件 | 依赖 |
|------|---------|------|
| 2.1 MemoryManager 增加 L1 结构化提取(PG 存储) | memory/manager.py | Phase 1 |
| 2.2 MemoryManager 增加 L1 去重 | memory/manager.py | 2.1 |
| 2.3 LLMNodeAgent 适配 model_instance_id | engine.py | 1.5 |
| 2.4 RAGNodeAgent 适配 embedding_model_id | engine.py | 1.5 |
| 2.5 AgentConfig 关联 model_instance_id(向后兼容) | models/\_\_init\_\_.py, agent_manager/ | 1.5 |
| 2.6 新增模板转换节点 | engine.py, TemplateTransformConfig.vue | 无 |
| 2.7 新增迭代节点 | engine.py, IterationConfig.vue | 无 |
| 2.8 发布前校验 + 版本管理前端 | flow_engine/router.py, FlowEditor.vue | 无 |
### Phase 3 — 画像 + 检索优化(P2,2-3 周)
| 任务 | 涉及文件 | 依赖 |
|------|---------|------|
| 3.1 L2 场景块提取(PG 存储) | memory/manager.py | Phase 2 |
| 3.2 L3 用户画像生成(PG 存储) | memory/manager.py | 3.1 |
| 3.3 混合检索(pgvector 向量 + tsvector 全文 + RRF) | memory/manager.py | Embedding 模型 |
| 3.4 新增问题优化节点 | engine.py, QuestionOptimiserConfig.vue | 3.3 |
| 3.5 模板系统数据库化 | flow_templates 表, router.py | 无 |
| 3.6 草稿自动保存 | FlowEditor.vue, router.py | 无 |
---
## 七、向后兼容性保障
### 7.1 AgentConfig 向后兼容
```python
# AgentConfig 改造:保留 model 字段,新增可选字段
class AgentConfig(Base):
# 现有字段保留不变
model = Column(String(50), default="gpt-4o-mini") # 保留!旧数据自动兼容
# 新增可选字段(nullable=True,旧记录自动为 NULL)
model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True)
embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True)
```
**兼容逻辑**:
```python
# engine.py 中 LLMNodeAgent 获取模型配置
def _resolve_model_config(self, config: dict) -> dict:
model_instance_id = config.get("model_instance_id")
if model_instance_id:
# 新路径:从 ModelInstance 获取完整配置
return {
"base_url": provider.base_url,
"api_key": provider.api_key,
"model": model.model_name,
"params": model.default_params,
}
else:
# 旧路径:fallback 到 AgentConfig.model + 全局 settings
return {
"base_url": settings.LLM_API_BASE,
"api_key": settings.LLM_API_KEY,
"model": config.get("model", "gpt-4o-mini"),
"params": {},
}
```
### 7.2 记忆存储迁移兼容
```python
# MemoryManager 同时支持 Redis 和 PostgreSQL
class MemoryManager:
def __init__(self, db_session_factory, redis: Redis):
self.db_session_factory = db_session_factory
self.redis = redis
async def inject_memory(self, user_id, flow_id, session_id, context):
# 优先从 Redis 缓存读取(兼容旧数据)
cached = await self._get_cached_messages(user_id, flow_id, session_id)
if cached:
recent_messages = cached
else:
# 缓存未命中,查 PostgreSQL
recent_messages = await self._query_recent_messages(user_id, flow_id, session_id)
if recent_messages:
# 回填 Redis 缓存
await self._cache_messages(user_id, flow_id, session_id, recent_messages)
# ...后续逻辑
```
### 7.3 数据库变更管理规范
所有涉及表结构改动的变更,必须:
1. 编写对应的 SQL 迁移文件到 `init-db/` 目录
2. SQL 文件使用 `IF NOT EXISTS` / `IF NOT EXISTS` 保证幂等
3. 新增列必须设置 `DEFAULT` 值或 `NULLABLE`,确保旧数据兼容
4. 迁移文件按序号命名:`01-init.sql` → `02-add-published-cols.sql``03-memory-tables.sql``04-model-provider.sql`
---
## 八、风险与注意事项
1. **pgvector 扩展安装**:需确认 Docker 镜像中的 PostgreSQL 是否包含 pgvector 扩展。如未包含,需在 Dockerfile 中添加 `RUN apt-get install postgresql-15-pgvector` 或使用 `pgvector/pgvector:pg15` 镜像
2. **Redis → PG 数据迁移**:迁移期间需短暂停写,建议在低峰期执行。迁移脚本需处理 Redis 中 JSON 序列化格式与 PG 表结构的映射
3. **LLM 调用成本**:L1 提取和去重都需要额外 LLM 调用,建议使用低成本模型(如 gpt-4o-mini)并控制触发频率
4. **模型供应商 API Key 安全**:必须加密存储,不能明文写入数据库。建议使用 `cryptography.fernet` 对称加密
5. **向后兼容**:AgentConfig.model 字段保留,model_instance_id 为可选(nullable=True),旧数据自动兼容,无需迁移
6. **节点类型注册**:新增节点需同时更新前端 nodeTypes 数组和后端 _create_agent() 分支,缺一不可
7. **流类型不可变**:flow_mode 在创建时确定,后续不可更改(Chatflow 和 Workflow 的引擎逻辑差异大)
8. **PostgreSQL 连接池**:记忆表查询复用现有连接池(pool_size=20),需监控连接数是否够用,必要时调整
9. **Redis 缓存一致性**:PG 写入成功但 Redis 缓存更新失败时,下次读取会缓存未命中走 PG 查询,自动修复,无需额外处理

8
backend/main.py

@ -19,22 +19,23 @@ from modules.rag.router import router as rag_router
from modules.chat.router import router as chat_router
from modules.custom_tool.router import router as custom_tool_router
from modules.memory.router import router as memory_router
from modules.memory.manager import init_memory_manager
from modules.memory.manager import init_memory_manager, get_memory_manager
from modules.model_provider.router import router as model_provider_router
from websocket_manager import ws_manager
from middleware.rbac_middleware import rbac_middleware
from middleware.rate_limiter import rate_limit_middleware
from middleware.cache_manager import cache_manager
from database import AsyncSessionLocal
@asynccontextmanager
async def lifespan(app: AgentApp):
await init_db()
await cache_manager.connect()
await init_memory_manager()
await init_memory_manager(AsyncSessionLocal)
yield
await cache_manager.disconnect()
try:
from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
await mm.redis.close()
except Exception:
@ -71,3 +72,4 @@ app.include_router(rag_router)
app.include_router(chat_router)
app.include_router(custom_tool_router)
app.include_router(memory_router)
app.include_router(model_provider_router)

113
backend/models/__init__.py

@ -141,6 +141,7 @@ class FlowDefinition(Base):
published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True)
draft_definition_json = Column(JSON, nullable=True, default=None)
creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
flow_mode = Column(String(20), default="chatflow")
published_to_wecom = Column(Boolean, default=False)
published_to_web = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
@ -176,6 +177,23 @@ class FlowApiKey(Base):
created_at = Column(DateTime, default=datetime.utcnow)
class FlowTemplate(Base):
__tablename__ = "flow_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False)
description = Column(Text, default="")
category = Column(String(50), default="")
definition_json = Column(JSON, nullable=False, default=dict)
icon = Column(String(50), default="")
sort_order = Column(Integer, default=0)
is_builtin = Column(Boolean, default=False)
usage_count = Column(Integer, default=0)
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class CustomTool(Base):
__tablename__ = "custom_tools"
@ -261,6 +279,8 @@ class AgentConfig(Base):
description = Column(String(500))
system_prompt = Column(Text, default="")
model = Column(String(50), default="gpt-4o-mini")
model_instance_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
embedding_model_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
temperature = Column(Float, default=0.7)
tools = Column(JSON, default=list)
status = Column(String(20), default="active")
@ -280,3 +300,96 @@ class AuditLog(Base):
detail = Column(JSON, default=dict)
ip_address = Column(String(50))
created_at = Column(DateTime, default=datetime.utcnow)
class MemoryMessage(Base):
__tablename__ = "memory_messages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
session_id = Column(UUID(as_uuid=True), nullable=False)
role = Column(String(20), nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
class MemoryAtom(Base):
__tablename__ = "memory_atoms"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
atom_type = Column(String(20), nullable=False)
content = Column(Text, nullable=False)
priority = Column(Integer, default=50)
source_session_id = Column(UUID(as_uuid=True), nullable=True)
metadata = Column(JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MemoryScene(Base):
__tablename__ = "memory_scenes"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
scene_name = Column(String(200), nullable=False)
summary = Column(Text, nullable=False)
heat = Column(Integer, default=0)
content = Column(JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MemoryPersona(Base):
__tablename__ = "memory_personas"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True)
content = Column(JSON, default=dict, nullable=False)
raw_text = Column(Text, default="")
version = Column(Integer, default=1)
updated_at = Column(DateTime, default=datetime.utcnow)
class MemorySession(Base):
__tablename__ = "memory_sessions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
session_id = Column(UUID(as_uuid=True), nullable=False)
flow_name = Column(String(200), default="")
message_count = Column(Integer, default=0)
last_active_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime, default=datetime.utcnow)
class ModelProvider(Base):
__tablename__ = "model_providers"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
provider_type = Column(String(50), nullable=False)
base_url = Column(String(500))
api_key = Column(Text)
extra_config = Column(JSON, default=dict)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ModelInstance(Base):
__tablename__ = "model_instances"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
provider_id = Column(UUID(as_uuid=True), ForeignKey("model_providers.id", ondelete="CASCADE"))
model_name = Column(String(100), nullable=False)
model_type = Column(String(30), nullable=False)
display_name = Column(String(200))
capabilities = Column(JSON, default=dict)
default_params = Column(JSON, default=dict)
is_default = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)

447
backend/modules/flow_engine/engine.py

@ -3,6 +3,7 @@ import uuid
import logging
import re
import asyncio
import httpx
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
@ -12,6 +13,35 @@ from config import settings
logger = logging.getLogger(__name__)
async def _resolve_model_instance(model_instance_id: str) -> dict | None:
try:
from database import AsyncSessionLocal
from sqlalchemy import text
import uuid as _uuid
uid = _uuid.UUID(model_instance_id)
async with AsyncSessionLocal() as db:
result = await db.execute(
text("""
SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key
FROM model_instances mi
JOIN model_providers mp ON mi.provider_id = mp.id
WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true
"""),
{"id": uid},
)
row = result.fetchone()
if row:
return {
"model": row[0],
"base_url": row[2],
"api_key": row[3],
"params": row[1] or {},
}
except Exception:
pass
return None
class FlowSessionMemory:
def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id
@ -241,6 +271,15 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
model_instance_id = config.get("model_instance_id")
base_url = settings.LLM_API_BASE
api_key = settings.LLM_API_KEY
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
model_config = resolved.get("model", model_config)
base_url = resolved.get("base_url", base_url)
api_key = resolved.get("api_key", api_key)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
max_tokens = config.get("max_tokens", 2000)
@ -254,6 +293,8 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
max_tokens=max_tokens,
stream=stream,
stream_callback=stream_cb,
base_url=base_url,
api_key=api_key,
)
memory = context.get("_memory")
if memory:
@ -297,6 +338,12 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
elif node_type == "rag":
model_instance_id = config.get("model_instance_id")
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
config = dict(config)
config["_resolved_model"] = resolved
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
@ -314,6 +361,24 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
sandbox = config.get("sandbox", True)
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
elif node_type == "http_request":
return HttpRequestNodeAgent(node_id=node_id, config=config)
elif node_type == "question_classifier":
return QuestionClassifierNodeAgent(node_id=node_id, config=config)
elif node_type == "variable_assigner":
return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context)
elif node_type == "template_transform":
return TemplateTransformNodeAgent(node_id=node_id, config=config)
elif node_type == "iteration":
return IterationNodeAgent(node_id=node_id, config=config)
elif node_type == "question_optimiser":
return QuestionOptimiserNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
@ -330,8 +395,178 @@ class PassThroughAgent(AgentBase):
pass
class HttpRequestNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"HttpRequest_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
method = self.config.get("method", "GET").upper()
url = self.config.get("url", "")
headers = self.config.get("headers", {})
body = self.config.get("body", "")
auth_type = self.config.get("auth_type", "none")
auth_config = self.config.get("auth_config", {})
timeout = self.config.get("timeout", 30)
retry_count = self.config.get("retry_count", 0)
if isinstance(headers, str):
try:
headers = json.loads(headers)
except (json.JSONDecodeError, ValueError):
headers = {}
request_headers = dict(headers)
if auth_type == "bearer" and auth_config.get("token"):
request_headers["Authorization"] = f"Bearer {auth_config['token']}"
elif auth_type == "api_key" and auth_config.get("api_key"):
key_name = auth_config.get("key_name", "X-API-Key")
request_headers[key_name] = auth_config["api_key"]
elif auth_type == "basic" and auth_config.get("username"):
import base64
credentials = f"{auth_config['username']}:{auth_config.get('password', '')}"
request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}"
last_error = None
for attempt in range(retry_count + 1):
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
request_kwargs = {"headers": request_headers}
if method in ("POST", "PUT", "PATCH"):
if isinstance(body, dict):
request_kwargs["json"] = body
elif body and isinstance(body, str):
try:
request_kwargs["json"] = json.loads(body)
except (json.JSONDecodeError, ValueError):
request_kwargs["content"] = body
response = await client.request(method, url, **request_kwargs)
resp_status = response.status_code
resp_text = response.text
try:
resp_body = json.loads(resp_text)
except (json.JSONDecodeError, ValueError):
resp_body = resp_text
result = json.dumps({
"status_code": resp_status,
"body": resp_body,
}, ensure_ascii=False)
return Msg(self.name, result, "assistant")
except Exception as e:
last_error = str(e)
if attempt < retry_count:
await asyncio.sleep(1)
error_result = json.dumps({
"status_code": 0,
"error": last_error or "unknown error",
"body": None,
}, ensure_ascii=False)
return Msg(self.name, error_result, "assistant")
async def observe(self, msg) -> None:
pass
class QuestionClassifierNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Classifier_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
categories = self.config.get("categories", [])
instruction = self.config.get("instruction", "")
model_name = self.config.get("model", settings.LLM_MODEL)
temperature = self.config.get("temperature", 0.3)
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not categories:
return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant")
category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories])
prompt = f"""请对以下用户输入进行意图分类。
分类选项
{category_desc}
{instruction}
用户输入{user_text}
请只返回一个JSON对象格式为{{"category": "分类名称", "confidence": 0.0~1.0}}"""
try:
import openai
client = openai.AsyncOpenAI(
base_url=settings.LLM_API_BASE,
api_key=settings.LLM_API_KEY,
)
response = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=200,
)
result_text = response.choices[0].message.content.strip()
try:
parsed = json.loads(result_text)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant")
except Exception as e:
logger.error(f"QuestionClassifier error: {e}")
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant")
async def observe(self, msg) -> None:
pass
class VariableAssignerNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None, context: dict = None):
super().__init__()
self.name = f"VarAssign_{node_id}"
self.config = config or {}
self._context = context or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
assignments = self.config.get("assignments", [])
results = {}
for assignment in assignments:
target_var = assignment.get("target_var", "")
source_type = assignment.get("source_type", "constant")
source_value = assignment.get("source_value", "")
if source_type == "constant":
value = source_value
elif source_type == "upstream_output":
value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
elif source_type == "template":
value = _resolve_template(source_value, self._context, msg)
elif source_type == "expression":
try:
safe_locals = {"msg": msg, "context": self._context}
value = eval(source_value, {"__builtins__": {}}, safe_locals)
value = str(value)
except Exception as e:
value = f"[expression error: {e}]"
else:
value = source_value
results[target_var] = value
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
@ -340,6 +575,8 @@ class LLMNodeAgent(AgentBase):
self.max_tokens = max_tokens
self.stream = stream
self.stream_callback = stream_callback
self.base_url = base_url or settings.LLM_API_BASE
self.api_key = api_key or settings.LLM_API_KEY
self._memory = None
def set_memory(self, memory):
@ -390,29 +627,40 @@ class LLMNodeAgent(AgentBase):
return Msg(self.name, res_text, "assistant")
async def _blocking_llm_call(self, messages: list[dict]) -> str:
from agentscope_integration.factory import AgentFactory
from agentscope.formatter import OpenAIChatFormatter
import httpx
model = AgentFactory._get_model()
formatter = OpenAIChatFormatter()
scope_msgs = []
for m in messages:
scope_msgs.append(Msg(m["role"], m["content"], m["role"]))
prompt = formatter.format(scope_msgs)
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": False,
}
res = await model(prompt)
if isinstance(res, list):
return res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
return res.get_text_content()
return str(res)
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
resp = await client.post(url, json=body, headers=headers)
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception as e:
logger.warning(f"LLM 阻塞调用失败: {e}")
return f"[LLM 调用失败: {e}]"
async def _stream_llm_call(self, messages: list[dict]) -> str:
import httpx
import json
api_base = settings.LLM_API_BASE.rstrip("/")
api_key = settings.LLM_API_KEY
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
@ -977,11 +1225,12 @@ class RAGNodeAgent(AgentBase):
def _get_model(self):
from agentscope.model import OpenAIChatModel
resolved = self.config.get("_resolved_model", {})
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
model_name=resolved.get("model", settings.LLM_MODEL),
api_key=resolved.get("api_key", settings.LLM_API_KEY),
api_base=resolved.get("base_url", settings.LLM_API_BASE),
)
def _get_formatter(self):
@ -1050,6 +1299,164 @@ def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
return result
class TemplateTransformNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Template_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
template = self.config.get("template", "")
output_type = self.config.get("output_type", "string")
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
rendered = _resolve_template(template, context, msg)
if not rendered:
rendered = template
rendered = rendered.replace("{{input}}", user_text)
except Exception:
rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text
if output_type == "json":
try:
parsed = json.loads(rendered)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
pass
return Msg(self.name, rendered, "assistant")
async def observe(self, msg) -> None:
pass
class IterationNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Iteration_{node_id}"
self.config = config or {}
self._results: list[str] = []
async def reply(self, msg: Msg, **kwargs) -> Msg:
input_array_source = self.config.get("input_array_source", "")
max_iterations = self.config.get("max_iterations", 20)
context = kwargs.get("context", {})
items = []
if input_array_source:
resolved = _resolve_template(input_array_source, context, msg)
try:
items = json.loads(resolved)
if not isinstance(items, list):
items = [resolved]
except (json.JSONDecodeError, ValueError):
items = [resolved] if resolved else []
else:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
items = json.loads(user_text)
if not isinstance(items, list):
items = [user_text]
except (json.JSONDecodeError, ValueError):
items = [item.strip() for item in user_text.split("\n") if item.strip()]
if not items:
items = [user_text]
items = items[:max_iterations]
results = []
for i, item in enumerate(items):
results.append({
"index": i,
"item": item,
})
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
def get_iteration_items(self) -> list:
return self._results
async def observe(self, msg) -> None:
pass
class QuestionOptimiserNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"QOpt_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
optimization_type = self.config.get("optimization_type", "rewrite")
model_name = self.config.get("model", settings.LLM_MODEL)
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
history = context.get("_memory", {}).get("recent_messages", [])
persona = context.get("_memory", {}).get("persona", {})
atoms = context.get("_memory", {}).get("atoms", [])
persona_text = persona.get("raw_text", "")
atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else ""
history_text = "\n".join(
f"{m['role']}: {m['content'][:200]}" for m in history[-6:]
) if history else ""
context_parts = []
if persona_text:
context_parts.append(f"用户画像: {persona_text}")
if atoms_text:
context_parts.append(f"已知信息: {atoms_text}")
if history_text:
context_parts.append(f"近期对话: {history_text}")
context_block = "\n".join(context_parts) if context_parts else ""
if optimization_type == "rewrite":
prompt = f"""{context_block}
原始问题: {user_text}
请将以上问题进行优化改写使其更清晰具体完整补充可能缺失的上下文信息
只返回优化后的问题不要其他内容"""
elif optimization_type == "expand":
prompt = f"""{context_block}
简短问题: {user_text}
请将以上问题扩展为更详细的版本添加必要的背景和细节
只返回扩展后的问题不要其他内容"""
else:
return Msg(self.name, user_text, "assistant")
try:
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text)
return Msg(self.name, result.strip(), "assistant")
except Exception as e:
logger.warning(f"问题优化失败: {e}")
return Msg(self.name, user_text, "assistant")
async def observe(self, msg) -> None:
pass
class ParallelMergeNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()

8
backend/modules/flow_engine/gateway.py

@ -80,12 +80,12 @@ async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)):
username = user
if response_mode == "streaming":
return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db)
return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db, session_id)
return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db)
return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db, session_id)
async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db):
async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db, session_id=None):
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=input_text, role="user")
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
@ -127,7 +127,7 @@ async def _chat_blocking(flow_id, definition, input_text, user_id, username, flo
raise HTTPException(500, f"流执行失败: {str(e)}")
async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db):
async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db, session_id=None):
async def event_generator():
import asyncio
engine = FlowEngine(definition)

249
backend/modules/flow_engine/router.py

@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User
from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User, MemoryMessage, FlowTemplate
from schemas import (
FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut,
FlowVersionOut, FlowApiKeyCreate, FlowApiKeyOut,
@ -34,6 +34,7 @@ def _build_flow_out(f) -> FlowDefinitionOut:
published_version_id=getattr(f, 'published_version_id', None),
published_to_wecom=getattr(f, 'published_to_wecom', False),
published_to_web=getattr(f, 'published_to_web', False),
flow_mode=getattr(f, 'flow_mode', 'chatflow') or 'chatflow',
created_at=f.created_at, updated_at=f.updated_at,
)
@ -71,6 +72,7 @@ async def create_flow(req: FlowDefinitionCreate, request: Request, db: AsyncSess
definition_json=definition_json,
draft_definition_json=definition_json,
creator_id=uuid.UUID(user_ctx["id"]),
flow_mode=req.flow_mode,
)
db.add(flow)
await db.flush()
@ -133,6 +135,43 @@ async def _snapshot_publish(flow: FlowDefinition, db: AsyncSession, user_id: str
return new_version
def validate_flow_definition(definition: dict, flow_mode: str = "chatflow") -> list[str]:
errors = []
nodes = definition.get("nodes", [])
edges = definition.get("edges", [])
if not nodes:
errors.append("流定义中没有节点")
return errors
if not any(n.get("type") == "trigger" for n in nodes):
errors.append("缺少触发/起始节点 (trigger)")
if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes):
errors.append("对话型流必须包含至少一个 LLM 节点")
node_ids = {n["id"] for n in nodes}
connected_ids = set()
for e in edges:
connected_ids.add(e.get("source", ""))
connected_ids.add(e.get("target", ""))
for n in nodes:
if n["id"] not in connected_ids and len(nodes) > 1:
errors.append(f"节点 '{n.get('label', n['id'])}' 未连接")
return errors
@router.post("/definitions/{flow_id}/validate")
async def validate_flow(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
f = await db.get(FlowDefinition, flow_id)
if not f:
raise HTTPException(404, "流定义不存在")
flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
errors = validate_flow_definition(f.definition_json, flow_mode)
return {"code": 200, "valid": len(errors) == 0, "errors": errors}
@router.post("/definitions/{flow_id}/publish")
async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
f = await db.get(FlowDefinition, flow_id)
@ -141,8 +180,18 @@ async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession =
nodes = f.definition_json.get("nodes", [])
if not nodes:
raise HTTPException(400, "流定义中没有节点")
flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
errors = validate_flow_definition(f.definition_json, flow_mode)
if errors:
raise HTTPException(400, f"流校验失败: {'; '.join(errors)}")
user_ctx = request.state.user
await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True)
body = {}
try:
body = await request.json()
except Exception:
pass
changelog = body.get("changelog", "")
await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True, changelog=changelog)
return {"code": 200, "message": "流已上架到企微", "data": {"status": "published", "version": f.version}}
@ -261,12 +310,14 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db:
try:
from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
await mm.inject_memory(
user_id=user_ctx["id"],
flow_id=str(flow_id),
session_id=session_id,
context=context,
)
flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
if flow_mode == 'chatflow':
await mm.inject_memory(
user_id=user_ctx["id"],
flow_id=str(flow_id),
session_id=session_id,
context=context,
)
except Exception as e:
logger.debug(f"记忆注入跳过: {e}")
@ -279,14 +330,15 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db:
try:
from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
asyncio.create_task(mm.record_exchange(
user_id=user_ctx["id"],
flow_id=str(flow_id),
session_id=session_id,
user_msg=input_text,
assistant_msg=output_text,
flow_name=f.name,
))
if flow_mode == 'chatflow':
asyncio.create_task(mm.record_exchange(
user_id=user_ctx["id"],
flow_id=str(flow_id),
session_id=session_id,
user_msg=input_text,
assistant_msg=output_text,
flow_name=f.name,
))
except Exception as e:
logger.debug(f"记忆记录跳过: {e}")
@ -383,6 +435,8 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
except Exception:
mm_stream = None
flow_mode_stream = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
async def event_generator():
engine = FlowEngine(definition)
context = {
@ -394,7 +448,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
"_stream_callback": None,
}
if mm_stream:
if mm_stream and flow_mode_stream == 'chatflow':
try:
await mm_stream.inject_memory(
user_id=user_ctx["id"],
@ -476,7 +530,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
yield f"data: {json.dumps({'event': 'workflow_finished', 'data': {'output': output_text, 'session_id': session_id, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}, 'latency_ms': elapsed_ms}}, ensure_ascii=False)}\n\n"
if mm_stream:
if mm_stream and flow_mode_stream == 'chatflow':
try:
asyncio.create_task(mm_stream.record_exchange(
user_id=user_ctx["id"],
@ -501,8 +555,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
finished_at=datetime.utcnow(),
)
db.add(execution)
finally:
yield "data: [DONE]\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@ -615,86 +668,6 @@ async def list_executions(
# ============================== 模板 ==============================
FLOW_TEMPLATES = [
{
"id": "tpl_doc_process", "name": "文档处理流", "description": "自动解析文档内容,提取关键信息并生成摘要", "icon": "Document",
"nodes": [
{"id": "n1", "type": "trigger", "label": "文档上传", "config": {"event_type": "document_upload"}, "position": {"x": 100, "y": 100}},
{"id": "n2", "type": "tool", "label": "解析文档", "config": {"tool_name": "parse_document"}, "position": {"x": 400, "y": 100}},
{"id": "n3", "type": "llm", "label": "生成摘要", "config": {"system_prompt": "请为以下文档内容生成简洁摘要", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 700, "y": 100}},
{"id": "n4", "type": "output", "label": "输出结果", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}},
],
"edges": [
{"source": "n1", "target": "n2", "sourceHandle": "source"},
{"source": "n2", "target": "n3", "sourceHandle": "source"},
{"source": "n3", "target": "n4", "sourceHandle": "source"},
],
},
{
"id": "tpl_wecom_notify", "name": "企微通知流", "description": "接收触发后查询数据并推送企微通知", "icon": "Bell",
"nodes": [
{"id": "n1", "type": "trigger", "label": "定时触发", "config": {"event_type": "scheduled"}, "position": {"x": 100, "y": 100}},
{"id": "n2", "type": "tool", "label": "查询任务", "config": {"tool_name": "list_tasks"}, "position": {"x": 400, "y": 100}},
{"id": "n3", "type": "condition", "label": "有待办任务?", "config": {"condition": "tasks.length > 0"}, "position": {"x": 700, "y": 100}},
{"id": "n4", "type": "wecom_notify", "label": "推送通知", "config": {"message_template": "您有{{tasks.length}}条待办任务", "target": "@all"}, "position": {"x": 1000, "y": 50}},
{"id": "n5", "type": "output", "label": "无任务", "config": {"format": "text"}, "position": {"x": 1000, "y": 200}},
],
"edges": [
{"source": "n1", "target": "n2", "sourceHandle": "source"},
{"source": "n2", "target": "n3", "sourceHandle": "source"},
{"source": "n3", "target": "n4", "sourceHandle": "true"},
{"source": "n3", "target": "n5", "sourceHandle": "false"},
],
},
{
"id": "tpl_data_analysis", "name": "数据分析流", "description": "查询员工数据并生成效率分析报告", "icon": "DataAnalysis",
"nodes": [
{"id": "n1", "type": "trigger", "label": "分析请求", "config": {"event_type": "button_click"}, "position": {"x": 100, "y": 100}},
{"id": "n2", "type": "tool", "label": "查询下属", "config": {"tool_name": "list_subordinates"}, "position": {"x": 400, "y": 100}},
{"id": "n3", "type": "tool", "label": "统计数据", "config": {"tool_name": "get_task_statistics"}, "position": {"x": 700, "y": 100}},
{"id": "n4", "type": "llm", "label": "生成报告", "config": {"system_prompt": "基于以下数据生成团队效率分析报告", "model": "gpt-4o", "temperature": 0.7}, "position": {"x": 1000, "y": 100}},
{"id": "n5", "type": "output", "label": "报告输出", "config": {"format": "json"}, "position": {"x": 1300, "y": 100}},
],
"edges": [
{"source": "n1", "target": "n2", "sourceHandle": "source"},
{"source": "n2", "target": "n3", "sourceHandle": "source"},
{"source": "n3", "target": "n4", "sourceHandle": "source"},
{"source": "n4", "target": "n5", "sourceHandle": "source"},
],
},
{
"id": "tpl_rag_qa", "name": "知识库问答流", "description": "从知识库检索信息后由LLM回答", "icon": "Search",
"nodes": [
{"id": "n1", "type": "trigger", "label": "问题触发", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}},
{"id": "n2", "type": "rag", "label": "知识检索", "config": {"knowledge_base": "default", "top_k": 5}, "position": {"x": 400, "y": 100}},
{"id": "n3", "type": "llm", "label": "生成回答", "config": {"system_prompt": "基于知识库检索结果回答用户问题", "model": "gpt-4o-mini", "temperature": 0.3}, "position": {"x": 700, "y": 100}},
{"id": "n4", "type": "output", "label": "输出答案", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}},
],
"edges": [
{"source": "n1", "target": "n2", "sourceHandle": "source"},
{"source": "n2", "target": "n3", "sourceHandle": "source"},
{"source": "n3", "target": "n4", "sourceHandle": "source"},
],
},
{
"id": "tpl_task_auto", "name": "任务自动分配流", "description": "根据描述自动创建任务并分派给合适人员", "icon": "Tools",
"nodes": [
{"id": "n1", "type": "trigger", "label": "任务描述", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}},
{"id": "n2", "type": "llm", "label": "分析任务", "config": {"system_prompt": "分析以下任务描述,提取标题、优先级、负责人", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 400, "y": 100}},
{"id": "n3", "type": "tool", "label": "创建任务", "config": {"tool_name": "create_task"}, "position": {"x": 700, "y": 100}},
{"id": "n4", "type": "wecom_notify", "label": "通知负责人", "config": {"message_template": "您有新任务: {{task_title}}", "target": "@all"}, "position": {"x": 1000, "y": 100}},
{"id": "n5", "type": "output", "label": "完成", "config": {"format": "text"}, "position": {"x": 1300, "y": 100}},
],
"edges": [
{"source": "n1", "target": "n2", "sourceHandle": "source"},
{"source": "n2", "target": "n3", "sourceHandle": "source"},
{"source": "n3", "target": "n4", "sourceHandle": "source"},
{"source": "n4", "target": "n5", "sourceHandle": "source"},
],
},
]
@router.get("/market", response_model=list[FlowDefinitionOut])
async def flow_market(db: AsyncSession = Depends(get_db)):
result = await db.execute(
@ -704,24 +677,80 @@ async def flow_market(db: AsyncSession = Depends(get_db)):
@router.get("/templates")
async def get_flow_templates(request: Request):
return {"code": 200, "data": FLOW_TEMPLATES}
async def get_flow_templates(db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(FlowTemplate).order_by(FlowTemplate.sort_order.asc(), FlowTemplate.created_at.desc())
)
templates = result.scalars().all()
return {"code": 200, "data": [
{
"id": str(t.id), "name": t.name, "description": t.description,
"category": t.category, "definition_json": t.definition_json,
"icon": t.icon, "is_builtin": t.is_builtin, "usage_count": t.usage_count,
}
for t in templates
]}
@router.post("/templates")
async def create_flow_template(request: Request, db: AsyncSession = Depends(get_db)):
body = await request.json()
user_ctx = request.state.user
ft = FlowTemplate(
name=body.get("name", ""),
description=body.get("description", ""),
category=body.get("category", "general"),
definition_json=body.get("definition_json", {}),
icon=body.get("icon", ""),
sort_order=body.get("sort_order", 0),
created_by=uuid.UUID(user_ctx["id"]),
)
db.add(ft)
await db.commit()
await db.refresh(ft)
return {"code": 200, "data": {"id": str(ft.id), "name": ft.name}}
@router.put("/templates/{template_id}")
async def update_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
ft = await db.get(FlowTemplate, template_id)
if not ft:
raise HTTPException(404, "模板不存在")
body = await request.json()
for field in ("name", "description", "category", "definition_json", "icon", "sort_order"):
if field in body:
setattr(ft, field, body[field])
await db.commit()
return {"code": 200, "data": {"id": str(ft.id)}}
@router.delete("/templates/{template_id}")
async def delete_flow_template(template_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
ft = await db.get(FlowTemplate, template_id)
if not ft:
raise HTTPException(404, "模板不存在")
if ft.is_builtin:
raise HTTPException(400, "内置模板不可删除")
await db.delete(ft)
await db.commit()
return {"code": 200, "message": "模板已删除"}
@router.post("/templates/{template_id}/use")
async def use_flow_template(template_id: str, request: Request, db: AsyncSession = Depends(get_db)):
template = next((t for t in FLOW_TEMPLATES if t["id"] == template_id), None)
if not template:
async def use_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
ft = await db.get(FlowTemplate, template_id)
if not ft:
raise HTTPException(404, "模板不存在")
ft.usage_count = (ft.usage_count or 0) + 1
user_ctx = request.state.user
definition_json = {"nodes": template["nodes"], "edges": template["edges"], "trigger": {}}
flow = FlowDefinition(
name=template["name"] + " (副本)",
description=template["description"],
definition_json=definition_json,
draft_definition_json=definition_json,
name=ft.name + " (副本)",
description=ft.description or "",
definition_json=ft.definition_json,
draft_definition_json=ft.definition_json,
creator_id=uuid.UUID(user_ctx["id"]),
)
db.add(flow)
await db.flush()
await db.commit()
return _build_flow_out(flow)

4
backend/modules/memory/__init__.py

@ -1,4 +1,4 @@
from .manager import MemoryManager, get_memory_manager
from .manager import MemoryManager, get_memory_manager, init_memory_manager
from .router import router
__all__ = ["MemoryManager", "get_memory_manager", "router"]
__all__ = ["MemoryManager", "get_memory_manager", "init_memory_manager", "router"]

820
backend/modules/memory/manager.py

@ -1,8 +1,13 @@
import json
import asyncio
import uuid
import logging
from datetime import datetime
from datetime import datetime, timezone
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
logger = logging.getLogger(__name__)
@ -17,21 +22,27 @@ def get_memory_manager() -> "MemoryManager":
return _memory_manager
async def init_memory_manager():
async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
global _memory_manager
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await redis.ping()
_memory_manager = MemoryManager(redis)
_memory_manager = MemoryManager(db_factory, redis)
class MemoryManager:
KEY_PREFIX = "mem"
DEFAULT_TTL = 604800
SESSION_INDEX_TTL = 2592000
MAX_HISTORY = 40
REDIS_CACHE_SIZE = 10
REDIS_CACHE_TTL = 300
SUMMARY_CACHE_KEY = "mem:summary"
MSG_CACHE_KEY = "mem:cache:msgs"
ATOM_EXTRACT_EVERY = 10
SCENE_EXTRACT_EVERY = 50
PERSONA_UPDATE_EVERY = 30
def __init__(self, redis: Redis):
def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis):
self.db_factory = db_factory
self.redis = redis
self._extract_tasks: dict[str, asyncio.Task] = {}
async def inject_memory(
self,
@ -40,12 +51,27 @@ class MemoryManager:
session_id: str,
context: dict,
):
messages = await self._get_recent_messages(user_id, flow_id, session_id)
summary = await self._get_summary(user_id, flow_id, session_id)
uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id)
cached = await self._redis_get_recent(uid, sid)
if cached:
recent_messages = cached
else:
recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY)
if recent_messages:
await self._redis_set_recent(uid, sid, recent_messages)
summary = await self._get_summary(uid, fid, sid)
atoms = await self._get_relevant_atoms(uid, fid)
persona = await self._get_persona(uid)
context["_memory_context"] = {
"recent_messages": list(reversed(messages)),
"recent_messages": recent_messages,
"summary": summary,
"atoms": atoms,
"persona": persona,
"session_id": session_id,
}
@ -58,115 +84,238 @@ class MemoryManager:
assistant_msg: str,
flow_name: str = "",
):
key = self._msg_key(user_id, flow_id, session_id)
ts = datetime.utcnow().isoformat()
uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id)
ts = datetime.now(timezone.utc)
try:
async with self.redis.pipeline() as pipe:
pipe.lpush(key,
json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False),
json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False),
async with self.db_factory() as db:
user_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"role": "user",
"content": user_msg,
"created_at": ts,
}
asst_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"role": "assistant",
"content": assistant_msg,
"created_at": ts,
}
await db.execute(
text("""
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
"""),
user_row,
)
await db.execute(
text("""
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
"""),
asst_row,
)
pipe.ltrim(key, 0, self.MAX_HISTORY - 1)
pipe.expire(key, self.DEFAULT_TTL)
pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={
session_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"flow_name": flow_name,
"last_active_at": ts,
})
pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL)
"created_at": ts,
}
await db.execute(
text("""
INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at)
ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET
message_count = memory_sessions.message_count + 2,
last_active_at = EXCLUDED.last_active_at,
flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name)
"""),
session_row,
)
pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL)
await db.commit()
except Exception as e:
logger.warning(f"记录记忆失败(PG): {e}")
return
await pipe.execute()
try:
new_msgs = [
{"role": "user", "content": user_msg, "ts": ts.isoformat()},
{"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
]
await self._redis_append_recent(uid, sid, new_msgs)
except Exception as e:
logger.warning(f"记录记忆失败: {e}")
logger.debug(f"Redis缓存更新失败: {e}")
asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id))
asyncio.create_task(self._maybe_summarize(uid, fid, sid))
asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid))
asyncio.create_task(self._maybe_extract_scenes(uid, fid))
asyncio.create_task(self._maybe_update_persona(uid, fid))
async def get_conversation_history(
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
) -> list[dict]:
messages = await self._get_recent_messages(user_id, flow_id, session_id, limit)
return list(reversed(messages))
uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id)
fid = uuid.UUID(flow_id) if flow_id else None
return await self._pg_get_recent(uid, fid, sid, limit)
async def delete_session(self, user_id: str, session_id: str):
uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id)
try:
patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*")
async with self.redis.pipeline() as pipe:
if patterns:
pipe.delete(*patterns)
pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
await pipe.execute()
async with self.db_factory() as db:
await db.execute(
text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
await db.execute(
text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
await db.commit()
except Exception as e:
logger.warning(f"清除记忆失败: {e}")
logger.warning(f"清除记忆失败(PG): {e}")
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
await self.redis.delete(cache_key)
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
await self.redis.delete(summary_key)
except Exception:
pass
async def list_user_sessions(self, user_id: str) -> list[dict]:
uid = uuid.UUID(user_id)
try:
session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions")
async with self.db_factory() as db:
result = await db.execute(
text("""
SELECT session_id, flow_id, flow_name, last_active_at
FROM memory_sessions
WHERE user_id = :uid
ORDER BY last_active_at DESC
LIMIT 100
"""),
{"uid": uid},
)
rows = result.fetchall()
return [
{
"session_id": str(row[0]),
"flow_id": str(row[1]),
"flow_name": row[2] or "",
"last_active_at": row[3].isoformat() if row[3] else "",
}
for row in rows
]
except Exception:
return []
sessions = []
for sid in session_ids:
try:
keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta")
for k in keys:
meta = await self.redis.hgetall(k)
parts = k.split(":")
flow_id = parts[2] if len(parts) > 2 else ""
sessions.append({
"session_id": sid,
"flow_id": flow_id,
"flow_name": meta.get("flow_name", ""),
"last_active_at": meta.get("last_active_at", ""),
})
except Exception:
continue
async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]:
try:
async with self.db_factory() as db:
if fid:
result = await db.execute(
text("""
SELECT role, content, created_at
FROM memory_messages
WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid
ORDER BY created_at DESC
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "sid": sid, "limit": limit},
)
else:
result = await db.execute(
text("""
SELECT role, content, created_at
FROM memory_messages
WHERE user_id = :uid AND session_id = :sid
ORDER BY created_at DESC
LIMIT :limit
"""),
{"uid": uid, "sid": sid, "limit": limit},
)
rows = result.fetchall()
return [
{"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""}
for row in reversed(rows)
]
except Exception:
return []
return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True)
async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None:
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
raw = await self.redis.get(cache_key)
if raw:
return json.loads(raw)
except Exception:
pass
return None
async def _get_recent_messages(
self, user_id: str, flow_id: str, session_id: str, limit: int = None
) -> list[dict]:
limit = limit or self.MAX_HISTORY
try:
key = self._msg_key(user_id, flow_id, session_id)
raw = await self.redis.lrange(key, 0, limit - 1)
result = []
for m in raw:
try:
result.append(json.loads(m))
except json.JSONDecodeError:
continue
return result
async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]):
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
top = messages[-self.REDIS_CACHE_SIZE:]
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False))
except Exception:
return []
pass
async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str:
async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]):
try:
key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
val = await self.redis.get(key)
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
existing = await self._redis_get_recent(uid, sid) or []
combined = existing + new_msgs
recent = combined[-self.REDIS_CACHE_SIZE:]
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False))
except Exception:
pass
async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str:
try:
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
val = await self.redis.get(summary_key)
return val or ""
except Exception:
return ""
async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str):
async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
try:
key = self._msg_key(user_id, flow_id, session_id)
count = await self.redis.llen(key)
if count < 30:
return
summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
existing = await self.redis.get(summary_key)
if existing:
return
recent = await self._get_recent_messages(user_id, flow_id, session_id, 20)
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
row = result.fetchone()
count = row[0] if row else 0
if count < 30:
return
recent = await self._pg_get_recent(uid, fid, sid, 20)
dialogue = "\n".join(
f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10])
f"{m['role']}: {m['content'][:500]}" for m in recent[-10:]
)
import httpx
@ -192,10 +341,511 @@ class MemoryManager:
except Exception:
pass
@staticmethod
def _msg_key(user_id: str, flow_id: str, session_id: str) -> str:
return f"mem:{user_id}:{flow_id}:{session_id}:messages"
async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]:
try:
async with self.db_factory() as db:
result = await db.execute(
text("""
SELECT id, atom_type, content, priority, metadata
FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY priority DESC, updated_at DESC
LIMIT 20
"""),
{"uid": uid, "fid": fid},
)
rows = result.fetchall()
return [
{"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]}
for row in rows
]
except Exception:
return []
async def _get_persona(self, uid: uuid.UUID) -> dict:
try:
async with self.db_factory() as db:
result = await db.execute(
text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
row = result.fetchone()
if row:
return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1}
except Exception:
pass
return {}
async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
try:
task_key = f"extract_atoms:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"),
{"uid": uid, "fid": fid, "sid": sid},
)
row = result.fetchone()
count = row[0] if row else 0
last_extract = await db.execute(
text("""
SELECT COUNT(*) FROM memory_atoms
WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str
"""),
{"uid": uid, "fid": fid, "sid_str": str(sid)},
)
last_count = last_extract.fetchone()[0]
if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY:
return
recent = await self._pg_get_recent(uid, fid, sid, 20)
dialogue = "\n".join(
f"{m['role']}: {m['content'][:400]}" for m in recent[-15:]
)
task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str):
try:
prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。
对话内容
{dialogue}
请以JSON数组格式返回每个原子包含以下字段
- atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求)
- content: 原子的具体内容一句话描述
- priority: 0-100的整数表示重要性越高越核心
格式示例
[
{{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}},
{{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}}
]
只返回JSON数组不要其他内容如果没有可提取的信息返回空数组[]"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 800,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
try:
atoms = json.loads(result_text)
if not isinstance(atoms, list):
return
except (json.JSONDecodeError, ValueError):
return
await self._dedup_and_store_atoms(uid, fid, sid, atoms)
except Exception as e:
logger.warning(f"L1原子提取失败: {e}")
async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]):
try:
async with self.db_factory() as db:
existing = await db.execute(
text("""
SELECT id, content FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
"""),
{"uid": uid, "fid": fid},
)
existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()]
for atom in atoms:
content = atom.get("content", "").strip()
if not content:
continue
is_dup = False
for ex in existing_atoms:
if self._text_similarity(content, ex["content"]) > 0.75:
existing_atoms.remove(ex)
await db.execute(
text("""
UPDATE memory_atoms
SET priority = GREATEST(priority, :priority),
updated_at = NOW(),
metadata = metadata || :meta
WHERE id = :id
"""),
{
"id": ex["id"],
"priority": atom.get("priority", 50),
"meta": json.dumps({"last_source_session_id": str(sid)}),
},
)
is_dup = True
break
if not is_dup:
await db.execute(
text("""
INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at)
VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW())
"""),
{
"uid": uid,
"fid": fid,
"atype": atom.get("atom_type", "episodic"),
"content": content,
"priority": atom.get("priority", 50),
"sid": sid,
"meta": json.dumps({"source_session_id": str(sid)}),
},
)
await db.commit()
except Exception as e:
logger.warning(f"L1原子存储失败: {e}")
@staticmethod
def _meta_key(user_id: str, flow_id: str, session_id: str) -> str:
return f"mem:{user_id}:{flow_id}:{session_id}:meta"
def _text_similarity(a: str, b: str) -> float:
a_words = set(a.lower().split())
b_words = set(b.lower().split())
if not a_words or not b_words:
return 0.0
intersection = a_words & b_words
union = a_words | b_words
return len(intersection) / len(union)
async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID):
try:
task_key = f"extract_scenes:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
{"uid": uid, "fid": fid},
)
row = result.fetchone()
count = row[0] if row else 0
scene_count_result = await db.execute(
text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
{"uid": uid, "fid": fid},
)
scene_count = scene_count_result.fetchone()[0]
if count < self.SCENE_EXTRACT_EVERY:
return
if scene_count > 0:
atoms_result = await db.execute(
text("""
SELECT updated_at FROM memory_scenes
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY updated_at DESC LIMIT 1
"""),
{"uid": uid, "fid": fid},
)
latest_scene = atoms_result.fetchone()
if latest_scene:
from datetime import timezone, timedelta
ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc)
if ago < timedelta(hours=12):
return
atoms_result = await db.execute(
text("""
SELECT atom_type, content, priority FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY priority DESC LIMIT 30
"""),
{"uid": uid, "fid": fid},
)
atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()]
if len(atoms) < 10:
return
task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]):
try:
atoms_text = "\n".join(
f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms
)
prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。
记忆原子
{atoms_text}
请以JSON数组格式返回每个场景块包含
- scene_name: 场景名称简短标签
- summary: 场景摘要2-3句话总结该场景的关键信息
格式示例
[
{{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}}
]
只返回JSON数组不要其他内容"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
try:
scenes = json.loads(result_text)
if not isinstance(scenes, list):
return
except (json.JSONDecodeError, ValueError):
return
async with self.db_factory() as db:
for scene in scenes:
await db.execute(
text("""
INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at)
VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW())
"""),
{
"uid": uid,
"fid": fid,
"name": scene.get("scene_name", ""),
"summary": scene.get("summary", ""),
"content": json.dumps(scene, ensure_ascii=False),
},
)
await db.commit()
except Exception as e:
logger.warning(f"L2场景提取失败: {e}")
async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID):
try:
task_key = f"update_persona:{uid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
msg_count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"),
{"uid": uid},
)
row = result.fetchone()
msg_count = row[0] if row else 0
persona_result = await db.execute(
text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
persona_row = persona_result.fetchone()
if persona_row:
from datetime import timezone, timedelta
ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc)
if ago < timedelta(hours=6):
return
if msg_count < self.PERSONA_UPDATE_EVERY:
return
persona_atoms = []
async with self.db_factory() as db:
atoms_result = await db.execute(
text("""
SELECT content FROM memory_atoms
WHERE user_id = :uid AND atom_type = 'persona'
ORDER BY priority DESC, updated_at DESC LIMIT 30
"""),
{"uid": uid},
)
persona_atoms = [r[0] for r in atoms_result.fetchall()]
persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息"
task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int):
try:
prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。
用户信息
{persona_text}
请返回一个JSON对象包含以下字段
- name: 用户姓名未知则"未知"
- role: 角色/职位
- preferences: 偏好/习惯列表
- skills: 技能/专长列表
- personality: 性格描述
- raw_text: 综合画像的文本描述2-3句话
只返回JSON对象不要其他内容"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}")
try:
persona = json.loads(result_text)
except (json.JSONDecodeError, ValueError):
persona = {"raw_text": result_text}
async with self.db_factory() as db:
existing = await db.execute(
text("SELECT id FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
if existing.fetchone():
await db.execute(
text("""
UPDATE memory_personas
SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW()
WHERE user_id = :uid
"""),
{
"uid": uid,
"content": json.dumps(persona, ensure_ascii=False),
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
"version": version,
},
)
else:
await db.execute(
text("""
INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at)
VALUES (:uid, :content, :raw_text, :version, NOW())
"""),
{
"uid": uid,
"content": json.dumps(persona, ensure_ascii=False),
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
"version": version,
},
)
await db.commit()
except Exception as e:
logger.warning(f"L3画像更新失败: {e}")
async def search_memory(
self,
uid: uuid.UUID,
query: str,
fid: uuid.UUID = None,
top_k: int = 10,
embedding: list[float] = None,
) -> list[dict]:
results = []
try:
async with self.db_factory() as db:
if embedding and len(embedding) > 0:
emb_str = "[" + ",".join(str(v) for v in embedding) + "]"
vector_results = await db.execute(
text(f"""
SELECT id, atom_type, content, priority,
(1.0 - (embedding <=> :emb::vector)) AS similarity
FROM memory_atoms
WHERE user_id = :uid
{"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
AND embedding IS NOT NULL
ORDER BY embedding <=> :emb::vector
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid
else {"uid": uid, "emb": emb_str, "limit": top_k * 2},
)
for row in vector_results.fetchall():
results.append({
"id": str(row[0]), "type": row[1], "content": row[2],
"priority": row[3], "vector_score": float(row[4]),
})
text_results = await db.execute(
text(f"""
SELECT id, atom_type, content, priority,
ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score
FROM memory_atoms
WHERE user_id = :uid
{"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
AND content_tsv @@ plainto_tsquery('simple', :query)
ORDER BY text_score DESC
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid
else {"uid": uid, "query": query, "limit": top_k * 2},
)
for row in text_results.fetchall():
results.append({
"id": str(row[0]), "type": row[1], "content": row[2],
"priority": row[3], "text_score": float(row[4]),
})
rrf_scores: dict[str, float] = {}
k = 60
for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)):
if "vector_score" in r:
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)):
if "text_score" in r:
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
seen = set()
merged = []
for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True):
for r in results:
if r["id"] == rid and rid not in seen:
seen.add(rid)
r["rrf_score"] = score
merged.append(r)
break
merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True)
return merged[:top_k]
except Exception as e:
logger.warning(f"混合检索失败: {e}")
return results[:top_k] if results else []

2
backend/modules/memory/router.py

@ -13,7 +13,7 @@ async def list_sessions(request: Request, user=Depends(get_current_user)):
@router.get("/sessions/{session_id}")
async def get_session(session_id: str, flow_id: str = "", request: Request, user=Depends(get_current_user)):
async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)):
mm = get_memory_manager()
history = await mm.get_conversation_history(
user_id=str(user.id),

3
backend/modules/model_provider/__init__.py

@ -0,0 +1,3 @@
from .router import router
__all__ = ["router"]

181
backend/modules/model_provider/router.py

@ -0,0 +1,181 @@
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import ModelProvider, ModelInstance
from dependencies import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
@router.get("")
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelProvider).order_by(ModelProvider.created_at.desc())
)
return {
"code": 200,
"data": [
{
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"base_url": p.base_url,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() if p.created_at else "",
}
for p in result.scalars().all()
],
}
@router.post("")
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
existing = await db.execute(
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
)
if existing.scalars().first():
raise HTTPException(400, "相同 base_url 的供应商已存在")
p = ModelProvider(
name=payload["name"],
provider_type=payload["provider_type"],
base_url=payload.get("base_url", ""),
api_key=payload.get("api_key", ""),
extra_config=payload.get("extra_config", {}),
)
db.add(p)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.put("/{provider_id}")
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
p.name = payload.get("name", p.name)
p.base_url = payload.get("base_url", p.base_url)
p.api_key = payload.get("api_key", p.api_key)
p.provider_type = payload.get("provider_type", p.provider_type)
p.extra_config = payload.get("extra_config", p.extra_config)
p.is_active = payload.get("is_active", p.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
await db.delete(p)
await db.commit()
return {"code": 200, "message": "已删除"}
@router.get("/models/all")
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
}
for m in result.scalars().all()
],
}
@router.get("/{provider_id}/models")
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"provider_id": str(m.provider_id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
"capabilities": m.capabilities,
"default_params": m.default_params,
"is_default": m.is_default,
"is_active": m.is_active,
}
for m in result.scalars().all()
],
}
@router.post("/{provider_id}/models")
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
existing = await db.execute(
select(ModelInstance).where(
ModelInstance.provider_id == uuid.UUID(provider_id),
ModelInstance.model_name == payload["model_name"],
)
)
if existing.scalars().first():
raise HTTPException(400, "相同名称的模型已存在")
m = ModelInstance(
provider_id=uuid.UUID(provider_id),
model_name=payload["model_name"],
model_type=payload["model_type"],
display_name=payload.get("display_name", payload["model_name"]),
capabilities=payload.get("capabilities", {}),
default_params=payload.get("default_params", {}),
is_default=payload.get("is_default", False),
)
db.add(m)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.put("/{provider_id}/models/{model_id}")
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
m.model_name = payload.get("model_name", m.model_name)
m.model_type = payload.get("model_type", m.model_type)
m.display_name = payload.get("display_name", m.display_name)
m.capabilities = payload.get("capabilities", m.capabilities)
m.default_params = payload.get("default_params", m.default_params)
m.is_default = payload.get("is_default", m.is_default)
m.is_active = payload.get("is_active", m.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.delete("/{provider_id}/models/{model_id}")
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
await db.delete(m)
await db.commit()
return {"code": 200, "message": "已删除"}

2
backend/schemas/__init__.py

@ -275,6 +275,7 @@ class FlowDefinitionCreate(BaseModel):
trigger: dict = {}
nodes: list[FlowNode]
edges: list[FlowEdge]
flow_mode: str = "chatflow"
class FlowDefinitionUpdate(BaseModel):
@ -295,6 +296,7 @@ class FlowDefinitionOut(BaseModel):
published_version_id: uuid.UUID | None = None
published_to_wecom: bool
published_to_web: bool = False
flow_mode: str = "chatflow"
created_at: datetime | None = None
updated_at: datetime | None = None

13
docker-compose.prod.yml

@ -3,6 +3,19 @@ services:
image: postgres:16-alpine
container_name: ent-postgres
restart: always
entrypoint: >
sh -c '
if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then
echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..."
apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1
wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp
cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1
cd / && rm -rf /tmp/pgvector-0.7.4
apk del .pg_build > /dev/null 2>&1
echo ">>> pgvector 安装完成"
fi
exec docker-entrypoint.sh postgres
'
environment:
POSTGRES_USER: enterprise
POSTGRES_PASSWORD: enterprise123

13
docker-compose.yml

@ -12,6 +12,19 @@ services:
image: postgres:16-alpine
container_name: ent-postgres
restart: always
entrypoint: >
sh -c '
if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then
echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..."
apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1
wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp
cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1
cd / && rm -rf /tmp/pgvector-0.7.4
apk del .pg_build > /dev/null 2>&1
echo ">>> pgvector 安装完成"
fi
exec docker-entrypoint.sh postgres
'
environment:
POSTGRES_USER: enterprise
POSTGRES_PASSWORD: enterprise123

15
frontend/src/api/index.ts

@ -89,6 +89,8 @@ export const flowApi = {
getMarket: () => api.get('/flow/market'),
getTemplates: () => api.get('/flow/templates'),
useTemplate: (id: string) => api.post(`/flow/templates/${id}/use`),
getVersions: (id: string) => api.get(`/flow/definitions/${id}/versions`),
rollbackVersion: (flowId: string, versionId: string) => api.post(`/flow/definitions/${flowId}/rollback/${versionId}`),
}
export const chatApi = {
@ -192,3 +194,16 @@ export const ragApi = {
deleteDocument: (source: string) => api.delete(`/rag/documents/${encodeURIComponent(source)}`),
getStats: () => api.get('/rag/stats'),
}
export const modelProviderApi = {
getProviders: () => api.get('/model-providers'),
getProvider: (id: string) => api.get(`/model-providers/${id}`),
createProvider: (data: any) => api.post('/model-providers', data),
updateProvider: (id: string, data: any) => api.put(`/model-providers/${id}`, data),
deleteProvider: (id: string) => api.delete(`/model-providers/${id}`),
getModels: (providerId: string) => api.get(`/model-providers/${providerId}/models`),
getAllModels: () => api.get('/model-providers/models/all'),
createModel: (providerId: string, data: any) => api.post(`/model-providers/${providerId}/models`, data),
updateModel: (providerId: string, modelId: string, data: any) => api.put(`/model-providers/${providerId}/models/${modelId}`, data),
deleteModel: (providerId: string, modelId: string) => api.delete(`/model-providers/${providerId}/models/${modelId}`),
}

2
frontend/src/components/layout/AdminLayout.vue

@ -52,6 +52,7 @@
<el-icon><Cpu /></el-icon>
<span>AI能力配置</span>
</template>
<el-menu-item index="/admin/model/providers">模型供应商管理</el-menu-item>
<el-menu-item index="/admin/rag/knowledge">知识库管理</el-menu-item>
<el-menu-item index="/admin/agent/list">智能体管理</el-menu-item>
<el-menu-item index="/admin/wecom/config">企微机器人配置</el-menu-item>
@ -133,6 +134,7 @@ const activeMenu = computed(() => {
if (path.startsWith('/admin/rag')) return path
if (path.startsWith('/admin/agent')) return path
if (path.startsWith('/admin/wecom')) return path
if (path.startsWith('/admin/model')) return path
if (path.startsWith('/admin/task')) return path
if (path.startsWith('/admin/monitor')) return path
if (path.startsWith('/admin/audit')) return path

6
frontend/src/router/index.ts

@ -167,6 +167,12 @@ const router = createRouter({
component: () => import('@/views/wecom/BotConfig.vue'),
meta: { title: '企微机器人配置', perms: ['admin:access'] },
},
{
path: 'model/providers',
name: 'AdminModelProviders',
component: () => import('@/views/model/ModelProviderManager.vue'),
meta: { title: '模型供应商管理', perms: ['admin:access'] },
},
{
path: 'agent/list',
name: 'AdminAgentList',

129
frontend/src/views/flow/FlowEditor.vue

@ -8,10 +8,15 @@
<div class="editor-toolbar">
<el-input v-model="flowName" placeholder="流名称" style="width: 200px" />
<el-input v-model="flowDesc" placeholder="描述" style="width: 300px; margin-left: 12px" />
<el-select v-model="flowMode" style="width: 140px; margin-left: 12px" :disabled="isEdit">
<el-option label="对话型 (Chatflow)" value="chatflow" />
<el-option label="工作流型 (Workflow)" value="workflow" />
</el-select>
<el-button type="primary" @click="saveFlow" :loading="saving">保存</el-button>
<el-button @click="testFlow">验证</el-button>
<el-button v-if="isEdit" type="success" @click="publishFlow">上架到企微</el-button>
<el-button v-if="isEdit" type="primary" @click="publishToWeb">上架到网页</el-button>
<el-button v-if="isEdit" @click="showVersionHistory">版本历史</el-button>
</div>
</el-card>
@ -85,15 +90,42 @@
</el-form>
</div>
</div>
<el-dialog v-model="versionDialogVisible" title="版本历史" width="700px">
<el-table :data="versions" v-loading="loadingVersions" max-height="400">
<el-table-column prop="version" label="版本" width="80" />
<el-table-column prop="changelog" label="变更说明" min-width="200">
<template #default="{ row }">
<span v-if="row.changelog">{{ row.changelog }}</span>
<span v-else style="color:#999">-</span>
</template>
</el-table-column>
<el-table-column label="时间" width="170">
<template #default="{ row }">
{{ formatTime(row.created_at) }}
</template>
</el-table-column>
<el-table-column label="操作" width="120">
<template #default="{ row }">
<el-button size="small" type="primary" link @click="rollbackVersion(row.id)">
回滚到此版本
</el-button>
</template>
</el-table-column>
</el-table>
<div v-if="!versions.length && !loadingVersions" style="text-align:center;padding:24px;color:#999">
暂无版本记录
</div>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ref, computed, onMounted, onBeforeUnmount } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { ElMessage } from 'element-plus'
import { ElMessage, ElMessageBox } from 'element-plus'
import { MarkerType } from '@vue-flow/core'
import { Promotion, ChatDotRound, Tools, Connection, Bell, DataAnalysis, Search, RefreshRight, Document } from '@element-plus/icons-vue'
import { Promotion, ChatDotRound, Tools, Connection, Bell, DataAnalysis, Search, RefreshRight, Document, Link, Operation, Edit, Sunny } from '@element-plus/icons-vue'
import FlowCanvas from './FlowCanvas.vue'
import TriggerConfig from './node-configs/TriggerConfig.vue'
import LlmConfig from './node-configs/LlmConfig.vue'
@ -105,6 +137,13 @@ import RagConfig from './node-configs/RagConfig.vue'
import OutputConfig from './node-configs/OutputConfig.vue'
import LoopConfig from './node-configs/LoopConfig.vue'
import CodeConfig from './node-configs/CodeConfig.vue'
import HttpRequestConfig from './node-configs/HttpRequestConfig.vue'
import QuestionClassifierConfig from './node-configs/QuestionClassifierConfig.vue'
import VariableAssignerConfig from './node-configs/VariableAssignerConfig.vue'
import TemplateTransformConfig from './node-configs/TemplateTransformConfig.vue'
import IterationConfig from './node-configs/IterationConfig.vue'
import QuestionOptimiserConfig from './node-configs/QuestionOptimiserConfig.vue'
import MergeConfig from './node-configs/MergeConfig.vue'
const route = useRoute()
const router = useRouter()
@ -116,7 +155,11 @@ const initError = ref('')
const flowName = ref('新工作流')
const flowDesc = ref('')
const flowStatus = ref('')
const flowMode = ref('chatflow')
const saving = ref(false)
const versionDialogVisible = ref(false)
const versions = ref<any[]>([])
const loadingVersions = ref(false)
const selectedNodeId = ref('')
const selectedNodeData = ref<any>({})
const mcpServers = ref<any[]>([])
@ -140,6 +183,12 @@ const nodeTypes = [
{ type: 'merge', label: '变量聚合', icon: Connection, typeDesc: '并行汇聚' },
{ type: 'code', label: '代码执行', icon: Document, typeDesc: '代码执行' },
{ type: 'output', label: '输出节点', icon: Promotion, typeDesc: '结果输出' },
{ type: 'http_request', label: 'HTTP请求', icon: Link, typeDesc: '外部API调用' },
{ type: 'question_classifier', label: '问题分类', icon: Operation, typeDesc: '意图分类路由' },
{ type: 'variable_assigner', label: '变量赋值', icon: Edit, typeDesc: '变量操作' },
{ type: 'template_transform', label: '模板转换', icon: Document, typeDesc: '格式转换' },
{ type: 'iteration', label: '迭代处理', icon: RefreshRight, typeDesc: '数组遍历' },
{ type: 'question_optimiser', label: '问题优化', icon: Sunny, typeDesc: '查询优化' },
]
const configComponentMap: Record<string, any> = {
@ -154,7 +203,13 @@ const configComponentMap: Record<string, any> = {
output: OutputConfig,
loop: LoopConfig,
code: CodeConfig,
merge: NotifyConfig,
merge: MergeConfig,
http_request: HttpRequestConfig,
question_classifier: QuestionClassifierConfig,
variable_assigner: VariableAssignerConfig,
template_transform: TemplateTransformConfig,
iteration: IterationConfig,
question_optimiser: QuestionOptimiserConfig,
}
function getConfigComponent(type: string) {
@ -173,6 +228,12 @@ const colorMap: Record<string, string> = {
loop: '#13c2c2',
code: '#eb2f96',
output: '#722ed1',
http_request: '#2d8cf0',
question_classifier: '#ff9900',
variable_assigner: '#19be6b',
template_transform: '#9c27b0',
iteration: '#ff5722',
question_optimiser: '#e6a23c',
}
const selectedNode = computed(() => nodes.value.find((n: any) => n.id === selectedNodeId.value) || null)
@ -221,8 +282,14 @@ function getDefaultConfig(type: string) {
rag: { knowledge_base: '', top_k: 5, search_mode: 'hybrid', similarity_threshold: 0.7, result_sort: 'similarity', include_metadata: true },
loop: { loop_type: 'fixed', count: 3, iterator_variable: 'item', max_iterations: 10, items: [] },
merge: { merge_type: 'concat', expected_branches: 2 },
code: { language: 'python', code: '', timeout: 30, sandbox: true },
output: { format: 'text', output_template: '', indent: 2, encoding: 'utf-8', truncate: false, max_length: 2000 },
code: { language: 'python', code: '', timeout: 30, sandbox: true },
output: { format: 'text', output_template: '', indent: 2, encoding: 'utf-8', truncate: false, max_length: 2000 },
http_request: { method: 'GET', url: '', headers: '{}', body: '', auth_type: 'none', auth_config: {}, timeout: 30, retry_count: 0 },
question_classifier: { categories: [], instruction: '', model: '', temperature: 0.3 },
variable_assigner: { assignments: [] },
template_transform: { template: '', output_type: 'string' },
iteration: { input_array_source: '', max_iterations: 20 },
question_optimiser: { optimization_type: 'rewrite', model: '' },
}
return defaults[type] || {}
}
@ -277,6 +344,7 @@ async function loadFlow() {
flowName.value = flow.name || ''
flowDesc.value = flow.description || ''
flowStatus.value = flow.status || ''
flowMode.value = flow.flow_mode || 'chatflow'
const definition = flow.definition_json || {}
const loadedNodes: any[] = []
const loadedEdges: any[] = []
@ -322,7 +390,7 @@ async function saveFlow() {
const snapshot = canvasRef.value?.getSnapshot() || { nodes: [], edges: [] }
const serializedNodes = snapshot.nodes.map((n: any) => ({ id: n.id, type: n.data?.type || n.type, label: n.data?.label || n.id, config: n.data?.config || {}, position: n.position }))
const serializedEdges = snapshot.edges.map((e: any) => ({ id: e.id, source: e.source, target: e.target, sourceHandle: e.sourceHandle || 'source' }))
const payload = { name: flowName.value, description: flowDesc.value, nodes: serializedNodes, edges: serializedEdges, trigger: {} }
const payload = { name: flowName.value, description: flowDesc.value, nodes: serializedNodes, edges: serializedEdges, trigger: {}, flow_mode: flowMode.value }
if (isEdit.value) { await flowApi.updateFlow(flowId.value, payload); ElMessage.success('保存成功') }
else { const res: any = await flowApi.createFlow(payload); const data = res?.data || res || {}; if (data.id) { router.replace(`/admin/flow/editor/${data.id}`); ElMessage.success('创建成功') } }
} finally { saving.value = false }
@ -343,15 +411,62 @@ async function publishToWeb() {
try { const { flowApi } = await import('@/api'); await flowApi.publishToWeb(flowId.value); ElMessage.success('流已上架到网页'); await loadFlow() } catch {}
}
async function showVersionHistory() {
versionDialogVisible.value = true
await loadVersions()
}
async function loadVersions() {
loadingVersions.value = true
try {
const { flowApi } = await import('@/api')
const res: any = await flowApi.getVersions(flowId.value)
versions.value = Array.isArray(res?.data) ? res.data : (Array.isArray(res) ? res : [])
} catch { versions.value = [] }
finally { loadingVersions.value = false }
}
async function rollbackVersion(versionId: string) {
try {
await ElMessageBox.confirm('回滚将覆盖当前定义,确定继续?', '确认回滚', { type: 'warning' })
const { flowApi } = await import('@/api')
await flowApi.rollbackVersion(flowId.value, versionId)
ElMessage.success('已回滚到该版本')
versionDialogVisible.value = false
await loadFlow()
} catch {}
}
function formatTime(ts: string) {
if (!ts) return '-'
try { return new Date(ts).toLocaleString() } catch { return ts }
}
onMounted(async () => {
try {
if (isEdit.value) { await loadFlow() }
await Promise.allSettled([loadMcpServers(), loadAgents()])
startAutoSave()
} catch (e: any) {
console.error('FlowEditor init error:', e)
initError.value = '初始化失败: ' + (e?.message || '未知错误')
}
})
let autoSaveTimer: ReturnType<typeof setInterval> | null = null
function startAutoSave() {
if (autoSaveTimer) clearInterval(autoSaveTimer)
autoSaveTimer = setInterval(() => {
if (flowName.value && isEdit.value) {
saveFlow().catch(() => {})
}
}, 30000)
}
onBeforeUnmount(() => {
if (autoSaveTimer) clearInterval(autoSaveTimer)
})
</script>
<style scoped>

14
frontend/src/views/flow/FlowNode.vue

@ -13,7 +13,7 @@
<Handle type="target" :position="Position.Left" id="target" class="node-handle" />
</template>
<template v-if="data?.type === 'trigger' || data?.type === 'llm' || data?.type === 'tool' || data?.type === 'mcp' || data?.type === 'rag' || data?.type === 'code' || data?.type === 'merge'">
<template v-if="data?.type === 'trigger' || data?.type === 'llm' || data?.type === 'tool' || data?.type === 'mcp' || data?.type === 'rag' || data?.type === 'code' || data?.type === 'merge' || data?.type === 'http_request' || data?.type === 'question_classifier' || data?.type === 'variable_assigner' || data?.type === 'template_transform' || data?.type === 'iteration' || data?.type === 'question_optimiser'">
<Handle type="source" :position="Position.Right" id="source" class="node-handle" />
</template>
@ -84,6 +84,18 @@ const configSummary = computed(() => {
return cfg.language ? `${cfg.language}` : ''
case 'output':
return cfg.format ? `格式: ${cfg.format.toUpperCase()}` : ''
case 'http_request':
return cfg.method && cfg.url ? `${cfg.method} ${truncate(cfg.url, 20)}` : (cfg.method || 'HTTP')
case 'question_classifier':
return (cfg.categories || []).length ? `${(cfg.categories || []).length}个分类` : '未配置分类'
case 'variable_assigner':
return (cfg.assignments || []).length ? `${(cfg.assignments || []).length}条规则` : '未配置规则'
case 'template_transform':
return cfg.template ? `模板: ${truncate(cfg.template, 15)}` : '未配置模板'
case 'iteration':
return cfg.max_iterations ? `最多${cfg.max_iterations}` : '数组遍历'
case 'question_optimiser':
return cfg.optimization_type === 'expand' ? '扩展细化' : '改写优化'
default:
return ''
}

127
frontend/src/views/flow/node-configs/HttpRequestConfig.vue

@ -0,0 +1,127 @@
<template>
<div class="node-config">
<el-divider content-position="left">请求配置</el-divider>
<el-form-item label="请求方法">
<el-select :model-value="modelValue.method || 'GET'" @change="update('method', $event)">
<el-option label="GET" value="GET" />
<el-option label="POST" value="POST" />
<el-option label="PUT" value="PUT" />
<el-option label="DELETE" value="DELETE" />
<el-option label="PATCH" value="PATCH" />
</el-select>
</el-form-item>
<el-form-item label="请求地址">
<el-input :model-value="modelValue.url" @input="(e: any) => update('url', e)" placeholder="https://api.example.com/endpoint" />
<div class="field-hint">支持变量模板: {{node_id.output}}</div>
</el-form-item>
<el-divider content-position="left">请求头</el-divider>
<el-form-item label="Headers">
<el-input
:model-value="modelValue.headersText"
@input="(e: any) => { update('headersText', e); try { update('headers', JSON.parse(e)) } catch {} }"
type="textarea"
:rows="3"
placeholder='{"Content-Type": "application/json"}'
/>
<div class="field-hint">JSON格式每行一个键值对</div>
</el-form-item>
<el-divider content-position="left">请求体 (POST/PUT/PATCH)</el-divider>
<el-form-item label="Body类型">
<el-select :model-value="modelValue.body_type || 'json'" @change="update('body_type', $event)">
<el-option label="JSON" value="json" />
<el-option label="原始文本" value="raw" />
<el-option label="表单" value="form" />
</el-select>
</el-form-item>
<el-form-item label="请求体">
<el-input
:model-value="modelValue.body"
@input="(e: any) => update('body', e)"
type="textarea"
:rows="4"
placeholder='{"key": "value"}'
/>
</el-form-item>
<el-divider content-position="left">认证配置</el-divider>
<el-form-item label="认证方式">
<el-select :model-value="modelValue.auth_type || 'none'" @change="update('auth_type', $event)">
<el-option label="无认证" value="none" />
<el-option label="Bearer Token" value="bearer" />
<el-option label="API Key" value="api_key" />
<el-option label="Basic Auth" value="basic" />
</el-select>
</el-form-item>
<template v-if="modelValue.auth_type === 'bearer'">
<el-form-item label="Token">
<el-input :model-value="(modelValue.auth_config || {}).token" @input="updateAuth('token', $event)" placeholder="your-bearer-token" />
</el-form-item>
</template>
<template v-if="modelValue.auth_type === 'api_key'">
<el-form-item label="Key名称">
<el-input :model-value="(modelValue.auth_config || {}).key_name || 'X-API-Key'" @input="updateAuth('key_name', $event)" placeholder="X-API-Key" />
</el-form-item>
<el-form-item label="API Key">
<el-input :model-value="(modelValue.auth_config || {}).api_key" @input="updateAuth('api_key', $event)" placeholder="your-api-key" />
</el-form-item>
</template>
<template v-if="modelValue.auth_type === 'basic'">
<el-form-item label="用户名">
<el-input :model-value="(modelValue.auth_config || {}).username" @input="updateAuth('username', $event)" placeholder="username" />
</el-form-item>
<el-form-item label="密码">
<el-input :model-value="(modelValue.auth_config || {}).password" @input="updateAuth('password', $event)" type="password" placeholder="password" />
</el-form-item>
</template>
<el-divider content-position="left">高级选项</el-divider>
<el-form-item label="超时时间(秒)">
<el-input-number :model-value="modelValue.timeout || 30" @change="update('timeout', $event)" :min="1" :max="300" />
</el-form-item>
<el-form-item label="重试次数">
<el-input-number :model-value="modelValue.retry_count || 0" @change="update('retry_count', $event)" :min="0" :max="5" />
</el-form-item>
</div>
</template>
<script setup lang="ts">
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
function updateAuth(key: string, val: any) {
emit('change')
emit('update:modelValue', {
...props.modelValue,
auth_config: { ...(props.modelValue.auth_config || {}), [key]: val }
})
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
</style>

80
frontend/src/views/flow/node-configs/IterationConfig.vue

@ -0,0 +1,80 @@
<template>
<div class="node-config">
<el-divider content-position="left">迭代配置</el-divider>
<el-form-item label="数据来源">
<el-select :model-value="modelValue.input_array_source_type || 'auto'" @change="update('input_array_source_type', $event)">
<el-option label="自动解析上游输入" value="auto" />
<el-option label="指定节点输出" value="node_output" />
<el-option label="模板变量" value="template" />
</el-select>
<div class="field-hint">
自动模式尝试将上游输入解析为JSON数组失败则按行分割
</div>
</el-form-item>
<el-form-item v-if="modelValue.input_array_source_type === 'template'" label="数组变量">
<el-input :model-value="modelValue.input_array_source" @input="(e: any) => update('input_array_source', e)" placeholder="如:{{rag_node.items}}" />
<div class="field-hint">支持 {"{{node_id.output}}"} 变量</div>
</el-form-item>
<el-divider content-position="left">迭代限制</el-divider>
<el-form-item label="最大迭代次数">
<el-input-number :model-value="modelValue.max_iterations || 20" @change="update('max_iterations', $event)" :min="1" :max="100" />
</el-form-item>
<el-form-item label="输出格式">
<el-select :model-value="modelValue.output_format || 'json_array'" @change="update('output_format', $event)">
<el-option label="JSON数组" value="json_array" />
<el-option label="逐条文本" value="text_lines" />
</el-select>
</el-form-item>
<el-divider content-position="left">使用说明</el-divider>
<div class="help-box">
<p>迭代节点将对输入数组中的每个元素执行一次下游节点的处理</p>
<p class="help-tip">
将迭代节点连接到需要循环处理的节点每次迭代会将数组中一个元素作为输入传递
</p>
<p class="help-tip">
输出为包含每次迭代结果的对象数组: [{"index":0, "item":..., "result":...}]
</p>
</div>
</div>
</template>
<script setup lang="ts">
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
.help-box {
background: #f5f7fa;
border-radius: 6px;
padding: 12px;
font-size: 13px;
color: #606266;
line-height: 1.6;
}
.help-tip {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
</style>

24
frontend/src/views/flow/node-configs/MergeConfig.vue

@ -0,0 +1,24 @@
<template>
<div class="merge-config">
<el-form-item label="合并方式">
<el-select v-model="config.merge_type" @change="onChange">
<el-option label="拼接合并" value="concat" />
<el-option label="JSON合并" value="json" />
<el-option label="首个非空" value="first_non_empty" />
</el-select>
</el-form-item>
<el-form-item label="预期分支数">
<el-input-number v-model="config.expected_branches" :min="0" :max="20" @change="onChange" />
<div style="font-size: 12px; color: #999; margin-top: 4px;">设为0时自动等待所有分支完成</div>
</el-form-item>
</div>
</template>
<script setup lang="ts">
const config = defineModel<any>({ required: true })
const emit = defineEmits(['change'])
function onChange() {
emit('change')
}
</script>

120
frontend/src/views/flow/node-configs/QuestionClassifierConfig.vue

@ -0,0 +1,120 @@
<template>
<div class="node-config">
<el-divider content-position="left">分类配置</el-divider>
<el-form-item label="模型">
<el-input :model-value="modelValue.model" @input="(e: any) => update('model', e)" placeholder="默认使用系统LLM模型" />
<div class="field-hint">留空则使用系统默认模型</div>
</el-form-item>
<el-form-item label="温度">
<el-slider :model-value="modelValue.temperature ?? 0.3" @change="update('temperature', $event)" :min="0" :max="1" :step="0.1" show-input />
<div class="field-hint">较低温度使分类结果更稳定</div>
</el-form-item>
<el-form-item label="分类指令">
<el-input
:model-value="modelValue.instruction"
@input="(e: any) => update('instruction', e)"
type="textarea"
:rows="2"
placeholder="如:根据用户意图将其分类到对应的业务场景"
/>
</el-form-item>
<el-divider content-position="left">
<span>分类选项</span>
<el-button size="small" type="primary" link @click="addCategory" style="margin-left:8px">+ 添加</el-button>
</el-divider>
<div v-if="categories.length === 0" class="empty-hint">
暂无分类请点击"添加"按钮
</div>
<div v-for="(cat, index) in categories" :key="index" class="category-item">
<div class="category-header">
<span class="category-index">分类 {{ index + 1 }}</span>
<el-button size="small" type="danger" link @click="removeCategory(index)">
<el-icon><Delete /></el-icon>
</el-button>
</div>
<el-form-item label="名称">
<el-input :model-value="cat.name" @input="(e: any) => updateCategory(index, 'name', e)" placeholder="如:订单查询" />
</el-form-item>
<el-form-item label="描述">
<el-input :model-value="cat.description" @input="(e: any) => updateCategory(index, 'description', e)" placeholder="如:用户想查询订单状态、物流等" />
</el-form-item>
</div>
<el-divider content-position="left">输出配置</el-divider>
<el-form-item label="输出包含置信度">
<el-switch :model-value="modelValue.output_confidence ?? true" @change="update('output_confidence', $event)" />
</el-form-item>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { Delete } from '@element-plus/icons-vue'
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
const categories = computed(() => props.modelValue.categories || [])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
function updateCategory(index: number, key: string, val: any) {
const newList = [...categories.value]
newList[index] = { ...newList[index], [key]: val }
update('categories', newList)
}
function addCategory() {
const newList = [...categories.value, { name: '', description: '' }]
update('categories', newList)
}
function removeCategory(index: number) {
const newList = categories.value.filter((_: any, i: number) => i !== index)
update('categories', newList)
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
.empty-hint {
text-align: center;
color: #909399;
font-size: 13px;
padding: 12px 0;
}
.category-item {
border: 1px solid #e4e7ed;
border-radius: 6px;
padding: 8px 12px;
margin-bottom: 12px;
}
.category-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 4px;
}
.category-index {
font-size: 13px;
font-weight: 500;
color: #606266;
}
</style>

81
frontend/src/views/flow/node-configs/QuestionOptimiserConfig.vue

@ -0,0 +1,81 @@
<template>
<div class="node-config">
<el-divider content-position="left">优化配置</el-divider>
<el-form-item label="优化类型">
<el-select :model-value="modelValue.optimization_type || 'rewrite'" @change="update('optimization_type', $event)">
<el-option label="改写优化" value="rewrite">
<span>改写优化 <span class="opt-desc">- 使其更清晰具体</span></span>
</el-option>
<el-option label="扩展细化" value="expand">
<span>扩展细化 <span class="opt-desc">- 添加细节和背景</span></span>
</el-option>
</el-select>
</el-form-item>
<el-form-item label="模型">
<el-input :model-value="modelValue.model" @input="(e: any) => update('model', e)" placeholder="默认使用系统LLM模型" />
<div class="field-hint">留空则使用系统默认模型</div>
</el-form-item>
<el-divider content-position="left">上下文增强</el-divider>
<el-form-item label="引入用户画像">
<el-switch :model-value="modelValue.include_persona ?? true" @change="update('include_persona', $event)" />
</el-form-item>
<el-form-item label="引入记忆原子">
<el-switch :model-value="modelValue.include_atoms ?? true" @change="update('include_atoms', $event)" />
</el-form-item>
<el-form-item label="引入近期对话">
<el-switch :model-value="modelValue.include_history ?? true" @change="update('include_history', $event)" />
</el-form-item>
<el-divider content-position="left">使用说明</el-divider>
<div class="help-box">
<p><strong>改写优化</strong>将模糊不完整的问题改写为清晰具体的版本自动补充缺失的上下文</p>
<p><strong>扩展细化</strong>将简短的问题扩展为更详细的描述增加背景和细节信息</p>
<p class="help-tip">支持从记忆系统自动获取用户画像记忆原子和近期对话作为优化上下文</p>
</div>
</div>
</template>
<script setup lang="ts">
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
.opt-desc {
font-size: 11px;
color: #909399;
}
.help-box {
background: #f5f7fa;
border-radius: 6px;
padding: 12px;
font-size: 13px;
color: #606266;
line-height: 1.6;
}
.help-tip {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
</style>

94
frontend/src/views/flow/node-configs/TemplateTransformConfig.vue

@ -0,0 +1,94 @@
<template>
<div class="node-config">
<el-divider content-position="left">模板配置</el-divider>
<el-form-item label="模板内容">
<el-input
:model-value="modelValue.template"
@input="(e: any) => update('template', e)"
type="textarea"
:rows="6"
placeholder="用户说:{{input}}&#10;解析结果:{{llm_node.output}}"
/>
<div class="field-hint">
支持变量语法: {"{{input}}"} (上游输入), {"{{node_id.output}}"} (节点输出),
{"{{node_id.field}}"} (节点字段)
</div>
</el-form-item>
<el-divider content-position="left">输出配置</el-divider>
<el-form-item label="输出类型">
<el-select :model-value="modelValue.output_type || 'string'" @change="update('output_type', $event)">
<el-option label="字符串" value="string" />
<el-option label="JSON" value="json" />
<el-option label="数组" value="array" />
</el-select>
</el-form-item>
<el-divider content-position="left">模板示例</el-divider>
<div class="example-box">
<p class="example-title">常用模板</p>
<div class="example-item" @click="setTemplate('请帮我处理:{{input}}')">
<code>请帮我处理{"{{input}}"}</code>
</div>
<div class="example-item" @click="setTemplate('根据{{rag_node.output}},回答:{{input}}')">
<code>根据{"{{rag_node.output}}"}回答{"{{input}}"}</code>
</div>
<div class="example-item" @click="setTemplate('{\"query\": \"{{input}}\", \"context\": \"{{trigger.data}}\"}')">
<code>{"{"}"query": "{"{{input}}"}", "context": "{"{{trigger.data}}"}"{"}"}</code>
</div>
</div>
</div>
</template>
<script setup lang="ts">
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
function setTemplate(text: string) {
update('template', text)
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
.example-box {
background: #f5f7fa;
border-radius: 6px;
padding: 12px;
}
.example-title {
font-size: 12px;
color: #909399;
margin-bottom: 8px;
}
.example-item {
cursor: pointer;
padding: 4px 8px;
border-radius: 4px;
margin-bottom: 4px;
font-size: 13px;
background: #fff;
}
.example-item:hover {
background: #ecf5ff;
}
.example-item code {
font-family: monospace;
font-size: 12px;
}
</style>

123
frontend/src/views/flow/node-configs/VariableAssignerConfig.vue

@ -0,0 +1,123 @@
<template>
<div class="node-config">
<el-divider content-position="left">
<span>变量赋值</span>
<el-button size="small" type="primary" link @click="addAssignment" style="margin-left:8px">+ 添加</el-button>
</el-divider>
<div v-if="assignments.length === 0" class="empty-hint">
暂无赋值规则请点击"添加"按钮
</div>
<div v-for="(item, index) in assignments" :key="index" class="assign-item">
<div class="assign-header">
<span class="assign-index">规则 {{ index + 1 }}</span>
<el-button size="small" type="danger" link @click="removeAssignment(index)">
<el-icon><Delete /></el-icon>
</el-button>
</div>
<el-form-item label="目标变量名">
<el-input :model-value="item.target_var" @input="(e: any) => updateAssignment(index, 'target_var', e)" placeholder="如:customer_name" />
</el-form-item>
<el-form-item label="来源类型">
<el-select :model-value="item.source_type || 'constant'" @change="updateAssignment(index, 'source_type', $event)">
<el-option label="常量" value="constant" />
<el-option label="上游节点输出" value="upstream_output" />
<el-option label="模板变量" value="template" />
<el-option label="表达式" value="expression" />
</el-select>
</el-form-item>
<el-form-item v-if="item.source_type === 'constant' || !item.source_type || item.source_type === 'constant'" label="值">
<el-input :model-value="item.source_value" @input="(e: any) => updateAssignment(index, 'source_value', e)" placeholder="常量值" />
</el-form-item>
<el-form-item v-if="item.source_type === 'template'" label="模板">
<el-input :model-value="item.source_value" @input="(e: any) => updateAssignment(index, 'source_value', e)" placeholder="如:{{llm_node.output}}" />
<div class="field-hint">支持 {{node_id.output}} 变量语法</div>
</el-form-item>
<el-form-item v-if="item.source_type === 'expression'" label="表达式">
<el-input :model-value="item.source_value" @input="(e: any) => updateAssignment(index, 'source_value', e)" placeholder="如:str(len(input_data))" />
<div class="field-hint">Python表达式可用变量: msg, context</div>
</el-form-item>
<el-form-item v-if="item.source_type === 'upstream_output'" label="说明">
<span class="field-hint">将自动获取上游节点的输出内容作为值</span>
</el-form-item>
</div>
<el-divider content-position="left">高级选项</el-divider>
<el-form-item label="覆盖已有变量">
<el-switch :model-value="modelValue.overwrite ?? true" @change="update('overwrite', $event)" />
</el-form-item>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { Delete } from '@element-plus/icons-vue'
const props = defineProps<{
modelValue: any
}>()
const emit = defineEmits(['change', 'update:modelValue'])
const assignments = computed(() => props.modelValue.assignments || [])
function update(key: string, val: any) {
emit('change')
emit('update:modelValue', { ...props.modelValue, [key]: val })
}
function updateAssignment(index: number, key: string, val: any) {
const newList = [...assignments.value]
newList[index] = { ...newList[index], [key]: val }
update('assignments', newList)
}
function addAssignment() {
const newList = [...assignments.value, { target_var: '', source_type: 'constant', source_value: '' }]
update('assignments', newList)
}
function removeAssignment(index: number) {
const newList = assignments.value.filter((_: any, i: number) => i !== index)
update('assignments', newList)
}
</script>
<style scoped>
.field-hint {
font-size: 12px;
color: #909399;
margin-top: 4px;
}
.empty-hint {
text-align: center;
color: #909399;
font-size: 13px;
padding: 12px 0;
}
.assign-item {
border: 1px solid #e4e7ed;
border-radius: 6px;
padding: 8px 12px;
margin-bottom: 12px;
}
.assign-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 4px;
}
.assign-index {
font-size: 13px;
font-weight: 500;
color: #606266;
}
</style>

338
frontend/src/views/model/ModelProviderManager.vue

@ -0,0 +1,338 @@
<template>
<div class="model-provider-manager">
<h2 style="margin-bottom: 16px;">模型供应商管理</h2>
<div style="margin-bottom: 16px; display: flex; gap: 12px;">
<el-button type="primary" @click="openProviderDialog()">添加供应商</el-button>
</div>
<div v-if="!selectedProvider">
<el-table :data="providers" v-loading="loading" style="width: 100%">
<el-table-column prop="name" label="供应商名称" width="180" />
<el-table-column prop="provider_type" label="类型" width="120">
<template #default="{ row }">
<el-tag size="small">{{ row.provider_type }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="base_url" label="API 端点" show-overflow-tooltip />
<el-table-column prop="is_active" label="状态" width="80">
<template #default="{ row }">
<el-tag :type="row.is_active ? 'success' : 'info'" size="small">
{{ row.is_active ? '启用' : '禁用' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作" width="280">
<template #default="{ row }">
<el-button size="small" @click="selectProvider(row)">管理模型</el-button>
<el-button size="small" @click="openProviderDialog(row)">编辑</el-button>
<el-button size="small" type="danger" @click="deleteProvider(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</div>
<div v-else>
<div style="margin-bottom: 12px; display: flex; gap: 12px; align-items: center;">
<el-button @click="selectedProvider = null"> 返回供应商列表</el-button>
<span style="font-size: 16px; font-weight: bold;">{{ selectedProvider.name }} - 模型管理</span>
<el-button type="primary" size="small" @click="openModelDialog()">添加模型</el-button>
</div>
<el-tabs v-model="modelTab">
<el-tab-pane label="LLM" name="llm" />
<el-tab-pane label="Embedding" name="embedding" />
<el-tab-pane label="Rerank" name="rerank" />
</el-tabs>
<el-table :data="filteredModels" v-loading="modelLoading" style="width: 100%">
<el-table-column prop="model_name" label="模型名" width="200" />
<el-table-column prop="display_name" label="显示名称" width="200" />
<el-table-column prop="model_type" label="类型" width="100">
<template #default="{ row }">
<el-tag size="small">{{ row.model_type }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="is_default" label="默认" width="80">
<template #default="{ row }">
<el-tag :type="row.is_default ? 'success' : 'info'" size="small">
{{ row.is_default ? '是' : '否' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作" width="180">
<template #default="{ row }">
<el-button size="small" @click="openModelDialog(row)">编辑</el-button>
<el-button size="small" type="danger" @click="deleteModel(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</div>
<el-dialog v-model="providerDialogVisible" :title="editingProviderId ? '编辑供应商' : '添加供应商'" width="550px">
<el-form :model="providerForm" label-width="100px">
<el-form-item label="供应商名称" required>
<el-input v-model="providerForm.name" placeholder="如 OpenAI / 智谱AI / 本地Ollama" />
</el-form-item>
<el-form-item label="类型" required>
<el-select v-model="providerForm.provider_type" style="width: 100%">
<el-option label="OpenAI兼容" value="openai_compatible" />
<el-option label="OpenAI" value="openai" />
<el-option label="智谱AI" value="zhipu" />
<el-option label="Ollama" value="ollama" />
<el-option label="DeepSeek" value="deepseek" />
</el-select>
</el-form-item>
<el-form-item label="API 端点">
<el-input v-model="providerForm.base_url" placeholder="如 https://api.openai.com/v1" />
</el-form-item>
<el-form-item label="API Key">
<el-input v-model="providerForm.api_key" type="password" placeholder="密钥(加密存储)" show-password />
</el-form-item>
<el-form-item label="启用">
<el-switch v-model="providerForm.is_active" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="providerDialogVisible = false">取消</el-button>
<el-button type="primary" @click="saveProvider" :loading="saving">保存</el-button>
</template>
</el-dialog>
<el-dialog v-model="modelDialogVisible" :title="editingModelId ? '编辑模型' : '添加模型'" width="550px">
<el-form :model="modelForm" label-width="100px">
<el-form-item label="模型类型" required>
<el-select v-model="modelForm.model_type" style="width: 100%" :disabled="!!editingModelId">
<el-option label="LLM(大语言模型)" value="llm" />
<el-option label="Embedding(嵌入模型)" value="embedding" />
<el-option label="Rerank(重排序模型)" value="rerank" />
</el-select>
</el-form-item>
<el-form-item label="模型名" required>
<el-input v-model="modelForm.model_name" placeholder="如 gpt-4o / text-embedding-3-small" />
</el-form-item>
<el-form-item label="显示名称">
<el-input v-model="modelForm.display_name" placeholder="如 GPT-4o" />
</el-form-item>
<el-form-item label="设为默认">
<el-switch v-model="modelForm.is_default" />
<span style="margin-left: 8px; color: #999; font-size: 12px;">该类型下的默认模型</span>
</el-form-item>
<el-form-item label="能力配置" v-if="modelForm.model_type === 'llm'">
<div style="width: 100%;">
<el-checkbox v-model="llmVision" style="margin-right: 16px;">支持 Vision</el-checkbox>
<el-checkbox v-model="llmFunctionCalling">支持 Function Calling</el-checkbox>
</div>
</el-form-item>
<el-form-item label="默认参数">
<el-input v-model="defaultParamsJson" type="textarea" :rows="3" placeholder='{"temperature": 0.7, "max_tokens": 4096}' />
</el-form-item>
<el-form-item label="启用">
<el-switch v-model="modelForm.is_active" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="modelDialogVisible = false">取消</el-button>
<el-button type="primary" @click="saveModel" :loading="modelSaving">保存</el-button>
</template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import api from '@/api'
const providers = ref<any[]>([])
const loading = ref(false)
const saving = ref(false)
const providerDialogVisible = ref(false)
const editingProviderId = ref('')
const providerForm = reactive({
name: '',
provider_type: 'openai_compatible',
base_url: '',
api_key: '',
is_active: true,
})
const selectedProvider = ref<any>(null)
const models = ref<any[]>([])
const modelLoading = ref(false)
const modelSaving = ref(false)
const modelDialogVisible = ref(false)
const editingModelId = ref('')
const modelTab = ref('llm')
const llmVision = ref(false)
const llmFunctionCalling = ref(false)
const defaultParamsJson = ref('')
const modelForm = reactive({
model_name: '',
model_type: 'llm',
display_name: '',
is_default: false,
is_active: true,
})
const filteredModels = computed(() => models.value.filter((m: any) => m.model_type === modelTab.value))
function resetProviderForm() {
providerForm.name = ''
providerForm.provider_type = 'openai_compatible'
providerForm.base_url = ''
providerForm.api_key = ''
providerForm.is_active = true
}
function resetModelForm() {
modelForm.model_name = ''
modelForm.model_type = 'llm'
modelForm.display_name = ''
modelForm.is_default = false
modelForm.is_active = true
llmVision.value = false
llmFunctionCalling.value = false
defaultParamsJson.value = ''
}
function openProviderDialog(row?: any) {
if (row) {
editingProviderId.value = row.id
providerForm.name = row.name
providerForm.provider_type = row.provider_type
providerForm.base_url = row.base_url
providerForm.api_key = row.api_key || ''
providerForm.is_active = row.is_active
} else {
editingProviderId.value = ''
resetProviderForm()
}
providerDialogVisible.value = true
}
async function loadProviders() {
loading.value = true
try {
const res: any = await api.get('/model-providers')
providers.value = res?.data || res || []
} finally {
loading.value = false
}
}
async function saveProvider() {
saving.value = true
try {
const data = { ...providerForm }
if (editingProviderId.value) {
await api.put(`/model-providers/${editingProviderId.value}`, data)
ElMessage.success('供应商已更新')
} else {
await api.post('/model-providers', data)
ElMessage.success('供应商已添加')
}
providerDialogVisible.value = false
await loadProviders()
} catch {
// interceptor handles error
} finally {
saving.value = false
}
}
async function deleteProvider(id: string) {
try {
await ElMessageBox.confirm('删除供应商将同时删除其下所有模型配置,确定继续?', '确认删除', { type: 'warning' })
await api.delete(`/model-providers/${id}`)
ElMessage.success('已删除')
if (selectedProvider.value?.id === id) selectedProvider.value = null
await loadProviders()
} catch { /* cancelled */ }
}
async function selectProvider(provider: any) {
selectedProvider.value = provider
await loadModels(provider.id)
}
async function loadModels(providerId: string) {
modelLoading.value = true
try {
const res: any = await api.get(`/model-providers/${providerId}/models`)
models.value = res?.data || res || []
} finally {
modelLoading.value = false
}
}
function openModelDialog(row?: any) {
if (row) {
editingModelId.value = row.id
modelForm.model_name = row.model_name
modelForm.model_type = row.model_type
modelForm.display_name = row.display_name || ''
modelForm.is_default = row.is_default
modelForm.is_active = row.is_active
const caps = row.capabilities || {}
llmVision.value = !!caps.vision
llmFunctionCalling.value = !!caps.function_calling
defaultParamsJson.value = row.default_params ? JSON.stringify(row.default_params, null, 2) : ''
} else {
editingModelId.value = ''
resetModelForm()
}
modelDialogVisible.value = true
}
async function saveModel() {
if (!selectedProvider.value) return
modelSaving.value = true
try {
let defaultParams = {}
try { defaultParams = JSON.parse(defaultParamsJson.value || '{}') } catch { /* keep empty */ }
const capabilities: any = {}
if (modelForm.model_type === 'llm') {
capabilities.vision = llmVision.value
capabilities.function_calling = llmFunctionCalling.value
}
const data = {
...modelForm,
capabilities,
default_params: defaultParams,
}
if (editingModelId.value) {
await api.put(`/model-providers/${selectedProvider.value.id}/models/${editingModelId.value}`, data)
ElMessage.success('模型已更新')
} else {
await api.post(`/model-providers/${selectedProvider.value.id}/models`, data)
ElMessage.success('模型已添加')
}
modelDialogVisible.value = false
await loadModels(selectedProvider.value.id)
} catch {
// interceptor handles error
} finally {
modelSaving.value = false
}
}
async function deleteModel(id: string) {
if (!selectedProvider.value) return
try {
await ElMessageBox.confirm('确定删除此模型配置?', '确认删除', { type: 'warning' })
await api.delete(`/model-providers/${selectedProvider.value.id}/models/${id}`)
ElMessage.success('已删除')
await loadModels(selectedProvider.value.id)
} catch { /* cancelled */ }
}
onMounted(() => {
loadProviders()
})
</script>

6
init-db/02-add-published-cols.sql

@ -0,0 +1,6 @@
-- 02-migration: 补充 flow_definitions 缺失列
-- 幂等执行,IF NOT EXISTS 重复运行不会报错
ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_to_wecom BOOLEAN DEFAULT FALSE;
ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_to_web BOOLEAN DEFAULT FALSE;

88
init-db/03-memory-tables.sql

@ -0,0 +1,88 @@
-- 03-memory-tables.sql
-- 记忆管理模块:PostgreSQL 主存储 + Redis 缓存层
-- 幂等执行,IF NOT EXISTS
-- pgvector 扩展(容器启动时已通过 docker-compose entrypoint 自动安装)
CREATE EXTENSION IF NOT EXISTS vector;
-- ============================================================
-- L0: 原始对话消息
-- ============================================================
CREATE TABLE IF NOT EXISTS memory_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
flow_id UUID NOT NULL REFERENCES flow_definitions(id) ON DELETE CASCADE,
session_id UUID NOT NULL,
role VARCHAR(20) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_memory_messages_session ON memory_messages(user_id, flow_id, session_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_memory_messages_user_flow ON memory_messages(user_id, flow_id, created_at DESC);
-- ============================================================
-- L1: 结构化记忆原子
-- ============================================================
CREATE TABLE IF NOT EXISTS memory_atoms (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
flow_id UUID REFERENCES flow_definitions(id) ON DELETE SET NULL,
atom_type VARCHAR(20) NOT NULL,
content TEXT NOT NULL,
priority SMALLINT DEFAULT 50,
source_session_id UUID,
metadata JSONB DEFAULT '{}',
embedding vector(1536),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_memory_atoms_user ON memory_atoms(user_id, atom_type, priority DESC);
CREATE INDEX IF NOT EXISTS idx_memory_atoms_embedding ON memory_atoms USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
-- ============================================================
-- L2: 场景块
-- ============================================================
CREATE TABLE IF NOT EXISTS memory_scenes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
flow_id UUID REFERENCES flow_definitions(id) ON DELETE SET NULL,
scene_name VARCHAR(200) NOT NULL,
summary TEXT NOT NULL,
heat INTEGER DEFAULT 0,
content JSONB DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_memory_scenes_user ON memory_scenes(user_id, flow_id, heat DESC);
-- ============================================================
-- L3: 用户画像
-- ============================================================
CREATE TABLE IF NOT EXISTS memory_personas (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL UNIQUE REFERENCES users(id) ON DELETE CASCADE,
content JSONB NOT NULL DEFAULT '{}',
raw_text TEXT DEFAULT '',
version INTEGER DEFAULT 1,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- ============================================================
-- 会话元数据
-- ============================================================
CREATE TABLE IF NOT EXISTS memory_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
flow_id UUID NOT NULL REFERENCES flow_definitions(id) ON DELETE CASCADE,
session_id UUID NOT NULL,
flow_name VARCHAR(200) DEFAULT '',
message_count INTEGER DEFAULT 0,
last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(user_id, flow_id, session_id)
);
CREATE INDEX IF NOT EXISTS idx_memory_sessions_user ON memory_sessions(user_id, last_active_at DESC);

30
init-db/04-model-provider.sql

@ -0,0 +1,30 @@
-- 04-model-provider.sql
-- OpenAI-API-Compatible 模型供应商管理
-- 支持 LLM / Embedding / Rerank 三种模型类型,每种独立配置
CREATE TABLE IF NOT EXISTS model_providers (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(100) NOT NULL,
provider_type VARCHAR(50) NOT NULL,
base_url VARCHAR(500),
api_key TEXT,
extra_config JSONB DEFAULT '{}',
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS model_instances (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider_id UUID NOT NULL REFERENCES model_providers(id) ON DELETE CASCADE,
model_name VARCHAR(100) NOT NULL,
model_type VARCHAR(30) NOT NULL,
display_name VARCHAR(200),
capabilities JSONB DEFAULT '{}',
default_params JSONB DEFAULT '{}',
is_default BOOLEAN DEFAULT FALSE,
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_model_instances_type ON model_instances(model_type, is_active);
CREATE INDEX IF NOT EXISTS idx_model_instances_provider ON model_instances(provider_id);

7
init-db/05-flow-mode.sql

@ -0,0 +1,7 @@
-- 05-flow-mode.sql
-- FlowDefinition 新增 flow_mode 字段,区分对话型和工作流型
ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS flow_mode VARCHAR(20) DEFAULT 'chatflow';
-- 已有数据默认设为 chatflow(对话型),保持向后兼容
UPDATE flow_definitions SET flow_mode = 'chatflow' WHERE flow_mode IS NULL;

7
init-db/06-agent-model-ids.sql

@ -0,0 +1,7 @@
-- 04-agent-model-ids.sql
-- AgentConfig 新增 model_instance_id 和 embedding_model_id 可选外键
-- nullable=True,旧记录自动为 NULL,完全向后兼容
ALTER TABLE agent_configs ADD COLUMN IF NOT EXISTS model_instance_id UUID REFERENCES model_instances(id);
ALTER TABLE agent_configs ADD COLUMN IF NOT EXISTS embedding_model_id UUID REFERENCES model_instances(id);

27
init-db/07-hybrid-search.sql

@ -0,0 +1,27 @@
-- 为 memory_atoms 表添加全文搜索索引,支持混合检索
CREATE EXTENSION IF NOT EXISTS vector;
ALTER TABLE memory_atoms ADD COLUMN IF NOT EXISTS content_tsv tsvector;
CREATE INDEX IF NOT EXISTS idx_memory_atoms_content_tsv
ON memory_atoms USING GIN (content_tsv);
CREATE OR REPLACE FUNCTION update_memory_atoms_tsv() RETURNS trigger AS $$
BEGIN
NEW.content_tsv := to_tsvector('simple', COALESCE(NEW.content, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_trigger WHERE tgname = 'trg_memory_atoms_tsv'
) THEN
CREATE TRIGGER trg_memory_atoms_tsv
BEFORE INSERT OR UPDATE ON memory_atoms
FOR EACH ROW EXECUTE FUNCTION update_memory_atoms_tsv();
END IF;
END $$;
UPDATE memory_atoms SET content_tsv = to_tsvector('simple', COALESCE(content, ''));

18
init-db/08-flow-templates.sql

@ -0,0 +1,18 @@
-- 流模板系统
CREATE TABLE IF NOT EXISTS flow_templates (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(200) NOT NULL,
description TEXT,
category VARCHAR(50),
definition_json JSONB NOT NULL DEFAULT '{}',
icon VARCHAR(50),
sort_order INTEGER DEFAULT 0,
is_builtin BOOLEAN DEFAULT false,
usage_count INTEGER DEFAULT 0,
created_by UUID REFERENCES users(id),
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_flow_templates_category ON flow_templates(category);
CREATE INDEX IF NOT EXISTS idx_flow_templates_sort ON flow_templates(sort_order);

300
scripts/migrate_memory_redis_to_pg.py

@ -0,0 +1,300 @@
"""Redis → PostgreSQL 记忆数据迁移脚本
Redis 中现有记忆数据消息摘要迁移到 PostgreSQL memory_messages
使用方式:
python scripts/migrate_memory_redis_to_pg.py [--dry-run] [--user-id UUID]
选项:
--dry-run 仅扫描不实际写入
--user-id 只迁移指定用户的数据
--batch-size 每批写入条数 (默认100)
"""
import asyncio
import json
import uuid
import sys
import os
from datetime import datetime
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from config import settings
async def scan_redis_keys(redis: Redis, pattern: str) -> list[str]:
keys = []
cursor = 0
while True:
cursor, batch = await redis.scan(cursor, match=pattern, count=200)
keys.extend(batch)
if cursor == 0:
break
return keys
def parse_key_info(key: str) -> dict | None:
"""从 Redis key 中解析 user_id, flow_id, session_id"""
parts = key.split(":")
info = {}
uid_idx = None
fid_idx = None
sid_idx = None
for i, p in enumerate(parts):
try:
val = uuid.UUID(p)
if uid_idx is None:
uid_idx = i
info["user_id"] = val
elif sid_idx is None:
sid_idx = i
info["session_id"] = val
elif fid_idx is None:
fid_idx = i
info["flow_id"] = val
except ValueError:
if p == "messages" and uid_idx is not None and sid_idx is not None and fid_idx is None:
fid_idx = -1
continue
if not info.get("user_id"):
return None
return info
async def migrate_old_format_keys(redis: Redis, engine, dry_run: bool):
"""迁移旧格式: mem:{uid}:{fid}:{sid}:messages (List) 和 mem:{uid}:{fid}:{sid}:meta (Hash)"""
print(">>> 扫描旧格式 Redis 记忆键 ...")
keys = await scan_redis_keys(redis, "mem:*:messages")
keys += await scan_redis_keys(redis, "mem:*:meta")
print(f" 找到 {len(keys)} 个旧格式键")
migrated = 0
session_info = {}
for key in keys:
key_type = await redis.type(key)
info = parse_key_info(key)
if not info:
print(f" [跳过] 无法解析键: {key}")
continue
uid = info.get("user_id")
sid = info.get("session_id")
if key_type == "hash":
meta = await redis.hgetall(key)
session_info[str(sid)] = meta
elif key_type == "list":
messages = await redis.lrange(key, 0, -1)
for raw in messages:
try:
msg = json.loads(raw)
role = msg.get("role", "user")
content = msg.get("content", "")
ts_str = msg.get("ts", msg.get("timestamp", ""))
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_messages (user_id, flow_id, session_id, role, content, created_at)
VALUES (:uid, :fid, :sid, :role, :content, :ts)
ON CONFLICT DO NOTHING
"""),
{
"uid": uid,
"fid": info.get("flow_id"),
"sid": sid,
"role": role,
"content": content,
"ts": created_at,
},
)
await session.commit()
migrated += 1
except (json.JSONDecodeError, Exception) as e:
print(f" [错误] 解析消息失败: {raw[:80]}... -> {e}")
return migrated
async def migrate_new_format_cache(redis: Redis, engine, dry_run: bool):
"""迁移新格式缓存: mem:cache:msgs:{uid}:{sid} (String JSON array)"""
print(">>> 扫描新格式 Redis 消息缓存 ...")
keys = await scan_redis_keys(redis, "mem:cache:msgs:*")
print(f" 找到 {len(keys)} 个缓存键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 5:
continue
try:
uid = uuid.UUID(parts[3])
sid = uuid.UUID(parts[4])
except (ValueError, IndexError):
continue
raw = await redis.get(key)
if not raw:
continue
try:
messages = json.loads(raw)
except json.JSONDecodeError:
continue
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
ts_str = msg.get("ts", "")
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_messages (user_id, session_id, role, content, created_at)
VALUES (:uid, :sid, :role, :content, :ts)
ON CONFLICT DO NOTHING
"""),
{"uid": uid, "sid": sid, "role": role, "content": content, "ts": created_at},
)
await session.commit()
migrated += 1
return migrated
async def migrate_summaries(redis: Redis, engine, dry_run: bool):
"""迁移 Redis 摘要到 PostgreSQL memory_atoms"""
print(">>> 扫描 Redis 摘要缓存 ...")
keys = await scan_redis_keys(redis, "mem:summary:*")
print(f" 找到 {len(keys)} 个摘要键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 4:
continue
try:
uid = uuid.UUID(parts[2])
sid = uuid.UUID(parts[3])
except (ValueError, IndexError):
continue
summary = await redis.get(key)
if not summary or len(summary.strip()) < 10:
continue
if not dry_run:
async with AsyncSession(engine) as session:
result = await session.execute(
text("SELECT id FROM memory_atoms WHERE user_id = :uid AND source_session_id = :sid AND atom_type = 'summary'"),
{"uid": uid, "sid": sid},
)
existing = result.fetchone()
if existing:
await session.execute(
text("UPDATE memory_atoms SET content = :content, updated_at = NOW() WHERE id = :id"),
{"content": summary, "id": existing[0]},
)
else:
await session.execute(
text("""
INSERT INTO memory_atoms (user_id, atom_type, content, priority, source_session_id, created_at, updated_at)
VALUES (:uid, 'summary', :content, 60, :sid, NOW(), NOW())
"""),
{"uid": uid, "content": summary, "sid": sid},
)
await session.commit()
migrated += 1
return migrated
async def migrate_session_list(redis: Redis, engine, dry_run: bool):
"""迁移 mem:{uid}:sessions Set 中的会话列表"""
print(">>> 扫描会话列表 (mem:*:sessions)...")
keys = await scan_redis_keys(redis, "mem:*:sessions")
print(f" 找到 {len(keys)} 个会话列表键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 3:
continue
try:
uid = uuid.UUID(parts[1])
except ValueError:
continue
sessions = await redis.smembers(key)
for sid_str in sessions:
try:
sid = uuid.UUID(sid_str)
except ValueError:
continue
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_sessions (user_id, session_id, last_active_at, created_at)
VALUES (:uid, :sid, NOW(), NOW())
ON CONFLICT (user_id, session_id) DO NOTHING
"""),
{"uid": uid, "sid": sid},
)
await session.commit()
migrated += 1
return migrated
async def main():
dry_run = "--dry-run" in sys.argv
batch_size = 100
for i, arg in enumerate(sys.argv):
if arg == "--batch-size" and i + 1 < len(sys.argv):
batch_size = int(sys.argv[i + 1])
print("=" * 60)
print("Redis → PostgreSQL 记忆数据迁移")
print(f"模式: {'试运行(不写入)' if dry_run else '正式迁移'}")
print(f"数据库: {settings.DATABASE_URL}")
print(f"Redis: {settings.REDIS_URL}")
print("=" * 60)
engine = create_async_engine(settings.DATABASE_URL, echo=False)
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
try:
await redis.ping()
print("Redis 连接成功")
total = 0
total += await migrate_old_format_keys(redis, engine, dry_run)
total += await migrate_new_format_cache(redis, engine, dry_run)
total += await migrate_summaries(redis, engine, dry_run)
total += await migrate_session_list(redis, engine, dry_run)
print(f"\n迁移完成!共处理 {total} 条记录")
if dry_run:
print("提示: 使用不带 --dry-run 参数运行以实际写入数据")
finally:
await redis.aclose()
await engine.dispose()
if __name__ == "__main__":
asyncio.run(main())
Loading…
Cancel
Save