diff --git a/PLAN9.md b/PLAN9.md
new file mode 100644
index 0000000..56212ec
--- /dev/null
+++ b/PLAN9.md
@@ -0,0 +1,893 @@
+# PLAN9 — 全面升级计划:记忆增强 · 流重构 · 模型管理 · 节点扩充
+
+> 日期:2026-05-15
+> 状态:规划中,待确认后实施
+
+---
+
+## 一、记忆管理升级:吸收 Agent-Memory 分层架构
+
+### 1.1 现状分析
+
+**我们当前方案(PLAN8 已实施)**:
+- 纯 Redis 存储:List 存消息、Hash 存元数据、String 存摘要
+- 透明中间件:在 `flow_engine/router.py` 的 `execute_flow` / `execute_flow_stream` 中注入记忆
+- LLMNodeAgent.reply() 从 `context["_memory_context"]` 读取历史摘要 + 最近消息
+- 7 天 TTL 自动过期,异步记录不阻塞响应
+
+**Agent-Memory 项目(ali-agentscope-src/Agent-Memory)**:
+- TypeScript 实现,以 OpenClaw 插件形式运行
+- 四层记忆金字塔:L0(原始对话) → L1(结构化原子) → L2(场景块) → L3(用户画像)
+- 双写机制:JSONL 持久化 + SQLite 向量检索
+- 混合搜索:BM25 关键词 + Embedding 向量 + RRF 融合
+- Context Offload:长任务上下文溢出时 Mermaid 符号化压缩
+- Pipeline 调度:每 N 轮对话触发 L1 提取,延迟触发 L2/L3
+
+### 1.2 对比评估
+
+| 维度 | 我们方案 | Agent-Memory | 评价 |
+|------|---------|-------------|------|
+| **存储** | Redis(纯内存) | SQLite + JSONL(磁盘) | 我们快但不适合海量历史;他们慢但可持久化大量数据 |
+| **记忆分层** | 无分层(平铺消息列表) | L0/L1/L2/L3 四层金字塔 | **最大差距**:我们没有结构化提取,全是原始消息 |
+| **检索方式** | 全量最近 N 条 + 摘要 | BM25 + 向量 + RRF 混合搜索 | 我们只能按时间取最近消息,无法按语义检索 |
+| **去重** | 无 | L1 提取后 batchDedup(store/update/merge/skip) | 我们会重复存储相同信息 |
+| **用户画像** | 未实现 | L3 Persona 自动生成 + 增量更新 | 他们有完整的画像系统 |
+| **上下文溢出** | 无处理 | Context Offload + Mermaid 压缩 | 我们长对话会直接截断,丢失信息 |
+| **接入方式** | Python 原生集成 | TypeScript 插件,需 OpenClaw/Hermes 宿主 | **不能直接部署接入**,语言和架构不兼容 |
+| **性能** | 毫秒级(Redis) | 秒级(SQLite + LLM 提取) | 我们快但浅;他们慢但深 |
+
+### 1.3 结论
+
+**不能直接部署接入**:Agent-Memory 是 TypeScript 项目,依赖 OpenClaw/Hermes 宿主环境,与我们的 Python/FastAPI 架构完全不兼容。
+
+**应吸收其核心设计思想**,但存储底座需要重新选型——详见 1.4 节分析。
+
+### 1.4 记忆存储底座选型:Redis vs MongoDB vs PostgreSQL
+
+#### 1.4.1 当前 Redis 方案的数据结构与访问模式
+
+```
+当前 Redis 键结构:
+ mem:{uid}:{fid}:{sid}:messages → List(LPUSH 写入,LRANGE 读取最近 N 条)
+ mem:{uid}:{fid}:{sid}:meta → Hash(HSET 写入,HGETALL 读取)
+ mem:{uid}:{fid}:{sid}:summary → String(SETEX 写入,GET 读取)
+ mem:{uid}:sessions → Set(SADD 写入,SMEMBERS 读取)
+
+访问模式:
+ 写入:record_exchange() → pipeline 批量 LPUSH + HSET + SADD(每次对话 1 次)
+ 读取:inject_memory() → LRANGE + GET(每次对话 1 次,延迟 < 1ms)
+ 删除:delete_session() → KEYS + DEL(低频,用户主动清除)
+ 列表:list_user_sessions() → SMEMBERS + KEYS + HGETALL(低频,管理页面)
+```
+
+#### 1.4.2 三种数据库四维度对比
+
+**维度一:内存占用与性能**
+
+| 指标 | Redis | MongoDB | PostgreSQL |
+|------|-------|---------|-----------|
+| **写入延迟** | < 1ms(纯内存) | 2-5ms(WAL + 内存映射) | 3-8ms(WAL + fsync) |
+| **读取延迟** | < 1ms | 1-3ms(WiredTiger 缓存命中) | 1-5ms(shared_buffers 命中) |
+| **inject_memory 延迟** | ~1ms | ~5ms | ~8ms |
+| **内存消耗(1万会话×40条消息)** | ~800MB(纯内存,无压缩) | ~200MB(磁盘 + 缓存热数据) | ~150MB(磁盘 + shared_buffers) |
+| **内存消耗(100万会话)** | ~80GB(不可行,需分片) | ~20GB(缓存 + 磁盘冷数据淘汰) | ~15GB(磁盘为主,热数据缓存) |
+| **大规模场景性能** | 受内存上限约束,超内存即 OOM | 缓存冷热分离,可支撑 TB 级 | 磁盘为主,可支撑 PB 级 |
+| **向量检索性能** | 需 Redis Stack(RediSearch),10万向量 ~50ms | 原生向量索引(Atlas Vector Search),10万向量 ~30ms | pgvector 扩展,10万向量 ~40ms |
+| **全文检索性能** | 需 Redis Stack(FTS 模块),非标准 | 原生文本索引,成熟 | tsvector + GIN 索引,成熟 |
+
+**关键结论**:
+- Redis 在小规模(< 10万会话)下性能最优,但内存成本线性增长
+- MongoDB 和 PostgreSQL 在大规模场景下更优,冷数据自动落盘不占内存
+- 向量检索三者性能接近,但 Redis 需额外安装 Stack 模块
+
+**维度二:数据模型适配性**
+
+| 指标 | Redis | MongoDB | PostgreSQL |
+|------|-------|---------|-----------|
+| **消息列表(L0)** | List(天然有序,但无结构化查询) | 嵌入文档数组或独立文档(灵活) | 关系表 + 时间索引(标准) |
+| **结构化原子(L1)** | Hash + Sorted Set(需手动管理) | 文档(天然 JSON,灵活 schema) | JSONB 列(灵活 + 可索引) |
+| **场景块(L2)** | Hash(需手动序列化) | 文档(嵌套结构自然表达) | JSONB 或独立表 |
+| **用户画像(L3)** | String(纯文本,无查询能力) | 文档(结构化,可按字段查询) | JSONB 或独立表(可按字段查询) |
+| **向量存储** | 需 Redis Stack(非标准部署) | 原生支持(Atlas Vector Search) | pgvector 扩展(成熟) |
+| **数据迁移复杂度** | 基准(当前方案) | 中(需新建 MongoDB 连接 + 集合设计) | **低**(已有 PostgreSQL 连接池,复用 SQLAlchemy) |
+| **Schema 演进** | 无 Schema(灵活但无约束) | 无 Schema(灵活,可加验证) | 有 Schema(需迁移 SQL,但类型安全) |
+
+**关键结论**:
+- MongoDB 对文档型记忆数据最自然(JSON 原生存储,无需 ORM 映射)
+- PostgreSQL 迁移成本最低(项目已用 SQLAlchemy + asyncpg,复用现有连接池)
+- Redis 对 L0 消息列表操作最自然(LPUSH/LRANGE),但对 L1/L2/L3 结构化数据支持弱
+
+**维度三:扩展性与维护成本**
+
+| 指标 | Redis | MongoDB | PostgreSQL |
+|------|-------|---------|-----------|
+| **水平扩展** | Redis Cluster(分片复杂,跨 slot 操作受限) | 原生分片(Shard Key 自动路由) | 读写分离 + 分区表(Citus 扩展) |
+| **运维复杂度** | 低(单实例简单,Cluster 复杂) | 中(副本集 + 分片需监控) | **最低**(项目已有 PostgreSQL 运维) |
+| **新增基础设施** | 无(已部署) | **需新增 MongoDB 服务** | **无需新增**(复用现有 PostgreSQL) |
+| **备份恢复** | RDB/AOF(需配置持久化策略) | mongodump/oplog | pg_dump/WAL 归档(**已有**) |
+| **监控工具** | redis-cli INFO | MongoDB Compass/Cloud Manager | pg_stat_statements(**已有**) |
+| **长期维护成本** | 高(内存持续增长需扩容) | 中 | **低**(磁盘扩容便宜,运维体系成熟) |
+
+**关键结论**:
+- **PostgreSQL 运维成本最低**:项目已有 PostgreSQL 实例、连接池、备份策略,无需新增基础设施
+- MongoDB 需要额外部署和维护一套数据库服务
+- Redis 长期成本最高:内存是磁盘的 10-50 倍价格
+
+**维度四:事务与一致性需求**
+
+| 指标 | Redis | MongoDB | PostgreSQL |
+|------|-------|---------|-----------|
+| **事务支持** | Pipeline(非原子,MULTI 仅单 key 原子) | 4.0+ 多文档事务(副本集) | 完整 ACID 事务 |
+| **记忆数据事务需求** | 低(单次写入可容忍部分失败) | 低 | 低 |
+| **模型配置事务需求** | 不涉及 | 不涉及 | 高(需与现有业务表关联) |
+| **数据一致性** | 最终一致(AOF 异步刷盘可能丢 1s 数据) | 强一致(副本集 w:majority) | 强一致(WAL 同步刷盘) |
+| **记忆丢失风险** | AOF 异步刷盘时宕机可能丢失 | 副本集保证不丢 | WAL 保证不丢 |
+
+**关键结论**:
+- 记忆数据对事务要求低(丢失一条消息可接受),三种数据库均满足
+- 但 PostgreSQL 的强一致性在模型配置等业务数据上更有优势
+- Redis AOF 异步模式下有 1 秒数据丢失窗口,对"不失忆"要求有风险
+
+#### 1.4.3 综合评估与选型决策
+
+| 评估维度 | Redis | MongoDB | PostgreSQL | 权重 |
+|---------|-------|---------|-----------|------|
+| 性能(小规模) | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ | 20% |
+| 性能(大规模) | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 25% |
+| 数据模型适配 | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 20% |
+| 迁移成本 | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | 15% |
+| 运维成本 | ⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | 10% |
+| 事务/一致性 | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | 10% |
+| **加权总分** | **3.35** | **3.25** | **3.95** | — |
+
+#### 1.4.4 选型结论:PostgreSQL + Redis 混合方案
+
+**选择 PostgreSQL 作为记忆主存储,Redis 保留为热数据缓存层**。
+
+理由:
+1. **零新增基础设施**:项目已有 PostgreSQL(asyncpg + SQLAlchemy),无需部署新服务
+2. **迁移成本最低**:复用现有连接池、ORM、迁移框架,新增表即可
+3. **PGVector 原生支持**:向量检索无需额外模块,`CREATE EXTENSION vector` 即可
+4. **JSONB 灵活 + 可索引**:L1/L2/L3 结构化数据用 JSONB 存储,支持 GIN 索引按字段查询
+5. **强一致性保证不失忆**:WAL 同步刷盘,无 Redis AOF 的 1 秒丢失窗口
+6. **大规模成本优势**:磁盘存储比内存便宜 10-50 倍,100 万会话仅需 ~15GB 磁盘
+
+**Redis 保留角色**:
+- 热数据缓存:最近 10 条消息缓存到 Redis(inject_memory 时直接读缓存,延迟 < 1ms)
+- 速率限制:已有 cache_manager 的限流功能
+- 会话状态:当前活跃会话的临时状态
+
+#### 1.4.5 PostgreSQL 记忆数据表设计
+
+```sql
+-- L0: 原始对话消息
+CREATE TABLE memory_messages (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id UUID NOT NULL REFERENCES users(id),
+ flow_id UUID NOT NULL REFERENCES flow_definitions(id),
+ session_id UUID NOT NULL,
+ role VARCHAR(20) NOT NULL, -- "user" / "assistant"
+ content TEXT NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX idx_memory_messages_session ON memory_messages(user_id, flow_id, session_id, created_at DESC);
+CREATE INDEX idx_memory_messages_user_flow ON memory_messages(user_id, flow_id, created_at DESC);
+
+-- L1: 结构化记忆原子
+CREATE TABLE memory_atoms (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id UUID NOT NULL REFERENCES users(id),
+ flow_id UUID REFERENCES flow_definitions(id), -- NULL 表示全局
+ atom_type VARCHAR(20) NOT NULL, -- "persona" / "episodic" / "instruction"
+ content TEXT NOT NULL,
+ priority SMALLINT DEFAULT 50, -- 0-100,越高越核心
+ source_session_id UUID, -- 来源会话
+ metadata JSONB DEFAULT '{}', -- 扩展元数据
+ embedding vector(1536), -- 向量(需 pgvector 扩展)
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX idx_memory_atoms_user ON memory_atoms(user_id, atom_type, priority DESC);
+CREATE INDEX idx_memory_atoms_embedding ON memory_atoms USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
+
+-- L2: 场景块
+CREATE TABLE memory_scenes (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id UUID NOT NULL REFERENCES users(id),
+ flow_id UUID REFERENCES flow_definitions(id),
+ scene_name VARCHAR(200) NOT NULL,
+ summary TEXT NOT NULL,
+ heat INTEGER DEFAULT 0, -- 热度(访问频率)
+ content JSONB DEFAULT '{}', -- 场景完整内容
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX idx_memory_scenes_user ON memory_scenes(user_id, flow_id, heat DESC);
+
+-- L3: 用户画像
+CREATE TABLE memory_personas (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id UUID NOT NULL UNIQUE REFERENCES users(id), -- 每用户一条
+ content JSONB NOT NULL DEFAULT '{}', -- 结构化画像
+ raw_text TEXT DEFAULT '', -- 原始画像文本(注入 LLM 用)
+ version INTEGER DEFAULT 1,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+-- 会话元数据
+CREATE TABLE memory_sessions (
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
+ user_id UUID NOT NULL REFERENCES users(id),
+ flow_id UUID NOT NULL REFERENCES flow_definitions(id),
+ session_id UUID NOT NULL,
+ flow_name VARCHAR(200) DEFAULT '',
+ message_count INTEGER DEFAULT 0,
+ last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ UNIQUE(user_id, flow_id, session_id)
+);
+
+CREATE INDEX idx_memory_sessions_user ON memory_sessions(user_id, last_active_at DESC);
+```
+
+#### 1.4.6 MemoryManager 改造方案
+
+```python
+class MemoryManager:
+ def __init__(self, db_session_factory, redis: Redis):
+ self.db_session_factory = db_session_factory # AsyncSessionLocal
+ self.redis = redis # 热数据缓存
+
+ async def inject_memory(self, user_id, flow_id, session_id, context):
+ # 1. 先查 Redis 缓存(最近 10 条消息)
+ cached = await self._get_cached_messages(user_id, flow_id, session_id)
+ if cached:
+ recent_messages = cached
+ else:
+ # 2. 缓存未命中,查 PostgreSQL
+ recent_messages = await self._query_recent_messages(user_id, flow_id, session_id)
+ # 3. 回填 Redis 缓存
+ await self._cache_messages(user_id, flow_id, session_id, recent_messages)
+
+ # 4. 查 L1 原子(PostgreSQL)
+ atoms = await self._query_relevant_atoms(user_id, flow_id, context.get("input", ""))
+
+ # 5. 查 L3 画像(PostgreSQL)
+ persona = await self._query_persona(user_id)
+
+ context["_memory_context"] = {
+ "recent_messages": recent_messages,
+ "atoms": atoms,
+ "persona": persona,
+ "session_id": session_id,
+ }
+
+ async def record_exchange(self, user_id, flow_id, session_id, user_msg, assistant_msg, flow_name=""):
+ ts = datetime.utcnow()
+
+ # 1. 写 PostgreSQL(主存储,保证持久化)
+ async with self.db_session_factory() as session:
+ session.add(MemoryMessage(
+ user_id=user_id, flow_id=flow_id, session_id=session_id,
+ role="user", content=user_msg, created_at=ts,
+ ))
+ session.add(MemoryMessage(
+ user_id=user_id, flow_id=flow_id, session_id=session_id,
+ role="assistant", content=assistant_msg, created_at=ts,
+ ))
+ await session.commit()
+
+ # 2. 更新 Redis 缓存(热数据)
+ await self._cache_append_message(user_id, flow_id, session_id, [
+ {"role": "user", "content": user_msg, "ts": ts.isoformat()},
+ {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
+ ])
+
+ # 3. 异步触发 L1 提取
+ asyncio.create_task(self._maybe_extract_atoms(user_id, flow_id, session_id))
+```
+
+#### 1.4.7 迁移步骤
+
+```
+Step 1: 创建 PostgreSQL 记忆表(init-db/03-memory-tables.sql)
+Step 2: 安装 pgvector 扩展(init-db/03-memory-tables.sql 中 CREATE EXTENSION IF NOT EXISTS vector)
+Step 3: 改造 MemoryManager,PostgreSQL 为主存储,Redis 为缓存
+Step 4: 数据迁移脚本:Redis → PostgreSQL(一次性,读取 Redis 中现有记忆数据写入 PG)
+Step 5: 验证全链路:inject_memory / record_exchange / delete_session
+Step 6: 确认无误后,Redis 中的记忆键可设置短 TTL 自然过期
+```
+
+#### 1.4.8 迁移 SQL 文件
+
+详见 `init-db/03-memory-tables.sql`(实施时创建),包含:
+- `CREATE EXTENSION IF NOT EXISTS vector`
+- 上述 5 张表的 CREATE TABLE + INDEX
+- 幂等执行(IF NOT EXISTS)
+
+---
+
+## 二、全面升级流管理(对标 Dify)
+
+### 2.1 现状问题
+
+1. **节点类型不足**:当前 11 种节点(trigger/llm/tool/mcp/notify/condition/rag/loop/merge/code/output),Dify 有 15+ 种
+2. **记忆不是默认行为**:虽然 PLAN8 实现了透明中间件,但缺乏 Dify 那样的"对话型 vs 工作流型"区分
+3. **变量系统原始**:仅 `{{node_id.output}}` 模板,无类型、无作用域、无聚合
+4. **版本管理不完善**:有 FlowVersion 快照但缺少完整的草稿/发布分离流程
+
+### 2.2 流类型区分(Chatflow vs Workflow)
+
+**Dify 的核心设计**:区分两种应用模式,记忆机制不同
+
+| 模式 | 记忆 | 典型场景 | 对应我们的 |
+|------|------|---------|-----------|
+| **Chatflow(对话型)** | 自动维护对话历史,LLM 节点自动注入上下文 | 客服、助手、问答 | FlowChat.vue 使用的流 |
+| **Workflow(工作流型)** | 无自动记忆,每次执行独立 | 数据处理、批量任务、API 调用 | API 网关调用的流 |
+
+**改造方案**:
+
+```python
+# FlowDefinition 新增字段
+flow_mode = Column(String(20), default="chatflow") # "chatflow" | "workflow"
+
+# 记忆中间件逻辑调整
+if f.flow_mode == "chatflow":
+ await mm.inject_memory(...) # 注入记忆
+ asyncio.create_task(mm.record_exchange(...)) # 记录对话
+else:
+ pass # workflow 模式不注入记忆
+```
+
+前端 FlowEditor.vue 新增流类型选择(创建流时选择,创建后不可更改)。
+
+### 2.3 记忆默认启用策略
+
+**核心原则**:所有 Chatflow 类型的流,记忆管理是默认行为,不需要用户在流中添加节点。
+
+```
+用户输入 → router.py
+ │
+ ├─ Chatflow 模式:
+ │ ├─ 执行前:inject_memory() → 检索历史 + 画像 → 注入 context
+ │ ├─ LLM 调用:LLMNodeAgent 自动读取 _memory_context
+ │ └─ 执行后:record_exchange() → 异步记录 + L1 提取
+ │
+ └─ Workflow 模式:
+ └─ 直接执行,无记忆注入
+```
+
+### 2.4 流创建流程增强
+
+**当前**:创建流 → 编辑画布 → 发布
+
+**升级后**:
+```
+创建流
+ ├─ 选择流类型(Chatflow / Workflow)
+ ├─ 选择模板(可选,从模板市场选择)
+ └─ 进入编辑器
+
+编辑画布
+ ├─ 拖拽节点
+ ├─ 配置节点参数
+ └─ 连线
+
+发布
+ ├─ 拓扑完整性检查(无孤立节点、起始/结束节点存在)
+ ├─ 必填参数校验
+ ├─ 创建版本快照(不可变)
+ └─ 更新 published_version_id 指针
+```
+
+### 2.5 流市场 vs 流列表的关系梳理
+
+**当前问题**:流市场的"已上架工作流列表"和一级菜单"流列表"内容重叠,用户困惑。
+
+**梳理方案**:
+
+| 页面 | 定位 | 数据来源 | 操作 |
+|------|------|---------|------|
+| **流列表**(管理端) | 我创建/我管理的所有流 | `GET /api/flow/definitions` | 编辑、删除、发布/下架、上架到市场 |
+| **流市场**(用户端) | 已上架到市场的公开流 | `GET /api/flow/market` | 查看详情、安装/使用、评价 |
+| **模板中心**(创建时) | 系统预置 + 社区贡献的模板 | `GET /api/flow/templates` | 一键使用模板创建新流 |
+
+**关键区分**:
+- 流列表 = 私有工作区,看到的是自己管理的流
+- 流市场 = 公开商店,看到的是别人发布到市场的流
+- 模板中心 = 快速起步,创建新流时的入口
+
+---
+
+## 三、智能体管理改造:OpenAI-API-Compatible 模型管理
+
+### 3.1 现状问题
+
+当前 `AgentConfig` 模型只有一个 `model` 字段(String),无法区分模型类型,无法管理 Embedding/Rerank 模型,无法支持多供应商。
+
+```python
+# 当前 AgentConfig
+model = Column(String(50), default="gpt-4o-mini") # 只能存一个模型名
+```
+
+### 3.2 Dify 的模型供应商架构
+
+Dify 采用三层架构:
+
+```
+ModelProvider(供应商抽象层)
+ ├── 模型类型:LLM / Embedding / Rerank / TTS / Speech-to-Text
+ ├── 供应商实例:OpenAI / Anthropic / ZhipuAI / Ollama / OpenAI-API-Compatible
+ └── 模型实例:gpt-4o / text-embedding-3-small / cohere-rerank
+```
+
+**OpenAI-API-Compatible 模式**:任何实现了 OpenAI API 格式的服务都能即插即用,只需配置 base_url + api_key + model_name。
+
+### 3.3 改造方案
+
+#### 数据模型
+
+```python
+# 新增:模型供应商表
+class ModelProvider(Base):
+ __tablename__ = "model_providers"
+
+ id = Column(UUID, primary_key=True, default=uuid.uuid4)
+ name = Column(String(100), nullable=False) # "OpenAI" / "智谱AI" / "本地Ollama"
+ provider_type = Column(String(50), nullable=False) # "openai" / "zhipu" / "ollama" / "openai_compatible"
+ base_url = Column(String(500)) # API 端点
+ api_key = Column(Text) # 加密存储
+ extra_config = Column(JSON, default=dict) # 供应商特有配置
+ is_active = Column(Boolean, default=True)
+ created_at = Column(DateTime, default=datetime.utcnow)
+
+
+# 新增:模型实例表
+class ModelInstance(Base):
+ __tablename__ = "model_instances"
+
+ id = Column(UUID, primary_key=True, default=uuid.uuid4)
+ provider_id = Column(UUID, ForeignKey("model_providers.id"))
+ model_name = Column(String(100), nullable=False) # "gpt-4o" / "embedding-3" / "bge-rerank"
+ model_type = Column(String(30), nullable=False) # "llm" / "embedding" / "rerank"
+ display_name = Column(String(200)) # "GPT-4o" / "文本嵌入v3"
+ capabilities = Column(JSON, default=dict) # {"vision": true, "function_calling": true, "max_tokens": 128000}
+ default_params = Column(JSON, default=dict) # {"temperature": 0.7, "top_p": 1.0}
+ is_default = Column(Boolean, default=False) # 是否为该类型的默认模型
+ is_active = Column(Boolean, default=True)
+ created_at = Column(DateTime, default=datetime.utcnow)
+```
+
+#### 各模型类型的配置参数
+
+| 模型类型 | 通用参数 | 特有参数 |
+|---------|---------|---------|
+| **LLM** | model, temperature, top_p, max_tokens, stream | vision(多模态)、function_calling、response_format |
+| **Embedding** | model, dimensions | encoding_format, input_type |
+| **Rerank** | model, top_n | query, documents, return_documents |
+
+#### API 端点设计
+
+```
+# 供应商管理
+POST /api/model-providers/ # 添加供应商
+GET /api/model-providers/ # 列出供应商
+PUT /api/model-providers/{id} # 更新供应商
+DELETE /api/model-providers/{id} # 删除供应商
+POST /api/model-providers/{id}/test # 测试连通性
+
+# 模型实例管理
+POST /api/model-providers/{id}/models/ # 添加模型
+GET /api/model-instances/ # 列出所有模型(支持 ?type=llm 筛选)
+PUT /api/model-instances/{id} # 更新模型
+DELETE /api/model-instances/{id} # 删除模型
+POST /api/model-instances/{id}/test # 测试模型调用
+
+# 默认模型设置
+PUT /api/model-defaults/ # 设置各类型默认模型
+GET /api/model-defaults/ # 获取各类型默认模型
+```
+
+#### 前端页面
+
+新增 `ModelProviderManager.vue`(管理端):
+- 供应商列表(卡片式,显示名称、类型、状态、模型数量)
+- 添加供应商表单(选择类型 → 填写 base_url/api_key → 自动检测可用模型)
+- 模型实例列表(按类型 Tab 分组:LLM / Embedding / Rerank)
+- 每个模型可设置默认参数、是否为默认模型
+- 测试按钮:发送测试请求验证连通性
+
+#### AgentConfig 关联调整
+
+```python
+# AgentConfig 改造
+class AgentConfig(Base):
+ # ...现有字段保留...
+ model = Column(String(50), default="gpt-4o-mini") # 保留兼容
+ model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联模型实例
+ embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联嵌入模型
+```
+
+#### 引擎层适配
+
+```python
+# engine.py 中 LLMNodeAgent 改造
+class LLMNodeAgent(AgentBase):
+ async def reply(self, msg, **kwargs):
+ context = kwargs.get("context", {})
+ config = self.config
+
+ # 优先使用 model_instance_id 获取模型配置
+ model_instance_id = config.get("model_instance_id")
+ if model_instance_id:
+ provider, model = await self._resolve_model(model_instance_id)
+ base_url = provider.base_url
+ api_key = provider.api_key
+ model_name = model.model_name
+ else:
+ # fallback 到旧逻辑
+ base_url = settings.OPENAI_BASE_URL
+ api_key = settings.OPENAI_API_KEY
+ model_name = config.get("model", "gpt-4o-mini")
+```
+
+### 3.4 实施优先级
+
+| 步骤 | 内容 | 优先级 |
+|------|------|--------|
+| Step 1 | 创建 ModelProvider + ModelInstance 数据模型 + 迁移 SQL | **P0** |
+| Step 2 | 实现供应商/模型 CRUD API | **P0** |
+| Step 3 | 前端 ModelProviderManager.vue 页面 | **P0** |
+| Step 4 | LLMNodeAgent 适配 model_instance_id | P1 |
+| Step 5 | RAGNodeAgent 适配 embedding_model_id | P1 |
+| Step 6 | AgentConfig 关联 model_instance_id | P1 |
+| Step 7 | 供应商自动检测可用模型 | P2 |
+
+---
+
+## 四、流编辑器节点扩充(对标 Dify)
+
+### 4.1 现有节点 vs Dify 节点对比
+
+| 节点 | 我们有 | Dify 有 | 差距说明 |
+|------|:------:|:-------:|---------|
+| 触发/开始 | ✅ trigger | ✅ start | 我们多了企微触发,Dify 只有输入变量 |
+| LLM | ✅ llm | ✅ llm | 基本对齐 |
+| 工具调用 | ✅ tool | ✅ tool | 基本对齐 |
+| MCP | ✅ mcp | ❌ | 我们独有 |
+| 通知 | ✅ notify | ❌ | 我们独有(企微通知) |
+| 条件分支 | ✅ condition | ✅ if-else | 基本对齐 |
+| RAG检索 | ✅ rag | ✅ knowledge-retrieval | 基本对齐 |
+| 循环 | ✅ loop | ✅ iteration | Dify 是数组迭代,我们是通用循环 |
+| 变量聚合 | ✅ merge | ✅ variable-assigner | 基本对齐 |
+| 代码执行 | ✅ code | ✅ code | 基本对齐 |
+| 输出/结束 | ✅ output | ✅ end | 基本对齐 |
+| **HTTP 请求** | ❌ | ✅ http-request | **缺失**:调用外部 API |
+| **问题分类器** | ❌ | ✅ question-classifier | **缺失**:意图路由 |
+| **模板转换** | ❌ | ✅ template-transform | **缺失**:Jinja2 格式化 |
+| **变量赋值** | ❌ | ✅ variable-assigner | **缺失**:运行时变量操作 |
+| **迭代** | ❌ | ✅ iteration | **缺失**:数组逐项处理(与 loop 不同) |
+| **问题优化** | ❌ | ✅ question-optimiser | **缺失**:检索前 query 改写 |
+
+### 4.2 新增节点方案
+
+#### P0 — 必须新增(使用频率高)
+
+**1. HTTP 请求节点(http_request)**
+
+```
+功能:发送 HTTP 请求(GET/POST/PUT/DELETE/PATCH)
+配置参数:
+ - method: 请求方法
+ - url: 请求地址(支持变量模板)
+ - headers: 请求头(JSON)
+ - body: 请求体(支持变量模板)
+ - auth_type: 认证方式(none/api_key/bearer/basic/oauth2)
+ - auth_config: 认证配置
+ - timeout: 超时时间(秒)
+ - retry_count: 重试次数
+输出:
+ - status_code: 状态码
+ - headers: 响应头
+ - body: 响应体(JSON 解析)
+ - raw: 原始文本
+前端:HttpRequestConfig.vue
+```
+
+**2. 问题分类器节点(question_classifier)**
+
+```
+功能:基于 LLM 对用户输入进行意图分类,路由到不同分支
+配置参数:
+ - model: 使用的 LLM(可选,默认用系统默认)
+ - categories: 分类列表 [{name, description}]
+ - instruction: 分类指令(补充说明)
+输出:
+ - category: 分类结果
+ - confidence: 置信度
+ - 多出口:每个分类一个出口
+前端:QuestionClassifierConfig.vue
+引擎:调用 LLM 做分类,根据结果走不同分支(类似 condition 但基于 LLM)
+```
+
+**3. 变量赋值节点(variable_assigner)**
+
+```
+功能:在运行时对变量进行赋值/修改操作
+配置参数:
+ - assignments: [{target_var, source_type, source_value}]
+ - source_type: "constant" / "upstream_output" / "template" / "expression"
+输出:
+ - 变量值更新到 context 中
+前端:VariableAssignerConfig.vue
+```
+
+#### P1 — 建议新增(提升体验)
+
+**4. 模板转换节点(template_transform)**
+
+```
+功能:使用 Jinja2 模板语法对变量进行格式化/拼接/转换
+配置参数:
+ - template: Jinja2 模板字符串
+ - output_type: 输出类型(string/json/array)
+输出:
+ - rendered: 渲染后的文本
+前端:TemplateTransformConfig.vue
+引擎:使用 jinja2 库渲染模板
+```
+
+**5. 迭代节点(iteration)**
+
+```
+功能:对列表/数组逐项处理,内部可嵌套子工作流
+与 loop 的区别:
+ - loop:固定次数或条件循环,处理同一逻辑
+ - iteration:遍历数组,每次迭代处理一个元素,输出结果数组
+配置参数:
+ - input_array: 输入数组(变量引用)
+ - output_variable: 每次迭代的输出变量名
+ - max_iterations: 最大迭代次数
+输出:
+ - output: 结果数组
+前端:IterationConfig.vue
+```
+
+**6. 问题优化节点(question_optimiser)**
+
+```
+功能:在 RAG 检索前对用户 query 进行改写/扩展,提升检索召回率
+配置参数:
+ - model: 使用的 LLM
+ - strategy: 优化策略(rewrite/expand/decompose)
+ - instruction: 自定义优化指令
+输出:
+ - optimized_query: 优化后的查询
+ - original_query: 原始查询
+前端:QuestionOptimiserConfig.vue
+```
+
+### 4.3 节点前端配置组件清单
+
+| 新增节点 | 配置组件 | 复杂度 |
+|---------|---------|--------|
+| http_request | HttpRequestConfig.vue | 中 |
+| question_classifier | QuestionClassifierConfig.vue | 中 |
+| variable_assigner | VariableAssignerConfig.vue | 低 |
+| template_transform | TemplateTransformConfig.vue | 低 |
+| iteration | IterationConfig.vue | 中 |
+| question_optimiser | QuestionOptimiserConfig.vue | 低 |
+
+### 4.4 引擎层改造
+
+```python
+# engine.py _create_agent() 新增分支
+elif node_type == "http_request":
+ return HttpRequestNodeAgent(...)
+elif node_type == "question_classifier":
+ return QuestionClassifierNodeAgent(...)
+elif node_type == "variable_assigner":
+ return VariableAssignerNodeAgent(...)
+elif node_type == "template_transform":
+ return TemplateTransformNodeAgent(...)
+elif node_type == "iteration":
+ return IterationNodeAgent(...)
+elif node_type == "question_optimiser":
+ return QuestionOptimiserNodeAgent(...)
+```
+
+---
+
+## 五、版本管理与模板系统完善
+
+### 5.1 当前版本管理现状
+
+- ✅ FlowVersion 模型已存在
+- ✅ publish_flow / unpublish_flow / rollback_flow API 已实现
+- ✅ 执行时加载 published_version 的 definition_json
+- ❌ 缺少前端版本管理界面(版本列表、对比、回滚按钮)
+- ❌ 缺少发布前的完整性校验
+- ❌ 缺少草稿自动保存
+
+### 5.2 完善方案
+
+**前端版本管理**:
+- FlowEditor.vue 工具栏增加"版本历史"按钮
+- 版本历史弹窗:显示版本列表(版本号、发布时间、发布人、变更说明)
+- 支持查看历史版本的定义 JSON
+- 支持回滚到指定版本
+- 发布时弹出"变更说明"输入框
+
+**发布前校验**:
+```python
+async def _validate_before_publish(definition: dict) -> list[str]:
+ errors = []
+ nodes = definition.get("nodes", [])
+ edges = definition.get("edges", [])
+
+ # 1. 必须有起始节点
+ if not any(n.get("type") == "trigger" for n in nodes):
+ errors.append("缺少触发/起始节点")
+
+ # 2. Chatflow 必须有 LLM 节点
+ if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes):
+ errors.append("对话型流必须包含至少一个 LLM 节点")
+
+ # 3. 无孤立节点
+ connected_ids = set()
+ for e in edges:
+ connected_ids.add(e["source"])
+ connected_ids.add(e["target"])
+ for n in nodes:
+ if n["id"] not in connected_ids and len(nodes) > 1:
+ errors.append(f"节点 '{n.get('label', n['id'])}' 未连接")
+
+ return errors
+```
+
+**草稿自动保存**:
+- 前端 FlowEditor.vue 每 30 秒自动保存草稿到 `draft_definition_json`
+- 切换流时检测未保存草稿,提示用户
+
+### 5.3 模板系统增强
+
+**当前**:硬编码 2 个模板(文档处理流、企微通知流)
+
+**升级后**:
+- 模板存储到数据库(新增 `flow_templates` 表)
+- 支持管理员创建模板(从已有流"另存为模板")
+- 模板分类:客服对话 / 文档处理 / 数据分析 / 通知推送 / 自定义
+- 创建流时显示模板选择页面(可选,也可从空白创建)
+
+---
+
+## 六、实施路线图
+
+### Phase 1 — 基础架构升级(P0,2-3 周)
+
+| 任务 | 涉及文件 | 依赖 |
+|------|---------|------|
+| 1.1 创建 PostgreSQL 记忆表 + pgvector 扩展 | init-db/03-memory-tables.sql, models/\_\_init\_\_.py | 无 |
+| 1.2 改造 MemoryManager:PG 主存储 + Redis 缓存 | memory/manager.py | 1.1 |
+| 1.3 Redis → PostgreSQL 数据迁移脚本 | scripts/migrate_memory_redis_to_pg.py | 1.2 |
+| 1.4 创建 ModelProvider + ModelInstance 数据模型 | models/\_\_init\_\_.py, init-db/04-model-provider.sql | 无 |
+| 1.5 实现供应商/模型 CRUD API | modules/model_provider/router.py | 1.4 |
+| 1.6 前端 ModelProviderManager.vue | frontend/src/views/model/ | 1.5 |
+| 1.7 FlowDefinition 新增 flow_mode 字段 | models/\_\_init\_\_.py, init-db/ | 无 |
+| 1.8 记忆中间件按 flow_mode 区分 | flow_engine/router.py | 1.7 |
+| 1.9 新增 HTTP 请求节点 | engine.py, HttpRequestConfig.vue | 无 |
+| 1.10 新增问题分类器节点 | engine.py, QuestionClassifierConfig.vue | 无 |
+| 1.11 新增变量赋值节点 | engine.py, VariableAssignerConfig.vue | 无 |
+
+### Phase 2 — 记忆分层 + 模型适配(P1,2-3 周)
+
+| 任务 | 涉及文件 | 依赖 |
+|------|---------|------|
+| 2.1 MemoryManager 增加 L1 结构化提取(PG 存储) | memory/manager.py | Phase 1 |
+| 2.2 MemoryManager 增加 L1 去重 | memory/manager.py | 2.1 |
+| 2.3 LLMNodeAgent 适配 model_instance_id | engine.py | 1.5 |
+| 2.4 RAGNodeAgent 适配 embedding_model_id | engine.py | 1.5 |
+| 2.5 AgentConfig 关联 model_instance_id(向后兼容) | models/\_\_init\_\_.py, agent_manager/ | 1.5 |
+| 2.6 新增模板转换节点 | engine.py, TemplateTransformConfig.vue | 无 |
+| 2.7 新增迭代节点 | engine.py, IterationConfig.vue | 无 |
+| 2.8 发布前校验 + 版本管理前端 | flow_engine/router.py, FlowEditor.vue | 无 |
+
+### Phase 3 — 画像 + 检索优化(P2,2-3 周)
+
+| 任务 | 涉及文件 | 依赖 |
+|------|---------|------|
+| 3.1 L2 场景块提取(PG 存储) | memory/manager.py | Phase 2 |
+| 3.2 L3 用户画像生成(PG 存储) | memory/manager.py | 3.1 |
+| 3.3 混合检索(pgvector 向量 + tsvector 全文 + RRF) | memory/manager.py | Embedding 模型 |
+| 3.4 新增问题优化节点 | engine.py, QuestionOptimiserConfig.vue | 3.3 |
+| 3.5 模板系统数据库化 | flow_templates 表, router.py | 无 |
+| 3.6 草稿自动保存 | FlowEditor.vue, router.py | 无 |
+
+---
+
+## 七、向后兼容性保障
+
+### 7.1 AgentConfig 向后兼容
+
+```python
+# AgentConfig 改造:保留 model 字段,新增可选字段
+class AgentConfig(Base):
+ # 现有字段保留不变
+ model = Column(String(50), default="gpt-4o-mini") # 保留!旧数据自动兼容
+
+ # 新增可选字段(nullable=True,旧记录自动为 NULL)
+ model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True)
+ embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True)
+```
+
+**兼容逻辑**:
+```python
+# engine.py 中 LLMNodeAgent 获取模型配置
+def _resolve_model_config(self, config: dict) -> dict:
+ model_instance_id = config.get("model_instance_id")
+ if model_instance_id:
+ # 新路径:从 ModelInstance 获取完整配置
+ return {
+ "base_url": provider.base_url,
+ "api_key": provider.api_key,
+ "model": model.model_name,
+ "params": model.default_params,
+ }
+ else:
+ # 旧路径:fallback 到 AgentConfig.model + 全局 settings
+ return {
+ "base_url": settings.LLM_API_BASE,
+ "api_key": settings.LLM_API_KEY,
+ "model": config.get("model", "gpt-4o-mini"),
+ "params": {},
+ }
+```
+
+### 7.2 记忆存储迁移兼容
+
+```python
+# MemoryManager 同时支持 Redis 和 PostgreSQL
+class MemoryManager:
+ def __init__(self, db_session_factory, redis: Redis):
+ self.db_session_factory = db_session_factory
+ self.redis = redis
+
+ async def inject_memory(self, user_id, flow_id, session_id, context):
+ # 优先从 Redis 缓存读取(兼容旧数据)
+ cached = await self._get_cached_messages(user_id, flow_id, session_id)
+ if cached:
+ recent_messages = cached
+ else:
+ # 缓存未命中,查 PostgreSQL
+ recent_messages = await self._query_recent_messages(user_id, flow_id, session_id)
+ if recent_messages:
+ # 回填 Redis 缓存
+ await self._cache_messages(user_id, flow_id, session_id, recent_messages)
+
+ # ...后续逻辑
+```
+
+### 7.3 数据库变更管理规范
+
+所有涉及表结构改动的变更,必须:
+1. 编写对应的 SQL 迁移文件到 `init-db/` 目录
+2. SQL 文件使用 `IF NOT EXISTS` / `IF NOT EXISTS` 保证幂等
+3. 新增列必须设置 `DEFAULT` 值或 `NULLABLE`,确保旧数据兼容
+4. 迁移文件按序号命名:`01-init.sql` → `02-add-published-cols.sql` → `03-memory-tables.sql` → `04-model-provider.sql`
+
+---
+
+## 八、风险与注意事项
+
+1. **pgvector 扩展安装**:需确认 Docker 镜像中的 PostgreSQL 是否包含 pgvector 扩展。如未包含,需在 Dockerfile 中添加 `RUN apt-get install postgresql-15-pgvector` 或使用 `pgvector/pgvector:pg15` 镜像
+2. **Redis → PG 数据迁移**:迁移期间需短暂停写,建议在低峰期执行。迁移脚本需处理 Redis 中 JSON 序列化格式与 PG 表结构的映射
+3. **LLM 调用成本**:L1 提取和去重都需要额外 LLM 调用,建议使用低成本模型(如 gpt-4o-mini)并控制触发频率
+4. **模型供应商 API Key 安全**:必须加密存储,不能明文写入数据库。建议使用 `cryptography.fernet` 对称加密
+5. **向后兼容**:AgentConfig.model 字段保留,model_instance_id 为可选(nullable=True),旧数据自动兼容,无需迁移
+6. **节点类型注册**:新增节点需同时更新前端 nodeTypes 数组和后端 _create_agent() 分支,缺一不可
+7. **流类型不可变**:flow_mode 在创建时确定,后续不可更改(Chatflow 和 Workflow 的引擎逻辑差异大)
+8. **PostgreSQL 连接池**:记忆表查询复用现有连接池(pool_size=20),需监控连接数是否够用,必要时调整
+9. **Redis 缓存一致性**:PG 写入成功但 Redis 缓存更新失败时,下次读取会缓存未命中走 PG 查询,自动修复,无需额外处理
diff --git a/backend/main.py b/backend/main.py
index 723d341..1edb916 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -19,22 +19,23 @@ from modules.rag.router import router as rag_router
from modules.chat.router import router as chat_router
from modules.custom_tool.router import router as custom_tool_router
from modules.memory.router import router as memory_router
-from modules.memory.manager import init_memory_manager
+from modules.memory.manager import init_memory_manager, get_memory_manager
+from modules.model_provider.router import router as model_provider_router
from websocket_manager import ws_manager
from middleware.rbac_middleware import rbac_middleware
from middleware.rate_limiter import rate_limit_middleware
from middleware.cache_manager import cache_manager
+from database import AsyncSessionLocal
@asynccontextmanager
async def lifespan(app: AgentApp):
await init_db()
await cache_manager.connect()
- await init_memory_manager()
+ await init_memory_manager(AsyncSessionLocal)
yield
await cache_manager.disconnect()
try:
- from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
await mm.redis.close()
except Exception:
@@ -70,4 +71,5 @@ app.include_router(system_router)
app.include_router(rag_router)
app.include_router(chat_router)
app.include_router(custom_tool_router)
-app.include_router(memory_router)
\ No newline at end of file
+app.include_router(memory_router)
+app.include_router(model_provider_router)
\ No newline at end of file
diff --git a/backend/models/__init__.py b/backend/models/__init__.py
index dffc3d7..6ff3fea 100644
--- a/backend/models/__init__.py
+++ b/backend/models/__init__.py
@@ -141,6 +141,7 @@ class FlowDefinition(Base):
published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True)
draft_definition_json = Column(JSON, nullable=True, default=None)
creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
+ flow_mode = Column(String(20), default="chatflow")
published_to_wecom = Column(Boolean, default=False)
published_to_web = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
@@ -176,6 +177,23 @@ class FlowApiKey(Base):
created_at = Column(DateTime, default=datetime.utcnow)
+class FlowTemplate(Base):
+ __tablename__ = "flow_templates"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ name = Column(String(200), nullable=False)
+ description = Column(Text, default="")
+ category = Column(String(50), default="")
+ definition_json = Column(JSON, nullable=False, default=dict)
+ icon = Column(String(50), default="")
+ sort_order = Column(Integer, default=0)
+ is_builtin = Column(Boolean, default=False)
+ usage_count = Column(Integer, default=0)
+ created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
+ created_at = Column(DateTime, default=datetime.utcnow)
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+
+
class CustomTool(Base):
__tablename__ = "custom_tools"
@@ -261,6 +279,8 @@ class AgentConfig(Base):
description = Column(String(500))
system_prompt = Column(Text, default="")
model = Column(String(50), default="gpt-4o-mini")
+ model_instance_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
+ embedding_model_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
temperature = Column(Float, default=0.7)
tools = Column(JSON, default=list)
status = Column(String(20), default="active")
@@ -279,4 +299,97 @@ class AuditLog(Base):
resource_id = Column(String(100))
detail = Column(JSON, default=dict)
ip_address = Column(String(50))
+ created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class MemoryMessage(Base):
+ __tablename__ = "memory_messages"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
+ flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
+ session_id = Column(UUID(as_uuid=True), nullable=False)
+ role = Column(String(20), nullable=False)
+ content = Column(Text, nullable=False)
+ created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class MemoryAtom(Base):
+ __tablename__ = "memory_atoms"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
+ flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
+ atom_type = Column(String(20), nullable=False)
+ content = Column(Text, nullable=False)
+ priority = Column(Integer, default=50)
+ source_session_id = Column(UUID(as_uuid=True), nullable=True)
+ metadata = Column(JSON, default=dict)
+ created_at = Column(DateTime, default=datetime.utcnow)
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+
+
+class MemoryScene(Base):
+ __tablename__ = "memory_scenes"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
+ flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
+ scene_name = Column(String(200), nullable=False)
+ summary = Column(Text, nullable=False)
+ heat = Column(Integer, default=0)
+ content = Column(JSON, default=dict)
+ created_at = Column(DateTime, default=datetime.utcnow)
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+
+
+class MemoryPersona(Base):
+ __tablename__ = "memory_personas"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True)
+ content = Column(JSON, default=dict, nullable=False)
+ raw_text = Column(Text, default="")
+ version = Column(Integer, default=1)
+ updated_at = Column(DateTime, default=datetime.utcnow)
+
+
+class MemorySession(Base):
+ __tablename__ = "memory_sessions"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
+ flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
+ session_id = Column(UUID(as_uuid=True), nullable=False)
+ flow_name = Column(String(200), default="")
+ message_count = Column(Integer, default=0)
+ last_active_at = Column(DateTime, default=datetime.utcnow)
+ created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class ModelProvider(Base):
+ __tablename__ = "model_providers"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ name = Column(String(100), nullable=False)
+ provider_type = Column(String(50), nullable=False)
+ base_url = Column(String(500))
+ api_key = Column(Text)
+ extra_config = Column(JSON, default=dict)
+ is_active = Column(Boolean, default=True)
+ created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class ModelInstance(Base):
+ __tablename__ = "model_instances"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+ provider_id = Column(UUID(as_uuid=True), ForeignKey("model_providers.id", ondelete="CASCADE"))
+ model_name = Column(String(100), nullable=False)
+ model_type = Column(String(30), nullable=False)
+ display_name = Column(String(200))
+ capabilities = Column(JSON, default=dict)
+ default_params = Column(JSON, default=dict)
+ is_default = Column(Boolean, default=False)
+ is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
\ No newline at end of file
diff --git a/backend/modules/flow_engine/engine.py b/backend/modules/flow_engine/engine.py
index 2106f25..bda58d7 100644
--- a/backend/modules/flow_engine/engine.py
+++ b/backend/modules/flow_engine/engine.py
@@ -3,6 +3,7 @@ import uuid
import logging
import re
import asyncio
+import httpx
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
@@ -12,6 +13,35 @@ from config import settings
logger = logging.getLogger(__name__)
+async def _resolve_model_instance(model_instance_id: str) -> dict | None:
+ try:
+ from database import AsyncSessionLocal
+ from sqlalchemy import text
+ import uuid as _uuid
+ uid = _uuid.UUID(model_instance_id)
+ async with AsyncSessionLocal() as db:
+ result = await db.execute(
+ text("""
+ SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key
+ FROM model_instances mi
+ JOIN model_providers mp ON mi.provider_id = mp.id
+ WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true
+ """),
+ {"id": uid},
+ )
+ row = result.fetchone()
+ if row:
+ return {
+ "model": row[0],
+ "base_url": row[2],
+ "api_key": row[3],
+ "params": row[1] or {},
+ }
+ except Exception:
+ pass
+ return None
+
+
class FlowSessionMemory:
def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id
@@ -241,6 +271,15 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
+ model_instance_id = config.get("model_instance_id")
+ base_url = settings.LLM_API_BASE
+ api_key = settings.LLM_API_KEY
+ if model_instance_id:
+ resolved = await _resolve_model_instance(model_instance_id)
+ if resolved:
+ model_config = resolved.get("model", model_config)
+ base_url = resolved.get("base_url", base_url)
+ api_key = resolved.get("api_key", api_key)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
max_tokens = config.get("max_tokens", 2000)
@@ -254,6 +293,8 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
max_tokens=max_tokens,
stream=stream,
stream_callback=stream_cb,
+ base_url=base_url,
+ api_key=api_key,
)
memory = context.get("_memory")
if memory:
@@ -297,6 +338,12 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
elif node_type == "rag":
+ model_instance_id = config.get("model_instance_id")
+ if model_instance_id:
+ resolved = await _resolve_model_instance(model_instance_id)
+ if resolved:
+ config = dict(config)
+ config["_resolved_model"] = resolved
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
@@ -314,6 +361,24 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase:
sandbox = config.get("sandbox", True)
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
+ elif node_type == "http_request":
+ return HttpRequestNodeAgent(node_id=node_id, config=config)
+
+ elif node_type == "question_classifier":
+ return QuestionClassifierNodeAgent(node_id=node_id, config=config)
+
+ elif node_type == "variable_assigner":
+ return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context)
+
+ elif node_type == "template_transform":
+ return TemplateTransformNodeAgent(node_id=node_id, config=config)
+
+ elif node_type == "iteration":
+ return IterationNodeAgent(node_id=node_id, config=config)
+
+ elif node_type == "question_optimiser":
+ return QuestionOptimiserNodeAgent(node_id=node_id, config=config)
+
else:
return PassThroughAgent(node_id)
@@ -330,8 +395,178 @@ class PassThroughAgent(AgentBase):
pass
+class HttpRequestNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None):
+ super().__init__()
+ self.name = f"HttpRequest_{node_id}"
+ self.config = config or {}
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ method = self.config.get("method", "GET").upper()
+ url = self.config.get("url", "")
+ headers = self.config.get("headers", {})
+ body = self.config.get("body", "")
+ auth_type = self.config.get("auth_type", "none")
+ auth_config = self.config.get("auth_config", {})
+ timeout = self.config.get("timeout", 30)
+ retry_count = self.config.get("retry_count", 0)
+
+ if isinstance(headers, str):
+ try:
+ headers = json.loads(headers)
+ except (json.JSONDecodeError, ValueError):
+ headers = {}
+
+ request_headers = dict(headers)
+ if auth_type == "bearer" and auth_config.get("token"):
+ request_headers["Authorization"] = f"Bearer {auth_config['token']}"
+ elif auth_type == "api_key" and auth_config.get("api_key"):
+ key_name = auth_config.get("key_name", "X-API-Key")
+ request_headers[key_name] = auth_config["api_key"]
+ elif auth_type == "basic" and auth_config.get("username"):
+ import base64
+ credentials = f"{auth_config['username']}:{auth_config.get('password', '')}"
+ request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}"
+
+ last_error = None
+ for attempt in range(retry_count + 1):
+ try:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
+ request_kwargs = {"headers": request_headers}
+ if method in ("POST", "PUT", "PATCH"):
+ if isinstance(body, dict):
+ request_kwargs["json"] = body
+ elif body and isinstance(body, str):
+ try:
+ request_kwargs["json"] = json.loads(body)
+ except (json.JSONDecodeError, ValueError):
+ request_kwargs["content"] = body
+ response = await client.request(method, url, **request_kwargs)
+ resp_status = response.status_code
+ resp_text = response.text
+ try:
+ resp_body = json.loads(resp_text)
+ except (json.JSONDecodeError, ValueError):
+ resp_body = resp_text
+
+ result = json.dumps({
+ "status_code": resp_status,
+ "body": resp_body,
+ }, ensure_ascii=False)
+ return Msg(self.name, result, "assistant")
+ except Exception as e:
+ last_error = str(e)
+ if attempt < retry_count:
+ await asyncio.sleep(1)
+
+ error_result = json.dumps({
+ "status_code": 0,
+ "error": last_error or "unknown error",
+ "body": None,
+ }, ensure_ascii=False)
+ return Msg(self.name, error_result, "assistant")
+
+ async def observe(self, msg) -> None:
+ pass
+
+
+class QuestionClassifierNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None):
+ super().__init__()
+ self.name = f"Classifier_{node_id}"
+ self.config = config or {}
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ categories = self.config.get("categories", [])
+ instruction = self.config.get("instruction", "")
+ model_name = self.config.get("model", settings.LLM_MODEL)
+ temperature = self.config.get("temperature", 0.3)
+
+ user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
+
+ if not categories:
+ return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant")
+
+ category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories])
+ prompt = f"""请对以下用户输入进行意图分类。
+
+分类选项:
+{category_desc}
+
+{instruction}
+
+用户输入:{user_text}
+
+请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}"""
+
+ try:
+ import openai
+ client = openai.AsyncOpenAI(
+ base_url=settings.LLM_API_BASE,
+ api_key=settings.LLM_API_KEY,
+ )
+ response = await client.chat.completions.create(
+ model=model_name,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=temperature,
+ max_tokens=200,
+ )
+ result_text = response.choices[0].message.content.strip()
+ try:
+ parsed = json.loads(result_text)
+ return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
+ except (json.JSONDecodeError, ValueError):
+ return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant")
+ except Exception as e:
+ logger.error(f"QuestionClassifier error: {e}")
+ return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant")
+
+ async def observe(self, msg) -> None:
+ pass
+
+
+class VariableAssignerNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None, context: dict = None):
+ super().__init__()
+ self.name = f"VarAssign_{node_id}"
+ self.config = config or {}
+ self._context = context or {}
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ assignments = self.config.get("assignments", [])
+ results = {}
+ for assignment in assignments:
+ target_var = assignment.get("target_var", "")
+ source_type = assignment.get("source_type", "constant")
+ source_value = assignment.get("source_value", "")
+
+ if source_type == "constant":
+ value = source_value
+ elif source_type == "upstream_output":
+ value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
+ elif source_type == "template":
+ value = _resolve_template(source_value, self._context, msg)
+ elif source_type == "expression":
+ try:
+ safe_locals = {"msg": msg, "context": self._context}
+ value = eval(source_value, {"__builtins__": {}}, safe_locals)
+ value = str(value)
+ except Exception as e:
+ value = f"[expression error: {e}]"
+ else:
+ value = source_value
+
+ results[target_var] = value
+
+ output = json.dumps(results, ensure_ascii=False)
+ return Msg(self.name, output, "assistant")
+
+ async def observe(self, msg) -> None:
+ pass
+
+
class LLMNodeAgent(AgentBase):
- def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None):
+ def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
@@ -340,6 +575,8 @@ class LLMNodeAgent(AgentBase):
self.max_tokens = max_tokens
self.stream = stream
self.stream_callback = stream_callback
+ self.base_url = base_url or settings.LLM_API_BASE
+ self.api_key = api_key or settings.LLM_API_KEY
self._memory = None
def set_memory(self, memory):
@@ -390,29 +627,40 @@ class LLMNodeAgent(AgentBase):
return Msg(self.name, res_text, "assistant")
async def _blocking_llm_call(self, messages: list[dict]) -> str:
- from agentscope_integration.factory import AgentFactory
- from agentscope.formatter import OpenAIChatFormatter
+ import httpx
- model = AgentFactory._get_model()
- formatter = OpenAIChatFormatter()
- scope_msgs = []
- for m in messages:
- scope_msgs.append(Msg(m["role"], m["content"], m["role"]))
- prompt = formatter.format(scope_msgs)
+ api_base = self.base_url.rstrip("/")
+ api_key = self.api_key
+ model_name = self.model_name or settings.LLM_MODEL
+
+ url = f"{api_base}/chat/completions"
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ }
+ body = {
+ "model": model_name,
+ "messages": messages,
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "stream": False,
+ }
- res = await model(prompt)
- if isinstance(res, list):
- return res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
- elif hasattr(res, 'get_text_content'):
- return res.get_text_content()
- return str(res)
+ try:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
+ resp = await client.post(url, json=body, headers=headers)
+ data = resp.json()
+ return data.get("choices", [{}])[0].get("message", {}).get("content", "")
+ except Exception as e:
+ logger.warning(f"LLM 阻塞调用失败: {e}")
+ return f"[LLM 调用失败: {e}]"
async def _stream_llm_call(self, messages: list[dict]) -> str:
import httpx
import json
- api_base = settings.LLM_API_BASE.rstrip("/")
- api_key = settings.LLM_API_KEY
+ api_base = self.base_url.rstrip("/")
+ api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
@@ -977,11 +1225,12 @@ class RAGNodeAgent(AgentBase):
def _get_model(self):
from agentscope.model import OpenAIChatModel
+ resolved = self.config.get("_resolved_model", {})
return OpenAIChatModel(
config_name=f"rag_{self.name}",
- model_name=settings.LLM_MODEL,
- api_key=settings.LLM_API_KEY,
- api_base=settings.LLM_API_BASE,
+ model_name=resolved.get("model", settings.LLM_MODEL),
+ api_key=resolved.get("api_key", settings.LLM_API_KEY),
+ api_base=resolved.get("base_url", settings.LLM_API_BASE),
)
def _get_formatter(self):
@@ -1050,6 +1299,164 @@ def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
return result
+class TemplateTransformNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None):
+ super().__init__()
+ self.name = f"Template_{node_id}"
+ self.config = config or {}
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ template = self.config.get("template", "")
+ output_type = self.config.get("output_type", "string")
+ context = kwargs.get("context", {})
+
+ user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
+
+ try:
+ rendered = _resolve_template(template, context, msg)
+ if not rendered:
+ rendered = template
+ rendered = rendered.replace("{{input}}", user_text)
+ except Exception:
+ rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text
+
+ if output_type == "json":
+ try:
+ parsed = json.loads(rendered)
+ return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ return Msg(self.name, rendered, "assistant")
+
+ async def observe(self, msg) -> None:
+ pass
+
+
+class IterationNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None):
+ super().__init__()
+ self.name = f"Iteration_{node_id}"
+ self.config = config or {}
+ self._results: list[str] = []
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ input_array_source = self.config.get("input_array_source", "")
+ max_iterations = self.config.get("max_iterations", 20)
+ context = kwargs.get("context", {})
+
+ items = []
+ if input_array_source:
+ resolved = _resolve_template(input_array_source, context, msg)
+ try:
+ items = json.loads(resolved)
+ if not isinstance(items, list):
+ items = [resolved]
+ except (json.JSONDecodeError, ValueError):
+ items = [resolved] if resolved else []
+ else:
+ user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
+ try:
+ items = json.loads(user_text)
+ if not isinstance(items, list):
+ items = [user_text]
+ except (json.JSONDecodeError, ValueError):
+ items = [item.strip() for item in user_text.split("\n") if item.strip()]
+ if not items:
+ items = [user_text]
+
+ items = items[:max_iterations]
+ results = []
+ for i, item in enumerate(items):
+ results.append({
+ "index": i,
+ "item": item,
+ })
+
+ output = json.dumps(results, ensure_ascii=False)
+ return Msg(self.name, output, "assistant")
+
+ def get_iteration_items(self) -> list:
+ return self._results
+
+ async def observe(self, msg) -> None:
+ pass
+
+
+class QuestionOptimiserNodeAgent(AgentBase):
+ def __init__(self, node_id: str, config: dict = None):
+ super().__init__()
+ self.name = f"QOpt_{node_id}"
+ self.config = config or {}
+
+ async def reply(self, msg: Msg, **kwargs) -> Msg:
+ optimization_type = self.config.get("optimization_type", "rewrite")
+ model_name = self.config.get("model", settings.LLM_MODEL)
+ context = kwargs.get("context", {})
+
+ user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
+ history = context.get("_memory", {}).get("recent_messages", [])
+ persona = context.get("_memory", {}).get("persona", {})
+ atoms = context.get("_memory", {}).get("atoms", [])
+
+ persona_text = persona.get("raw_text", "")
+ atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else ""
+ history_text = "\n".join(
+ f"{m['role']}: {m['content'][:200]}" for m in history[-6:]
+ ) if history else ""
+
+ context_parts = []
+ if persona_text:
+ context_parts.append(f"用户画像: {persona_text}")
+ if atoms_text:
+ context_parts.append(f"已知信息: {atoms_text}")
+ if history_text:
+ context_parts.append(f"近期对话: {history_text}")
+
+ context_block = "\n".join(context_parts) if context_parts else ""
+
+ if optimization_type == "rewrite":
+ prompt = f"""{context_block}
+
+原始问题: {user_text}
+
+请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。
+只返回优化后的问题,不要其他内容。"""
+ elif optimization_type == "expand":
+ prompt = f"""{context_block}
+
+简短问题: {user_text}
+
+请将以上问题扩展为更详细的版本,添加必要的背景和细节。
+只返回扩展后的问题,不要其他内容。"""
+ else:
+ return Msg(self.name, user_text, "assistant")
+
+ try:
+ import httpx
+ api_base = settings.LLM_API_BASE.rstrip("/")
+ async with httpx.AsyncClient(timeout=30) as client:
+ resp = await client.post(
+ f"{api_base}/chat/completions",
+ json={
+ "model": model_name,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 300,
+ "temperature": 0.3,
+ },
+ headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
+ )
+ data = resp.json()
+ result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text)
+ return Msg(self.name, result.strip(), "assistant")
+ except Exception as e:
+ logger.warning(f"问题优化失败: {e}")
+ return Msg(self.name, user_text, "assistant")
+
+ async def observe(self, msg) -> None:
+ pass
+
+
class ParallelMergeNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
diff --git a/backend/modules/flow_engine/gateway.py b/backend/modules/flow_engine/gateway.py
index 4554e7b..a895407 100644
--- a/backend/modules/flow_engine/gateway.py
+++ b/backend/modules/flow_engine/gateway.py
@@ -80,12 +80,12 @@ async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)):
username = user
if response_mode == "streaming":
- return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db)
+ return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db, session_id)
- return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db)
+ return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db, session_id)
-async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db):
+async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db, session_id=None):
engine = FlowEngine(definition)
input_msg = Msg(name="user", content=input_text, role="user")
context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())}
@@ -127,7 +127,7 @@ async def _chat_blocking(flow_id, definition, input_text, user_id, username, flo
raise HTTPException(500, f"流执行失败: {str(e)}")
-async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db):
+async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db, session_id=None):
async def event_generator():
import asyncio
engine = FlowEngine(definition)
diff --git a/backend/modules/flow_engine/router.py b/backend/modules/flow_engine/router.py
index c45956e..1ea6867 100644
--- a/backend/modules/flow_engine/router.py
+++ b/backend/modules/flow_engine/router.py
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
-from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User
+from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User, MemoryMessage, FlowTemplate
from schemas import (
FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut,
FlowVersionOut, FlowApiKeyCreate, FlowApiKeyOut,
@@ -34,6 +34,7 @@ def _build_flow_out(f) -> FlowDefinitionOut:
published_version_id=getattr(f, 'published_version_id', None),
published_to_wecom=getattr(f, 'published_to_wecom', False),
published_to_web=getattr(f, 'published_to_web', False),
+ flow_mode=getattr(f, 'flow_mode', 'chatflow') or 'chatflow',
created_at=f.created_at, updated_at=f.updated_at,
)
@@ -71,6 +72,7 @@ async def create_flow(req: FlowDefinitionCreate, request: Request, db: AsyncSess
definition_json=definition_json,
draft_definition_json=definition_json,
creator_id=uuid.UUID(user_ctx["id"]),
+ flow_mode=req.flow_mode,
)
db.add(flow)
await db.flush()
@@ -133,6 +135,43 @@ async def _snapshot_publish(flow: FlowDefinition, db: AsyncSession, user_id: str
return new_version
+def validate_flow_definition(definition: dict, flow_mode: str = "chatflow") -> list[str]:
+ errors = []
+ nodes = definition.get("nodes", [])
+ edges = definition.get("edges", [])
+
+ if not nodes:
+ errors.append("流定义中没有节点")
+ return errors
+
+ if not any(n.get("type") == "trigger" for n in nodes):
+ errors.append("缺少触发/起始节点 (trigger)")
+
+ if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes):
+ errors.append("对话型流必须包含至少一个 LLM 节点")
+
+ node_ids = {n["id"] for n in nodes}
+ connected_ids = set()
+ for e in edges:
+ connected_ids.add(e.get("source", ""))
+ connected_ids.add(e.get("target", ""))
+ for n in nodes:
+ if n["id"] not in connected_ids and len(nodes) > 1:
+ errors.append(f"节点 '{n.get('label', n['id'])}' 未连接")
+
+ return errors
+
+
+@router.post("/definitions/{flow_id}/validate")
+async def validate_flow(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
+ f = await db.get(FlowDefinition, flow_id)
+ if not f:
+ raise HTTPException(404, "流定义不存在")
+ flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
+ errors = validate_flow_definition(f.definition_json, flow_mode)
+ return {"code": 200, "valid": len(errors) == 0, "errors": errors}
+
+
@router.post("/definitions/{flow_id}/publish")
async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
f = await db.get(FlowDefinition, flow_id)
@@ -141,8 +180,18 @@ async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession =
nodes = f.definition_json.get("nodes", [])
if not nodes:
raise HTTPException(400, "流定义中没有节点")
+ flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
+ errors = validate_flow_definition(f.definition_json, flow_mode)
+ if errors:
+ raise HTTPException(400, f"流校验失败: {'; '.join(errors)}")
user_ctx = request.state.user
- await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True)
+ body = {}
+ try:
+ body = await request.json()
+ except Exception:
+ pass
+ changelog = body.get("changelog", "")
+ await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True, changelog=changelog)
return {"code": 200, "message": "流已上架到企微", "data": {"status": "published", "version": f.version}}
@@ -261,12 +310,14 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db:
try:
from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
- await mm.inject_memory(
- user_id=user_ctx["id"],
- flow_id=str(flow_id),
- session_id=session_id,
- context=context,
- )
+ flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
+ if flow_mode == 'chatflow':
+ await mm.inject_memory(
+ user_id=user_ctx["id"],
+ flow_id=str(flow_id),
+ session_id=session_id,
+ context=context,
+ )
except Exception as e:
logger.debug(f"记忆注入跳过: {e}")
@@ -279,14 +330,15 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db:
try:
from modules.memory.manager import get_memory_manager
mm = get_memory_manager()
- asyncio.create_task(mm.record_exchange(
- user_id=user_ctx["id"],
- flow_id=str(flow_id),
- session_id=session_id,
- user_msg=input_text,
- assistant_msg=output_text,
- flow_name=f.name,
- ))
+ if flow_mode == 'chatflow':
+ asyncio.create_task(mm.record_exchange(
+ user_id=user_ctx["id"],
+ flow_id=str(flow_id),
+ session_id=session_id,
+ user_msg=input_text,
+ assistant_msg=output_text,
+ flow_name=f.name,
+ ))
except Exception as e:
logger.debug(f"记忆记录跳过: {e}")
@@ -383,6 +435,8 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
except Exception:
mm_stream = None
+ flow_mode_stream = getattr(f, 'flow_mode', 'chatflow') or 'chatflow'
+
async def event_generator():
engine = FlowEngine(definition)
context = {
@@ -394,7 +448,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
"_stream_callback": None,
}
- if mm_stream:
+ if mm_stream and flow_mode_stream == 'chatflow':
try:
await mm_stream.inject_memory(
user_id=user_ctx["id"],
@@ -476,7 +530,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
yield f"data: {json.dumps({'event': 'workflow_finished', 'data': {'output': output_text, 'session_id': session_id, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}, 'latency_ms': elapsed_ms}}, ensure_ascii=False)}\n\n"
- if mm_stream:
+ if mm_stream and flow_mode_stream == 'chatflow':
try:
asyncio.create_task(mm_stream.record_exchange(
user_id=user_ctx["id"],
@@ -501,8 +555,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes
finished_at=datetime.utcnow(),
)
db.add(execution)
- finally:
- yield "data: [DONE]\n\n"
+ yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -615,86 +668,6 @@ async def list_executions(
# ============================== 模板 ==============================
-FLOW_TEMPLATES = [
- {
- "id": "tpl_doc_process", "name": "文档处理流", "description": "自动解析文档内容,提取关键信息并生成摘要", "icon": "Document",
- "nodes": [
- {"id": "n1", "type": "trigger", "label": "文档上传", "config": {"event_type": "document_upload"}, "position": {"x": 100, "y": 100}},
- {"id": "n2", "type": "tool", "label": "解析文档", "config": {"tool_name": "parse_document"}, "position": {"x": 400, "y": 100}},
- {"id": "n3", "type": "llm", "label": "生成摘要", "config": {"system_prompt": "请为以下文档内容生成简洁摘要", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 700, "y": 100}},
- {"id": "n4", "type": "output", "label": "输出结果", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}},
- ],
- "edges": [
- {"source": "n1", "target": "n2", "sourceHandle": "source"},
- {"source": "n2", "target": "n3", "sourceHandle": "source"},
- {"source": "n3", "target": "n4", "sourceHandle": "source"},
- ],
- },
- {
- "id": "tpl_wecom_notify", "name": "企微通知流", "description": "接收触发后查询数据并推送企微通知", "icon": "Bell",
- "nodes": [
- {"id": "n1", "type": "trigger", "label": "定时触发", "config": {"event_type": "scheduled"}, "position": {"x": 100, "y": 100}},
- {"id": "n2", "type": "tool", "label": "查询任务", "config": {"tool_name": "list_tasks"}, "position": {"x": 400, "y": 100}},
- {"id": "n3", "type": "condition", "label": "有待办任务?", "config": {"condition": "tasks.length > 0"}, "position": {"x": 700, "y": 100}},
- {"id": "n4", "type": "wecom_notify", "label": "推送通知", "config": {"message_template": "您有{{tasks.length}}条待办任务", "target": "@all"}, "position": {"x": 1000, "y": 50}},
- {"id": "n5", "type": "output", "label": "无任务", "config": {"format": "text"}, "position": {"x": 1000, "y": 200}},
- ],
- "edges": [
- {"source": "n1", "target": "n2", "sourceHandle": "source"},
- {"source": "n2", "target": "n3", "sourceHandle": "source"},
- {"source": "n3", "target": "n4", "sourceHandle": "true"},
- {"source": "n3", "target": "n5", "sourceHandle": "false"},
- ],
- },
- {
- "id": "tpl_data_analysis", "name": "数据分析流", "description": "查询员工数据并生成效率分析报告", "icon": "DataAnalysis",
- "nodes": [
- {"id": "n1", "type": "trigger", "label": "分析请求", "config": {"event_type": "button_click"}, "position": {"x": 100, "y": 100}},
- {"id": "n2", "type": "tool", "label": "查询下属", "config": {"tool_name": "list_subordinates"}, "position": {"x": 400, "y": 100}},
- {"id": "n3", "type": "tool", "label": "统计数据", "config": {"tool_name": "get_task_statistics"}, "position": {"x": 700, "y": 100}},
- {"id": "n4", "type": "llm", "label": "生成报告", "config": {"system_prompt": "基于以下数据生成团队效率分析报告", "model": "gpt-4o", "temperature": 0.7}, "position": {"x": 1000, "y": 100}},
- {"id": "n5", "type": "output", "label": "报告输出", "config": {"format": "json"}, "position": {"x": 1300, "y": 100}},
- ],
- "edges": [
- {"source": "n1", "target": "n2", "sourceHandle": "source"},
- {"source": "n2", "target": "n3", "sourceHandle": "source"},
- {"source": "n3", "target": "n4", "sourceHandle": "source"},
- {"source": "n4", "target": "n5", "sourceHandle": "source"},
- ],
- },
- {
- "id": "tpl_rag_qa", "name": "知识库问答流", "description": "从知识库检索信息后由LLM回答", "icon": "Search",
- "nodes": [
- {"id": "n1", "type": "trigger", "label": "问题触发", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}},
- {"id": "n2", "type": "rag", "label": "知识检索", "config": {"knowledge_base": "default", "top_k": 5}, "position": {"x": 400, "y": 100}},
- {"id": "n3", "type": "llm", "label": "生成回答", "config": {"system_prompt": "基于知识库检索结果回答用户问题", "model": "gpt-4o-mini", "temperature": 0.3}, "position": {"x": 700, "y": 100}},
- {"id": "n4", "type": "output", "label": "输出答案", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}},
- ],
- "edges": [
- {"source": "n1", "target": "n2", "sourceHandle": "source"},
- {"source": "n2", "target": "n3", "sourceHandle": "source"},
- {"source": "n3", "target": "n4", "sourceHandle": "source"},
- ],
- },
- {
- "id": "tpl_task_auto", "name": "任务自动分配流", "description": "根据描述自动创建任务并分派给合适人员", "icon": "Tools",
- "nodes": [
- {"id": "n1", "type": "trigger", "label": "任务描述", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}},
- {"id": "n2", "type": "llm", "label": "分析任务", "config": {"system_prompt": "分析以下任务描述,提取标题、优先级、负责人", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 400, "y": 100}},
- {"id": "n3", "type": "tool", "label": "创建任务", "config": {"tool_name": "create_task"}, "position": {"x": 700, "y": 100}},
- {"id": "n4", "type": "wecom_notify", "label": "通知负责人", "config": {"message_template": "您有新任务: {{task_title}}", "target": "@all"}, "position": {"x": 1000, "y": 100}},
- {"id": "n5", "type": "output", "label": "完成", "config": {"format": "text"}, "position": {"x": 1300, "y": 100}},
- ],
- "edges": [
- {"source": "n1", "target": "n2", "sourceHandle": "source"},
- {"source": "n2", "target": "n3", "sourceHandle": "source"},
- {"source": "n3", "target": "n4", "sourceHandle": "source"},
- {"source": "n4", "target": "n5", "sourceHandle": "source"},
- ],
- },
-]
-
-
@router.get("/market", response_model=list[FlowDefinitionOut])
async def flow_market(db: AsyncSession = Depends(get_db)):
result = await db.execute(
@@ -704,24 +677,80 @@ async def flow_market(db: AsyncSession = Depends(get_db)):
@router.get("/templates")
-async def get_flow_templates(request: Request):
- return {"code": 200, "data": FLOW_TEMPLATES}
+async def get_flow_templates(db: AsyncSession = Depends(get_db)):
+ result = await db.execute(
+ select(FlowTemplate).order_by(FlowTemplate.sort_order.asc(), FlowTemplate.created_at.desc())
+ )
+ templates = result.scalars().all()
+ return {"code": 200, "data": [
+ {
+ "id": str(t.id), "name": t.name, "description": t.description,
+ "category": t.category, "definition_json": t.definition_json,
+ "icon": t.icon, "is_builtin": t.is_builtin, "usage_count": t.usage_count,
+ }
+ for t in templates
+ ]}
+
+
+@router.post("/templates")
+async def create_flow_template(request: Request, db: AsyncSession = Depends(get_db)):
+ body = await request.json()
+ user_ctx = request.state.user
+ ft = FlowTemplate(
+ name=body.get("name", ""),
+ description=body.get("description", ""),
+ category=body.get("category", "general"),
+ definition_json=body.get("definition_json", {}),
+ icon=body.get("icon", ""),
+ sort_order=body.get("sort_order", 0),
+ created_by=uuid.UUID(user_ctx["id"]),
+ )
+ db.add(ft)
+ await db.commit()
+ await db.refresh(ft)
+ return {"code": 200, "data": {"id": str(ft.id), "name": ft.name}}
+
+
+@router.put("/templates/{template_id}")
+async def update_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
+ ft = await db.get(FlowTemplate, template_id)
+ if not ft:
+ raise HTTPException(404, "模板不存在")
+ body = await request.json()
+ for field in ("name", "description", "category", "definition_json", "icon", "sort_order"):
+ if field in body:
+ setattr(ft, field, body[field])
+ await db.commit()
+ return {"code": 200, "data": {"id": str(ft.id)}}
+
+
+@router.delete("/templates/{template_id}")
+async def delete_flow_template(template_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
+ ft = await db.get(FlowTemplate, template_id)
+ if not ft:
+ raise HTTPException(404, "模板不存在")
+ if ft.is_builtin:
+ raise HTTPException(400, "内置模板不可删除")
+ await db.delete(ft)
+ await db.commit()
+ return {"code": 200, "message": "模板已删除"}
@router.post("/templates/{template_id}/use")
-async def use_flow_template(template_id: str, request: Request, db: AsyncSession = Depends(get_db)):
- template = next((t for t in FLOW_TEMPLATES if t["id"] == template_id), None)
- if not template:
+async def use_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
+ ft = await db.get(FlowTemplate, template_id)
+ if not ft:
raise HTTPException(404, "模板不存在")
+ ft.usage_count = (ft.usage_count or 0) + 1
user_ctx = request.state.user
- definition_json = {"nodes": template["nodes"], "edges": template["edges"], "trigger": {}}
flow = FlowDefinition(
- name=template["name"] + " (副本)",
- description=template["description"],
- definition_json=definition_json,
- draft_definition_json=definition_json,
+ name=ft.name + " (副本)",
+ description=ft.description or "",
+ definition_json=ft.definition_json,
+ draft_definition_json=ft.definition_json,
creator_id=uuid.UUID(user_ctx["id"]),
)
db.add(flow)
await db.flush()
+ await db.commit()
return _build_flow_out(flow)
\ No newline at end of file
diff --git a/backend/modules/memory/__init__.py b/backend/modules/memory/__init__.py
index c0d3a31..0872066 100644
--- a/backend/modules/memory/__init__.py
+++ b/backend/modules/memory/__init__.py
@@ -1,4 +1,4 @@
-from .manager import MemoryManager, get_memory_manager
+from .manager import MemoryManager, get_memory_manager, init_memory_manager
from .router import router
-__all__ = ["MemoryManager", "get_memory_manager", "router"]
\ No newline at end of file
+__all__ = ["MemoryManager", "get_memory_manager", "init_memory_manager", "router"]
\ No newline at end of file
diff --git a/backend/modules/memory/manager.py b/backend/modules/memory/manager.py
index 7ae0298..e18669f 100644
--- a/backend/modules/memory/manager.py
+++ b/backend/modules/memory/manager.py
@@ -1,8 +1,13 @@
import json
import asyncio
+import uuid
import logging
-from datetime import datetime
+from datetime import datetime, timezone
+from typing import Callable
+
from redis.asyncio import Redis
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
logger = logging.getLogger(__name__)
@@ -17,21 +22,27 @@ def get_memory_manager() -> "MemoryManager":
return _memory_manager
-async def init_memory_manager():
+async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
global _memory_manager
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await redis.ping()
- _memory_manager = MemoryManager(redis)
+ _memory_manager = MemoryManager(db_factory, redis)
class MemoryManager:
- KEY_PREFIX = "mem"
- DEFAULT_TTL = 604800
- SESSION_INDEX_TTL = 2592000
MAX_HISTORY = 40
+ REDIS_CACHE_SIZE = 10
+ REDIS_CACHE_TTL = 300
+ SUMMARY_CACHE_KEY = "mem:summary"
+ MSG_CACHE_KEY = "mem:cache:msgs"
+ ATOM_EXTRACT_EVERY = 10
+ SCENE_EXTRACT_EVERY = 50
+ PERSONA_UPDATE_EVERY = 30
- def __init__(self, redis: Redis):
+ def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis):
+ self.db_factory = db_factory
self.redis = redis
+ self._extract_tasks: dict[str, asyncio.Task] = {}
async def inject_memory(
self,
@@ -40,12 +51,27 @@ class MemoryManager:
session_id: str,
context: dict,
):
- messages = await self._get_recent_messages(user_id, flow_id, session_id)
- summary = await self._get_summary(user_id, flow_id, session_id)
+ uid = uuid.UUID(user_id)
+ fid = uuid.UUID(flow_id)
+ sid = uuid.UUID(session_id)
+
+ cached = await self._redis_get_recent(uid, sid)
+ if cached:
+ recent_messages = cached
+ else:
+ recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY)
+ if recent_messages:
+ await self._redis_set_recent(uid, sid, recent_messages)
+
+ summary = await self._get_summary(uid, fid, sid)
+ atoms = await self._get_relevant_atoms(uid, fid)
+ persona = await self._get_persona(uid)
context["_memory_context"] = {
- "recent_messages": list(reversed(messages)),
+ "recent_messages": recent_messages,
"summary": summary,
+ "atoms": atoms,
+ "persona": persona,
"session_id": session_id,
}
@@ -58,115 +84,238 @@ class MemoryManager:
assistant_msg: str,
flow_name: str = "",
):
- key = self._msg_key(user_id, flow_id, session_id)
- ts = datetime.utcnow().isoformat()
+ uid = uuid.UUID(user_id)
+ fid = uuid.UUID(flow_id)
+ sid = uuid.UUID(session_id)
+ ts = datetime.now(timezone.utc)
try:
- async with self.redis.pipeline() as pipe:
- pipe.lpush(key,
- json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False),
- json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False),
+ async with self.db_factory() as db:
+ user_row = {
+ "id": uuid.uuid4(),
+ "user_id": uid,
+ "flow_id": fid,
+ "session_id": sid,
+ "role": "user",
+ "content": user_msg,
+ "created_at": ts,
+ }
+ asst_row = {
+ "id": uuid.uuid4(),
+ "user_id": uid,
+ "flow_id": fid,
+ "session_id": sid,
+ "role": "assistant",
+ "content": assistant_msg,
+ "created_at": ts,
+ }
+ await db.execute(
+ text("""
+ INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
+ VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
+ """),
+ user_row,
+ )
+ await db.execute(
+ text("""
+ INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
+ VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
+ """),
+ asst_row,
)
- pipe.ltrim(key, 0, self.MAX_HISTORY - 1)
- pipe.expire(key, self.DEFAULT_TTL)
- pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={
+ session_row = {
+ "id": uuid.uuid4(),
+ "user_id": uid,
+ "flow_id": fid,
+ "session_id": sid,
"flow_name": flow_name,
"last_active_at": ts,
- })
- pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL)
+ "created_at": ts,
+ }
+ await db.execute(
+ text("""
+ INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at)
+ VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at)
+ ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET
+ message_count = memory_sessions.message_count + 2,
+ last_active_at = EXCLUDED.last_active_at,
+ flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name)
+ """),
+ session_row,
+ )
- pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
- pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL)
+ await db.commit()
+ except Exception as e:
+ logger.warning(f"记录记忆失败(PG): {e}")
+ return
- await pipe.execute()
+ try:
+ new_msgs = [
+ {"role": "user", "content": user_msg, "ts": ts.isoformat()},
+ {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
+ ]
+ await self._redis_append_recent(uid, sid, new_msgs)
except Exception as e:
- logger.warning(f"记录记忆失败: {e}")
+ logger.debug(f"Redis缓存更新失败: {e}")
- asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id))
+ asyncio.create_task(self._maybe_summarize(uid, fid, sid))
+ asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid))
+ asyncio.create_task(self._maybe_extract_scenes(uid, fid))
+ asyncio.create_task(self._maybe_update_persona(uid, fid))
async def get_conversation_history(
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
) -> list[dict]:
- messages = await self._get_recent_messages(user_id, flow_id, session_id, limit)
- return list(reversed(messages))
+ uid = uuid.UUID(user_id)
+ sid = uuid.UUID(session_id)
+ fid = uuid.UUID(flow_id) if flow_id else None
+ return await self._pg_get_recent(uid, fid, sid, limit)
async def delete_session(self, user_id: str, session_id: str):
+ uid = uuid.UUID(user_id)
+ sid = uuid.UUID(session_id)
+
try:
- patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*")
- async with self.redis.pipeline() as pipe:
- if patterns:
- pipe.delete(*patterns)
- pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
- await pipe.execute()
+ async with self.db_factory() as db:
+ await db.execute(
+ text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
+ {"uid": uid, "sid": sid},
+ )
+ await db.execute(
+ text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"),
+ {"uid": uid, "sid": sid},
+ )
+ await db.commit()
except Exception as e:
- logger.warning(f"清除记忆失败: {e}")
+ logger.warning(f"清除记忆失败(PG): {e}")
+
+ try:
+ cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
+ await self.redis.delete(cache_key)
+ summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
+ await self.redis.delete(summary_key)
+ except Exception:
+ pass
async def list_user_sessions(self, user_id: str) -> list[dict]:
+ uid = uuid.UUID(user_id)
try:
- session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions")
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("""
+ SELECT session_id, flow_id, flow_name, last_active_at
+ FROM memory_sessions
+ WHERE user_id = :uid
+ ORDER BY last_active_at DESC
+ LIMIT 100
+ """),
+ {"uid": uid},
+ )
+ rows = result.fetchall()
+ return [
+ {
+ "session_id": str(row[0]),
+ "flow_id": str(row[1]),
+ "flow_name": row[2] or "",
+ "last_active_at": row[3].isoformat() if row[3] else "",
+ }
+ for row in rows
+ ]
except Exception:
return []
- sessions = []
- for sid in session_ids:
- try:
- keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta")
- for k in keys:
- meta = await self.redis.hgetall(k)
- parts = k.split(":")
- flow_id = parts[2] if len(parts) > 2 else ""
- sessions.append({
- "session_id": sid,
- "flow_id": flow_id,
- "flow_name": meta.get("flow_name", ""),
- "last_active_at": meta.get("last_active_at", ""),
- })
- except Exception:
- continue
+ async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]:
+ try:
+ async with self.db_factory() as db:
+ if fid:
+ result = await db.execute(
+ text("""
+ SELECT role, content, created_at
+ FROM memory_messages
+ WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid
+ ORDER BY created_at DESC
+ LIMIT :limit
+ """),
+ {"uid": uid, "fid": fid, "sid": sid, "limit": limit},
+ )
+ else:
+ result = await db.execute(
+ text("""
+ SELECT role, content, created_at
+ FROM memory_messages
+ WHERE user_id = :uid AND session_id = :sid
+ ORDER BY created_at DESC
+ LIMIT :limit
+ """),
+ {"uid": uid, "sid": sid, "limit": limit},
+ )
+ rows = result.fetchall()
+ return [
+ {"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""}
+ for row in reversed(rows)
+ ]
+ except Exception:
+ return []
- return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True)
+ async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None:
+ try:
+ cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
+ raw = await self.redis.get(cache_key)
+ if raw:
+ return json.loads(raw)
+ except Exception:
+ pass
+ return None
- async def _get_recent_messages(
- self, user_id: str, flow_id: str, session_id: str, limit: int = None
- ) -> list[dict]:
- limit = limit or self.MAX_HISTORY
- try:
- key = self._msg_key(user_id, flow_id, session_id)
- raw = await self.redis.lrange(key, 0, limit - 1)
- result = []
- for m in raw:
- try:
- result.append(json.loads(m))
- except json.JSONDecodeError:
- continue
- return result
+ async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]):
+ try:
+ cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
+ top = messages[-self.REDIS_CACHE_SIZE:]
+ await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False))
except Exception:
- return []
+ pass
- async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str:
+ async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]):
try:
- key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
- val = await self.redis.get(key)
+ cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
+ existing = await self._redis_get_recent(uid, sid) or []
+ combined = existing + new_msgs
+ recent = combined[-self.REDIS_CACHE_SIZE:]
+ await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False))
+ except Exception:
+ pass
+
+ async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str:
+ try:
+ summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
+ val = await self.redis.get(summary_key)
return val or ""
except Exception:
return ""
- async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str):
+ async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
try:
- key = self._msg_key(user_id, flow_id, session_id)
- count = await self.redis.llen(key)
- if count < 30:
- return
-
- summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
+ summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
existing = await self.redis.get(summary_key)
if existing:
return
- recent = await self._get_recent_messages(user_id, flow_id, session_id, 20)
+ count = 0
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
+ {"uid": uid, "sid": sid},
+ )
+ row = result.fetchone()
+ count = row[0] if row else 0
+
+ if count < 30:
+ return
+
+ recent = await self._pg_get_recent(uid, fid, sid, 20)
dialogue = "\n".join(
- f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10])
+ f"{m['role']}: {m['content'][:500]}" for m in recent[-10:]
)
import httpx
@@ -192,10 +341,511 @@ class MemoryManager:
except Exception:
pass
- @staticmethod
- def _msg_key(user_id: str, flow_id: str, session_id: str) -> str:
- return f"mem:{user_id}:{flow_id}:{session_id}:messages"
+ async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]:
+ try:
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("""
+ SELECT id, atom_type, content, priority, metadata
+ FROM memory_atoms
+ WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
+ ORDER BY priority DESC, updated_at DESC
+ LIMIT 20
+ """),
+ {"uid": uid, "fid": fid},
+ )
+ rows = result.fetchall()
+ return [
+ {"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]}
+ for row in rows
+ ]
+ except Exception:
+ return []
+
+ async def _get_persona(self, uid: uuid.UUID) -> dict:
+ try:
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"),
+ {"uid": uid},
+ )
+ row = result.fetchone()
+ if row:
+ return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1}
+ except Exception:
+ pass
+ return {}
+
+ async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
+ try:
+ task_key = f"extract_atoms:{uid}:{fid}"
+ if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
+ return
+
+ count = 0
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"),
+ {"uid": uid, "fid": fid, "sid": sid},
+ )
+ row = result.fetchone()
+ count = row[0] if row else 0
+
+ last_extract = await db.execute(
+ text("""
+ SELECT COUNT(*) FROM memory_atoms
+ WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str
+ """),
+ {"uid": uid, "fid": fid, "sid_str": str(sid)},
+ )
+ last_count = last_extract.fetchone()[0]
+ if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY:
+ return
+
+ recent = await self._pg_get_recent(uid, fid, sid, 20)
+ dialogue = "\n".join(
+ f"{m['role']}: {m['content'][:400]}" for m in recent[-15:]
+ )
+
+ task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue))
+ self._extract_tasks[task_key] = task
+ except Exception:
+ pass
+
+ async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str):
+ try:
+ prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。
+
+对话内容:
+{dialogue}
+
+请以JSON数组格式返回,每个原子包含以下字段:
+- atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求)
+- content: 原子的具体内容(一句话描述)
+- priority: 0-100的整数,表示重要性(越高越核心)
+
+格式示例:
+[
+ {{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}},
+ {{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}}
+]
+
+只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。"""
+
+ import httpx
+ api_base = settings.LLM_API_BASE.rstrip("/")
+ async with httpx.AsyncClient(timeout=60) as client:
+ resp = await client.post(
+ f"{api_base}/chat/completions",
+ json={
+ "model": settings.LLM_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 800,
+ "temperature": 0.3,
+ },
+ headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
+ )
+ data = resp.json()
+ result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
+
+ try:
+ atoms = json.loads(result_text)
+ if not isinstance(atoms, list):
+ return
+ except (json.JSONDecodeError, ValueError):
+ return
+
+ await self._dedup_and_store_atoms(uid, fid, sid, atoms)
+ except Exception as e:
+ logger.warning(f"L1原子提取失败: {e}")
+
+ async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]):
+ try:
+ async with self.db_factory() as db:
+ existing = await db.execute(
+ text("""
+ SELECT id, content FROM memory_atoms
+ WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
+ """),
+ {"uid": uid, "fid": fid},
+ )
+ existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()]
+
+ for atom in atoms:
+ content = atom.get("content", "").strip()
+ if not content:
+ continue
+
+ is_dup = False
+ for ex in existing_atoms:
+ if self._text_similarity(content, ex["content"]) > 0.75:
+ existing_atoms.remove(ex)
+ await db.execute(
+ text("""
+ UPDATE memory_atoms
+ SET priority = GREATEST(priority, :priority),
+ updated_at = NOW(),
+ metadata = metadata || :meta
+ WHERE id = :id
+ """),
+ {
+ "id": ex["id"],
+ "priority": atom.get("priority", 50),
+ "meta": json.dumps({"last_source_session_id": str(sid)}),
+ },
+ )
+ is_dup = True
+ break
+
+ if not is_dup:
+ await db.execute(
+ text("""
+ INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at)
+ VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW())
+ """),
+ {
+ "uid": uid,
+ "fid": fid,
+ "atype": atom.get("atom_type", "episodic"),
+ "content": content,
+ "priority": atom.get("priority", 50),
+ "sid": sid,
+ "meta": json.dumps({"source_session_id": str(sid)}),
+ },
+ )
+
+ await db.commit()
+ except Exception as e:
+ logger.warning(f"L1原子存储失败: {e}")
@staticmethod
- def _meta_key(user_id: str, flow_id: str, session_id: str) -> str:
- return f"mem:{user_id}:{flow_id}:{session_id}:meta"
\ No newline at end of file
+ def _text_similarity(a: str, b: str) -> float:
+ a_words = set(a.lower().split())
+ b_words = set(b.lower().split())
+ if not a_words or not b_words:
+ return 0.0
+ intersection = a_words & b_words
+ union = a_words | b_words
+ return len(intersection) / len(union)
+
+ async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID):
+ try:
+ task_key = f"extract_scenes:{uid}:{fid}"
+ if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
+ return
+
+ count = 0
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
+ {"uid": uid, "fid": fid},
+ )
+ row = result.fetchone()
+ count = row[0] if row else 0
+
+ scene_count_result = await db.execute(
+ text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
+ {"uid": uid, "fid": fid},
+ )
+ scene_count = scene_count_result.fetchone()[0]
+
+ if count < self.SCENE_EXTRACT_EVERY:
+ return
+
+ if scene_count > 0:
+ atoms_result = await db.execute(
+ text("""
+ SELECT updated_at FROM memory_scenes
+ WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
+ ORDER BY updated_at DESC LIMIT 1
+ """),
+ {"uid": uid, "fid": fid},
+ )
+ latest_scene = atoms_result.fetchone()
+ if latest_scene:
+ from datetime import timezone, timedelta
+ ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc)
+ if ago < timedelta(hours=12):
+ return
+
+ atoms_result = await db.execute(
+ text("""
+ SELECT atom_type, content, priority FROM memory_atoms
+ WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
+ ORDER BY priority DESC LIMIT 30
+ """),
+ {"uid": uid, "fid": fid},
+ )
+ atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()]
+
+ if len(atoms) < 10:
+ return
+
+ task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms))
+ self._extract_tasks[task_key] = task
+ except Exception:
+ pass
+
+ async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]):
+ try:
+ atoms_text = "\n".join(
+ f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms
+ )
+
+ prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。
+
+记忆原子:
+{atoms_text}
+
+请以JSON数组格式返回,每个场景块包含:
+- scene_name: 场景名称(简短标签)
+- summary: 场景摘要(2-3句话总结该场景的关键信息)
+
+格式示例:
+[
+ {{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}}
+]
+
+只返回JSON数组,不要其他内容。"""
+
+ import httpx
+ api_base = settings.LLM_API_BASE.rstrip("/")
+ async with httpx.AsyncClient(timeout=60) as client:
+ resp = await client.post(
+ f"{api_base}/chat/completions",
+ json={
+ "model": settings.LLM_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 500,
+ "temperature": 0.3,
+ },
+ headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
+ )
+ data = resp.json()
+ result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
+
+ try:
+ scenes = json.loads(result_text)
+ if not isinstance(scenes, list):
+ return
+ except (json.JSONDecodeError, ValueError):
+ return
+
+ async with self.db_factory() as db:
+ for scene in scenes:
+ await db.execute(
+ text("""
+ INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at)
+ VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW())
+ """),
+ {
+ "uid": uid,
+ "fid": fid,
+ "name": scene.get("scene_name", ""),
+ "summary": scene.get("summary", ""),
+ "content": json.dumps(scene, ensure_ascii=False),
+ },
+ )
+ await db.commit()
+ except Exception as e:
+ logger.warning(f"L2场景提取失败: {e}")
+
+ async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID):
+ try:
+ task_key = f"update_persona:{uid}"
+ if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
+ return
+
+ msg_count = 0
+ async with self.db_factory() as db:
+ result = await db.execute(
+ text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"),
+ {"uid": uid},
+ )
+ row = result.fetchone()
+ msg_count = row[0] if row else 0
+
+ persona_result = await db.execute(
+ text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"),
+ {"uid": uid},
+ )
+ persona_row = persona_result.fetchone()
+ if persona_row:
+ from datetime import timezone, timedelta
+ ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc)
+ if ago < timedelta(hours=6):
+ return
+
+ if msg_count < self.PERSONA_UPDATE_EVERY:
+ return
+
+ persona_atoms = []
+ async with self.db_factory() as db:
+ atoms_result = await db.execute(
+ text("""
+ SELECT content FROM memory_atoms
+ WHERE user_id = :uid AND atom_type = 'persona'
+ ORDER BY priority DESC, updated_at DESC LIMIT 30
+ """),
+ {"uid": uid},
+ )
+ persona_atoms = [r[0] for r in atoms_result.fetchall()]
+
+ persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息"
+
+ task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1))
+ self._extract_tasks[task_key] = task
+ except Exception:
+ pass
+
+ async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int):
+ try:
+ prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。
+
+用户信息:
+{persona_text}
+
+请返回一个JSON对象,包含以下字段:
+- name: 用户姓名(未知则"未知")
+- role: 角色/职位
+- preferences: 偏好/习惯列表
+- skills: 技能/专长列表
+- personality: 性格描述
+- raw_text: 综合画像的文本描述(2-3句话)
+
+只返回JSON对象,不要其他内容。"""
+
+ import httpx
+ api_base = settings.LLM_API_BASE.rstrip("/")
+ async with httpx.AsyncClient(timeout=60) as client:
+ resp = await client.post(
+ f"{api_base}/chat/completions",
+ json={
+ "model": settings.LLM_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 500,
+ "temperature": 0.3,
+ },
+ headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
+ )
+ data = resp.json()
+ result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}")
+
+ try:
+ persona = json.loads(result_text)
+ except (json.JSONDecodeError, ValueError):
+ persona = {"raw_text": result_text}
+
+ async with self.db_factory() as db:
+ existing = await db.execute(
+ text("SELECT id FROM memory_personas WHERE user_id = :uid"),
+ {"uid": uid},
+ )
+ if existing.fetchone():
+ await db.execute(
+ text("""
+ UPDATE memory_personas
+ SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW()
+ WHERE user_id = :uid
+ """),
+ {
+ "uid": uid,
+ "content": json.dumps(persona, ensure_ascii=False),
+ "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
+ "version": version,
+ },
+ )
+ else:
+ await db.execute(
+ text("""
+ INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at)
+ VALUES (:uid, :content, :raw_text, :version, NOW())
+ """),
+ {
+ "uid": uid,
+ "content": json.dumps(persona, ensure_ascii=False),
+ "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
+ "version": version,
+ },
+ )
+ await db.commit()
+ except Exception as e:
+ logger.warning(f"L3画像更新失败: {e}")
+
+ async def search_memory(
+ self,
+ uid: uuid.UUID,
+ query: str,
+ fid: uuid.UUID = None,
+ top_k: int = 10,
+ embedding: list[float] = None,
+ ) -> list[dict]:
+ results = []
+ try:
+ async with self.db_factory() as db:
+ if embedding and len(embedding) > 0:
+ emb_str = "[" + ",".join(str(v) for v in embedding) + "]"
+ vector_results = await db.execute(
+ text(f"""
+ SELECT id, atom_type, content, priority,
+ (1.0 - (embedding <=> :emb::vector)) AS similarity
+ FROM memory_atoms
+ WHERE user_id = :uid
+ {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
+ AND embedding IS NOT NULL
+ ORDER BY embedding <=> :emb::vector
+ LIMIT :limit
+ """),
+ {"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid
+ else {"uid": uid, "emb": emb_str, "limit": top_k * 2},
+ )
+ for row in vector_results.fetchall():
+ results.append({
+ "id": str(row[0]), "type": row[1], "content": row[2],
+ "priority": row[3], "vector_score": float(row[4]),
+ })
+
+ text_results = await db.execute(
+ text(f"""
+ SELECT id, atom_type, content, priority,
+ ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score
+ FROM memory_atoms
+ WHERE user_id = :uid
+ {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
+ AND content_tsv @@ plainto_tsquery('simple', :query)
+ ORDER BY text_score DESC
+ LIMIT :limit
+ """),
+ {"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid
+ else {"uid": uid, "query": query, "limit": top_k * 2},
+ )
+ for row in text_results.fetchall():
+ results.append({
+ "id": str(row[0]), "type": row[1], "content": row[2],
+ "priority": row[3], "text_score": float(row[4]),
+ })
+
+ rrf_scores: dict[str, float] = {}
+ k = 60
+ for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)):
+ if "vector_score" in r:
+ rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
+ for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)):
+ if "text_score" in r:
+ rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
+
+ seen = set()
+ merged = []
+ for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True):
+ for r in results:
+ if r["id"] == rid and rid not in seen:
+ seen.add(rid)
+ r["rrf_score"] = score
+ merged.append(r)
+ break
+
+ merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True)
+ return merged[:top_k]
+ except Exception as e:
+ logger.warning(f"混合检索失败: {e}")
+ return results[:top_k] if results else []
\ No newline at end of file
diff --git a/backend/modules/memory/router.py b/backend/modules/memory/router.py
index 3ca3af2..dc5311a 100644
--- a/backend/modules/memory/router.py
+++ b/backend/modules/memory/router.py
@@ -13,7 +13,7 @@ async def list_sessions(request: Request, user=Depends(get_current_user)):
@router.get("/sessions/{session_id}")
-async def get_session(session_id: str, flow_id: str = "", request: Request, user=Depends(get_current_user)):
+async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)):
mm = get_memory_manager()
history = await mm.get_conversation_history(
user_id=str(user.id),
diff --git a/backend/modules/model_provider/__init__.py b/backend/modules/model_provider/__init__.py
new file mode 100644
index 0000000..43a7103
--- /dev/null
+++ b/backend/modules/model_provider/__init__.py
@@ -0,0 +1,3 @@
+from .router import router
+
+__all__ = ["router"]
\ No newline at end of file
diff --git a/backend/modules/model_provider/router.py b/backend/modules/model_provider/router.py
new file mode 100644
index 0000000..b1c768f
--- /dev/null
+++ b/backend/modules/model_provider/router.py
@@ -0,0 +1,181 @@
+import uuid
+import logging
+from fastapi import APIRouter, Depends, HTTPException
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+from database import get_db
+from models import ModelProvider, ModelInstance
+from dependencies import get_current_user
+
+logger = logging.getLogger(__name__)
+router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
+
+
+@router.get("")
+async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ result = await db.execute(
+ select(ModelProvider).order_by(ModelProvider.created_at.desc())
+ )
+ return {
+ "code": 200,
+ "data": [
+ {
+ "id": str(p.id),
+ "name": p.name,
+ "provider_type": p.provider_type,
+ "base_url": p.base_url,
+ "is_active": p.is_active,
+ "created_at": p.created_at.isoformat() if p.created_at else "",
+ }
+ for p in result.scalars().all()
+ ],
+ }
+
+
+@router.post("")
+async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ existing = await db.execute(
+ select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
+ )
+ if existing.scalars().first():
+ raise HTTPException(400, "相同 base_url 的供应商已存在")
+
+ p = ModelProvider(
+ name=payload["name"],
+ provider_type=payload["provider_type"],
+ base_url=payload.get("base_url", ""),
+ api_key=payload.get("api_key", ""),
+ extra_config=payload.get("extra_config", {}),
+ )
+ db.add(p)
+ await db.commit()
+ return {"code": 200, "data": {"id": str(p.id)}}
+
+
+@router.put("/{provider_id}")
+async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ p = await db.get(ModelProvider, uuid.UUID(provider_id))
+ if not p:
+ raise HTTPException(404, "供应商不存在")
+
+ p.name = payload.get("name", p.name)
+ p.base_url = payload.get("base_url", p.base_url)
+ p.api_key = payload.get("api_key", p.api_key)
+ p.provider_type = payload.get("provider_type", p.provider_type)
+ p.extra_config = payload.get("extra_config", p.extra_config)
+ p.is_active = payload.get("is_active", p.is_active)
+ await db.commit()
+ return {"code": 200, "data": {"id": str(p.id)}}
+
+
+@router.delete("/{provider_id}")
+async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ p = await db.get(ModelProvider, uuid.UUID(provider_id))
+ if not p:
+ raise HTTPException(404, "供应商不存在")
+ await db.delete(p)
+ await db.commit()
+ return {"code": 200, "message": "已删除"}
+
+
+@router.get("/models/all")
+async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ result = await db.execute(
+ select(ModelInstance)
+ .where(ModelInstance.is_active == True)
+ .order_by(ModelInstance.model_type, ModelInstance.model_name)
+ )
+ return {
+ "code": 200,
+ "data": [
+ {
+ "id": str(m.id),
+ "model_name": m.model_name,
+ "model_type": m.model_type,
+ "display_name": m.display_name,
+ }
+ for m in result.scalars().all()
+ ],
+ }
+
+
+@router.get("/{provider_id}/models")
+async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ result = await db.execute(
+ select(ModelInstance)
+ .where(ModelInstance.provider_id == uuid.UUID(provider_id))
+ .order_by(ModelInstance.model_type, ModelInstance.model_name)
+ )
+ return {
+ "code": 200,
+ "data": [
+ {
+ "id": str(m.id),
+ "provider_id": str(m.provider_id),
+ "model_name": m.model_name,
+ "model_type": m.model_type,
+ "display_name": m.display_name,
+ "capabilities": m.capabilities,
+ "default_params": m.default_params,
+ "is_default": m.is_default,
+ "is_active": m.is_active,
+ }
+ for m in result.scalars().all()
+ ],
+ }
+
+
+@router.post("/{provider_id}/models")
+async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ p = await db.get(ModelProvider, uuid.UUID(provider_id))
+ if not p:
+ raise HTTPException(404, "供应商不存在")
+
+ existing = await db.execute(
+ select(ModelInstance).where(
+ ModelInstance.provider_id == uuid.UUID(provider_id),
+ ModelInstance.model_name == payload["model_name"],
+ )
+ )
+ if existing.scalars().first():
+ raise HTTPException(400, "相同名称的模型已存在")
+
+ m = ModelInstance(
+ provider_id=uuid.UUID(provider_id),
+ model_name=payload["model_name"],
+ model_type=payload["model_type"],
+ display_name=payload.get("display_name", payload["model_name"]),
+ capabilities=payload.get("capabilities", {}),
+ default_params=payload.get("default_params", {}),
+ is_default=payload.get("is_default", False),
+ )
+ db.add(m)
+ await db.commit()
+ return {"code": 200, "data": {"id": str(m.id)}}
+
+
+@router.put("/{provider_id}/models/{model_id}")
+async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ m = await db.get(ModelInstance, uuid.UUID(model_id))
+ if not m or str(m.provider_id) != provider_id:
+ raise HTTPException(404, "模型不存在")
+
+ m.model_name = payload.get("model_name", m.model_name)
+ m.model_type = payload.get("model_type", m.model_type)
+ m.display_name = payload.get("display_name", m.display_name)
+ m.capabilities = payload.get("capabilities", m.capabilities)
+ m.default_params = payload.get("default_params", m.default_params)
+ m.is_default = payload.get("is_default", m.is_default)
+ m.is_active = payload.get("is_active", m.is_active)
+ await db.commit()
+ return {"code": 200, "data": {"id": str(m.id)}}
+
+
+@router.delete("/{provider_id}/models/{model_id}")
+async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
+ m = await db.get(ModelInstance, uuid.UUID(model_id))
+ if not m or str(m.provider_id) != provider_id:
+ raise HTTPException(404, "模型不存在")
+ await db.delete(m)
+ await db.commit()
+ return {"code": 200, "message": "已删除"}
\ No newline at end of file
diff --git a/backend/schemas/__init__.py b/backend/schemas/__init__.py
index 62bd5d1..8fc3a32 100644
--- a/backend/schemas/__init__.py
+++ b/backend/schemas/__init__.py
@@ -275,6 +275,7 @@ class FlowDefinitionCreate(BaseModel):
trigger: dict = {}
nodes: list[FlowNode]
edges: list[FlowEdge]
+ flow_mode: str = "chatflow"
class FlowDefinitionUpdate(BaseModel):
@@ -295,6 +296,7 @@ class FlowDefinitionOut(BaseModel):
published_version_id: uuid.UUID | None = None
published_to_wecom: bool
published_to_web: bool = False
+ flow_mode: str = "chatflow"
created_at: datetime | None = None
updated_at: datetime | None = None
diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml
index 7ee4c75..d1570e0 100644
--- a/docker-compose.prod.yml
+++ b/docker-compose.prod.yml
@@ -3,6 +3,19 @@ services:
image: postgres:16-alpine
container_name: ent-postgres
restart: always
+ entrypoint: >
+ sh -c '
+ if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then
+ echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..."
+ apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1
+ wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp
+ cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1
+ cd / && rm -rf /tmp/pgvector-0.7.4
+ apk del .pg_build > /dev/null 2>&1
+ echo ">>> pgvector 安装完成"
+ fi
+ exec docker-entrypoint.sh postgres
+ '
environment:
POSTGRES_USER: enterprise
POSTGRES_PASSWORD: enterprise123
diff --git a/docker-compose.yml b/docker-compose.yml
index ae8201e..9c2dfba 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -12,6 +12,19 @@ services:
image: postgres:16-alpine
container_name: ent-postgres
restart: always
+ entrypoint: >
+ sh -c '
+ if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then
+ echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..."
+ apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1
+ wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp
+ cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1
+ cd / && rm -rf /tmp/pgvector-0.7.4
+ apk del .pg_build > /dev/null 2>&1
+ echo ">>> pgvector 安装完成"
+ fi
+ exec docker-entrypoint.sh postgres
+ '
environment:
POSTGRES_USER: enterprise
POSTGRES_PASSWORD: enterprise123
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index f53f3ef..a200a17 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -89,6 +89,8 @@ export const flowApi = {
getMarket: () => api.get('/flow/market'),
getTemplates: () => api.get('/flow/templates'),
useTemplate: (id: string) => api.post(`/flow/templates/${id}/use`),
+ getVersions: (id: string) => api.get(`/flow/definitions/${id}/versions`),
+ rollbackVersion: (flowId: string, versionId: string) => api.post(`/flow/definitions/${flowId}/rollback/${versionId}`),
}
export const chatApi = {
@@ -191,4 +193,17 @@ export const ragApi = {
getDocuments: () => api.get('/rag/documents'),
deleteDocument: (source: string) => api.delete(`/rag/documents/${encodeURIComponent(source)}`),
getStats: () => api.get('/rag/stats'),
+}
+
+export const modelProviderApi = {
+ getProviders: () => api.get('/model-providers'),
+ getProvider: (id: string) => api.get(`/model-providers/${id}`),
+ createProvider: (data: any) => api.post('/model-providers', data),
+ updateProvider: (id: string, data: any) => api.put(`/model-providers/${id}`, data),
+ deleteProvider: (id: string) => api.delete(`/model-providers/${id}`),
+ getModels: (providerId: string) => api.get(`/model-providers/${providerId}/models`),
+ getAllModels: () => api.get('/model-providers/models/all'),
+ createModel: (providerId: string, data: any) => api.post(`/model-providers/${providerId}/models`, data),
+ updateModel: (providerId: string, modelId: string, data: any) => api.put(`/model-providers/${providerId}/models/${modelId}`, data),
+ deleteModel: (providerId: string, modelId: string) => api.delete(`/model-providers/${providerId}/models/${modelId}`),
}
\ No newline at end of file
diff --git a/frontend/src/components/layout/AdminLayout.vue b/frontend/src/components/layout/AdminLayout.vue
index bcd5c39..22f3cd5 100644
--- a/frontend/src/components/layout/AdminLayout.vue
+++ b/frontend/src/components/layout/AdminLayout.vue
@@ -52,6 +52,7 @@
迭代节点将对输入数组中的每个元素执行一次下游节点的处理。
++ 将迭代节点连接到需要循环处理的节点,每次迭代会将数组中一个元素作为输入传递。 +
++ 输出为包含每次迭代结果的对象数组: [{"index":0, "item":..., "result":...}] +
+改写优化:将模糊、不完整的问题改写为清晰、具体的版本,自动补充缺失的上下文。
+扩展细化:将简短的问题扩展为更详细的描述,增加背景和细节信息。
+支持从记忆系统自动获取用户画像、记忆原子和近期对话作为优化上下文。
+常用模板:
+请帮我处理:{"{{input}}"}
+ 根据{"{{rag_node.output}}"},回答:{"{{input}}"}
+ {"{"}"query": "{"{{input}}"}", "context": "{"{{trigger.data}}"}"{"}"}
+