diff --git a/PLAN9.md b/PLAN9.md new file mode 100644 index 0000000..56212ec --- /dev/null +++ b/PLAN9.md @@ -0,0 +1,893 @@ +# PLAN9 — 全面升级计划:记忆增强 · 流重构 · 模型管理 · 节点扩充 + +> 日期:2026-05-15 +> 状态:规划中,待确认后实施 + +--- + +## 一、记忆管理升级:吸收 Agent-Memory 分层架构 + +### 1.1 现状分析 + +**我们当前方案(PLAN8 已实施)**: +- 纯 Redis 存储:List 存消息、Hash 存元数据、String 存摘要 +- 透明中间件:在 `flow_engine/router.py` 的 `execute_flow` / `execute_flow_stream` 中注入记忆 +- LLMNodeAgent.reply() 从 `context["_memory_context"]` 读取历史摘要 + 最近消息 +- 7 天 TTL 自动过期,异步记录不阻塞响应 + +**Agent-Memory 项目(ali-agentscope-src/Agent-Memory)**: +- TypeScript 实现,以 OpenClaw 插件形式运行 +- 四层记忆金字塔:L0(原始对话) → L1(结构化原子) → L2(场景块) → L3(用户画像) +- 双写机制:JSONL 持久化 + SQLite 向量检索 +- 混合搜索:BM25 关键词 + Embedding 向量 + RRF 融合 +- Context Offload:长任务上下文溢出时 Mermaid 符号化压缩 +- Pipeline 调度:每 N 轮对话触发 L1 提取,延迟触发 L2/L3 + +### 1.2 对比评估 + +| 维度 | 我们方案 | Agent-Memory | 评价 | +|------|---------|-------------|------| +| **存储** | Redis(纯内存) | SQLite + JSONL(磁盘) | 我们快但不适合海量历史;他们慢但可持久化大量数据 | +| **记忆分层** | 无分层(平铺消息列表) | L0/L1/L2/L3 四层金字塔 | **最大差距**:我们没有结构化提取,全是原始消息 | +| **检索方式** | 全量最近 N 条 + 摘要 | BM25 + 向量 + RRF 混合搜索 | 我们只能按时间取最近消息,无法按语义检索 | +| **去重** | 无 | L1 提取后 batchDedup(store/update/merge/skip) | 我们会重复存储相同信息 | +| **用户画像** | 未实现 | L3 Persona 自动生成 + 增量更新 | 他们有完整的画像系统 | +| **上下文溢出** | 无处理 | Context Offload + Mermaid 压缩 | 我们长对话会直接截断,丢失信息 | +| **接入方式** | Python 原生集成 | TypeScript 插件,需 OpenClaw/Hermes 宿主 | **不能直接部署接入**,语言和架构不兼容 | +| **性能** | 毫秒级(Redis) | 秒级(SQLite + LLM 提取) | 我们快但浅;他们慢但深 | + +### 1.3 结论 + +**不能直接部署接入**:Agent-Memory 是 TypeScript 项目,依赖 OpenClaw/Hermes 宿主环境,与我们的 Python/FastAPI 架构完全不兼容。 + +**应吸收其核心设计思想**,但存储底座需要重新选型——详见 1.4 节分析。 + +### 1.4 记忆存储底座选型:Redis vs MongoDB vs PostgreSQL + +#### 1.4.1 当前 Redis 方案的数据结构与访问模式 + +``` +当前 Redis 键结构: + mem:{uid}:{fid}:{sid}:messages → List(LPUSH 写入,LRANGE 读取最近 N 条) + mem:{uid}:{fid}:{sid}:meta → Hash(HSET 写入,HGETALL 读取) + mem:{uid}:{fid}:{sid}:summary → String(SETEX 写入,GET 读取) + mem:{uid}:sessions → Set(SADD 写入,SMEMBERS 读取) + +访问模式: + 写入:record_exchange() → pipeline 批量 LPUSH + HSET + SADD(每次对话 1 次) + 读取:inject_memory() → LRANGE + GET(每次对话 1 次,延迟 < 1ms) + 删除:delete_session() → KEYS + DEL(低频,用户主动清除) + 列表:list_user_sessions() → SMEMBERS + KEYS + HGETALL(低频,管理页面) +``` + +#### 1.4.2 三种数据库四维度对比 + +**维度一:内存占用与性能** + +| 指标 | Redis | MongoDB | PostgreSQL | +|------|-------|---------|-----------| +| **写入延迟** | < 1ms(纯内存) | 2-5ms(WAL + 内存映射) | 3-8ms(WAL + fsync) | +| **读取延迟** | < 1ms | 1-3ms(WiredTiger 缓存命中) | 1-5ms(shared_buffers 命中) | +| **inject_memory 延迟** | ~1ms | ~5ms | ~8ms | +| **内存消耗(1万会话×40条消息)** | ~800MB(纯内存,无压缩) | ~200MB(磁盘 + 缓存热数据) | ~150MB(磁盘 + shared_buffers) | +| **内存消耗(100万会话)** | ~80GB(不可行,需分片) | ~20GB(缓存 + 磁盘冷数据淘汰) | ~15GB(磁盘为主,热数据缓存) | +| **大规模场景性能** | 受内存上限约束,超内存即 OOM | 缓存冷热分离,可支撑 TB 级 | 磁盘为主,可支撑 PB 级 | +| **向量检索性能** | 需 Redis Stack(RediSearch),10万向量 ~50ms | 原生向量索引(Atlas Vector Search),10万向量 ~30ms | pgvector 扩展,10万向量 ~40ms | +| **全文检索性能** | 需 Redis Stack(FTS 模块),非标准 | 原生文本索引,成熟 | tsvector + GIN 索引,成熟 | + +**关键结论**: +- Redis 在小规模(< 10万会话)下性能最优,但内存成本线性增长 +- MongoDB 和 PostgreSQL 在大规模场景下更优,冷数据自动落盘不占内存 +- 向量检索三者性能接近,但 Redis 需额外安装 Stack 模块 + +**维度二:数据模型适配性** + +| 指标 | Redis | MongoDB | PostgreSQL | +|------|-------|---------|-----------| +| **消息列表(L0)** | List(天然有序,但无结构化查询) | 嵌入文档数组或独立文档(灵活) | 关系表 + 时间索引(标准) | +| **结构化原子(L1)** | Hash + Sorted Set(需手动管理) | 文档(天然 JSON,灵活 schema) | JSONB 列(灵活 + 可索引) | +| **场景块(L2)** | Hash(需手动序列化) | 文档(嵌套结构自然表达) | JSONB 或独立表 | +| **用户画像(L3)** | String(纯文本,无查询能力) | 文档(结构化,可按字段查询) | JSONB 或独立表(可按字段查询) | +| **向量存储** | 需 Redis Stack(非标准部署) | 原生支持(Atlas Vector Search) | pgvector 扩展(成熟) | +| **数据迁移复杂度** | 基准(当前方案) | 中(需新建 MongoDB 连接 + 集合设计) | **低**(已有 PostgreSQL 连接池,复用 SQLAlchemy) | +| **Schema 演进** | 无 Schema(灵活但无约束) | 无 Schema(灵活,可加验证) | 有 Schema(需迁移 SQL,但类型安全) | + +**关键结论**: +- MongoDB 对文档型记忆数据最自然(JSON 原生存储,无需 ORM 映射) +- PostgreSQL 迁移成本最低(项目已用 SQLAlchemy + asyncpg,复用现有连接池) +- Redis 对 L0 消息列表操作最自然(LPUSH/LRANGE),但对 L1/L2/L3 结构化数据支持弱 + +**维度三:扩展性与维护成本** + +| 指标 | Redis | MongoDB | PostgreSQL | +|------|-------|---------|-----------| +| **水平扩展** | Redis Cluster(分片复杂,跨 slot 操作受限) | 原生分片(Shard Key 自动路由) | 读写分离 + 分区表(Citus 扩展) | +| **运维复杂度** | 低(单实例简单,Cluster 复杂) | 中(副本集 + 分片需监控) | **最低**(项目已有 PostgreSQL 运维) | +| **新增基础设施** | 无(已部署) | **需新增 MongoDB 服务** | **无需新增**(复用现有 PostgreSQL) | +| **备份恢复** | RDB/AOF(需配置持久化策略) | mongodump/oplog | pg_dump/WAL 归档(**已有**) | +| **监控工具** | redis-cli INFO | MongoDB Compass/Cloud Manager | pg_stat_statements(**已有**) | +| **长期维护成本** | 高(内存持续增长需扩容) | 中 | **低**(磁盘扩容便宜,运维体系成熟) | + +**关键结论**: +- **PostgreSQL 运维成本最低**:项目已有 PostgreSQL 实例、连接池、备份策略,无需新增基础设施 +- MongoDB 需要额外部署和维护一套数据库服务 +- Redis 长期成本最高:内存是磁盘的 10-50 倍价格 + +**维度四:事务与一致性需求** + +| 指标 | Redis | MongoDB | PostgreSQL | +|------|-------|---------|-----------| +| **事务支持** | Pipeline(非原子,MULTI 仅单 key 原子) | 4.0+ 多文档事务(副本集) | 完整 ACID 事务 | +| **记忆数据事务需求** | 低(单次写入可容忍部分失败) | 低 | 低 | +| **模型配置事务需求** | 不涉及 | 不涉及 | 高(需与现有业务表关联) | +| **数据一致性** | 最终一致(AOF 异步刷盘可能丢 1s 数据) | 强一致(副本集 w:majority) | 强一致(WAL 同步刷盘) | +| **记忆丢失风险** | AOF 异步刷盘时宕机可能丢失 | 副本集保证不丢 | WAL 保证不丢 | + +**关键结论**: +- 记忆数据对事务要求低(丢失一条消息可接受),三种数据库均满足 +- 但 PostgreSQL 的强一致性在模型配置等业务数据上更有优势 +- Redis AOF 异步模式下有 1 秒数据丢失窗口,对"不失忆"要求有风险 + +#### 1.4.3 综合评估与选型决策 + +| 评估维度 | Redis | MongoDB | PostgreSQL | 权重 | +|---------|-------|---------|-----------|------| +| 性能(小规模) | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ | 20% | +| 性能(大规模) | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 25% | +| 数据模型适配 | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 20% | +| 迁移成本 | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | 15% | +| 运维成本 | ⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | 10% | +| 事务/一致性 | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | 10% | +| **加权总分** | **3.35** | **3.25** | **3.95** | — | + +#### 1.4.4 选型结论:PostgreSQL + Redis 混合方案 + +**选择 PostgreSQL 作为记忆主存储,Redis 保留为热数据缓存层**。 + +理由: +1. **零新增基础设施**:项目已有 PostgreSQL(asyncpg + SQLAlchemy),无需部署新服务 +2. **迁移成本最低**:复用现有连接池、ORM、迁移框架,新增表即可 +3. **PGVector 原生支持**:向量检索无需额外模块,`CREATE EXTENSION vector` 即可 +4. **JSONB 灵活 + 可索引**:L1/L2/L3 结构化数据用 JSONB 存储,支持 GIN 索引按字段查询 +5. **强一致性保证不失忆**:WAL 同步刷盘,无 Redis AOF 的 1 秒丢失窗口 +6. **大规模成本优势**:磁盘存储比内存便宜 10-50 倍,100 万会话仅需 ~15GB 磁盘 + +**Redis 保留角色**: +- 热数据缓存:最近 10 条消息缓存到 Redis(inject_memory 时直接读缓存,延迟 < 1ms) +- 速率限制:已有 cache_manager 的限流功能 +- 会话状态:当前活跃会话的临时状态 + +#### 1.4.5 PostgreSQL 记忆数据表设计 + +```sql +-- L0: 原始对话消息 +CREATE TABLE memory_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id), + flow_id UUID NOT NULL REFERENCES flow_definitions(id), + session_id UUID NOT NULL, + role VARCHAR(20) NOT NULL, -- "user" / "assistant" + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_memory_messages_session ON memory_messages(user_id, flow_id, session_id, created_at DESC); +CREATE INDEX idx_memory_messages_user_flow ON memory_messages(user_id, flow_id, created_at DESC); + +-- L1: 结构化记忆原子 +CREATE TABLE memory_atoms ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id), + flow_id UUID REFERENCES flow_definitions(id), -- NULL 表示全局 + atom_type VARCHAR(20) NOT NULL, -- "persona" / "episodic" / "instruction" + content TEXT NOT NULL, + priority SMALLINT DEFAULT 50, -- 0-100,越高越核心 + source_session_id UUID, -- 来源会话 + metadata JSONB DEFAULT '{}', -- 扩展元数据 + embedding vector(1536), -- 向量(需 pgvector 扩展) + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_memory_atoms_user ON memory_atoms(user_id, atom_type, priority DESC); +CREATE INDEX idx_memory_atoms_embedding ON memory_atoms USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); + +-- L2: 场景块 +CREATE TABLE memory_scenes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id), + flow_id UUID REFERENCES flow_definitions(id), + scene_name VARCHAR(200) NOT NULL, + summary TEXT NOT NULL, + heat INTEGER DEFAULT 0, -- 热度(访问频率) + content JSONB DEFAULT '{}', -- 场景完整内容 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_memory_scenes_user ON memory_scenes(user_id, flow_id, heat DESC); + +-- L3: 用户画像 +CREATE TABLE memory_personas ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL UNIQUE REFERENCES users(id), -- 每用户一条 + content JSONB NOT NULL DEFAULT '{}', -- 结构化画像 + raw_text TEXT DEFAULT '', -- 原始画像文本(注入 LLM 用) + version INTEGER DEFAULT 1, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- 会话元数据 +CREATE TABLE memory_sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id), + flow_id UUID NOT NULL REFERENCES flow_definitions(id), + session_id UUID NOT NULL, + flow_name VARCHAR(200) DEFAULT '', + message_count INTEGER DEFAULT 0, + last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(user_id, flow_id, session_id) +); + +CREATE INDEX idx_memory_sessions_user ON memory_sessions(user_id, last_active_at DESC); +``` + +#### 1.4.6 MemoryManager 改造方案 + +```python +class MemoryManager: + def __init__(self, db_session_factory, redis: Redis): + self.db_session_factory = db_session_factory # AsyncSessionLocal + self.redis = redis # 热数据缓存 + + async def inject_memory(self, user_id, flow_id, session_id, context): + # 1. 先查 Redis 缓存(最近 10 条消息) + cached = await self._get_cached_messages(user_id, flow_id, session_id) + if cached: + recent_messages = cached + else: + # 2. 缓存未命中,查 PostgreSQL + recent_messages = await self._query_recent_messages(user_id, flow_id, session_id) + # 3. 回填 Redis 缓存 + await self._cache_messages(user_id, flow_id, session_id, recent_messages) + + # 4. 查 L1 原子(PostgreSQL) + atoms = await self._query_relevant_atoms(user_id, flow_id, context.get("input", "")) + + # 5. 查 L3 画像(PostgreSQL) + persona = await self._query_persona(user_id) + + context["_memory_context"] = { + "recent_messages": recent_messages, + "atoms": atoms, + "persona": persona, + "session_id": session_id, + } + + async def record_exchange(self, user_id, flow_id, session_id, user_msg, assistant_msg, flow_name=""): + ts = datetime.utcnow() + + # 1. 写 PostgreSQL(主存储,保证持久化) + async with self.db_session_factory() as session: + session.add(MemoryMessage( + user_id=user_id, flow_id=flow_id, session_id=session_id, + role="user", content=user_msg, created_at=ts, + )) + session.add(MemoryMessage( + user_id=user_id, flow_id=flow_id, session_id=session_id, + role="assistant", content=assistant_msg, created_at=ts, + )) + await session.commit() + + # 2. 更新 Redis 缓存(热数据) + await self._cache_append_message(user_id, flow_id, session_id, [ + {"role": "user", "content": user_msg, "ts": ts.isoformat()}, + {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()}, + ]) + + # 3. 异步触发 L1 提取 + asyncio.create_task(self._maybe_extract_atoms(user_id, flow_id, session_id)) +``` + +#### 1.4.7 迁移步骤 + +``` +Step 1: 创建 PostgreSQL 记忆表(init-db/03-memory-tables.sql) +Step 2: 安装 pgvector 扩展(init-db/03-memory-tables.sql 中 CREATE EXTENSION IF NOT EXISTS vector) +Step 3: 改造 MemoryManager,PostgreSQL 为主存储,Redis 为缓存 +Step 4: 数据迁移脚本:Redis → PostgreSQL(一次性,读取 Redis 中现有记忆数据写入 PG) +Step 5: 验证全链路:inject_memory / record_exchange / delete_session +Step 6: 确认无误后,Redis 中的记忆键可设置短 TTL 自然过期 +``` + +#### 1.4.8 迁移 SQL 文件 + +详见 `init-db/03-memory-tables.sql`(实施时创建),包含: +- `CREATE EXTENSION IF NOT EXISTS vector` +- 上述 5 张表的 CREATE TABLE + INDEX +- 幂等执行(IF NOT EXISTS) + +--- + +## 二、全面升级流管理(对标 Dify) + +### 2.1 现状问题 + +1. **节点类型不足**:当前 11 种节点(trigger/llm/tool/mcp/notify/condition/rag/loop/merge/code/output),Dify 有 15+ 种 +2. **记忆不是默认行为**:虽然 PLAN8 实现了透明中间件,但缺乏 Dify 那样的"对话型 vs 工作流型"区分 +3. **变量系统原始**:仅 `{{node_id.output}}` 模板,无类型、无作用域、无聚合 +4. **版本管理不完善**:有 FlowVersion 快照但缺少完整的草稿/发布分离流程 + +### 2.2 流类型区分(Chatflow vs Workflow) + +**Dify 的核心设计**:区分两种应用模式,记忆机制不同 + +| 模式 | 记忆 | 典型场景 | 对应我们的 | +|------|------|---------|-----------| +| **Chatflow(对话型)** | 自动维护对话历史,LLM 节点自动注入上下文 | 客服、助手、问答 | FlowChat.vue 使用的流 | +| **Workflow(工作流型)** | 无自动记忆,每次执行独立 | 数据处理、批量任务、API 调用 | API 网关调用的流 | + +**改造方案**: + +```python +# FlowDefinition 新增字段 +flow_mode = Column(String(20), default="chatflow") # "chatflow" | "workflow" + +# 记忆中间件逻辑调整 +if f.flow_mode == "chatflow": + await mm.inject_memory(...) # 注入记忆 + asyncio.create_task(mm.record_exchange(...)) # 记录对话 +else: + pass # workflow 模式不注入记忆 +``` + +前端 FlowEditor.vue 新增流类型选择(创建流时选择,创建后不可更改)。 + +### 2.3 记忆默认启用策略 + +**核心原则**:所有 Chatflow 类型的流,记忆管理是默认行为,不需要用户在流中添加节点。 + +``` +用户输入 → router.py + │ + ├─ Chatflow 模式: + │ ├─ 执行前:inject_memory() → 检索历史 + 画像 → 注入 context + │ ├─ LLM 调用:LLMNodeAgent 自动读取 _memory_context + │ └─ 执行后:record_exchange() → 异步记录 + L1 提取 + │ + └─ Workflow 模式: + └─ 直接执行,无记忆注入 +``` + +### 2.4 流创建流程增强 + +**当前**:创建流 → 编辑画布 → 发布 + +**升级后**: +``` +创建流 + ├─ 选择流类型(Chatflow / Workflow) + ├─ 选择模板(可选,从模板市场选择) + └─ 进入编辑器 + +编辑画布 + ├─ 拖拽节点 + ├─ 配置节点参数 + └─ 连线 + +发布 + ├─ 拓扑完整性检查(无孤立节点、起始/结束节点存在) + ├─ 必填参数校验 + ├─ 创建版本快照(不可变) + └─ 更新 published_version_id 指针 +``` + +### 2.5 流市场 vs 流列表的关系梳理 + +**当前问题**:流市场的"已上架工作流列表"和一级菜单"流列表"内容重叠,用户困惑。 + +**梳理方案**: + +| 页面 | 定位 | 数据来源 | 操作 | +|------|------|---------|------| +| **流列表**(管理端) | 我创建/我管理的所有流 | `GET /api/flow/definitions` | 编辑、删除、发布/下架、上架到市场 | +| **流市场**(用户端) | 已上架到市场的公开流 | `GET /api/flow/market` | 查看详情、安装/使用、评价 | +| **模板中心**(创建时) | 系统预置 + 社区贡献的模板 | `GET /api/flow/templates` | 一键使用模板创建新流 | + +**关键区分**: +- 流列表 = 私有工作区,看到的是自己管理的流 +- 流市场 = 公开商店,看到的是别人发布到市场的流 +- 模板中心 = 快速起步,创建新流时的入口 + +--- + +## 三、智能体管理改造:OpenAI-API-Compatible 模型管理 + +### 3.1 现状问题 + +当前 `AgentConfig` 模型只有一个 `model` 字段(String),无法区分模型类型,无法管理 Embedding/Rerank 模型,无法支持多供应商。 + +```python +# 当前 AgentConfig +model = Column(String(50), default="gpt-4o-mini") # 只能存一个模型名 +``` + +### 3.2 Dify 的模型供应商架构 + +Dify 采用三层架构: + +``` +ModelProvider(供应商抽象层) + ├── 模型类型:LLM / Embedding / Rerank / TTS / Speech-to-Text + ├── 供应商实例:OpenAI / Anthropic / ZhipuAI / Ollama / OpenAI-API-Compatible + └── 模型实例:gpt-4o / text-embedding-3-small / cohere-rerank +``` + +**OpenAI-API-Compatible 模式**:任何实现了 OpenAI API 格式的服务都能即插即用,只需配置 base_url + api_key + model_name。 + +### 3.3 改造方案 + +#### 数据模型 + +```python +# 新增:模型供应商表 +class ModelProvider(Base): + __tablename__ = "model_providers" + + id = Column(UUID, primary_key=True, default=uuid.uuid4) + name = Column(String(100), nullable=False) # "OpenAI" / "智谱AI" / "本地Ollama" + provider_type = Column(String(50), nullable=False) # "openai" / "zhipu" / "ollama" / "openai_compatible" + base_url = Column(String(500)) # API 端点 + api_key = Column(Text) # 加密存储 + extra_config = Column(JSON, default=dict) # 供应商特有配置 + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + + +# 新增:模型实例表 +class ModelInstance(Base): + __tablename__ = "model_instances" + + id = Column(UUID, primary_key=True, default=uuid.uuid4) + provider_id = Column(UUID, ForeignKey("model_providers.id")) + model_name = Column(String(100), nullable=False) # "gpt-4o" / "embedding-3" / "bge-rerank" + model_type = Column(String(30), nullable=False) # "llm" / "embedding" / "rerank" + display_name = Column(String(200)) # "GPT-4o" / "文本嵌入v3" + capabilities = Column(JSON, default=dict) # {"vision": true, "function_calling": true, "max_tokens": 128000} + default_params = Column(JSON, default=dict) # {"temperature": 0.7, "top_p": 1.0} + is_default = Column(Boolean, default=False) # 是否为该类型的默认模型 + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) +``` + +#### 各模型类型的配置参数 + +| 模型类型 | 通用参数 | 特有参数 | +|---------|---------|---------| +| **LLM** | model, temperature, top_p, max_tokens, stream | vision(多模态)、function_calling、response_format | +| **Embedding** | model, dimensions | encoding_format, input_type | +| **Rerank** | model, top_n | query, documents, return_documents | + +#### API 端点设计 + +``` +# 供应商管理 +POST /api/model-providers/ # 添加供应商 +GET /api/model-providers/ # 列出供应商 +PUT /api/model-providers/{id} # 更新供应商 +DELETE /api/model-providers/{id} # 删除供应商 +POST /api/model-providers/{id}/test # 测试连通性 + +# 模型实例管理 +POST /api/model-providers/{id}/models/ # 添加模型 +GET /api/model-instances/ # 列出所有模型(支持 ?type=llm 筛选) +PUT /api/model-instances/{id} # 更新模型 +DELETE /api/model-instances/{id} # 删除模型 +POST /api/model-instances/{id}/test # 测试模型调用 + +# 默认模型设置 +PUT /api/model-defaults/ # 设置各类型默认模型 +GET /api/model-defaults/ # 获取各类型默认模型 +``` + +#### 前端页面 + +新增 `ModelProviderManager.vue`(管理端): +- 供应商列表(卡片式,显示名称、类型、状态、模型数量) +- 添加供应商表单(选择类型 → 填写 base_url/api_key → 自动检测可用模型) +- 模型实例列表(按类型 Tab 分组:LLM / Embedding / Rerank) +- 每个模型可设置默认参数、是否为默认模型 +- 测试按钮:发送测试请求验证连通性 + +#### AgentConfig 关联调整 + +```python +# AgentConfig 改造 +class AgentConfig(Base): + # ...现有字段保留... + model = Column(String(50), default="gpt-4o-mini") # 保留兼容 + model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联模型实例 + embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) # 新增:关联嵌入模型 +``` + +#### 引擎层适配 + +```python +# engine.py 中 LLMNodeAgent 改造 +class LLMNodeAgent(AgentBase): + async def reply(self, msg, **kwargs): + context = kwargs.get("context", {}) + config = self.config + + # 优先使用 model_instance_id 获取模型配置 + model_instance_id = config.get("model_instance_id") + if model_instance_id: + provider, model = await self._resolve_model(model_instance_id) + base_url = provider.base_url + api_key = provider.api_key + model_name = model.model_name + else: + # fallback 到旧逻辑 + base_url = settings.OPENAI_BASE_URL + api_key = settings.OPENAI_API_KEY + model_name = config.get("model", "gpt-4o-mini") +``` + +### 3.4 实施优先级 + +| 步骤 | 内容 | 优先级 | +|------|------|--------| +| Step 1 | 创建 ModelProvider + ModelInstance 数据模型 + 迁移 SQL | **P0** | +| Step 2 | 实现供应商/模型 CRUD API | **P0** | +| Step 3 | 前端 ModelProviderManager.vue 页面 | **P0** | +| Step 4 | LLMNodeAgent 适配 model_instance_id | P1 | +| Step 5 | RAGNodeAgent 适配 embedding_model_id | P1 | +| Step 6 | AgentConfig 关联 model_instance_id | P1 | +| Step 7 | 供应商自动检测可用模型 | P2 | + +--- + +## 四、流编辑器节点扩充(对标 Dify) + +### 4.1 现有节点 vs Dify 节点对比 + +| 节点 | 我们有 | Dify 有 | 差距说明 | +|------|:------:|:-------:|---------| +| 触发/开始 | ✅ trigger | ✅ start | 我们多了企微触发,Dify 只有输入变量 | +| LLM | ✅ llm | ✅ llm | 基本对齐 | +| 工具调用 | ✅ tool | ✅ tool | 基本对齐 | +| MCP | ✅ mcp | ❌ | 我们独有 | +| 通知 | ✅ notify | ❌ | 我们独有(企微通知) | +| 条件分支 | ✅ condition | ✅ if-else | 基本对齐 | +| RAG检索 | ✅ rag | ✅ knowledge-retrieval | 基本对齐 | +| 循环 | ✅ loop | ✅ iteration | Dify 是数组迭代,我们是通用循环 | +| 变量聚合 | ✅ merge | ✅ variable-assigner | 基本对齐 | +| 代码执行 | ✅ code | ✅ code | 基本对齐 | +| 输出/结束 | ✅ output | ✅ end | 基本对齐 | +| **HTTP 请求** | ❌ | ✅ http-request | **缺失**:调用外部 API | +| **问题分类器** | ❌ | ✅ question-classifier | **缺失**:意图路由 | +| **模板转换** | ❌ | ✅ template-transform | **缺失**:Jinja2 格式化 | +| **变量赋值** | ❌ | ✅ variable-assigner | **缺失**:运行时变量操作 | +| **迭代** | ❌ | ✅ iteration | **缺失**:数组逐项处理(与 loop 不同) | +| **问题优化** | ❌ | ✅ question-optimiser | **缺失**:检索前 query 改写 | + +### 4.2 新增节点方案 + +#### P0 — 必须新增(使用频率高) + +**1. HTTP 请求节点(http_request)** + +``` +功能:发送 HTTP 请求(GET/POST/PUT/DELETE/PATCH) +配置参数: + - method: 请求方法 + - url: 请求地址(支持变量模板) + - headers: 请求头(JSON) + - body: 请求体(支持变量模板) + - auth_type: 认证方式(none/api_key/bearer/basic/oauth2) + - auth_config: 认证配置 + - timeout: 超时时间(秒) + - retry_count: 重试次数 +输出: + - status_code: 状态码 + - headers: 响应头 + - body: 响应体(JSON 解析) + - raw: 原始文本 +前端:HttpRequestConfig.vue +``` + +**2. 问题分类器节点(question_classifier)** + +``` +功能:基于 LLM 对用户输入进行意图分类,路由到不同分支 +配置参数: + - model: 使用的 LLM(可选,默认用系统默认) + - categories: 分类列表 [{name, description}] + - instruction: 分类指令(补充说明) +输出: + - category: 分类结果 + - confidence: 置信度 + - 多出口:每个分类一个出口 +前端:QuestionClassifierConfig.vue +引擎:调用 LLM 做分类,根据结果走不同分支(类似 condition 但基于 LLM) +``` + +**3. 变量赋值节点(variable_assigner)** + +``` +功能:在运行时对变量进行赋值/修改操作 +配置参数: + - assignments: [{target_var, source_type, source_value}] + - source_type: "constant" / "upstream_output" / "template" / "expression" +输出: + - 变量值更新到 context 中 +前端:VariableAssignerConfig.vue +``` + +#### P1 — 建议新增(提升体验) + +**4. 模板转换节点(template_transform)** + +``` +功能:使用 Jinja2 模板语法对变量进行格式化/拼接/转换 +配置参数: + - template: Jinja2 模板字符串 + - output_type: 输出类型(string/json/array) +输出: + - rendered: 渲染后的文本 +前端:TemplateTransformConfig.vue +引擎:使用 jinja2 库渲染模板 +``` + +**5. 迭代节点(iteration)** + +``` +功能:对列表/数组逐项处理,内部可嵌套子工作流 +与 loop 的区别: + - loop:固定次数或条件循环,处理同一逻辑 + - iteration:遍历数组,每次迭代处理一个元素,输出结果数组 +配置参数: + - input_array: 输入数组(变量引用) + - output_variable: 每次迭代的输出变量名 + - max_iterations: 最大迭代次数 +输出: + - output: 结果数组 +前端:IterationConfig.vue +``` + +**6. 问题优化节点(question_optimiser)** + +``` +功能:在 RAG 检索前对用户 query 进行改写/扩展,提升检索召回率 +配置参数: + - model: 使用的 LLM + - strategy: 优化策略(rewrite/expand/decompose) + - instruction: 自定义优化指令 +输出: + - optimized_query: 优化后的查询 + - original_query: 原始查询 +前端:QuestionOptimiserConfig.vue +``` + +### 4.3 节点前端配置组件清单 + +| 新增节点 | 配置组件 | 复杂度 | +|---------|---------|--------| +| http_request | HttpRequestConfig.vue | 中 | +| question_classifier | QuestionClassifierConfig.vue | 中 | +| variable_assigner | VariableAssignerConfig.vue | 低 | +| template_transform | TemplateTransformConfig.vue | 低 | +| iteration | IterationConfig.vue | 中 | +| question_optimiser | QuestionOptimiserConfig.vue | 低 | + +### 4.4 引擎层改造 + +```python +# engine.py _create_agent() 新增分支 +elif node_type == "http_request": + return HttpRequestNodeAgent(...) +elif node_type == "question_classifier": + return QuestionClassifierNodeAgent(...) +elif node_type == "variable_assigner": + return VariableAssignerNodeAgent(...) +elif node_type == "template_transform": + return TemplateTransformNodeAgent(...) +elif node_type == "iteration": + return IterationNodeAgent(...) +elif node_type == "question_optimiser": + return QuestionOptimiserNodeAgent(...) +``` + +--- + +## 五、版本管理与模板系统完善 + +### 5.1 当前版本管理现状 + +- ✅ FlowVersion 模型已存在 +- ✅ publish_flow / unpublish_flow / rollback_flow API 已实现 +- ✅ 执行时加载 published_version 的 definition_json +- ❌ 缺少前端版本管理界面(版本列表、对比、回滚按钮) +- ❌ 缺少发布前的完整性校验 +- ❌ 缺少草稿自动保存 + +### 5.2 完善方案 + +**前端版本管理**: +- FlowEditor.vue 工具栏增加"版本历史"按钮 +- 版本历史弹窗:显示版本列表(版本号、发布时间、发布人、变更说明) +- 支持查看历史版本的定义 JSON +- 支持回滚到指定版本 +- 发布时弹出"变更说明"输入框 + +**发布前校验**: +```python +async def _validate_before_publish(definition: dict) -> list[str]: + errors = [] + nodes = definition.get("nodes", []) + edges = definition.get("edges", []) + + # 1. 必须有起始节点 + if not any(n.get("type") == "trigger" for n in nodes): + errors.append("缺少触发/起始节点") + + # 2. Chatflow 必须有 LLM 节点 + if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes): + errors.append("对话型流必须包含至少一个 LLM 节点") + + # 3. 无孤立节点 + connected_ids = set() + for e in edges: + connected_ids.add(e["source"]) + connected_ids.add(e["target"]) + for n in nodes: + if n["id"] not in connected_ids and len(nodes) > 1: + errors.append(f"节点 '{n.get('label', n['id'])}' 未连接") + + return errors +``` + +**草稿自动保存**: +- 前端 FlowEditor.vue 每 30 秒自动保存草稿到 `draft_definition_json` +- 切换流时检测未保存草稿,提示用户 + +### 5.3 模板系统增强 + +**当前**:硬编码 2 个模板(文档处理流、企微通知流) + +**升级后**: +- 模板存储到数据库(新增 `flow_templates` 表) +- 支持管理员创建模板(从已有流"另存为模板") +- 模板分类:客服对话 / 文档处理 / 数据分析 / 通知推送 / 自定义 +- 创建流时显示模板选择页面(可选,也可从空白创建) + +--- + +## 六、实施路线图 + +### Phase 1 — 基础架构升级(P0,2-3 周) + +| 任务 | 涉及文件 | 依赖 | +|------|---------|------| +| 1.1 创建 PostgreSQL 记忆表 + pgvector 扩展 | init-db/03-memory-tables.sql, models/\_\_init\_\_.py | 无 | +| 1.2 改造 MemoryManager:PG 主存储 + Redis 缓存 | memory/manager.py | 1.1 | +| 1.3 Redis → PostgreSQL 数据迁移脚本 | scripts/migrate_memory_redis_to_pg.py | 1.2 | +| 1.4 创建 ModelProvider + ModelInstance 数据模型 | models/\_\_init\_\_.py, init-db/04-model-provider.sql | 无 | +| 1.5 实现供应商/模型 CRUD API | modules/model_provider/router.py | 1.4 | +| 1.6 前端 ModelProviderManager.vue | frontend/src/views/model/ | 1.5 | +| 1.7 FlowDefinition 新增 flow_mode 字段 | models/\_\_init\_\_.py, init-db/ | 无 | +| 1.8 记忆中间件按 flow_mode 区分 | flow_engine/router.py | 1.7 | +| 1.9 新增 HTTP 请求节点 | engine.py, HttpRequestConfig.vue | 无 | +| 1.10 新增问题分类器节点 | engine.py, QuestionClassifierConfig.vue | 无 | +| 1.11 新增变量赋值节点 | engine.py, VariableAssignerConfig.vue | 无 | + +### Phase 2 — 记忆分层 + 模型适配(P1,2-3 周) + +| 任务 | 涉及文件 | 依赖 | +|------|---------|------| +| 2.1 MemoryManager 增加 L1 结构化提取(PG 存储) | memory/manager.py | Phase 1 | +| 2.2 MemoryManager 增加 L1 去重 | memory/manager.py | 2.1 | +| 2.3 LLMNodeAgent 适配 model_instance_id | engine.py | 1.5 | +| 2.4 RAGNodeAgent 适配 embedding_model_id | engine.py | 1.5 | +| 2.5 AgentConfig 关联 model_instance_id(向后兼容) | models/\_\_init\_\_.py, agent_manager/ | 1.5 | +| 2.6 新增模板转换节点 | engine.py, TemplateTransformConfig.vue | 无 | +| 2.7 新增迭代节点 | engine.py, IterationConfig.vue | 无 | +| 2.8 发布前校验 + 版本管理前端 | flow_engine/router.py, FlowEditor.vue | 无 | + +### Phase 3 — 画像 + 检索优化(P2,2-3 周) + +| 任务 | 涉及文件 | 依赖 | +|------|---------|------| +| 3.1 L2 场景块提取(PG 存储) | memory/manager.py | Phase 2 | +| 3.2 L3 用户画像生成(PG 存储) | memory/manager.py | 3.1 | +| 3.3 混合检索(pgvector 向量 + tsvector 全文 + RRF) | memory/manager.py | Embedding 模型 | +| 3.4 新增问题优化节点 | engine.py, QuestionOptimiserConfig.vue | 3.3 | +| 3.5 模板系统数据库化 | flow_templates 表, router.py | 无 | +| 3.6 草稿自动保存 | FlowEditor.vue, router.py | 无 | + +--- + +## 七、向后兼容性保障 + +### 7.1 AgentConfig 向后兼容 + +```python +# AgentConfig 改造:保留 model 字段,新增可选字段 +class AgentConfig(Base): + # 现有字段保留不变 + model = Column(String(50), default="gpt-4o-mini") # 保留!旧数据自动兼容 + + # 新增可选字段(nullable=True,旧记录自动为 NULL) + model_instance_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) + embedding_model_id = Column(UUID, ForeignKey("model_instances.id"), nullable=True) +``` + +**兼容逻辑**: +```python +# engine.py 中 LLMNodeAgent 获取模型配置 +def _resolve_model_config(self, config: dict) -> dict: + model_instance_id = config.get("model_instance_id") + if model_instance_id: + # 新路径:从 ModelInstance 获取完整配置 + return { + "base_url": provider.base_url, + "api_key": provider.api_key, + "model": model.model_name, + "params": model.default_params, + } + else: + # 旧路径:fallback 到 AgentConfig.model + 全局 settings + return { + "base_url": settings.LLM_API_BASE, + "api_key": settings.LLM_API_KEY, + "model": config.get("model", "gpt-4o-mini"), + "params": {}, + } +``` + +### 7.2 记忆存储迁移兼容 + +```python +# MemoryManager 同时支持 Redis 和 PostgreSQL +class MemoryManager: + def __init__(self, db_session_factory, redis: Redis): + self.db_session_factory = db_session_factory + self.redis = redis + + async def inject_memory(self, user_id, flow_id, session_id, context): + # 优先从 Redis 缓存读取(兼容旧数据) + cached = await self._get_cached_messages(user_id, flow_id, session_id) + if cached: + recent_messages = cached + else: + # 缓存未命中,查 PostgreSQL + recent_messages = await self._query_recent_messages(user_id, flow_id, session_id) + if recent_messages: + # 回填 Redis 缓存 + await self._cache_messages(user_id, flow_id, session_id, recent_messages) + + # ...后续逻辑 +``` + +### 7.3 数据库变更管理规范 + +所有涉及表结构改动的变更,必须: +1. 编写对应的 SQL 迁移文件到 `init-db/` 目录 +2. SQL 文件使用 `IF NOT EXISTS` / `IF NOT EXISTS` 保证幂等 +3. 新增列必须设置 `DEFAULT` 值或 `NULLABLE`,确保旧数据兼容 +4. 迁移文件按序号命名:`01-init.sql` → `02-add-published-cols.sql` → `03-memory-tables.sql` → `04-model-provider.sql` + +--- + +## 八、风险与注意事项 + +1. **pgvector 扩展安装**:需确认 Docker 镜像中的 PostgreSQL 是否包含 pgvector 扩展。如未包含,需在 Dockerfile 中添加 `RUN apt-get install postgresql-15-pgvector` 或使用 `pgvector/pgvector:pg15` 镜像 +2. **Redis → PG 数据迁移**:迁移期间需短暂停写,建议在低峰期执行。迁移脚本需处理 Redis 中 JSON 序列化格式与 PG 表结构的映射 +3. **LLM 调用成本**:L1 提取和去重都需要额外 LLM 调用,建议使用低成本模型(如 gpt-4o-mini)并控制触发频率 +4. **模型供应商 API Key 安全**:必须加密存储,不能明文写入数据库。建议使用 `cryptography.fernet` 对称加密 +5. **向后兼容**:AgentConfig.model 字段保留,model_instance_id 为可选(nullable=True),旧数据自动兼容,无需迁移 +6. **节点类型注册**:新增节点需同时更新前端 nodeTypes 数组和后端 _create_agent() 分支,缺一不可 +7. **流类型不可变**:flow_mode 在创建时确定,后续不可更改(Chatflow 和 Workflow 的引擎逻辑差异大) +8. **PostgreSQL 连接池**:记忆表查询复用现有连接池(pool_size=20),需监控连接数是否够用,必要时调整 +9. **Redis 缓存一致性**:PG 写入成功但 Redis 缓存更新失败时,下次读取会缓存未命中走 PG 查询,自动修复,无需额外处理 diff --git a/backend/main.py b/backend/main.py index 723d341..1edb916 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,22 +19,23 @@ from modules.rag.router import router as rag_router from modules.chat.router import router as chat_router from modules.custom_tool.router import router as custom_tool_router from modules.memory.router import router as memory_router -from modules.memory.manager import init_memory_manager +from modules.memory.manager import init_memory_manager, get_memory_manager +from modules.model_provider.router import router as model_provider_router from websocket_manager import ws_manager from middleware.rbac_middleware import rbac_middleware from middleware.rate_limiter import rate_limit_middleware from middleware.cache_manager import cache_manager +from database import AsyncSessionLocal @asynccontextmanager async def lifespan(app: AgentApp): await init_db() await cache_manager.connect() - await init_memory_manager() + await init_memory_manager(AsyncSessionLocal) yield await cache_manager.disconnect() try: - from modules.memory.manager import get_memory_manager mm = get_memory_manager() await mm.redis.close() except Exception: @@ -70,4 +71,5 @@ app.include_router(system_router) app.include_router(rag_router) app.include_router(chat_router) app.include_router(custom_tool_router) -app.include_router(memory_router) \ No newline at end of file +app.include_router(memory_router) +app.include_router(model_provider_router) \ No newline at end of file diff --git a/backend/models/__init__.py b/backend/models/__init__.py index dffc3d7..6ff3fea 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -141,6 +141,7 @@ class FlowDefinition(Base): published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) draft_definition_json = Column(JSON, nullable=True, default=None) creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + flow_mode = Column(String(20), default="chatflow") published_to_wecom = Column(Boolean, default=False) published_to_web = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.utcnow) @@ -176,6 +177,23 @@ class FlowApiKey(Base): created_at = Column(DateTime, default=datetime.utcnow) +class FlowTemplate(Base): + __tablename__ = "flow_templates" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(200), nullable=False) + description = Column(Text, default="") + category = Column(String(50), default="") + definition_json = Column(JSON, nullable=False, default=dict) + icon = Column(String(50), default="") + sort_order = Column(Integer, default=0) + is_builtin = Column(Boolean, default=False) + usage_count = Column(Integer, default=0) + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + class CustomTool(Base): __tablename__ = "custom_tools" @@ -261,6 +279,8 @@ class AgentConfig(Base): description = Column(String(500)) system_prompt = Column(Text, default="") model = Column(String(50), default="gpt-4o-mini") + model_instance_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True) + embedding_model_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True) temperature = Column(Float, default=0.7) tools = Column(JSON, default=list) status = Column(String(20), default="active") @@ -279,4 +299,97 @@ class AuditLog(Base): resource_id = Column(String(100)) detail = Column(JSON, default=dict) ip_address = Column(String(50)) + created_at = Column(DateTime, default=datetime.utcnow) + + +class MemoryMessage(Base): + __tablename__ = "memory_messages" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) + session_id = Column(UUID(as_uuid=True), nullable=False) + role = Column(String(20), nullable=False) + content = Column(Text, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + + +class MemoryAtom(Base): + __tablename__ = "memory_atoms" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True) + atom_type = Column(String(20), nullable=False) + content = Column(Text, nullable=False) + priority = Column(Integer, default=50) + source_session_id = Column(UUID(as_uuid=True), nullable=True) + metadata = Column(JSON, default=dict) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + +class MemoryScene(Base): + __tablename__ = "memory_scenes" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True) + scene_name = Column(String(200), nullable=False) + summary = Column(Text, nullable=False) + heat = Column(Integer, default=0) + content = Column(JSON, default=dict) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + +class MemoryPersona(Base): + __tablename__ = "memory_personas" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True) + content = Column(JSON, default=dict, nullable=False) + raw_text = Column(Text, default="") + version = Column(Integer, default=1) + updated_at = Column(DateTime, default=datetime.utcnow) + + +class MemorySession(Base): + __tablename__ = "memory_sessions" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) + session_id = Column(UUID(as_uuid=True), nullable=False) + flow_name = Column(String(200), default="") + message_count = Column(Integer, default=0) + last_active_at = Column(DateTime, default=datetime.utcnow) + created_at = Column(DateTime, default=datetime.utcnow) + + +class ModelProvider(Base): + __tablename__ = "model_providers" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(100), nullable=False) + provider_type = Column(String(50), nullable=False) + base_url = Column(String(500)) + api_key = Column(Text) + extra_config = Column(JSON, default=dict) + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.utcnow) + + +class ModelInstance(Base): + __tablename__ = "model_instances" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + provider_id = Column(UUID(as_uuid=True), ForeignKey("model_providers.id", ondelete="CASCADE")) + model_name = Column(String(100), nullable=False) + model_type = Column(String(30), nullable=False) + display_name = Column(String(200)) + capabilities = Column(JSON, default=dict) + default_params = Column(JSON, default=dict) + is_default = Column(Boolean, default=False) + is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.utcnow) \ No newline at end of file diff --git a/backend/modules/flow_engine/engine.py b/backend/modules/flow_engine/engine.py index 2106f25..bda58d7 100644 --- a/backend/modules/flow_engine/engine.py +++ b/backend/modules/flow_engine/engine.py @@ -3,6 +3,7 @@ import uuid import logging import re import asyncio +import httpx from agentscope.agent import AgentBase from agentscope.message import Msg from agentscope.tool import Toolkit @@ -12,6 +13,35 @@ from config import settings logger = logging.getLogger(__name__) +async def _resolve_model_instance(model_instance_id: str) -> dict | None: + try: + from database import AsyncSessionLocal + from sqlalchemy import text + import uuid as _uuid + uid = _uuid.UUID(model_instance_id) + async with AsyncSessionLocal() as db: + result = await db.execute( + text(""" + SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key + FROM model_instances mi + JOIN model_providers mp ON mi.provider_id = mp.id + WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true + """), + {"id": uid}, + ) + row = result.fetchone() + if row: + return { + "model": row[0], + "base_url": row[2], + "api_key": row[3], + "params": row[1] or {}, + } + except Exception: + pass + return None + + class FlowSessionMemory: def __init__(self, session_id: str = "", user_id: str = ""): self.session_id = session_id @@ -241,6 +271,15 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: elif node_type == "llm": model_config = config.get("model", settings.LLM_MODEL) + model_instance_id = config.get("model_instance_id") + base_url = settings.LLM_API_BASE + api_key = settings.LLM_API_KEY + if model_instance_id: + resolved = await _resolve_model_instance(model_instance_id) + if resolved: + model_config = resolved.get("model", model_config) + base_url = resolved.get("base_url", base_url) + api_key = resolved.get("api_key", api_key) temperature = config.get("temperature", 0.7) system_prompt = config.get("system_prompt", "你是AI助手。") max_tokens = config.get("max_tokens", 2000) @@ -254,6 +293,8 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: max_tokens=max_tokens, stream=stream, stream_callback=stream_cb, + base_url=base_url, + api_key=api_key, ) memory = context.get("_memory") if memory: @@ -297,6 +338,12 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type) elif node_type == "rag": + model_instance_id = config.get("model_instance_id") + if model_instance_id: + resolved = await _resolve_model_instance(model_instance_id) + if resolved: + config = dict(config) + config["_resolved_model"] = resolved return RAGNodeAgent(node_id=node_id, config=config) elif node_type == "output": @@ -314,6 +361,24 @@ async def _create_node_agent(node: dict, context: dict) -> AgentBase: sandbox = config.get("sandbox", True) return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox) + elif node_type == "http_request": + return HttpRequestNodeAgent(node_id=node_id, config=config) + + elif node_type == "question_classifier": + return QuestionClassifierNodeAgent(node_id=node_id, config=config) + + elif node_type == "variable_assigner": + return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context) + + elif node_type == "template_transform": + return TemplateTransformNodeAgent(node_id=node_id, config=config) + + elif node_type == "iteration": + return IterationNodeAgent(node_id=node_id, config=config) + + elif node_type == "question_optimiser": + return QuestionOptimiserNodeAgent(node_id=node_id, config=config) + else: return PassThroughAgent(node_id) @@ -330,8 +395,178 @@ class PassThroughAgent(AgentBase): pass +class HttpRequestNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"HttpRequest_{node_id}" + self.config = config or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + method = self.config.get("method", "GET").upper() + url = self.config.get("url", "") + headers = self.config.get("headers", {}) + body = self.config.get("body", "") + auth_type = self.config.get("auth_type", "none") + auth_config = self.config.get("auth_config", {}) + timeout = self.config.get("timeout", 30) + retry_count = self.config.get("retry_count", 0) + + if isinstance(headers, str): + try: + headers = json.loads(headers) + except (json.JSONDecodeError, ValueError): + headers = {} + + request_headers = dict(headers) + if auth_type == "bearer" and auth_config.get("token"): + request_headers["Authorization"] = f"Bearer {auth_config['token']}" + elif auth_type == "api_key" and auth_config.get("api_key"): + key_name = auth_config.get("key_name", "X-API-Key") + request_headers[key_name] = auth_config["api_key"] + elif auth_type == "basic" and auth_config.get("username"): + import base64 + credentials = f"{auth_config['username']}:{auth_config.get('password', '')}" + request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}" + + last_error = None + for attempt in range(retry_count + 1): + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client: + request_kwargs = {"headers": request_headers} + if method in ("POST", "PUT", "PATCH"): + if isinstance(body, dict): + request_kwargs["json"] = body + elif body and isinstance(body, str): + try: + request_kwargs["json"] = json.loads(body) + except (json.JSONDecodeError, ValueError): + request_kwargs["content"] = body + response = await client.request(method, url, **request_kwargs) + resp_status = response.status_code + resp_text = response.text + try: + resp_body = json.loads(resp_text) + except (json.JSONDecodeError, ValueError): + resp_body = resp_text + + result = json.dumps({ + "status_code": resp_status, + "body": resp_body, + }, ensure_ascii=False) + return Msg(self.name, result, "assistant") + except Exception as e: + last_error = str(e) + if attempt < retry_count: + await asyncio.sleep(1) + + error_result = json.dumps({ + "status_code": 0, + "error": last_error or "unknown error", + "body": None, + }, ensure_ascii=False) + return Msg(self.name, error_result, "assistant") + + async def observe(self, msg) -> None: + pass + + +class QuestionClassifierNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"Classifier_{node_id}" + self.config = config or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + categories = self.config.get("categories", []) + instruction = self.config.get("instruction", "") + model_name = self.config.get("model", settings.LLM_MODEL) + temperature = self.config.get("temperature", 0.3) + + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + + if not categories: + return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant") + + category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories]) + prompt = f"""请对以下用户输入进行意图分类。 + +分类选项: +{category_desc} + +{instruction} + +用户输入:{user_text} + +请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}""" + + try: + import openai + client = openai.AsyncOpenAI( + base_url=settings.LLM_API_BASE, + api_key=settings.LLM_API_KEY, + ) + response = await client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=200, + ) + result_text = response.choices[0].message.content.strip() + try: + parsed = json.loads(result_text) + return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant") + except (json.JSONDecodeError, ValueError): + return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant") + except Exception as e: + logger.error(f"QuestionClassifier error: {e}") + return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant") + + async def observe(self, msg) -> None: + pass + + +class VariableAssignerNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None, context: dict = None): + super().__init__() + self.name = f"VarAssign_{node_id}" + self.config = config or {} + self._context = context or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + assignments = self.config.get("assignments", []) + results = {} + for assignment in assignments: + target_var = assignment.get("target_var", "") + source_type = assignment.get("source_type", "constant") + source_value = assignment.get("source_value", "") + + if source_type == "constant": + value = source_value + elif source_type == "upstream_output": + value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + elif source_type == "template": + value = _resolve_template(source_value, self._context, msg) + elif source_type == "expression": + try: + safe_locals = {"msg": msg, "context": self._context} + value = eval(source_value, {"__builtins__": {}}, safe_locals) + value = str(value) + except Exception as e: + value = f"[expression error: {e}]" + else: + value = source_value + + results[target_var] = value + + output = json.dumps(results, ensure_ascii=False) + return Msg(self.name, output, "assistant") + + async def observe(self, msg) -> None: + pass + + class LLMNodeAgent(AgentBase): - def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None): + def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""): super().__init__() self.name = f"LLM_{node_id}" self.system_prompt = system_prompt @@ -340,6 +575,8 @@ class LLMNodeAgent(AgentBase): self.max_tokens = max_tokens self.stream = stream self.stream_callback = stream_callback + self.base_url = base_url or settings.LLM_API_BASE + self.api_key = api_key or settings.LLM_API_KEY self._memory = None def set_memory(self, memory): @@ -390,29 +627,40 @@ class LLMNodeAgent(AgentBase): return Msg(self.name, res_text, "assistant") async def _blocking_llm_call(self, messages: list[dict]) -> str: - from agentscope_integration.factory import AgentFactory - from agentscope.formatter import OpenAIChatFormatter + import httpx - model = AgentFactory._get_model() - formatter = OpenAIChatFormatter() - scope_msgs = [] - for m in messages: - scope_msgs.append(Msg(m["role"], m["content"], m["role"])) - prompt = formatter.format(scope_msgs) + api_base = self.base_url.rstrip("/") + api_key = self.api_key + model_name = self.model_name or settings.LLM_MODEL + + url = f"{api_base}/chat/completions" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + body = { + "model": model_name, + "messages": messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "stream": False, + } - res = await model(prompt) - if isinstance(res, list): - return res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) - elif hasattr(res, 'get_text_content'): - return res.get_text_content() - return str(res) + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client: + resp = await client.post(url, json=body, headers=headers) + data = resp.json() + return data.get("choices", [{}])[0].get("message", {}).get("content", "") + except Exception as e: + logger.warning(f"LLM 阻塞调用失败: {e}") + return f"[LLM 调用失败: {e}]" async def _stream_llm_call(self, messages: list[dict]) -> str: import httpx import json - api_base = settings.LLM_API_BASE.rstrip("/") - api_key = settings.LLM_API_KEY + api_base = self.base_url.rstrip("/") + api_key = self.api_key model_name = self.model_name or settings.LLM_MODEL url = f"{api_base}/chat/completions" @@ -977,11 +1225,12 @@ class RAGNodeAgent(AgentBase): def _get_model(self): from agentscope.model import OpenAIChatModel + resolved = self.config.get("_resolved_model", {}) return OpenAIChatModel( config_name=f"rag_{self.name}", - model_name=settings.LLM_MODEL, - api_key=settings.LLM_API_KEY, - api_base=settings.LLM_API_BASE, + model_name=resolved.get("model", settings.LLM_MODEL), + api_key=resolved.get("api_key", settings.LLM_API_KEY), + api_base=resolved.get("base_url", settings.LLM_API_BASE), ) def _get_formatter(self): @@ -1050,6 +1299,164 @@ def _resolve_template(template: str, context: dict, current_msg: Msg) -> str: return result +class TemplateTransformNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"Template_{node_id}" + self.config = config or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + template = self.config.get("template", "") + output_type = self.config.get("output_type", "string") + context = kwargs.get("context", {}) + + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + + try: + rendered = _resolve_template(template, context, msg) + if not rendered: + rendered = template + rendered = rendered.replace("{{input}}", user_text) + except Exception: + rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text + + if output_type == "json": + try: + parsed = json.loads(rendered) + return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant") + except (json.JSONDecodeError, ValueError): + pass + + return Msg(self.name, rendered, "assistant") + + async def observe(self, msg) -> None: + pass + + +class IterationNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"Iteration_{node_id}" + self.config = config or {} + self._results: list[str] = [] + + async def reply(self, msg: Msg, **kwargs) -> Msg: + input_array_source = self.config.get("input_array_source", "") + max_iterations = self.config.get("max_iterations", 20) + context = kwargs.get("context", {}) + + items = [] + if input_array_source: + resolved = _resolve_template(input_array_source, context, msg) + try: + items = json.loads(resolved) + if not isinstance(items, list): + items = [resolved] + except (json.JSONDecodeError, ValueError): + items = [resolved] if resolved else [] + else: + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + try: + items = json.loads(user_text) + if not isinstance(items, list): + items = [user_text] + except (json.JSONDecodeError, ValueError): + items = [item.strip() for item in user_text.split("\n") if item.strip()] + if not items: + items = [user_text] + + items = items[:max_iterations] + results = [] + for i, item in enumerate(items): + results.append({ + "index": i, + "item": item, + }) + + output = json.dumps(results, ensure_ascii=False) + return Msg(self.name, output, "assistant") + + def get_iteration_items(self) -> list: + return self._results + + async def observe(self, msg) -> None: + pass + + +class QuestionOptimiserNodeAgent(AgentBase): + def __init__(self, node_id: str, config: dict = None): + super().__init__() + self.name = f"QOpt_{node_id}" + self.config = config or {} + + async def reply(self, msg: Msg, **kwargs) -> Msg: + optimization_type = self.config.get("optimization_type", "rewrite") + model_name = self.config.get("model", settings.LLM_MODEL) + context = kwargs.get("context", {}) + + user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) + history = context.get("_memory", {}).get("recent_messages", []) + persona = context.get("_memory", {}).get("persona", {}) + atoms = context.get("_memory", {}).get("atoms", []) + + persona_text = persona.get("raw_text", "") + atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else "" + history_text = "\n".join( + f"{m['role']}: {m['content'][:200]}" for m in history[-6:] + ) if history else "" + + context_parts = [] + if persona_text: + context_parts.append(f"用户画像: {persona_text}") + if atoms_text: + context_parts.append(f"已知信息: {atoms_text}") + if history_text: + context_parts.append(f"近期对话: {history_text}") + + context_block = "\n".join(context_parts) if context_parts else "" + + if optimization_type == "rewrite": + prompt = f"""{context_block} + +原始问题: {user_text} + +请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。 +只返回优化后的问题,不要其他内容。""" + elif optimization_type == "expand": + prompt = f"""{context_block} + +简短问题: {user_text} + +请将以上问题扩展为更详细的版本,添加必要的背景和细节。 +只返回扩展后的问题,不要其他内容。""" + else: + return Msg(self.name, user_text, "assistant") + + try: + import httpx + api_base = settings.LLM_API_BASE.rstrip("/") + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{api_base}/chat/completions", + json={ + "model": model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 300, + "temperature": 0.3, + }, + headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, + ) + data = resp.json() + result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text) + return Msg(self.name, result.strip(), "assistant") + except Exception as e: + logger.warning(f"问题优化失败: {e}") + return Msg(self.name, user_text, "assistant") + + async def observe(self, msg) -> None: + pass + + class ParallelMergeNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() diff --git a/backend/modules/flow_engine/gateway.py b/backend/modules/flow_engine/gateway.py index 4554e7b..a895407 100644 --- a/backend/modules/flow_engine/gateway.py +++ b/backend/modules/flow_engine/gateway.py @@ -80,12 +80,12 @@ async def chat_messages(request: Request, db: AsyncSession = Depends(get_db)): username = user if response_mode == "streaming": - return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db) + return await _chat_stream(flow_id, definition, input_text, user_id, username, f, db, session_id) - return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db) + return await _chat_blocking(flow_id, definition, input_text, user_id, username, f, db, session_id) -async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db): +async def _chat_blocking(flow_id, definition, input_text, user_id, username, flow, db, session_id=None): engine = FlowEngine(definition) input_msg = Msg(name="user", content=input_text, role="user") context = {"user_id": user_id, "username": username, "_node_results": {}, "session_id": str(uuid.uuid4())} @@ -127,7 +127,7 @@ async def _chat_blocking(flow_id, definition, input_text, user_id, username, flo raise HTTPException(500, f"流执行失败: {str(e)}") -async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db): +async def _chat_stream(flow_id, definition, input_text, user_id, username, flow, db, session_id=None): async def event_generator(): import asyncio engine = FlowEngine(definition) diff --git a/backend/modules/flow_engine/router.py b/backend/modules/flow_engine/router.py index c45956e..1ea6867 100644 --- a/backend/modules/flow_engine/router.py +++ b/backend/modules/flow_engine/router.py @@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db -from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User +from models import FlowDefinition, FlowVersion, FlowApiKey, FlowExecution, User, MemoryMessage, FlowTemplate from schemas import ( FlowDefinitionCreate, FlowDefinitionUpdate, FlowDefinitionOut, FlowVersionOut, FlowApiKeyCreate, FlowApiKeyOut, @@ -34,6 +34,7 @@ def _build_flow_out(f) -> FlowDefinitionOut: published_version_id=getattr(f, 'published_version_id', None), published_to_wecom=getattr(f, 'published_to_wecom', False), published_to_web=getattr(f, 'published_to_web', False), + flow_mode=getattr(f, 'flow_mode', 'chatflow') or 'chatflow', created_at=f.created_at, updated_at=f.updated_at, ) @@ -71,6 +72,7 @@ async def create_flow(req: FlowDefinitionCreate, request: Request, db: AsyncSess definition_json=definition_json, draft_definition_json=definition_json, creator_id=uuid.UUID(user_ctx["id"]), + flow_mode=req.flow_mode, ) db.add(flow) await db.flush() @@ -133,6 +135,43 @@ async def _snapshot_publish(flow: FlowDefinition, db: AsyncSession, user_id: str return new_version +def validate_flow_definition(definition: dict, flow_mode: str = "chatflow") -> list[str]: + errors = [] + nodes = definition.get("nodes", []) + edges = definition.get("edges", []) + + if not nodes: + errors.append("流定义中没有节点") + return errors + + if not any(n.get("type") == "trigger" for n in nodes): + errors.append("缺少触发/起始节点 (trigger)") + + if flow_mode == "chatflow" and not any(n.get("type") == "llm" for n in nodes): + errors.append("对话型流必须包含至少一个 LLM 节点") + + node_ids = {n["id"] for n in nodes} + connected_ids = set() + for e in edges: + connected_ids.add(e.get("source", "")) + connected_ids.add(e.get("target", "")) + for n in nodes: + if n["id"] not in connected_ids and len(nodes) > 1: + errors.append(f"节点 '{n.get('label', n['id'])}' 未连接") + + return errors + + +@router.post("/definitions/{flow_id}/validate") +async def validate_flow(flow_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + f = await db.get(FlowDefinition, flow_id) + if not f: + raise HTTPException(404, "流定义不存在") + flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow' + errors = validate_flow_definition(f.definition_json, flow_mode) + return {"code": 200, "valid": len(errors) == 0, "errors": errors} + + @router.post("/definitions/{flow_id}/publish") async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): f = await db.get(FlowDefinition, flow_id) @@ -141,8 +180,18 @@ async def publish_flow(flow_id: uuid.UUID, request: Request, db: AsyncSession = nodes = f.definition_json.get("nodes", []) if not nodes: raise HTTPException(400, "流定义中没有节点") + flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow' + errors = validate_flow_definition(f.definition_json, flow_mode) + if errors: + raise HTTPException(400, f"流校验失败: {'; '.join(errors)}") user_ctx = request.state.user - await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True) + body = {} + try: + body = await request.json() + except Exception: + pass + changelog = body.get("changelog", "") + await _snapshot_publish(f, db, user_ctx["id"], publish_wecom=True, changelog=changelog) return {"code": 200, "message": "流已上架到企微", "data": {"status": "published", "version": f.version}} @@ -261,12 +310,14 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db: try: from modules.memory.manager import get_memory_manager mm = get_memory_manager() - await mm.inject_memory( - user_id=user_ctx["id"], - flow_id=str(flow_id), - session_id=session_id, - context=context, - ) + flow_mode = getattr(f, 'flow_mode', 'chatflow') or 'chatflow' + if flow_mode == 'chatflow': + await mm.inject_memory( + user_id=user_ctx["id"], + flow_id=str(flow_id), + session_id=session_id, + context=context, + ) except Exception as e: logger.debug(f"记忆注入跳过: {e}") @@ -279,14 +330,15 @@ async def execute_flow(flow_id: uuid.UUID, request: Request, payload: dict, db: try: from modules.memory.manager import get_memory_manager mm = get_memory_manager() - asyncio.create_task(mm.record_exchange( - user_id=user_ctx["id"], - flow_id=str(flow_id), - session_id=session_id, - user_msg=input_text, - assistant_msg=output_text, - flow_name=f.name, - )) + if flow_mode == 'chatflow': + asyncio.create_task(mm.record_exchange( + user_id=user_ctx["id"], + flow_id=str(flow_id), + session_id=session_id, + user_msg=input_text, + assistant_msg=output_text, + flow_name=f.name, + )) except Exception as e: logger.debug(f"记忆记录跳过: {e}") @@ -383,6 +435,8 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes except Exception: mm_stream = None + flow_mode_stream = getattr(f, 'flow_mode', 'chatflow') or 'chatflow' + async def event_generator(): engine = FlowEngine(definition) context = { @@ -394,7 +448,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes "_stream_callback": None, } - if mm_stream: + if mm_stream and flow_mode_stream == 'chatflow': try: await mm_stream.inject_memory( user_id=user_ctx["id"], @@ -476,7 +530,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes yield f"data: {json.dumps({'event': 'workflow_finished', 'data': {'output': output_text, 'session_id': session_id, 'node_results': {k: str(v)[:200] for k, v in context.get('_node_results', {}).items()}, 'latency_ms': elapsed_ms}}, ensure_ascii=False)}\n\n" - if mm_stream: + if mm_stream and flow_mode_stream == 'chatflow': try: asyncio.create_task(mm_stream.record_exchange( user_id=user_ctx["id"], @@ -501,8 +555,7 @@ async def execute_flow_stream(flow_id: uuid.UUID, request: Request, db: AsyncSes finished_at=datetime.utcnow(), ) db.add(execution) - finally: - yield "data: [DONE]\n\n" + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -615,86 +668,6 @@ async def list_executions( # ============================== 模板 ============================== -FLOW_TEMPLATES = [ - { - "id": "tpl_doc_process", "name": "文档处理流", "description": "自动解析文档内容,提取关键信息并生成摘要", "icon": "Document", - "nodes": [ - {"id": "n1", "type": "trigger", "label": "文档上传", "config": {"event_type": "document_upload"}, "position": {"x": 100, "y": 100}}, - {"id": "n2", "type": "tool", "label": "解析文档", "config": {"tool_name": "parse_document"}, "position": {"x": 400, "y": 100}}, - {"id": "n3", "type": "llm", "label": "生成摘要", "config": {"system_prompt": "请为以下文档内容生成简洁摘要", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 700, "y": 100}}, - {"id": "n4", "type": "output", "label": "输出结果", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}}, - ], - "edges": [ - {"source": "n1", "target": "n2", "sourceHandle": "source"}, - {"source": "n2", "target": "n3", "sourceHandle": "source"}, - {"source": "n3", "target": "n4", "sourceHandle": "source"}, - ], - }, - { - "id": "tpl_wecom_notify", "name": "企微通知流", "description": "接收触发后查询数据并推送企微通知", "icon": "Bell", - "nodes": [ - {"id": "n1", "type": "trigger", "label": "定时触发", "config": {"event_type": "scheduled"}, "position": {"x": 100, "y": 100}}, - {"id": "n2", "type": "tool", "label": "查询任务", "config": {"tool_name": "list_tasks"}, "position": {"x": 400, "y": 100}}, - {"id": "n3", "type": "condition", "label": "有待办任务?", "config": {"condition": "tasks.length > 0"}, "position": {"x": 700, "y": 100}}, - {"id": "n4", "type": "wecom_notify", "label": "推送通知", "config": {"message_template": "您有{{tasks.length}}条待办任务", "target": "@all"}, "position": {"x": 1000, "y": 50}}, - {"id": "n5", "type": "output", "label": "无任务", "config": {"format": "text"}, "position": {"x": 1000, "y": 200}}, - ], - "edges": [ - {"source": "n1", "target": "n2", "sourceHandle": "source"}, - {"source": "n2", "target": "n3", "sourceHandle": "source"}, - {"source": "n3", "target": "n4", "sourceHandle": "true"}, - {"source": "n3", "target": "n5", "sourceHandle": "false"}, - ], - }, - { - "id": "tpl_data_analysis", "name": "数据分析流", "description": "查询员工数据并生成效率分析报告", "icon": "DataAnalysis", - "nodes": [ - {"id": "n1", "type": "trigger", "label": "分析请求", "config": {"event_type": "button_click"}, "position": {"x": 100, "y": 100}}, - {"id": "n2", "type": "tool", "label": "查询下属", "config": {"tool_name": "list_subordinates"}, "position": {"x": 400, "y": 100}}, - {"id": "n3", "type": "tool", "label": "统计数据", "config": {"tool_name": "get_task_statistics"}, "position": {"x": 700, "y": 100}}, - {"id": "n4", "type": "llm", "label": "生成报告", "config": {"system_prompt": "基于以下数据生成团队效率分析报告", "model": "gpt-4o", "temperature": 0.7}, "position": {"x": 1000, "y": 100}}, - {"id": "n5", "type": "output", "label": "报告输出", "config": {"format": "json"}, "position": {"x": 1300, "y": 100}}, - ], - "edges": [ - {"source": "n1", "target": "n2", "sourceHandle": "source"}, - {"source": "n2", "target": "n3", "sourceHandle": "source"}, - {"source": "n3", "target": "n4", "sourceHandle": "source"}, - {"source": "n4", "target": "n5", "sourceHandle": "source"}, - ], - }, - { - "id": "tpl_rag_qa", "name": "知识库问答流", "description": "从知识库检索信息后由LLM回答", "icon": "Search", - "nodes": [ - {"id": "n1", "type": "trigger", "label": "问题触发", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}}, - {"id": "n2", "type": "rag", "label": "知识检索", "config": {"knowledge_base": "default", "top_k": 5}, "position": {"x": 400, "y": 100}}, - {"id": "n3", "type": "llm", "label": "生成回答", "config": {"system_prompt": "基于知识库检索结果回答用户问题", "model": "gpt-4o-mini", "temperature": 0.3}, "position": {"x": 700, "y": 100}}, - {"id": "n4", "type": "output", "label": "输出答案", "config": {"format": "text"}, "position": {"x": 1000, "y": 100}}, - ], - "edges": [ - {"source": "n1", "target": "n2", "sourceHandle": "source"}, - {"source": "n2", "target": "n3", "sourceHandle": "source"}, - {"source": "n3", "target": "n4", "sourceHandle": "source"}, - ], - }, - { - "id": "tpl_task_auto", "name": "任务自动分配流", "description": "根据描述自动创建任务并分派给合适人员", "icon": "Tools", - "nodes": [ - {"id": "n1", "type": "trigger", "label": "任务描述", "config": {"event_type": "text_message"}, "position": {"x": 100, "y": 100}}, - {"id": "n2", "type": "llm", "label": "分析任务", "config": {"system_prompt": "分析以下任务描述,提取标题、优先级、负责人", "model": "gpt-4o-mini", "temperature": 0.5}, "position": {"x": 400, "y": 100}}, - {"id": "n3", "type": "tool", "label": "创建任务", "config": {"tool_name": "create_task"}, "position": {"x": 700, "y": 100}}, - {"id": "n4", "type": "wecom_notify", "label": "通知负责人", "config": {"message_template": "您有新任务: {{task_title}}", "target": "@all"}, "position": {"x": 1000, "y": 100}}, - {"id": "n5", "type": "output", "label": "完成", "config": {"format": "text"}, "position": {"x": 1300, "y": 100}}, - ], - "edges": [ - {"source": "n1", "target": "n2", "sourceHandle": "source"}, - {"source": "n2", "target": "n3", "sourceHandle": "source"}, - {"source": "n3", "target": "n4", "sourceHandle": "source"}, - {"source": "n4", "target": "n5", "sourceHandle": "source"}, - ], - }, -] - - @router.get("/market", response_model=list[FlowDefinitionOut]) async def flow_market(db: AsyncSession = Depends(get_db)): result = await db.execute( @@ -704,24 +677,80 @@ async def flow_market(db: AsyncSession = Depends(get_db)): @router.get("/templates") -async def get_flow_templates(request: Request): - return {"code": 200, "data": FLOW_TEMPLATES} +async def get_flow_templates(db: AsyncSession = Depends(get_db)): + result = await db.execute( + select(FlowTemplate).order_by(FlowTemplate.sort_order.asc(), FlowTemplate.created_at.desc()) + ) + templates = result.scalars().all() + return {"code": 200, "data": [ + { + "id": str(t.id), "name": t.name, "description": t.description, + "category": t.category, "definition_json": t.definition_json, + "icon": t.icon, "is_builtin": t.is_builtin, "usage_count": t.usage_count, + } + for t in templates + ]} + + +@router.post("/templates") +async def create_flow_template(request: Request, db: AsyncSession = Depends(get_db)): + body = await request.json() + user_ctx = request.state.user + ft = FlowTemplate( + name=body.get("name", ""), + description=body.get("description", ""), + category=body.get("category", "general"), + definition_json=body.get("definition_json", {}), + icon=body.get("icon", ""), + sort_order=body.get("sort_order", 0), + created_by=uuid.UUID(user_ctx["id"]), + ) + db.add(ft) + await db.commit() + await db.refresh(ft) + return {"code": 200, "data": {"id": str(ft.id), "name": ft.name}} + + +@router.put("/templates/{template_id}") +async def update_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + ft = await db.get(FlowTemplate, template_id) + if not ft: + raise HTTPException(404, "模板不存在") + body = await request.json() + for field in ("name", "description", "category", "definition_json", "icon", "sort_order"): + if field in body: + setattr(ft, field, body[field]) + await db.commit() + return {"code": 200, "data": {"id": str(ft.id)}} + + +@router.delete("/templates/{template_id}") +async def delete_flow_template(template_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + ft = await db.get(FlowTemplate, template_id) + if not ft: + raise HTTPException(404, "模板不存在") + if ft.is_builtin: + raise HTTPException(400, "内置模板不可删除") + await db.delete(ft) + await db.commit() + return {"code": 200, "message": "模板已删除"} @router.post("/templates/{template_id}/use") -async def use_flow_template(template_id: str, request: Request, db: AsyncSession = Depends(get_db)): - template = next((t for t in FLOW_TEMPLATES if t["id"] == template_id), None) - if not template: +async def use_flow_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + ft = await db.get(FlowTemplate, template_id) + if not ft: raise HTTPException(404, "模板不存在") + ft.usage_count = (ft.usage_count or 0) + 1 user_ctx = request.state.user - definition_json = {"nodes": template["nodes"], "edges": template["edges"], "trigger": {}} flow = FlowDefinition( - name=template["name"] + " (副本)", - description=template["description"], - definition_json=definition_json, - draft_definition_json=definition_json, + name=ft.name + " (副本)", + description=ft.description or "", + definition_json=ft.definition_json, + draft_definition_json=ft.definition_json, creator_id=uuid.UUID(user_ctx["id"]), ) db.add(flow) await db.flush() + await db.commit() return _build_flow_out(flow) \ No newline at end of file diff --git a/backend/modules/memory/__init__.py b/backend/modules/memory/__init__.py index c0d3a31..0872066 100644 --- a/backend/modules/memory/__init__.py +++ b/backend/modules/memory/__init__.py @@ -1,4 +1,4 @@ -from .manager import MemoryManager, get_memory_manager +from .manager import MemoryManager, get_memory_manager, init_memory_manager from .router import router -__all__ = ["MemoryManager", "get_memory_manager", "router"] \ No newline at end of file +__all__ = ["MemoryManager", "get_memory_manager", "init_memory_manager", "router"] \ No newline at end of file diff --git a/backend/modules/memory/manager.py b/backend/modules/memory/manager.py index 7ae0298..e18669f 100644 --- a/backend/modules/memory/manager.py +++ b/backend/modules/memory/manager.py @@ -1,8 +1,13 @@ import json import asyncio +import uuid import logging -from datetime import datetime +from datetime import datetime, timezone +from typing import Callable + from redis.asyncio import Redis +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession from config import settings logger = logging.getLogger(__name__) @@ -17,21 +22,27 @@ def get_memory_manager() -> "MemoryManager": return _memory_manager -async def init_memory_manager(): +async def init_memory_manager(db_factory: Callable[[], AsyncSession]): global _memory_manager redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await redis.ping() - _memory_manager = MemoryManager(redis) + _memory_manager = MemoryManager(db_factory, redis) class MemoryManager: - KEY_PREFIX = "mem" - DEFAULT_TTL = 604800 - SESSION_INDEX_TTL = 2592000 MAX_HISTORY = 40 + REDIS_CACHE_SIZE = 10 + REDIS_CACHE_TTL = 300 + SUMMARY_CACHE_KEY = "mem:summary" + MSG_CACHE_KEY = "mem:cache:msgs" + ATOM_EXTRACT_EVERY = 10 + SCENE_EXTRACT_EVERY = 50 + PERSONA_UPDATE_EVERY = 30 - def __init__(self, redis: Redis): + def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis): + self.db_factory = db_factory self.redis = redis + self._extract_tasks: dict[str, asyncio.Task] = {} async def inject_memory( self, @@ -40,12 +51,27 @@ class MemoryManager: session_id: str, context: dict, ): - messages = await self._get_recent_messages(user_id, flow_id, session_id) - summary = await self._get_summary(user_id, flow_id, session_id) + uid = uuid.UUID(user_id) + fid = uuid.UUID(flow_id) + sid = uuid.UUID(session_id) + + cached = await self._redis_get_recent(uid, sid) + if cached: + recent_messages = cached + else: + recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY) + if recent_messages: + await self._redis_set_recent(uid, sid, recent_messages) + + summary = await self._get_summary(uid, fid, sid) + atoms = await self._get_relevant_atoms(uid, fid) + persona = await self._get_persona(uid) context["_memory_context"] = { - "recent_messages": list(reversed(messages)), + "recent_messages": recent_messages, "summary": summary, + "atoms": atoms, + "persona": persona, "session_id": session_id, } @@ -58,115 +84,238 @@ class MemoryManager: assistant_msg: str, flow_name: str = "", ): - key = self._msg_key(user_id, flow_id, session_id) - ts = datetime.utcnow().isoformat() + uid = uuid.UUID(user_id) + fid = uuid.UUID(flow_id) + sid = uuid.UUID(session_id) + ts = datetime.now(timezone.utc) try: - async with self.redis.pipeline() as pipe: - pipe.lpush(key, - json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False), - json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False), + async with self.db_factory() as db: + user_row = { + "id": uuid.uuid4(), + "user_id": uid, + "flow_id": fid, + "session_id": sid, + "role": "user", + "content": user_msg, + "created_at": ts, + } + asst_row = { + "id": uuid.uuid4(), + "user_id": uid, + "flow_id": fid, + "session_id": sid, + "role": "assistant", + "content": assistant_msg, + "created_at": ts, + } + await db.execute( + text(""" + INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) + VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) + """), + user_row, + ) + await db.execute( + text(""" + INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) + VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) + """), + asst_row, ) - pipe.ltrim(key, 0, self.MAX_HISTORY - 1) - pipe.expire(key, self.DEFAULT_TTL) - pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={ + session_row = { + "id": uuid.uuid4(), + "user_id": uid, + "flow_id": fid, + "session_id": sid, "flow_name": flow_name, "last_active_at": ts, - }) - pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL) + "created_at": ts, + } + await db.execute( + text(""" + INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at) + VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at) + ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET + message_count = memory_sessions.message_count + 2, + last_active_at = EXCLUDED.last_active_at, + flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name) + """), + session_row, + ) - pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id) - pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL) + await db.commit() + except Exception as e: + logger.warning(f"记录记忆失败(PG): {e}") + return - await pipe.execute() + try: + new_msgs = [ + {"role": "user", "content": user_msg, "ts": ts.isoformat()}, + {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()}, + ] + await self._redis_append_recent(uid, sid, new_msgs) except Exception as e: - logger.warning(f"记录记忆失败: {e}") + logger.debug(f"Redis缓存更新失败: {e}") - asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id)) + asyncio.create_task(self._maybe_summarize(uid, fid, sid)) + asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid)) + asyncio.create_task(self._maybe_extract_scenes(uid, fid)) + asyncio.create_task(self._maybe_update_persona(uid, fid)) async def get_conversation_history( self, user_id: str, flow_id: str, session_id: str, limit: int = 20 ) -> list[dict]: - messages = await self._get_recent_messages(user_id, flow_id, session_id, limit) - return list(reversed(messages)) + uid = uuid.UUID(user_id) + sid = uuid.UUID(session_id) + fid = uuid.UUID(flow_id) if flow_id else None + return await self._pg_get_recent(uid, fid, sid, limit) async def delete_session(self, user_id: str, session_id: str): + uid = uuid.UUID(user_id) + sid = uuid.UUID(session_id) + try: - patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*") - async with self.redis.pipeline() as pipe: - if patterns: - pipe.delete(*patterns) - pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id) - await pipe.execute() + async with self.db_factory() as db: + await db.execute( + text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), + {"uid": uid, "sid": sid}, + ) + await db.execute( + text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"), + {"uid": uid, "sid": sid}, + ) + await db.commit() except Exception as e: - logger.warning(f"清除记忆失败: {e}") + logger.warning(f"清除记忆失败(PG): {e}") + + try: + cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" + await self.redis.delete(cache_key) + summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" + await self.redis.delete(summary_key) + except Exception: + pass async def list_user_sessions(self, user_id: str) -> list[dict]: + uid = uuid.UUID(user_id) try: - session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions") + async with self.db_factory() as db: + result = await db.execute( + text(""" + SELECT session_id, flow_id, flow_name, last_active_at + FROM memory_sessions + WHERE user_id = :uid + ORDER BY last_active_at DESC + LIMIT 100 + """), + {"uid": uid}, + ) + rows = result.fetchall() + return [ + { + "session_id": str(row[0]), + "flow_id": str(row[1]), + "flow_name": row[2] or "", + "last_active_at": row[3].isoformat() if row[3] else "", + } + for row in rows + ] except Exception: return [] - sessions = [] - for sid in session_ids: - try: - keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta") - for k in keys: - meta = await self.redis.hgetall(k) - parts = k.split(":") - flow_id = parts[2] if len(parts) > 2 else "" - sessions.append({ - "session_id": sid, - "flow_id": flow_id, - "flow_name": meta.get("flow_name", ""), - "last_active_at": meta.get("last_active_at", ""), - }) - except Exception: - continue + async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]: + try: + async with self.db_factory() as db: + if fid: + result = await db.execute( + text(""" + SELECT role, content, created_at + FROM memory_messages + WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid + ORDER BY created_at DESC + LIMIT :limit + """), + {"uid": uid, "fid": fid, "sid": sid, "limit": limit}, + ) + else: + result = await db.execute( + text(""" + SELECT role, content, created_at + FROM memory_messages + WHERE user_id = :uid AND session_id = :sid + ORDER BY created_at DESC + LIMIT :limit + """), + {"uid": uid, "sid": sid, "limit": limit}, + ) + rows = result.fetchall() + return [ + {"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""} + for row in reversed(rows) + ] + except Exception: + return [] - return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True) + async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None: + try: + cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" + raw = await self.redis.get(cache_key) + if raw: + return json.loads(raw) + except Exception: + pass + return None - async def _get_recent_messages( - self, user_id: str, flow_id: str, session_id: str, limit: int = None - ) -> list[dict]: - limit = limit or self.MAX_HISTORY - try: - key = self._msg_key(user_id, flow_id, session_id) - raw = await self.redis.lrange(key, 0, limit - 1) - result = [] - for m in raw: - try: - result.append(json.loads(m)) - except json.JSONDecodeError: - continue - return result + async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]): + try: + cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" + top = messages[-self.REDIS_CACHE_SIZE:] + await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False)) except Exception: - return [] + pass - async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str: + async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]): try: - key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary" - val = await self.redis.get(key) + cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" + existing = await self._redis_get_recent(uid, sid) or [] + combined = existing + new_msgs + recent = combined[-self.REDIS_CACHE_SIZE:] + await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False)) + except Exception: + pass + + async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str: + try: + summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" + val = await self.redis.get(summary_key) return val or "" except Exception: return "" - async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str): + async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): try: - key = self._msg_key(user_id, flow_id, session_id) - count = await self.redis.llen(key) - if count < 30: - return - - summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary" + summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" existing = await self.redis.get(summary_key) if existing: return - recent = await self._get_recent_messages(user_id, flow_id, session_id, 20) + count = 0 + async with self.db_factory() as db: + result = await db.execute( + text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), + {"uid": uid, "sid": sid}, + ) + row = result.fetchone() + count = row[0] if row else 0 + + if count < 30: + return + + recent = await self._pg_get_recent(uid, fid, sid, 20) dialogue = "\n".join( - f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10]) + f"{m['role']}: {m['content'][:500]}" for m in recent[-10:] ) import httpx @@ -192,10 +341,511 @@ class MemoryManager: except Exception: pass - @staticmethod - def _msg_key(user_id: str, flow_id: str, session_id: str) -> str: - return f"mem:{user_id}:{flow_id}:{session_id}:messages" + async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]: + try: + async with self.db_factory() as db: + result = await db.execute( + text(""" + SELECT id, atom_type, content, priority, metadata + FROM memory_atoms + WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) + ORDER BY priority DESC, updated_at DESC + LIMIT 20 + """), + {"uid": uid, "fid": fid}, + ) + rows = result.fetchall() + return [ + {"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]} + for row in rows + ] + except Exception: + return [] + + async def _get_persona(self, uid: uuid.UUID) -> dict: + try: + async with self.db_factory() as db: + result = await db.execute( + text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"), + {"uid": uid}, + ) + row = result.fetchone() + if row: + return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1} + except Exception: + pass + return {} + + async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): + try: + task_key = f"extract_atoms:{uid}:{fid}" + if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): + return + + count = 0 + async with self.db_factory() as db: + result = await db.execute( + text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"), + {"uid": uid, "fid": fid, "sid": sid}, + ) + row = result.fetchone() + count = row[0] if row else 0 + + last_extract = await db.execute( + text(""" + SELECT COUNT(*) FROM memory_atoms + WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str + """), + {"uid": uid, "fid": fid, "sid_str": str(sid)}, + ) + last_count = last_extract.fetchone()[0] + if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY: + return + + recent = await self._pg_get_recent(uid, fid, sid, 20) + dialogue = "\n".join( + f"{m['role']}: {m['content'][:400]}" for m in recent[-15:] + ) + + task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue)) + self._extract_tasks[task_key] = task + except Exception: + pass + + async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str): + try: + prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。 + +对话内容: +{dialogue} + +请以JSON数组格式返回,每个原子包含以下字段: +- atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求) +- content: 原子的具体内容(一句话描述) +- priority: 0-100的整数,表示重要性(越高越核心) + +格式示例: +[ + {{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}}, + {{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}} +] + +只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。""" + + import httpx + api_base = settings.LLM_API_BASE.rstrip("/") + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{api_base}/chat/completions", + json={ + "model": settings.LLM_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 800, + "temperature": 0.3, + }, + headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, + ) + data = resp.json() + result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") + + try: + atoms = json.loads(result_text) + if not isinstance(atoms, list): + return + except (json.JSONDecodeError, ValueError): + return + + await self._dedup_and_store_atoms(uid, fid, sid, atoms) + except Exception as e: + logger.warning(f"L1原子提取失败: {e}") + + async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]): + try: + async with self.db_factory() as db: + existing = await db.execute( + text(""" + SELECT id, content FROM memory_atoms + WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) + """), + {"uid": uid, "fid": fid}, + ) + existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()] + + for atom in atoms: + content = atom.get("content", "").strip() + if not content: + continue + + is_dup = False + for ex in existing_atoms: + if self._text_similarity(content, ex["content"]) > 0.75: + existing_atoms.remove(ex) + await db.execute( + text(""" + UPDATE memory_atoms + SET priority = GREATEST(priority, :priority), + updated_at = NOW(), + metadata = metadata || :meta + WHERE id = :id + """), + { + "id": ex["id"], + "priority": atom.get("priority", 50), + "meta": json.dumps({"last_source_session_id": str(sid)}), + }, + ) + is_dup = True + break + + if not is_dup: + await db.execute( + text(""" + INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at) + VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW()) + """), + { + "uid": uid, + "fid": fid, + "atype": atom.get("atom_type", "episodic"), + "content": content, + "priority": atom.get("priority", 50), + "sid": sid, + "meta": json.dumps({"source_session_id": str(sid)}), + }, + ) + + await db.commit() + except Exception as e: + logger.warning(f"L1原子存储失败: {e}") @staticmethod - def _meta_key(user_id: str, flow_id: str, session_id: str) -> str: - return f"mem:{user_id}:{flow_id}:{session_id}:meta" \ No newline at end of file + def _text_similarity(a: str, b: str) -> float: + a_words = set(a.lower().split()) + b_words = set(b.lower().split()) + if not a_words or not b_words: + return 0.0 + intersection = a_words & b_words + union = a_words | b_words + return len(intersection) / len(union) + + async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID): + try: + task_key = f"extract_scenes:{uid}:{fid}" + if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): + return + + count = 0 + async with self.db_factory() as db: + result = await db.execute( + text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), + {"uid": uid, "fid": fid}, + ) + row = result.fetchone() + count = row[0] if row else 0 + + scene_count_result = await db.execute( + text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), + {"uid": uid, "fid": fid}, + ) + scene_count = scene_count_result.fetchone()[0] + + if count < self.SCENE_EXTRACT_EVERY: + return + + if scene_count > 0: + atoms_result = await db.execute( + text(""" + SELECT updated_at FROM memory_scenes + WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) + ORDER BY updated_at DESC LIMIT 1 + """), + {"uid": uid, "fid": fid}, + ) + latest_scene = atoms_result.fetchone() + if latest_scene: + from datetime import timezone, timedelta + ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc) + if ago < timedelta(hours=12): + return + + atoms_result = await db.execute( + text(""" + SELECT atom_type, content, priority FROM memory_atoms + WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) + ORDER BY priority DESC LIMIT 30 + """), + {"uid": uid, "fid": fid}, + ) + atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()] + + if len(atoms) < 10: + return + + task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms)) + self._extract_tasks[task_key] = task + except Exception: + pass + + async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]): + try: + atoms_text = "\n".join( + f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms + ) + + prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。 + +记忆原子: +{atoms_text} + +请以JSON数组格式返回,每个场景块包含: +- scene_name: 场景名称(简短标签) +- summary: 场景摘要(2-3句话总结该场景的关键信息) + +格式示例: +[ + {{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}} +] + +只返回JSON数组,不要其他内容。""" + + import httpx + api_base = settings.LLM_API_BASE.rstrip("/") + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{api_base}/chat/completions", + json={ + "model": settings.LLM_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 500, + "temperature": 0.3, + }, + headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, + ) + data = resp.json() + result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") + + try: + scenes = json.loads(result_text) + if not isinstance(scenes, list): + return + except (json.JSONDecodeError, ValueError): + return + + async with self.db_factory() as db: + for scene in scenes: + await db.execute( + text(""" + INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at) + VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW()) + """), + { + "uid": uid, + "fid": fid, + "name": scene.get("scene_name", ""), + "summary": scene.get("summary", ""), + "content": json.dumps(scene, ensure_ascii=False), + }, + ) + await db.commit() + except Exception as e: + logger.warning(f"L2场景提取失败: {e}") + + async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID): + try: + task_key = f"update_persona:{uid}" + if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): + return + + msg_count = 0 + async with self.db_factory() as db: + result = await db.execute( + text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"), + {"uid": uid}, + ) + row = result.fetchone() + msg_count = row[0] if row else 0 + + persona_result = await db.execute( + text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"), + {"uid": uid}, + ) + persona_row = persona_result.fetchone() + if persona_row: + from datetime import timezone, timedelta + ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc) + if ago < timedelta(hours=6): + return + + if msg_count < self.PERSONA_UPDATE_EVERY: + return + + persona_atoms = [] + async with self.db_factory() as db: + atoms_result = await db.execute( + text(""" + SELECT content FROM memory_atoms + WHERE user_id = :uid AND atom_type = 'persona' + ORDER BY priority DESC, updated_at DESC LIMIT 30 + """), + {"uid": uid}, + ) + persona_atoms = [r[0] for r in atoms_result.fetchall()] + + persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息" + + task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1)) + self._extract_tasks[task_key] = task + except Exception: + pass + + async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int): + try: + prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。 + +用户信息: +{persona_text} + +请返回一个JSON对象,包含以下字段: +- name: 用户姓名(未知则"未知") +- role: 角色/职位 +- preferences: 偏好/习惯列表 +- skills: 技能/专长列表 +- personality: 性格描述 +- raw_text: 综合画像的文本描述(2-3句话) + +只返回JSON对象,不要其他内容。""" + + import httpx + api_base = settings.LLM_API_BASE.rstrip("/") + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{api_base}/chat/completions", + json={ + "model": settings.LLM_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 500, + "temperature": 0.3, + }, + headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, + ) + data = resp.json() + result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}") + + try: + persona = json.loads(result_text) + except (json.JSONDecodeError, ValueError): + persona = {"raw_text": result_text} + + async with self.db_factory() as db: + existing = await db.execute( + text("SELECT id FROM memory_personas WHERE user_id = :uid"), + {"uid": uid}, + ) + if existing.fetchone(): + await db.execute( + text(""" + UPDATE memory_personas + SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW() + WHERE user_id = :uid + """), + { + "uid": uid, + "content": json.dumps(persona, ensure_ascii=False), + "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), + "version": version, + }, + ) + else: + await db.execute( + text(""" + INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at) + VALUES (:uid, :content, :raw_text, :version, NOW()) + """), + { + "uid": uid, + "content": json.dumps(persona, ensure_ascii=False), + "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), + "version": version, + }, + ) + await db.commit() + except Exception as e: + logger.warning(f"L3画像更新失败: {e}") + + async def search_memory( + self, + uid: uuid.UUID, + query: str, + fid: uuid.UUID = None, + top_k: int = 10, + embedding: list[float] = None, + ) -> list[dict]: + results = [] + try: + async with self.db_factory() as db: + if embedding and len(embedding) > 0: + emb_str = "[" + ",".join(str(v) for v in embedding) + "]" + vector_results = await db.execute( + text(f""" + SELECT id, atom_type, content, priority, + (1.0 - (embedding <=> :emb::vector)) AS similarity + FROM memory_atoms + WHERE user_id = :uid + {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""} + AND embedding IS NOT NULL + ORDER BY embedding <=> :emb::vector + LIMIT :limit + """), + {"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid + else {"uid": uid, "emb": emb_str, "limit": top_k * 2}, + ) + for row in vector_results.fetchall(): + results.append({ + "id": str(row[0]), "type": row[1], "content": row[2], + "priority": row[3], "vector_score": float(row[4]), + }) + + text_results = await db.execute( + text(f""" + SELECT id, atom_type, content, priority, + ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score + FROM memory_atoms + WHERE user_id = :uid + {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""} + AND content_tsv @@ plainto_tsquery('simple', :query) + ORDER BY text_score DESC + LIMIT :limit + """), + {"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid + else {"uid": uid, "query": query, "limit": top_k * 2}, + ) + for row in text_results.fetchall(): + results.append({ + "id": str(row[0]), "type": row[1], "content": row[2], + "priority": row[3], "text_score": float(row[4]), + }) + + rrf_scores: dict[str, float] = {} + k = 60 + for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)): + if "vector_score" in r: + rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) + for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)): + if "text_score" in r: + rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) + + seen = set() + merged = [] + for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True): + for r in results: + if r["id"] == rid and rid not in seen: + seen.add(rid) + r["rrf_score"] = score + merged.append(r) + break + + merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True) + return merged[:top_k] + except Exception as e: + logger.warning(f"混合检索失败: {e}") + return results[:top_k] if results else [] \ No newline at end of file diff --git a/backend/modules/memory/router.py b/backend/modules/memory/router.py index 3ca3af2..dc5311a 100644 --- a/backend/modules/memory/router.py +++ b/backend/modules/memory/router.py @@ -13,7 +13,7 @@ async def list_sessions(request: Request, user=Depends(get_current_user)): @router.get("/sessions/{session_id}") -async def get_session(session_id: str, flow_id: str = "", request: Request, user=Depends(get_current_user)): +async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)): mm = get_memory_manager() history = await mm.get_conversation_history( user_id=str(user.id), diff --git a/backend/modules/model_provider/__init__.py b/backend/modules/model_provider/__init__.py new file mode 100644 index 0000000..43a7103 --- /dev/null +++ b/backend/modules/model_provider/__init__.py @@ -0,0 +1,3 @@ +from .router import router + +__all__ = ["router"] \ No newline at end of file diff --git a/backend/modules/model_provider/router.py b/backend/modules/model_provider/router.py new file mode 100644 index 0000000..b1c768f --- /dev/null +++ b/backend/modules/model_provider/router.py @@ -0,0 +1,181 @@ +import uuid +import logging +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from database import get_db +from models import ModelProvider, ModelInstance +from dependencies import get_current_user + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"]) + + +@router.get("") +async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + result = await db.execute( + select(ModelProvider).order_by(ModelProvider.created_at.desc()) + ) + return { + "code": 200, + "data": [ + { + "id": str(p.id), + "name": p.name, + "provider_type": p.provider_type, + "base_url": p.base_url, + "is_active": p.is_active, + "created_at": p.created_at.isoformat() if p.created_at else "", + } + for p in result.scalars().all() + ], + } + + +@router.post("") +async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + existing = await db.execute( + select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", "")) + ) + if existing.scalars().first(): + raise HTTPException(400, "相同 base_url 的供应商已存在") + + p = ModelProvider( + name=payload["name"], + provider_type=payload["provider_type"], + base_url=payload.get("base_url", ""), + api_key=payload.get("api_key", ""), + extra_config=payload.get("extra_config", {}), + ) + db.add(p) + await db.commit() + return {"code": 200, "data": {"id": str(p.id)}} + + +@router.put("/{provider_id}") +async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + p = await db.get(ModelProvider, uuid.UUID(provider_id)) + if not p: + raise HTTPException(404, "供应商不存在") + + p.name = payload.get("name", p.name) + p.base_url = payload.get("base_url", p.base_url) + p.api_key = payload.get("api_key", p.api_key) + p.provider_type = payload.get("provider_type", p.provider_type) + p.extra_config = payload.get("extra_config", p.extra_config) + p.is_active = payload.get("is_active", p.is_active) + await db.commit() + return {"code": 200, "data": {"id": str(p.id)}} + + +@router.delete("/{provider_id}") +async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + p = await db.get(ModelProvider, uuid.UUID(provider_id)) + if not p: + raise HTTPException(404, "供应商不存在") + await db.delete(p) + await db.commit() + return {"code": 200, "message": "已删除"} + + +@router.get("/models/all") +async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + result = await db.execute( + select(ModelInstance) + .where(ModelInstance.is_active == True) + .order_by(ModelInstance.model_type, ModelInstance.model_name) + ) + return { + "code": 200, + "data": [ + { + "id": str(m.id), + "model_name": m.model_name, + "model_type": m.model_type, + "display_name": m.display_name, + } + for m in result.scalars().all() + ], + } + + +@router.get("/{provider_id}/models") +async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + result = await db.execute( + select(ModelInstance) + .where(ModelInstance.provider_id == uuid.UUID(provider_id)) + .order_by(ModelInstance.model_type, ModelInstance.model_name) + ) + return { + "code": 200, + "data": [ + { + "id": str(m.id), + "provider_id": str(m.provider_id), + "model_name": m.model_name, + "model_type": m.model_type, + "display_name": m.display_name, + "capabilities": m.capabilities, + "default_params": m.default_params, + "is_default": m.is_default, + "is_active": m.is_active, + } + for m in result.scalars().all() + ], + } + + +@router.post("/{provider_id}/models") +async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + p = await db.get(ModelProvider, uuid.UUID(provider_id)) + if not p: + raise HTTPException(404, "供应商不存在") + + existing = await db.execute( + select(ModelInstance).where( + ModelInstance.provider_id == uuid.UUID(provider_id), + ModelInstance.model_name == payload["model_name"], + ) + ) + if existing.scalars().first(): + raise HTTPException(400, "相同名称的模型已存在") + + m = ModelInstance( + provider_id=uuid.UUID(provider_id), + model_name=payload["model_name"], + model_type=payload["model_type"], + display_name=payload.get("display_name", payload["model_name"]), + capabilities=payload.get("capabilities", {}), + default_params=payload.get("default_params", {}), + is_default=payload.get("is_default", False), + ) + db.add(m) + await db.commit() + return {"code": 200, "data": {"id": str(m.id)}} + + +@router.put("/{provider_id}/models/{model_id}") +async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + m = await db.get(ModelInstance, uuid.UUID(model_id)) + if not m or str(m.provider_id) != provider_id: + raise HTTPException(404, "模型不存在") + + m.model_name = payload.get("model_name", m.model_name) + m.model_type = payload.get("model_type", m.model_type) + m.display_name = payload.get("display_name", m.display_name) + m.capabilities = payload.get("capabilities", m.capabilities) + m.default_params = payload.get("default_params", m.default_params) + m.is_default = payload.get("is_default", m.is_default) + m.is_active = payload.get("is_active", m.is_active) + await db.commit() + return {"code": 200, "data": {"id": str(m.id)}} + + +@router.delete("/{provider_id}/models/{model_id}") +async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + m = await db.get(ModelInstance, uuid.UUID(model_id)) + if not m or str(m.provider_id) != provider_id: + raise HTTPException(404, "模型不存在") + await db.delete(m) + await db.commit() + return {"code": 200, "message": "已删除"} \ No newline at end of file diff --git a/backend/schemas/__init__.py b/backend/schemas/__init__.py index 62bd5d1..8fc3a32 100644 --- a/backend/schemas/__init__.py +++ b/backend/schemas/__init__.py @@ -275,6 +275,7 @@ class FlowDefinitionCreate(BaseModel): trigger: dict = {} nodes: list[FlowNode] edges: list[FlowEdge] + flow_mode: str = "chatflow" class FlowDefinitionUpdate(BaseModel): @@ -295,6 +296,7 @@ class FlowDefinitionOut(BaseModel): published_version_id: uuid.UUID | None = None published_to_wecom: bool published_to_web: bool = False + flow_mode: str = "chatflow" created_at: datetime | None = None updated_at: datetime | None = None diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 7ee4c75..d1570e0 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -3,6 +3,19 @@ services: image: postgres:16-alpine container_name: ent-postgres restart: always + entrypoint: > + sh -c ' + if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then + echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..." + apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1 + wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp + cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1 + cd / && rm -rf /tmp/pgvector-0.7.4 + apk del .pg_build > /dev/null 2>&1 + echo ">>> pgvector 安装完成" + fi + exec docker-entrypoint.sh postgres + ' environment: POSTGRES_USER: enterprise POSTGRES_PASSWORD: enterprise123 diff --git a/docker-compose.yml b/docker-compose.yml index ae8201e..9c2dfba 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,19 @@ services: image: postgres:16-alpine container_name: ent-postgres restart: always + entrypoint: > + sh -c ' + if [ ! -f /usr/local/share/postgresql/extension/vector.control ]; then + echo ">>> pgvector 未安装,正在编译安装(仅首次约 30 秒)..." + apk add --no-cache --virtual .pg_build git build-base postgresql16-dev clang llvm-dev > /dev/null 2>&1 + wget -qO- https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz | tar xz -C /tmp + cd /tmp/pgvector-0.7.4 && make -j2 > /dev/null 2>&1 && make install > /dev/null 2>&1 + cd / && rm -rf /tmp/pgvector-0.7.4 + apk del .pg_build > /dev/null 2>&1 + echo ">>> pgvector 安装完成" + fi + exec docker-entrypoint.sh postgres + ' environment: POSTGRES_USER: enterprise POSTGRES_PASSWORD: enterprise123 diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index f53f3ef..a200a17 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -89,6 +89,8 @@ export const flowApi = { getMarket: () => api.get('/flow/market'), getTemplates: () => api.get('/flow/templates'), useTemplate: (id: string) => api.post(`/flow/templates/${id}/use`), + getVersions: (id: string) => api.get(`/flow/definitions/${id}/versions`), + rollbackVersion: (flowId: string, versionId: string) => api.post(`/flow/definitions/${flowId}/rollback/${versionId}`), } export const chatApi = { @@ -191,4 +193,17 @@ export const ragApi = { getDocuments: () => api.get('/rag/documents'), deleteDocument: (source: string) => api.delete(`/rag/documents/${encodeURIComponent(source)}`), getStats: () => api.get('/rag/stats'), +} + +export const modelProviderApi = { + getProviders: () => api.get('/model-providers'), + getProvider: (id: string) => api.get(`/model-providers/${id}`), + createProvider: (data: any) => api.post('/model-providers', data), + updateProvider: (id: string, data: any) => api.put(`/model-providers/${id}`, data), + deleteProvider: (id: string) => api.delete(`/model-providers/${id}`), + getModels: (providerId: string) => api.get(`/model-providers/${providerId}/models`), + getAllModels: () => api.get('/model-providers/models/all'), + createModel: (providerId: string, data: any) => api.post(`/model-providers/${providerId}/models`, data), + updateModel: (providerId: string, modelId: string, data: any) => api.put(`/model-providers/${providerId}/models/${modelId}`, data), + deleteModel: (providerId: string, modelId: string) => api.delete(`/model-providers/${providerId}/models/${modelId}`), } \ No newline at end of file diff --git a/frontend/src/components/layout/AdminLayout.vue b/frontend/src/components/layout/AdminLayout.vue index bcd5c39..22f3cd5 100644 --- a/frontend/src/components/layout/AdminLayout.vue +++ b/frontend/src/components/layout/AdminLayout.vue @@ -52,6 +52,7 @@ AI能力配置 + 模型供应商管理 知识库管理 智能体管理 企微机器人配置 @@ -133,6 +134,7 @@ const activeMenu = computed(() => { if (path.startsWith('/admin/rag')) return path if (path.startsWith('/admin/agent')) return path if (path.startsWith('/admin/wecom')) return path + if (path.startsWith('/admin/model')) return path if (path.startsWith('/admin/task')) return path if (path.startsWith('/admin/monitor')) return path if (path.startsWith('/admin/audit')) return path diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 0177566..adfea0a 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -167,6 +167,12 @@ const router = createRouter({ component: () => import('@/views/wecom/BotConfig.vue'), meta: { title: '企微机器人配置', perms: ['admin:access'] }, }, + { + path: 'model/providers', + name: 'AdminModelProviders', + component: () => import('@/views/model/ModelProviderManager.vue'), + meta: { title: '模型供应商管理', perms: ['admin:access'] }, + }, { path: 'agent/list', name: 'AdminAgentList', diff --git a/frontend/src/views/flow/FlowEditor.vue b/frontend/src/views/flow/FlowEditor.vue index bb14428..7710afa 100644 --- a/frontend/src/views/flow/FlowEditor.vue +++ b/frontend/src/views/flow/FlowEditor.vue @@ -8,10 +8,15 @@
+ + + + 保存 验证 上架到企微 上架到网页 + 版本历史
@@ -85,15 +90,42 @@ + + + + + + + + + + + + + + +
+ 暂无版本记录 +
+
\ No newline at end of file diff --git a/frontend/src/views/flow/node-configs/IterationConfig.vue b/frontend/src/views/flow/node-configs/IterationConfig.vue new file mode 100644 index 0000000..80f6b57 --- /dev/null +++ b/frontend/src/views/flow/node-configs/IterationConfig.vue @@ -0,0 +1,80 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/flow/node-configs/MergeConfig.vue b/frontend/src/views/flow/node-configs/MergeConfig.vue new file mode 100644 index 0000000..b4b138d --- /dev/null +++ b/frontend/src/views/flow/node-configs/MergeConfig.vue @@ -0,0 +1,24 @@ + + + diff --git a/frontend/src/views/flow/node-configs/QuestionClassifierConfig.vue b/frontend/src/views/flow/node-configs/QuestionClassifierConfig.vue new file mode 100644 index 0000000..3f22ec6 --- /dev/null +++ b/frontend/src/views/flow/node-configs/QuestionClassifierConfig.vue @@ -0,0 +1,120 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/flow/node-configs/QuestionOptimiserConfig.vue b/frontend/src/views/flow/node-configs/QuestionOptimiserConfig.vue new file mode 100644 index 0000000..0a2980b --- /dev/null +++ b/frontend/src/views/flow/node-configs/QuestionOptimiserConfig.vue @@ -0,0 +1,81 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/flow/node-configs/TemplateTransformConfig.vue b/frontend/src/views/flow/node-configs/TemplateTransformConfig.vue new file mode 100644 index 0000000..a83f0d2 --- /dev/null +++ b/frontend/src/views/flow/node-configs/TemplateTransformConfig.vue @@ -0,0 +1,94 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/flow/node-configs/VariableAssignerConfig.vue b/frontend/src/views/flow/node-configs/VariableAssignerConfig.vue new file mode 100644 index 0000000..f3b7145 --- /dev/null +++ b/frontend/src/views/flow/node-configs/VariableAssignerConfig.vue @@ -0,0 +1,123 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/views/model/ModelProviderManager.vue b/frontend/src/views/model/ModelProviderManager.vue new file mode 100644 index 0000000..ec7ee85 --- /dev/null +++ b/frontend/src/views/model/ModelProviderManager.vue @@ -0,0 +1,338 @@ + + + \ No newline at end of file diff --git a/init-db/02-add-published-cols.sql b/init-db/02-add-published-cols.sql new file mode 100644 index 0000000..f649478 --- /dev/null +++ b/init-db/02-add-published-cols.sql @@ -0,0 +1,6 @@ +-- 02-migration: 补充 flow_definitions 缺失列 +-- 幂等执行,IF NOT EXISTS 重复运行不会报错 + +ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_to_wecom BOOLEAN DEFAULT FALSE; + +ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_to_web BOOLEAN DEFAULT FALSE; diff --git a/init-db/03-memory-tables.sql b/init-db/03-memory-tables.sql new file mode 100644 index 0000000..955ff4f --- /dev/null +++ b/init-db/03-memory-tables.sql @@ -0,0 +1,88 @@ +-- 03-memory-tables.sql +-- 记忆管理模块:PostgreSQL 主存储 + Redis 缓存层 +-- 幂等执行,IF NOT EXISTS + +-- pgvector 扩展(容器启动时已通过 docker-compose entrypoint 自动安装) +CREATE EXTENSION IF NOT EXISTS vector; + +-- ============================================================ +-- L0: 原始对话消息 +-- ============================================================ +CREATE TABLE IF NOT EXISTS memory_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + flow_id UUID NOT NULL REFERENCES flow_definitions(id) ON DELETE CASCADE, + session_id UUID NOT NULL, + role VARCHAR(20) NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_memory_messages_session ON memory_messages(user_id, flow_id, session_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_memory_messages_user_flow ON memory_messages(user_id, flow_id, created_at DESC); + +-- ============================================================ +-- L1: 结构化记忆原子 +-- ============================================================ +CREATE TABLE IF NOT EXISTS memory_atoms ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + flow_id UUID REFERENCES flow_definitions(id) ON DELETE SET NULL, + atom_type VARCHAR(20) NOT NULL, + content TEXT NOT NULL, + priority SMALLINT DEFAULT 50, + source_session_id UUID, + metadata JSONB DEFAULT '{}', + embedding vector(1536), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_memory_atoms_user ON memory_atoms(user_id, atom_type, priority DESC); +CREATE INDEX IF NOT EXISTS idx_memory_atoms_embedding ON memory_atoms USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); + +-- ============================================================ +-- L2: 场景块 +-- ============================================================ +CREATE TABLE IF NOT EXISTS memory_scenes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + flow_id UUID REFERENCES flow_definitions(id) ON DELETE SET NULL, + scene_name VARCHAR(200) NOT NULL, + summary TEXT NOT NULL, + heat INTEGER DEFAULT 0, + content JSONB DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_memory_scenes_user ON memory_scenes(user_id, flow_id, heat DESC); + +-- ============================================================ +-- L3: 用户画像 +-- ============================================================ +CREATE TABLE IF NOT EXISTS memory_personas ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL UNIQUE REFERENCES users(id) ON DELETE CASCADE, + content JSONB NOT NULL DEFAULT '{}', + raw_text TEXT DEFAULT '', + version INTEGER DEFAULT 1, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- ============================================================ +-- 会话元数据 +-- ============================================================ +CREATE TABLE IF NOT EXISTS memory_sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + flow_id UUID NOT NULL REFERENCES flow_definitions(id) ON DELETE CASCADE, + session_id UUID NOT NULL, + flow_name VARCHAR(200) DEFAULT '', + message_count INTEGER DEFAULT 0, + last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(user_id, flow_id, session_id) +); + +CREATE INDEX IF NOT EXISTS idx_memory_sessions_user ON memory_sessions(user_id, last_active_at DESC); \ No newline at end of file diff --git a/init-db/04-model-provider.sql b/init-db/04-model-provider.sql new file mode 100644 index 0000000..fa7f7bb --- /dev/null +++ b/init-db/04-model-provider.sql @@ -0,0 +1,30 @@ +-- 04-model-provider.sql +-- OpenAI-API-Compatible 模型供应商管理 +-- 支持 LLM / Embedding / Rerank 三种模型类型,每种独立配置 + +CREATE TABLE IF NOT EXISTS model_providers ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name VARCHAR(100) NOT NULL, + provider_type VARCHAR(50) NOT NULL, + base_url VARCHAR(500), + api_key TEXT, + extra_config JSONB DEFAULT '{}', + is_active BOOLEAN DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS model_instances ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + provider_id UUID NOT NULL REFERENCES model_providers(id) ON DELETE CASCADE, + model_name VARCHAR(100) NOT NULL, + model_type VARCHAR(30) NOT NULL, + display_name VARCHAR(200), + capabilities JSONB DEFAULT '{}', + default_params JSONB DEFAULT '{}', + is_default BOOLEAN DEFAULT FALSE, + is_active BOOLEAN DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_model_instances_type ON model_instances(model_type, is_active); +CREATE INDEX IF NOT EXISTS idx_model_instances_provider ON model_instances(provider_id); \ No newline at end of file diff --git a/init-db/05-flow-mode.sql b/init-db/05-flow-mode.sql new file mode 100644 index 0000000..842eb42 --- /dev/null +++ b/init-db/05-flow-mode.sql @@ -0,0 +1,7 @@ +-- 05-flow-mode.sql +-- FlowDefinition 新增 flow_mode 字段,区分对话型和工作流型 + +ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS flow_mode VARCHAR(20) DEFAULT 'chatflow'; + +-- 已有数据默认设为 chatflow(对话型),保持向后兼容 +UPDATE flow_definitions SET flow_mode = 'chatflow' WHERE flow_mode IS NULL; \ No newline at end of file diff --git a/init-db/06-agent-model-ids.sql b/init-db/06-agent-model-ids.sql new file mode 100644 index 0000000..ddddbf4 --- /dev/null +++ b/init-db/06-agent-model-ids.sql @@ -0,0 +1,7 @@ +-- 04-agent-model-ids.sql +-- AgentConfig 新增 model_instance_id 和 embedding_model_id 可选外键 +-- nullable=True,旧记录自动为 NULL,完全向后兼容 + +ALTER TABLE agent_configs ADD COLUMN IF NOT EXISTS model_instance_id UUID REFERENCES model_instances(id); + +ALTER TABLE agent_configs ADD COLUMN IF NOT EXISTS embedding_model_id UUID REFERENCES model_instances(id); \ No newline at end of file diff --git a/init-db/07-hybrid-search.sql b/init-db/07-hybrid-search.sql new file mode 100644 index 0000000..e6aec66 --- /dev/null +++ b/init-db/07-hybrid-search.sql @@ -0,0 +1,27 @@ +-- 为 memory_atoms 表添加全文搜索索引,支持混合检索 +CREATE EXTENSION IF NOT EXISTS vector; + +ALTER TABLE memory_atoms ADD COLUMN IF NOT EXISTS content_tsv tsvector; + +CREATE INDEX IF NOT EXISTS idx_memory_atoms_content_tsv + ON memory_atoms USING GIN (content_tsv); + +CREATE OR REPLACE FUNCTION update_memory_atoms_tsv() RETURNS trigger AS $$ +BEGIN + NEW.content_tsv := to_tsvector('simple', COALESCE(NEW.content, '')); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_memory_atoms_tsv' + ) THEN + CREATE TRIGGER trg_memory_atoms_tsv + BEFORE INSERT OR UPDATE ON memory_atoms + FOR EACH ROW EXECUTE FUNCTION update_memory_atoms_tsv(); + END IF; +END $$; + +UPDATE memory_atoms SET content_tsv = to_tsvector('simple', COALESCE(content, '')); \ No newline at end of file diff --git a/init-db/08-flow-templates.sql b/init-db/08-flow-templates.sql new file mode 100644 index 0000000..0ee08f0 --- /dev/null +++ b/init-db/08-flow-templates.sql @@ -0,0 +1,18 @@ +-- 流模板系统 +CREATE TABLE IF NOT EXISTS flow_templates ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name VARCHAR(200) NOT NULL, + description TEXT, + category VARCHAR(50), + definition_json JSONB NOT NULL DEFAULT '{}', + icon VARCHAR(50), + sort_order INTEGER DEFAULT 0, + is_builtin BOOLEAN DEFAULT false, + usage_count INTEGER DEFAULT 0, + created_by UUID REFERENCES users(id), + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_flow_templates_category ON flow_templates(category); +CREATE INDEX IF NOT EXISTS idx_flow_templates_sort ON flow_templates(sort_order); \ No newline at end of file diff --git a/scripts/migrate_memory_redis_to_pg.py b/scripts/migrate_memory_redis_to_pg.py new file mode 100644 index 0000000..2d2b341 --- /dev/null +++ b/scripts/migrate_memory_redis_to_pg.py @@ -0,0 +1,300 @@ +"""Redis → PostgreSQL 记忆数据迁移脚本 + +将 Redis 中现有记忆数据(消息、摘要)迁移到 PostgreSQL memory_messages 表。 + +使用方式: + python scripts/migrate_memory_redis_to_pg.py [--dry-run] [--user-id UUID] + +选项: + --dry-run 仅扫描不实际写入 + --user-id 只迁移指定用户的数据 + --batch-size 每批写入条数 (默认100) +""" + +import asyncio +import json +import uuid +import sys +import os +from datetime import datetime + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend")) + +from redis.asyncio import Redis +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from config import settings + + +async def scan_redis_keys(redis: Redis, pattern: str) -> list[str]: + keys = [] + cursor = 0 + while True: + cursor, batch = await redis.scan(cursor, match=pattern, count=200) + keys.extend(batch) + if cursor == 0: + break + return keys + + +def parse_key_info(key: str) -> dict | None: + """从 Redis key 中解析 user_id, flow_id, session_id""" + parts = key.split(":") + info = {} + + uid_idx = None + fid_idx = None + sid_idx = None + + for i, p in enumerate(parts): + try: + val = uuid.UUID(p) + if uid_idx is None: + uid_idx = i + info["user_id"] = val + elif sid_idx is None: + sid_idx = i + info["session_id"] = val + elif fid_idx is None: + fid_idx = i + info["flow_id"] = val + except ValueError: + if p == "messages" and uid_idx is not None and sid_idx is not None and fid_idx is None: + fid_idx = -1 + continue + + if not info.get("user_id"): + return None + + return info + + +async def migrate_old_format_keys(redis: Redis, engine, dry_run: bool): + """迁移旧格式: mem:{uid}:{fid}:{sid}:messages (List) 和 mem:{uid}:{fid}:{sid}:meta (Hash)""" + print(">>> 扫描旧格式 Redis 记忆键 ...") + keys = await scan_redis_keys(redis, "mem:*:messages") + keys += await scan_redis_keys(redis, "mem:*:meta") + print(f" 找到 {len(keys)} 个旧格式键") + + migrated = 0 + session_info = {} + + for key in keys: + key_type = await redis.type(key) + info = parse_key_info(key) + if not info: + print(f" [跳过] 无法解析键: {key}") + continue + + uid = info.get("user_id") + sid = info.get("session_id") + + if key_type == "hash": + meta = await redis.hgetall(key) + session_info[str(sid)] = meta + elif key_type == "list": + messages = await redis.lrange(key, 0, -1) + for raw in messages: + try: + msg = json.loads(raw) + role = msg.get("role", "user") + content = msg.get("content", "") + ts_str = msg.get("ts", msg.get("timestamp", "")) + created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow() + + if not dry_run: + async with AsyncSession(engine) as session: + await session.execute( + text(""" + INSERT INTO memory_messages (user_id, flow_id, session_id, role, content, created_at) + VALUES (:uid, :fid, :sid, :role, :content, :ts) + ON CONFLICT DO NOTHING + """), + { + "uid": uid, + "fid": info.get("flow_id"), + "sid": sid, + "role": role, + "content": content, + "ts": created_at, + }, + ) + await session.commit() + migrated += 1 + except (json.JSONDecodeError, Exception) as e: + print(f" [错误] 解析消息失败: {raw[:80]}... -> {e}") + + return migrated + + +async def migrate_new_format_cache(redis: Redis, engine, dry_run: bool): + """迁移新格式缓存: mem:cache:msgs:{uid}:{sid} (String JSON array)""" + print(">>> 扫描新格式 Redis 消息缓存 ...") + keys = await scan_redis_keys(redis, "mem:cache:msgs:*") + print(f" 找到 {len(keys)} 个缓存键") + + migrated = 0 + for key in keys: + parts = key.split(":") + if len(parts) < 5: + continue + try: + uid = uuid.UUID(parts[3]) + sid = uuid.UUID(parts[4]) + except (ValueError, IndexError): + continue + + raw = await redis.get(key) + if not raw: + continue + try: + messages = json.loads(raw) + except json.JSONDecodeError: + continue + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + ts_str = msg.get("ts", "") + created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow() + + if not dry_run: + async with AsyncSession(engine) as session: + await session.execute( + text(""" + INSERT INTO memory_messages (user_id, session_id, role, content, created_at) + VALUES (:uid, :sid, :role, :content, :ts) + ON CONFLICT DO NOTHING + """), + {"uid": uid, "sid": sid, "role": role, "content": content, "ts": created_at}, + ) + await session.commit() + migrated += 1 + + return migrated + + +async def migrate_summaries(redis: Redis, engine, dry_run: bool): + """迁移 Redis 摘要到 PostgreSQL memory_atoms""" + print(">>> 扫描 Redis 摘要缓存 ...") + keys = await scan_redis_keys(redis, "mem:summary:*") + print(f" 找到 {len(keys)} 个摘要键") + + migrated = 0 + for key in keys: + parts = key.split(":") + if len(parts) < 4: + continue + try: + uid = uuid.UUID(parts[2]) + sid = uuid.UUID(parts[3]) + except (ValueError, IndexError): + continue + + summary = await redis.get(key) + if not summary or len(summary.strip()) < 10: + continue + + if not dry_run: + async with AsyncSession(engine) as session: + result = await session.execute( + text("SELECT id FROM memory_atoms WHERE user_id = :uid AND source_session_id = :sid AND atom_type = 'summary'"), + {"uid": uid, "sid": sid}, + ) + existing = result.fetchone() + if existing: + await session.execute( + text("UPDATE memory_atoms SET content = :content, updated_at = NOW() WHERE id = :id"), + {"content": summary, "id": existing[0]}, + ) + else: + await session.execute( + text(""" + INSERT INTO memory_atoms (user_id, atom_type, content, priority, source_session_id, created_at, updated_at) + VALUES (:uid, 'summary', :content, 60, :sid, NOW(), NOW()) + """), + {"uid": uid, "content": summary, "sid": sid}, + ) + await session.commit() + migrated += 1 + + return migrated + + +async def migrate_session_list(redis: Redis, engine, dry_run: bool): + """迁移 mem:{uid}:sessions Set 中的会话列表""" + print(">>> 扫描会话列表 (mem:*:sessions)...") + keys = await scan_redis_keys(redis, "mem:*:sessions") + print(f" 找到 {len(keys)} 个会话列表键") + + migrated = 0 + for key in keys: + parts = key.split(":") + if len(parts) < 3: + continue + try: + uid = uuid.UUID(parts[1]) + except ValueError: + continue + + sessions = await redis.smembers(key) + for sid_str in sessions: + try: + sid = uuid.UUID(sid_str) + except ValueError: + continue + + if not dry_run: + async with AsyncSession(engine) as session: + await session.execute( + text(""" + INSERT INTO memory_sessions (user_id, session_id, last_active_at, created_at) + VALUES (:uid, :sid, NOW(), NOW()) + ON CONFLICT (user_id, session_id) DO NOTHING + """), + {"uid": uid, "sid": sid}, + ) + await session.commit() + migrated += 1 + + return migrated + + +async def main(): + dry_run = "--dry-run" in sys.argv + batch_size = 100 + + for i, arg in enumerate(sys.argv): + if arg == "--batch-size" and i + 1 < len(sys.argv): + batch_size = int(sys.argv[i + 1]) + + print("=" * 60) + print("Redis → PostgreSQL 记忆数据迁移") + print(f"模式: {'试运行(不写入)' if dry_run else '正式迁移'}") + print(f"数据库: {settings.DATABASE_URL}") + print(f"Redis: {settings.REDIS_URL}") + print("=" * 60) + + engine = create_async_engine(settings.DATABASE_URL, echo=False) + redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) + + try: + await redis.ping() + print("Redis 连接成功") + + total = 0 + total += await migrate_old_format_keys(redis, engine, dry_run) + total += await migrate_new_format_cache(redis, engine, dry_run) + total += await migrate_summaries(redis, engine, dry_run) + total += await migrate_session_list(redis, engine, dry_run) + + print(f"\n迁移完成!共处理 {total} 条记录") + if dry_run: + print("提示: 使用不带 --dry-run 参数运行以实际写入数据") + finally: + await redis.aclose() + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file