"""三级记忆系统管理器。 记忆架构分为三个层次(L1/L2/L3): - L1(原子层):从对话中提取关键信息原子(用户偏好、事件、指令) - L2(场景层):对同类原子进行归纳,形成场景摘要 - L3(画像层):综合所有信息,生成用户画像(Persona) 数据存储: - PG(主存储):持久化记忆消息、原子、场景、画像 - Redis(缓存):近期消息缓存、对话摘要缓存 """ import json import asyncio import uuid import logging from datetime import datetime, timezone from typing import Callable from redis.asyncio import Redis from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from config import settings logger = logging.getLogger(__name__) _memory_manager: "MemoryManager | None" = None def get_memory_manager() -> "MemoryManager": """获取全局 MemoryManager 单例。""" global _memory_manager if _memory_manager is None: raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()") return _memory_manager async def init_memory_manager(db_factory: Callable[[], AsyncSession]): """初始化记忆管理器,创建 Redis 连接并实例化 MemoryManager。""" global _memory_manager redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await redis.ping() _memory_manager = MemoryManager(db_factory, redis) class MemoryManager: """三级记忆管理器,负责记忆的存储、检索、提取与归纳。""" MAX_HISTORY = 40 # 单次注入的最大历史消息数 REDIS_CACHE_SIZE = 10 # Redis 缓存保留的最近消息数 REDIS_CACHE_TTL = 300 # Redis 缓存 TTL(秒) SUMMARY_CACHE_KEY = "mem:summary" # 摘要缓存 Redis key 前缀 MSG_CACHE_KEY = "mem:cache:msgs" # 消息缓存 Redis key 前缀 ATOM_EXTRACT_EVERY = 10 # 每 N 条消息触发一次 L1 原子提取 SCENE_EXTRACT_EVERY = 50 # 每 N 条原子触发一次 L2 场景提取 PERSONA_UPDATE_EVERY = 30 # 每 N 条消息触发一次 L3 画像更新 def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis): """初始化记忆管理器。 Args: db_factory: 异步数据库会话工厂 redis: Redis 异步客户端实例 """ self.db_factory = db_factory self.redis = redis self._extract_tasks: dict[str, asyncio.Task] = {} # 后台提取任务追踪 async def inject_memory( self, user_id: str, flow_id: str, session_id: str, context: dict, ): """向对话上下文中注入三层记忆信息。 从 PG/Redis 中获取近期消息、摘要、原子记忆和画像, 合并后注入到 context["_memory_context"] 中供 LLM 使用。 Args: user_id: 用户 ID flow_id: 流程 ID session_id: 会话 ID context: 对话上下文字典(会在原地被修改) """ uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) cached = await self._redis_get_recent(uid, sid) if cached: recent_messages = cached else: recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY) if recent_messages: await self._redis_set_recent(uid, sid, recent_messages) summary = await self._get_summary(uid, fid, sid) atoms = await self._get_relevant_atoms(uid, fid) persona = await self._get_persona(uid) context["_memory_context"] = { "recent_messages": recent_messages, "summary": summary, "atoms": atoms, "persona": persona, "session_id": session_id, } async def record_exchange( self, user_id: str, flow_id: str, session_id: str, user_msg: str, assistant_msg: str, flow_name: str = "", ): """记录一次用户-助手对话交换。 将用户消息和助手消息写入 PG,同时更新 Redis 缓存, 并异步触发 L1/L1/L2/L3 记忆提取任务。 Args: user_id: 用户 ID flow_id: 流程 ID session_id: 会话 ID user_msg: 用户消息内容 assistant_msg: 助手回复内容 flow_name: 流程名称(可选,用于会话记录) """ uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) ts = datetime.now(timezone.utc) try: async with self.db_factory() as db: user_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "role": "user", "content": user_msg, "created_at": ts, } asst_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "role": "assistant", "content": assistant_msg, "created_at": ts, } await db.execute( text(""" INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) """), user_row, ) await db.execute( text(""" INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) """), asst_row, ) session_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "flow_name": flow_name, "last_active_at": ts, "created_at": ts, } await db.execute( text(""" INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at) ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET message_count = memory_sessions.message_count + 2, last_active_at = EXCLUDED.last_active_at, flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name) """), session_row, ) await db.commit() except Exception as e: logger.warning(f"记录记忆失败(PG): {e}") return try: new_msgs = [ {"role": "user", "content": user_msg, "ts": ts.isoformat()}, {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()}, ] await self._redis_append_recent(uid, sid, new_msgs) except Exception as e: logger.debug(f"Redis缓存更新失败: {e}") asyncio.create_task(self._maybe_summarize(uid, fid, sid)) asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid)) asyncio.create_task(self._maybe_extract_scenes(uid, fid)) asyncio.create_task(self._maybe_update_persona(uid, fid)) async def get_conversation_history( self, user_id: str, flow_id: str, session_id: str, limit: int = 20 ) -> list[dict]: """获取指定会话的对话历史。 Args: user_id: 用户 ID flow_id: 流程 ID session_id: 会话 ID limit: 返回的最大消息数 Returns: 消息列表,每项含 role/content/ts 字段 """ uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) fid = uuid.UUID(flow_id) if flow_id else None return await self._pg_get_recent(uid, fid, sid, limit) async def delete_session(self, user_id: str, session_id: str): """删除指定会话的所有记忆数据(PG + Redis)。""" uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) try: async with self.db_factory() as db: await db.execute( text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) await db.execute( text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) await db.commit() except Exception as e: logger.warning(f"清除记忆失败(PG): {e}") try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" await self.redis.delete(cache_key) summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" await self.redis.delete(summary_key) except Exception: pass async def list_user_sessions(self, user_id: str) -> list[dict]: """列出用户的所有记忆会话。 Args: user_id: 用户 ID Returns: 会话列表,每项含 session_id/flow_id/flow_name/last_active_at """ uid = uuid.UUID(user_id) try: async with self.db_factory() as db: result = await db.execute( text(""" SELECT session_id, flow_id, flow_name, last_active_at FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 100 """), {"uid": uid}, ) rows = result.fetchall() return [ { "session_id": str(row[0]), "flow_id": str(row[1]), "flow_name": row[2] or "", "last_active_at": row[3].isoformat() if row[3] else "", } for row in rows ] except Exception: return [] async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]: """从 PG 查询最近的对话消息。""" try: async with self.db_factory() as db: if fid: result = await db.execute( text(""" SELECT role, content, created_at FROM memory_messages WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid ORDER BY created_at DESC LIMIT :limit """), {"uid": uid, "fid": fid, "sid": sid, "limit": limit}, ) else: result = await db.execute( text(""" SELECT role, content, created_at FROM memory_messages WHERE user_id = :uid AND session_id = :sid ORDER BY created_at DESC LIMIT :limit """), {"uid": uid, "sid": sid, "limit": limit}, ) rows = result.fetchall() return [ {"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""} for row in reversed(rows) ] except Exception: return [] async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None: """从 Redis 读取缓存的消息列表。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" raw = await self.redis.get(cache_key) if raw: return json.loads(raw) except Exception: pass return None async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]): """将消息列表写入 Redis 缓存。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" top = messages[-self.REDIS_CACHE_SIZE:] await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False)) except Exception: pass async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]): """追加新消息到 Redis 缓存。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" existing = await self._redis_get_recent(uid, sid) or [] combined = existing + new_msgs recent = combined[-self.REDIS_CACHE_SIZE:] await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False)) except Exception: pass async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str: """从 Redis 读取对话摘要。""" try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" val = await self.redis.get(summary_key) return val or "" except Exception: return "" async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): """L1 条件触发:对话摘要生成。 当消息数 >= 30 且尚无摘要时,调用 LLM 生成摘要并缓存到 Redis。 """ try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" existing = await self.redis.get(summary_key) if existing: return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) row = result.fetchone() count = row[0] if row else 0 if count < 30: return recent = await self._pg_get_recent(uid, fid, sid, 20) dialogue = "\n".join( f"{m['role']}: {m['content'][:500]}" for m in recent[-10:] ) import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{ "role": "user", "content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}" }], "max_tokens": 200, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() summary = data.get("choices", [{}])[0].get("message", {}).get("content", "") if summary: await self.redis.setex(summary_key, 2592000, summary) except Exception: pass async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]: """从 PG 查询与用户/流程相关的高优先级原子记忆。""" try: async with self.db_factory() as db: result = await db.execute( text(""" SELECT id, atom_type, content, priority, metadata FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY priority DESC, updated_at DESC LIMIT 20 """), {"uid": uid, "fid": fid}, ) rows = result.fetchall() return [ {"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]} for row in rows ] except Exception: return [] async def _get_persona(self, uid: uuid.UUID) -> dict: """从 PG 查询用户画像。""" try: async with self.db_factory() as db: result = await db.execute( text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) row = result.fetchone() if row: return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1} except Exception: pass return {} async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): """L1 条件触发:原子记忆提取。 当消息数达到 ATOM_EXTRACT_EVERY 的整数倍时,调用 LLM 从对话中提取信息原子。 """ try: task_key = f"extract_atoms:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"), {"uid": uid, "fid": fid, "sid": sid}, ) row = result.fetchone() count = row[0] if row else 0 last_extract = await db.execute( text(""" SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str """), {"uid": uid, "fid": fid, "sid_str": str(sid)}, ) last_count = last_extract.fetchone()[0] if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY: return recent = await self._pg_get_recent(uid, fid, sid, 20) dialogue = "\n".join( f"{m['role']}: {m['content'][:400]}" for m in recent[-15:] ) task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue)) self._extract_tasks[task_key] = task except Exception: pass async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str): """执行 L1 原子记忆提取:调用 LLM 从对话中提取结构化记忆原子。""" try: prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。 对话内容: {dialogue} 请以JSON数组格式返回,每个原子包含以下字段: - atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求) - content: 原子的具体内容(一句话描述) - priority: 0-100的整数,表示重要性(越高越核心) 格式示例: [ {{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}}, {{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}} ] 只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 800, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") try: atoms = json.loads(result_text) if not isinstance(atoms, list): return except (json.JSONDecodeError, ValueError): return await self._dedup_and_store_atoms(uid, fid, sid, atoms) except Exception as e: logger.warning(f"L1原子提取失败: {e}") async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]): """对提取的原子进行去重并存储到 PG。 使用文本相似度判断是否与已有原子重复: - 相似度 > 75% 时更新优先级和元数据 - 否则插入新原子记录 """ try: async with self.db_factory() as db: existing = await db.execute( text(""" SELECT id, content FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) """), {"uid": uid, "fid": fid}, ) existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()] for atom in atoms: content = atom.get("content", "").strip() if not content: continue is_dup = False for ex in existing_atoms: if self._text_similarity(content, ex["content"]) > 0.75: existing_atoms.remove(ex) await db.execute( text(""" UPDATE memory_atoms SET priority = GREATER(priority, :priority), updated_at = NOW(), metadata = metadata || :meta WHERE id = :id """), { "id": ex["id"], "priority": atom.get("priority", 50), "meta": json.dumps({"last_source_session_id": str(sid)}), }, ) is_dup = True break if not is_dup: await db.execute( text(""" INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at) VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW()) """), { "uid": uid, "fid": fid, "atype": atom.get("atom_type", "episodic"), "content": content, "priority": atom.get("priority", 50), "sid": sid, "meta": json.dumps({"source_session_id": str(sid)}), }, ) await db.commit() except Exception as e: logger.warning(f"L1原子存储失败: {e}") @staticmethod def _text_similarity(a: str, b: str) -> float: """计算两段文本的 Jaccard 相似度(基于单词集合)。""" a_words = set(a.lower().split()) b_words = set(b.lower().split()) if not a_words or not b_words: return 0.0 intersection = a_words & b_words union = a_words | b_words return len(intersection) / len(union) async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID): """L2 条件触发:场景提取。 当原子数达到 SCENE_EXTRACT_EVERY 且距上次提取超过 12 小时时, 调用 LLM 对已有原子进行场景归纳。 """ try: task_key = f"extract_scenes:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), {"uid": uid, "fid": fid}, ) row = result.fetchone() count = row[0] if row else 0 scene_count_result = await db.execute( text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), {"uid": uid, "fid": fid}, ) scene_count = scene_count_result.fetchone()[0] if count < self.SCENE_EXTRACT_EVERY: return if scene_count > 0: atoms_result = await db.execute( text(""" SELECT updated_at FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY updated_at DESC LIMIT 1 """), {"uid": uid, "fid": fid}, ) latest_scene = atoms_result.fetchone() if latest_scene: from datetime import timezone, timedelta ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc) if ago < timedelta(hours=12): return atoms_result = await db.execute( text(""" SELECT atom_type, content, priority FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY priority DESC LIMIT 30 """), {"uid": uid, "fid": fid}, ) atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()] if len(atoms) < 10: return task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms)) self._extract_tasks[task_key] = task except Exception: pass async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]): """执行 L2 场景提取:调用 LLM 将原子记忆归纳为场景块。""" try: atoms_text = "\n".join( f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms ) prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。 记忆原子: {atoms_text} 请以JSON数组格式返回,每个场景块包含: - scene_name: 场景名称(简短标签) - summary: 场景摘要(2-3句话总结该场景的关键信息) 格式示例: [ {{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}} ] 只返回JSON数组,不要其他内容。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 500, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") try: scenes = json.loads(result_text) if not isinstance(scenes, list): return except (json.JSONDecodeError, ValueError): return async with self.db_factory() as db: for scene in scenes: await db.execute( text(""" INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at) VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW()) """), { "uid": uid, "fid": fid, "name": scene.get("scene_name", ""), "summary": scene.get("summary", ""), "content": json.dumps(scene, ensure_ascii=False), }, ) await db.commit() except Exception as e: logger.warning(f"L2场景提取失败: {e}") async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID): """L3 条件触发:用户画像更新。 当消息数达到 PERSONA_UPDATE_EVERY 且距上次更新超过 6 小时时, 基于已有 persona 类型原子重新生成画像。 """ try: task_key = f"update_persona:{uid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return msg_count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"), {"uid": uid}, ) row = result.fetchone() msg_count = row[0] if row else 0 persona_result = await db.execute( text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) persona_row = persona_result.fetchone() if persona_row: from datetime import timezone, timedelta ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc) if ago < timedelta(hours=6): return if msg_count < self.PERSONA_UPDATE_EVERY: return persona_atoms = [] async with self.db_factory() as db: atoms_result = await db.execute( text(""" SELECT content FROM memory_atoms WHERE user_id = :uid AND atom_type = 'persona' ORDER BY priority DESC, updated_at DESC LIMIT 30 """), {"uid": uid}, ) persona_atoms = [r[0] for r in atoms_result.fetchall()] persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息" task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1)) self._extract_tasks[task_key] = task except Exception: pass async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int): """执行 L3 画像更新:调用 LLM 生成结构化用户画像并持久化。""" try: prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。 用户信息: {persona_text} 请返回一个JSON对象,包含以下字段: - name: 用户姓名(未知则"未知") - role: 角色/职位 - preferences: 偏好/习惯列表 - skills: 技能/专长列表 - personality: 性格描述 - raw_text: 综合画像的文本描述(2-3句话) 只返回JSON对象,不要其他内容。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 500, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}") try: persona = json.loads(result_text) except (json.JSONDecodeError, ValueError): persona = {"raw_text": result_text} async with self.db_factory() as db: existing = await db.execute( text("SELECT id FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) if existing.fetchone(): await db.execute( text(""" UPDATE memory_personas SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW() WHERE user_id = :uid """), { "uid": uid, "content": json.dumps(persona, ensure_ascii=False), "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), "version": version, }, ) else: await db.execute( text(""" INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at) VALUES (:uid, :content, :raw_text, :version, NOW()) """), { "uid": uid, "content": json.dumps(persona, ensure_ascii=False), "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), "version": version, }, ) await db.commit() except Exception as e: logger.warning(f"L3画像更新失败: {e}") async def search_memory( self, uid: uuid.UUID, query: str, fid: uuid.UUID = None, top_k: int = 10, embedding: list[float] = None, ) -> list[dict]: """混合检索记忆:向量相似度 + 全文检索,使用 RRF 算法融合排序。 Args: uid: 用户 ID query: 搜索查询文本 fid: 流程 ID(可选,过滤范围) top_k: 返回结果数 embedding: 查询向量(可选,启用向量检索) Returns: 按 RRF 分数降序排列的记忆原子列表 """ results = [] try: async with self.db_factory() as db: flow_filter = "AND (flow_id = :fid OR flow_id IS NULL)" if fid else "" if embedding and len(embedding) > 0: emb_str = "[" + ",".join(str(v) for v in embedding) + "]" vector_results = await db.execute( text(f""" SELECT id, atom_type, content, priority, (1.0 - (embedding <=> :emb::vector)) AS similarity FROM memory_atoms WHERE user_id = :uid {flow_filter} AND embedding IS NOT NULL ORDER BY embedding <=> :emb::vector LIMIT :limit """), {"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid else {"uid": uid, "emb": emb_str, "limit": top_k * 2}, ) for row in vector_results.fetchall(): results.append({ "id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3], "vector_score": float(row[4]), }) text_results = await db.execute( text(f""" SELECT id, atom_type, content, priority, ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score FROM memory_atoms WHERE user_id = :uid {flow_filter} AND content_tsv @@ plainto_tsquery('simple', :query) ORDER BY text_score DESC LIMIT :limit """), {"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid else {"uid": uid, "query": query, "limit": top_k * 2}, ) for row in text_results.fetchall(): results.append({ "id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3], "text_score": float(row[4]), }) rrf_scores: dict[str, float] = {} k = 60 for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)): if "vector_score" in r: rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)): if "text_score" in r: rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) seen = set() merged = [] for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True): for r in results: if r["id"] == rid and rid not in seen: seen.add(rid) r["rrf_score"] = score merged.append(r) break merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True) return merged[:top_k] except Exception as e: logger.warning(f"混合检索失败: {e}") return results[:top_k] if results else []