You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

851 lines
35 KiB

import json
import asyncio
import uuid
import logging
from datetime import datetime, timezone
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
logger = logging.getLogger(__name__)
_memory_manager: "MemoryManager | None" = None
def get_memory_manager() -> "MemoryManager":
global _memory_manager
if _memory_manager is None:
raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()")
return _memory_manager
async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
global _memory_manager
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await redis.ping()
_memory_manager = MemoryManager(db_factory, redis)
class MemoryManager:
MAX_HISTORY = 40
REDIS_CACHE_SIZE = 10
REDIS_CACHE_TTL = 300
SUMMARY_CACHE_KEY = "mem:summary"
MSG_CACHE_KEY = "mem:cache:msgs"
ATOM_EXTRACT_EVERY = 10
SCENE_EXTRACT_EVERY = 50
PERSONA_UPDATE_EVERY = 30
def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis):
self.db_factory = db_factory
self.redis = redis
self._extract_tasks: dict[str, asyncio.Task] = {}
async def inject_memory(
self,
user_id: str,
flow_id: str,
session_id: str,
context: dict,
):
uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id)
cached = await self._redis_get_recent(uid, sid)
if cached:
recent_messages = cached
else:
recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY)
if recent_messages:
await self._redis_set_recent(uid, sid, recent_messages)
summary = await self._get_summary(uid, fid, sid)
atoms = await self._get_relevant_atoms(uid, fid)
persona = await self._get_persona(uid)
context["_memory_context"] = {
"recent_messages": recent_messages,
"summary": summary,
"atoms": atoms,
"persona": persona,
"session_id": session_id,
}
async def record_exchange(
self,
user_id: str,
flow_id: str,
session_id: str,
user_msg: str,
assistant_msg: str,
flow_name: str = "",
):
uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id)
ts = datetime.now(timezone.utc)
try:
async with self.db_factory() as db:
user_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"role": "user",
"content": user_msg,
"created_at": ts,
}
asst_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"role": "assistant",
"content": assistant_msg,
"created_at": ts,
}
await db.execute(
text("""
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
"""),
user_row,
)
await db.execute(
text("""
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
"""),
asst_row,
)
session_row = {
"id": uuid.uuid4(),
"user_id": uid,
"flow_id": fid,
"session_id": sid,
"flow_name": flow_name,
"last_active_at": ts,
"created_at": ts,
}
await db.execute(
text("""
INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at)
VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at)
ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET
message_count = memory_sessions.message_count + 2,
last_active_at = EXCLUDED.last_active_at,
flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name)
"""),
session_row,
)
await db.commit()
except Exception as e:
logger.warning(f"记录记忆失败(PG): {e}")
return
try:
new_msgs = [
{"role": "user", "content": user_msg, "ts": ts.isoformat()},
{"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
]
await self._redis_append_recent(uid, sid, new_msgs)
except Exception as e:
logger.debug(f"Redis缓存更新失败: {e}")
asyncio.create_task(self._maybe_summarize(uid, fid, sid))
asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid))
asyncio.create_task(self._maybe_extract_scenes(uid, fid))
asyncio.create_task(self._maybe_update_persona(uid, fid))
async def get_conversation_history(
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
) -> list[dict]:
uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id)
fid = uuid.UUID(flow_id) if flow_id else None
return await self._pg_get_recent(uid, fid, sid, limit)
async def delete_session(self, user_id: str, session_id: str):
uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id)
try:
async with self.db_factory() as db:
await db.execute(
text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
await db.execute(
text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
await db.commit()
except Exception as e:
logger.warning(f"清除记忆失败(PG): {e}")
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
await self.redis.delete(cache_key)
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
await self.redis.delete(summary_key)
except Exception:
pass
async def list_user_sessions(self, user_id: str) -> list[dict]:
uid = uuid.UUID(user_id)
try:
async with self.db_factory() as db:
result = await db.execute(
text("""
SELECT session_id, flow_id, flow_name, last_active_at
FROM memory_sessions
WHERE user_id = :uid
ORDER BY last_active_at DESC
LIMIT 100
"""),
{"uid": uid},
)
rows = result.fetchall()
return [
{
"session_id": str(row[0]),
"flow_id": str(row[1]),
"flow_name": row[2] or "",
"last_active_at": row[3].isoformat() if row[3] else "",
}
for row in rows
]
except Exception:
return []
async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]:
try:
async with self.db_factory() as db:
if fid:
result = await db.execute(
text("""
SELECT role, content, created_at
FROM memory_messages
WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid
ORDER BY created_at DESC
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "sid": sid, "limit": limit},
)
else:
result = await db.execute(
text("""
SELECT role, content, created_at
FROM memory_messages
WHERE user_id = :uid AND session_id = :sid
ORDER BY created_at DESC
LIMIT :limit
"""),
{"uid": uid, "sid": sid, "limit": limit},
)
rows = result.fetchall()
return [
{"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""}
for row in reversed(rows)
]
except Exception:
return []
async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None:
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
raw = await self.redis.get(cache_key)
if raw:
return json.loads(raw)
except Exception:
pass
return None
async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]):
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
top = messages[-self.REDIS_CACHE_SIZE:]
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False))
except Exception:
pass
async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]):
try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
existing = await self._redis_get_recent(uid, sid) or []
combined = existing + new_msgs
recent = combined[-self.REDIS_CACHE_SIZE:]
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False))
except Exception:
pass
async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str:
try:
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
val = await self.redis.get(summary_key)
return val or ""
except Exception:
return ""
async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
try:
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
existing = await self.redis.get(summary_key)
if existing:
return
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
{"uid": uid, "sid": sid},
)
row = result.fetchone()
count = row[0] if row else 0
if count < 30:
return
recent = await self._pg_get_recent(uid, fid, sid, 20)
dialogue = "\n".join(
f"{m['role']}: {m['content'][:500]}" for m in recent[-10:]
)
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{
"role": "user",
"content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}"
}],
"max_tokens": 200,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
summary = data.get("choices", [{}])[0].get("message", {}).get("content", "")
if summary:
await self.redis.setex(summary_key, 2592000, summary)
except Exception:
pass
async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]:
try:
async with self.db_factory() as db:
result = await db.execute(
text("""
SELECT id, atom_type, content, priority, metadata
FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY priority DESC, updated_at DESC
LIMIT 20
"""),
{"uid": uid, "fid": fid},
)
rows = result.fetchall()
return [
{"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]}
for row in rows
]
except Exception:
return []
async def _get_persona(self, uid: uuid.UUID) -> dict:
try:
async with self.db_factory() as db:
result = await db.execute(
text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
row = result.fetchone()
if row:
return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1}
except Exception:
pass
return {}
async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
try:
task_key = f"extract_atoms:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"),
{"uid": uid, "fid": fid, "sid": sid},
)
row = result.fetchone()
count = row[0] if row else 0
last_extract = await db.execute(
text("""
SELECT COUNT(*) FROM memory_atoms
WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str
"""),
{"uid": uid, "fid": fid, "sid_str": str(sid)},
)
last_count = last_extract.fetchone()[0]
if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY:
return
recent = await self._pg_get_recent(uid, fid, sid, 20)
dialogue = "\n".join(
f"{m['role']}: {m['content'][:400]}" for m in recent[-15:]
)
task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str):
try:
prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。
对话内容:
{dialogue}
请以JSON数组格式返回,每个原子包含以下字段:
- atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求)
- content: 原子的具体内容(一句话描述)
- priority: 0-100的整数,表示重要性(越高越核心)
格式示例:
[
{{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}},
{{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}}
]
只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 800,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
try:
atoms = json.loads(result_text)
if not isinstance(atoms, list):
return
except (json.JSONDecodeError, ValueError):
return
await self._dedup_and_store_atoms(uid, fid, sid, atoms)
except Exception as e:
logger.warning(f"L1原子提取失败: {e}")
async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]):
try:
async with self.db_factory() as db:
existing = await db.execute(
text("""
SELECT id, content FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
"""),
{"uid": uid, "fid": fid},
)
existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()]
for atom in atoms:
content = atom.get("content", "").strip()
if not content:
continue
is_dup = False
for ex in existing_atoms:
if self._text_similarity(content, ex["content"]) > 0.75:
existing_atoms.remove(ex)
await db.execute(
text("""
UPDATE memory_atoms
SET priority = GREATEST(priority, :priority),
updated_at = NOW(),
metadata = metadata || :meta
WHERE id = :id
"""),
{
"id": ex["id"],
"priority": atom.get("priority", 50),
"meta": json.dumps({"last_source_session_id": str(sid)}),
},
)
is_dup = True
break
if not is_dup:
await db.execute(
text("""
INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at)
VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW())
"""),
{
"uid": uid,
"fid": fid,
"atype": atom.get("atom_type", "episodic"),
"content": content,
"priority": atom.get("priority", 50),
"sid": sid,
"meta": json.dumps({"source_session_id": str(sid)}),
},
)
await db.commit()
except Exception as e:
logger.warning(f"L1原子存储失败: {e}")
@staticmethod
def _text_similarity(a: str, b: str) -> float:
a_words = set(a.lower().split())
b_words = set(b.lower().split())
if not a_words or not b_words:
return 0.0
intersection = a_words & b_words
union = a_words | b_words
return len(intersection) / len(union)
async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID):
try:
task_key = f"extract_scenes:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
{"uid": uid, "fid": fid},
)
row = result.fetchone()
count = row[0] if row else 0
scene_count_result = await db.execute(
text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
{"uid": uid, "fid": fid},
)
scene_count = scene_count_result.fetchone()[0]
if count < self.SCENE_EXTRACT_EVERY:
return
if scene_count > 0:
atoms_result = await db.execute(
text("""
SELECT updated_at FROM memory_scenes
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY updated_at DESC LIMIT 1
"""),
{"uid": uid, "fid": fid},
)
latest_scene = atoms_result.fetchone()
if latest_scene:
from datetime import timezone, timedelta
ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc)
if ago < timedelta(hours=12):
return
atoms_result = await db.execute(
text("""
SELECT atom_type, content, priority FROM memory_atoms
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
ORDER BY priority DESC LIMIT 30
"""),
{"uid": uid, "fid": fid},
)
atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()]
if len(atoms) < 10:
return
task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]):
try:
atoms_text = "\n".join(
f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms
)
prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。
记忆原子:
{atoms_text}
请以JSON数组格式返回,每个场景块包含:
- scene_name: 场景名称(简短标签)
- summary: 场景摘要(2-3句话总结该场景的关键信息)
格式示例:
[
{{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}}
]
只返回JSON数组,不要其他内容。"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
try:
scenes = json.loads(result_text)
if not isinstance(scenes, list):
return
except (json.JSONDecodeError, ValueError):
return
async with self.db_factory() as db:
for scene in scenes:
await db.execute(
text("""
INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at)
VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW())
"""),
{
"uid": uid,
"fid": fid,
"name": scene.get("scene_name", ""),
"summary": scene.get("summary", ""),
"content": json.dumps(scene, ensure_ascii=False),
},
)
await db.commit()
except Exception as e:
logger.warning(f"L2场景提取失败: {e}")
async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID):
try:
task_key = f"update_persona:{uid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
return
msg_count = 0
async with self.db_factory() as db:
result = await db.execute(
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"),
{"uid": uid},
)
row = result.fetchone()
msg_count = row[0] if row else 0
persona_result = await db.execute(
text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
persona_row = persona_result.fetchone()
if persona_row:
from datetime import timezone, timedelta
ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc)
if ago < timedelta(hours=6):
return
if msg_count < self.PERSONA_UPDATE_EVERY:
return
persona_atoms = []
async with self.db_factory() as db:
atoms_result = await db.execute(
text("""
SELECT content FROM memory_atoms
WHERE user_id = :uid AND atom_type = 'persona'
ORDER BY priority DESC, updated_at DESC LIMIT 30
"""),
{"uid": uid},
)
persona_atoms = [r[0] for r in atoms_result.fetchall()]
persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息"
task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1))
self._extract_tasks[task_key] = task
except Exception:
pass
async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int):
try:
prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。
用户信息:
{persona_text}
请返回一个JSON对象,包含以下字段:
- name: 用户姓名(未知则"未知"
- role: 角色/职位
- preferences: 偏好/习惯列表
- skills: 技能/专长列表
- personality: 性格描述
- raw_text: 综合画像的文本描述(2-3句话)
只返回JSON对象,不要其他内容。"""
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}")
try:
persona = json.loads(result_text)
except (json.JSONDecodeError, ValueError):
persona = {"raw_text": result_text}
async with self.db_factory() as db:
existing = await db.execute(
text("SELECT id FROM memory_personas WHERE user_id = :uid"),
{"uid": uid},
)
if existing.fetchone():
await db.execute(
text("""
UPDATE memory_personas
SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW()
WHERE user_id = :uid
"""),
{
"uid": uid,
"content": json.dumps(persona, ensure_ascii=False),
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
"version": version,
},
)
else:
await db.execute(
text("""
INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at)
VALUES (:uid, :content, :raw_text, :version, NOW())
"""),
{
"uid": uid,
"content": json.dumps(persona, ensure_ascii=False),
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
"version": version,
},
)
await db.commit()
except Exception as e:
logger.warning(f"L3画像更新失败: {e}")
async def search_memory(
self,
uid: uuid.UUID,
query: str,
fid: uuid.UUID = None,
top_k: int = 10,
embedding: list[float] = None,
) -> list[dict]:
results = []
try:
async with self.db_factory() as db:
if embedding and len(embedding) > 0:
emb_str = "[" + ",".join(str(v) for v in embedding) + "]"
vector_results = await db.execute(
text(f"""
SELECT id, atom_type, content, priority,
(1.0 - (embedding <=> :emb::vector)) AS similarity
FROM memory_atoms
WHERE user_id = :uid
{"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
AND embedding IS NOT NULL
ORDER BY embedding <=> :emb::vector
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid
else {"uid": uid, "emb": emb_str, "limit": top_k * 2},
)
for row in vector_results.fetchall():
results.append({
"id": str(row[0]), "type": row[1], "content": row[2],
"priority": row[3], "vector_score": float(row[4]),
})
text_results = await db.execute(
text(f"""
SELECT id, atom_type, content, priority,
ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score
FROM memory_atoms
WHERE user_id = :uid
{"""AND (flow_id = :fid OR flow_id IS NULL)""" if fid else ""}
AND content_tsv @@ plainto_tsquery('simple', :query)
ORDER BY text_score DESC
LIMIT :limit
"""),
{"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid
else {"uid": uid, "query": query, "limit": top_k * 2},
)
for row in text_results.fetchall():
results.append({
"id": str(row[0]), "type": row[1], "content": row[2],
"priority": row[3], "text_score": float(row[4]),
})
rrf_scores: dict[str, float] = {}
k = 60
for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)):
if "vector_score" in r:
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)):
if "text_score" in r:
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
seen = set()
merged = []
for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True):
for r in results:
if r["id"] == rid and rid not in seen:
seen.add(rid)
r["rrf_score"] = score
merged.append(r)
break
merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True)
return merged[:top_k]
except Exception as e:
logger.warning(f"混合检索失败: {e}")
return results[:top_k] if results else []