import json import asyncio import uuid import logging from datetime import datetime, timezone from typing import Callable from redis.asyncio import Redis from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from config import settings logger = logging.getLogger(__name__) _memory_manager: "MemoryManager | None" = None def get_memory_manager() -> "MemoryManager": global _memory_manager if _memory_manager is None: raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()") return _memory_manager async def init_memory_manager(db_factory: Callable[[], AsyncSession]): global _memory_manager redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await redis.ping() _memory_manager = MemoryManager(db_factory, redis) class MemoryManager: MAX_HISTORY = 40 REDIS_CACHE_SIZE = 10 REDIS_CACHE_TTL = 300 SUMMARY_CACHE_KEY = "mem:summary" MSG_CACHE_KEY = "mem:cache:msgs" ATOM_EXTRACT_EVERY = 10 SCENE_EXTRACT_EVERY = 50 PERSONA_UPDATE_EVERY = 30 def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis): self.db_factory = db_factory self.redis = redis self._extract_tasks: dict[str, asyncio.Task] = {} async def inject_memory( self, user_id: str, flow_id: str, session_id: str, context: dict, ): uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) cached = await self._redis_get_recent(uid, sid) if cached: recent_messages = cached else: recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY) if recent_messages: await self._redis_set_recent(uid, sid, recent_messages) summary = await self._get_summary(uid, fid, sid) atoms = await self._get_relevant_atoms(uid, fid) persona = await self._get_persona(uid) context["_memory_context"] = { "recent_messages": recent_messages, "summary": summary, "atoms": atoms, "persona": persona, "session_id": session_id, } async def record_exchange( self, user_id: str, flow_id: str, session_id: str, user_msg: str, assistant_msg: str, flow_name: str = "", ): uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) ts = datetime.now(timezone.utc) try: async with self.db_factory() as db: user_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "role": "user", "content": user_msg, "created_at": ts, } asst_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "role": "assistant", "content": assistant_msg, "created_at": ts, } await db.execute( text(""" INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) """), user_row, ) await db.execute( text(""" INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at) """), asst_row, ) session_row = { "id": uuid.uuid4(), "user_id": uid, "flow_id": fid, "session_id": sid, "flow_name": flow_name, "last_active_at": ts, "created_at": ts, } await db.execute( text(""" INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at) VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at) ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET message_count = memory_sessions.message_count + 2, last_active_at = EXCLUDED.last_active_at, flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name) """), session_row, ) await db.commit() except Exception as e: logger.warning(f"记录记忆失败(PG): {e}") return try: new_msgs = [ {"role": "user", "content": user_msg, "ts": ts.isoformat()}, {"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()}, ] await self._redis_append_recent(uid, sid, new_msgs) except Exception as e: logger.debug(f"Redis缓存更新失败: {e}") asyncio.create_task(self._maybe_summarize(uid, fid, sid)) asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid)) asyncio.create_task(self._maybe_extract_scenes(uid, fid)) asyncio.create_task(self._maybe_update_persona(uid, fid)) async def get_conversation_history( self, user_id: str, flow_id: str, session_id: str, limit: int = 20 ) -> list[dict]: uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) fid = uuid.UUID(flow_id) if flow_id else None return await self._pg_get_recent(uid, fid, sid, limit) async def delete_session(self, user_id: str, session_id: str): uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) try: async with self.db_factory() as db: await db.execute( text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) await db.execute( text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) await db.commit() except Exception as e: logger.warning(f"清除记忆失败(PG): {e}") try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" await self.redis.delete(cache_key) summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" await self.redis.delete(summary_key) except Exception: pass async def list_user_sessions(self, user_id: str) -> list[dict]: uid = uuid.UUID(user_id) try: async with self.db_factory() as db: result = await db.execute( text(""" SELECT session_id, flow_id, flow_name, last_active_at FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 100 """), {"uid": uid}, ) rows = result.fetchall() return [ { "session_id": str(row[0]), "flow_id": str(row[1]), "flow_name": row[2] or "", "last_active_at": row[3].isoformat() if row[3] else "", } for row in rows ] except Exception: return [] async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]: try: async with self.db_factory() as db: if fid: result = await db.execute( text(""" SELECT role, content, created_at FROM memory_messages WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid ORDER BY created_at DESC LIMIT :limit """), {"uid": uid, "fid": fid, "sid": sid, "limit": limit}, ) else: result = await db.execute( text(""" SELECT role, content, created_at FROM memory_messages WHERE user_id = :uid AND session_id = :sid ORDER BY created_at DESC LIMIT :limit """), {"uid": uid, "sid": sid, "limit": limit}, ) rows = result.fetchall() return [ {"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""} for row in reversed(rows) ] except Exception: return [] async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None: try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" raw = await self.redis.get(cache_key) if raw: return json.loads(raw) except Exception: pass return None async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]): try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" top = messages[-self.REDIS_CACHE_SIZE:] await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False)) except Exception: pass async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]): try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" existing = await self._redis_get_recent(uid, sid) or [] combined = existing + new_msgs recent = combined[-self.REDIS_CACHE_SIZE:] await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False)) except Exception: pass async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str: try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" val = await self.redis.get(summary_key) return val or "" except Exception: return "" async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" existing = await self.redis.get(summary_key) if existing: return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"), {"uid": uid, "sid": sid}, ) row = result.fetchone() count = row[0] if row else 0 if count < 30: return recent = await self._pg_get_recent(uid, fid, sid, 20) dialogue = "\n".join( f"{m['role']}: {m['content'][:500]}" for m in recent[-10:] ) import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{ "role": "user", "content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}" }], "max_tokens": 200, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() summary = data.get("choices", [{}])[0].get("message", {}).get("content", "") if summary: await self.redis.setex(summary_key, 2592000, summary) except Exception: pass async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]: try: async with self.db_factory() as db: result = await db.execute( text(""" SELECT id, atom_type, content, priority, metadata FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY priority DESC, updated_at DESC LIMIT 20 """), {"uid": uid, "fid": fid}, ) rows = result.fetchall() return [ {"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]} for row in rows ] except Exception: return [] async def _get_persona(self, uid: uuid.UUID) -> dict: try: async with self.db_factory() as db: result = await db.execute( text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) row = result.fetchone() if row: return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1} except Exception: pass return {} async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): try: task_key = f"extract_atoms:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"), {"uid": uid, "fid": fid, "sid": sid}, ) row = result.fetchone() count = row[0] if row else 0 last_extract = await db.execute( text(""" SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str """), {"uid": uid, "fid": fid, "sid_str": str(sid)}, ) last_count = last_extract.fetchone()[0] if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY: return recent = await self._pg_get_recent(uid, fid, sid, 20) dialogue = "\n".join( f"{m['role']}: {m['content'][:400]}" for m in recent[-15:] ) task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue)) self._extract_tasks[task_key] = task except Exception: pass async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str): try: prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。 对话内容: {dialogue} 请以JSON数组格式返回,每个原子包含以下字段: - atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求) - content: 原子的具体内容(一句话描述) - priority: 0-100的整数,表示重要性(越高越核心) 格式示例: [ {{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}}, {{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}} ] 只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 800, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") try: atoms = json.loads(result_text) if not isinstance(atoms, list): return except (json.JSONDecodeError, ValueError): return await self._dedup_and_store_atoms(uid, fid, sid, atoms) except Exception as e: logger.warning(f"L1原子提取失败: {e}") async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]): try: async with self.db_factory() as db: existing = await db.execute( text(""" SELECT id, content FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) """), {"uid": uid, "fid": fid}, ) existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()] for atom in atoms: content = atom.get("content", "").strip() if not content: continue is_dup = False for ex in existing_atoms: if self._text_similarity(content, ex["content"]) > 0.75: existing_atoms.remove(ex) await db.execute( text(""" UPDATE memory_atoms SET priority = GREATEST(priority, :priority), updated_at = NOW(), metadata = metadata || :meta WHERE id = :id """), { "id": ex["id"], "priority": atom.get("priority", 50), "meta": json.dumps({"last_source_session_id": str(sid)}), }, ) is_dup = True break if not is_dup: await db.execute( text(""" INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at) VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW()) """), { "uid": uid, "fid": fid, "atype": atom.get("atom_type", "episodic"), "content": content, "priority": atom.get("priority", 50), "sid": sid, "meta": json.dumps({"source_session_id": str(sid)}), }, ) await db.commit() except Exception as e: logger.warning(f"L1原子存储失败: {e}") @staticmethod def _text_similarity(a: str, b: str) -> float: a_words = set(a.lower().split()) b_words = set(b.lower().split()) if not a_words or not b_words: return 0.0 intersection = a_words & b_words union = a_words | b_words return len(intersection) / len(union) async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID): try: task_key = f"extract_scenes:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), {"uid": uid, "fid": fid}, ) row = result.fetchone() count = row[0] if row else 0 scene_count_result = await db.execute( text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"), {"uid": uid, "fid": fid}, ) scene_count = scene_count_result.fetchone()[0] if count < self.SCENE_EXTRACT_EVERY: return if scene_count > 0: atoms_result = await db.execute( text(""" SELECT updated_at FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY updated_at DESC LIMIT 1 """), {"uid": uid, "fid": fid}, ) latest_scene = atoms_result.fetchone() if latest_scene: from datetime import timezone, timedelta ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc) if ago < timedelta(hours=12): return atoms_result = await db.execute( text(""" SELECT atom_type, content, priority FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL) ORDER BY priority DESC LIMIT 30 """), {"uid": uid, "fid": fid}, ) atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()] if len(atoms) < 10: return task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms)) self._extract_tasks[task_key] = task except Exception: pass async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]): try: atoms_text = "\n".join( f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms ) prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。 记忆原子: {atoms_text} 请以JSON数组格式返回,每个场景块包含: - scene_name: 场景名称(简短标签) - summary: 场景摘要(2-3句话总结该场景的关键信息) 格式示例: [ {{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}} ] 只返回JSON数组,不要其他内容。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 500, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") try: scenes = json.loads(result_text) if not isinstance(scenes, list): return except (json.JSONDecodeError, ValueError): return async with self.db_factory() as db: for scene in scenes: await db.execute( text(""" INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at) VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW()) """), { "uid": uid, "fid": fid, "name": scene.get("scene_name", ""), "summary": scene.get("summary", ""), "content": json.dumps(scene, ensure_ascii=False), }, ) await db.commit() except Exception as e: logger.warning(f"L2场景提取失败: {e}") async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID): try: task_key = f"update_persona:{uid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): return msg_count = 0 async with self.db_factory() as db: result = await db.execute( text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"), {"uid": uid}, ) row = result.fetchone() msg_count = row[0] if row else 0 persona_result = await db.execute( text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) persona_row = persona_result.fetchone() if persona_row: from datetime import timezone, timedelta ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc) if ago < timedelta(hours=6): return if msg_count < self.PERSONA_UPDATE_EVERY: return persona_atoms = [] async with self.db_factory() as db: atoms_result = await db.execute( text(""" SELECT content FROM memory_atoms WHERE user_id = :uid AND atom_type = 'persona' ORDER BY priority DESC, updated_at DESC LIMIT 30 """), {"uid": uid}, ) persona_atoms = [r[0] for r in atoms_result.fetchall()] persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息" task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1)) self._extract_tasks[task_key] = task except Exception: pass async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int): try: prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。 用户信息: {persona_text} 请返回一个JSON对象,包含以下字段: - name: 用户姓名(未知则"未知") - role: 角色/职位 - preferences: 偏好/习惯列表 - skills: 技能/专长列表 - personality: 性格描述 - raw_text: 综合画像的文本描述(2-3句话) 只返回JSON对象,不要其他内容。""" import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": 500, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}") try: persona = json.loads(result_text) except (json.JSONDecodeError, ValueError): persona = {"raw_text": result_text} async with self.db_factory() as db: existing = await db.execute( text("SELECT id FROM memory_personas WHERE user_id = :uid"), {"uid": uid}, ) if existing.fetchone(): await db.execute( text(""" UPDATE memory_personas SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW() WHERE user_id = :uid """), { "uid": uid, "content": json.dumps(persona, ensure_ascii=False), "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), "version": version, }, ) else: await db.execute( text(""" INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at) VALUES (:uid, :content, :raw_text, :version, NOW()) """), { "uid": uid, "content": json.dumps(persona, ensure_ascii=False), "raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)), "version": version, }, ) await db.commit() except Exception as e: logger.warning(f"L3画像更新失败: {e}") async def search_memory( self, uid: uuid.UUID, query: str, fid: uuid.UUID = None, top_k: int = 10, embedding: list[float] = None, ) -> list[dict]: results = [] try: async with self.db_factory() as db: if embedding and len(embedding) > 0: emb_str = "[" + ",".join(str(v) for v in embedding) + "]" vector_results = await db.execute( text(f""" SELECT id, atom_type, content, priority, (1.0 - (embedding <=> :emb::vector)) AS similarity FROM memory_atoms WHERE user_id = :uid {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""} AND embedding IS NOT NULL ORDER BY embedding <=> :emb::vector LIMIT :limit """), {"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid else {"uid": uid, "emb": emb_str, "limit": top_k * 2}, ) for row in vector_results.fetchall(): results.append({ "id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3], "vector_score": float(row[4]), }) text_results = await db.execute( text(f""" SELECT id, atom_type, content, priority, ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score FROM memory_atoms WHERE user_id = :uid {"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""} AND content_tsv @@ plainto_tsquery('simple', :query) ORDER BY text_score DESC LIMIT :limit """), {"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid else {"uid": uid, "query": query, "limit": top_k * 2}, ) for row in text_results.fetchall(): results.append({ "id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3], "text_score": float(row[4]), }) rrf_scores: dict[str, float] = {} k = 60 for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)): if "vector_score" in r: rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)): if "text_score" in r: rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1) seen = set() merged = [] for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True): for r in results: if r["id"] == rid and rid not in seen: seen.add(rid) r["rrf_score"] = score merged.append(r) break merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True) return merged[:top_k] except Exception as e: logger.warning(f"混合检索失败: {e}") return results[:top_k] if results else []