You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1022 lines
42 KiB
1022 lines
42 KiB
"""三级记忆系统管理器。
|
|
|
|
记忆架构分为三个层次(L1/L2/L3):
|
|
- L1(原子层):从对话中提取关键信息原子(用户偏好、事件、指令)
|
|
- L2(场景层):对同类原子进行归纳,形成场景摘要
|
|
- L3(画像层):综合所有信息,生成用户画像(Persona)
|
|
|
|
数据存储:
|
|
- PG(主存储):持久化记忆消息、原子、场景、画像
|
|
- Redis(缓存):近期消息缓存、对话摘要缓存
|
|
"""
|
|
|
|
import json
|
|
import asyncio
|
|
import uuid
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Callable
|
|
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy import text, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from config import settings
|
|
from models import ModelInstance, ModelProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_memory_manager: "MemoryManager | None" = None
|
|
|
|
|
|
def get_memory_manager() -> "MemoryManager":
|
|
"""获取全局 MemoryManager 单例。"""
|
|
global _memory_manager
|
|
if _memory_manager is None:
|
|
raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()")
|
|
return _memory_manager
|
|
|
|
|
|
async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
|
|
"""初始化记忆管理器,创建 Redis 连接并实例化 MemoryManager。"""
|
|
global _memory_manager
|
|
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
await redis.ping()
|
|
_memory_manager = MemoryManager(db_factory, redis)
|
|
|
|
|
|
class MemoryManager:
|
|
"""三级记忆管理器,负责记忆的存储、检索、提取与归纳。"""
|
|
|
|
MAX_HISTORY = 40 # 单次注入的最大历史消息数
|
|
REDIS_CACHE_SIZE = 10 # Redis 缓存保留的最近消息数
|
|
REDIS_CACHE_TTL = 300 # Redis 缓存 TTL(秒)
|
|
SUMMARY_CACHE_KEY = "mem:summary" # 摘要缓存 Redis key 前缀
|
|
MSG_CACHE_KEY = "mem:cache:msgs" # 消息缓存 Redis key 前缀
|
|
ATOM_EXTRACT_EVERY = 10 # 每 N 条消息触发一次 L1 原子提取
|
|
SCENE_EXTRACT_EVERY = 50 # 每 N 条原子触发一次 L2 场景提取
|
|
PERSONA_UPDATE_EVERY = 30 # 每 N 条消息触发一次 L3 画像更新
|
|
|
|
def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis):
|
|
"""初始化记忆管理器。
|
|
|
|
Args:
|
|
db_factory: 异步数据库会话工厂
|
|
redis: Redis 异步客户端实例
|
|
"""
|
|
self.db_factory = db_factory
|
|
self.redis = redis
|
|
self._extract_tasks: dict[str, asyncio.Task] = {} # 后台提取任务追踪
|
|
self._llm_config_cache: dict | None = None # LLM 配置缓存
|
|
|
|
async def _get_llm_config(self) -> dict:
|
|
"""从 model_instances 表获取 LLM 配置,带缓存。
|
|
|
|
优先级:数据库默认 LLM → 环境变量回退。
|
|
缓存有效期:5分钟(避免每次调用都查库)。
|
|
|
|
Returns:
|
|
dict: {model_name, api_key, base_url}
|
|
"""
|
|
if self._llm_config_cache is not None:
|
|
return self._llm_config_cache
|
|
try:
|
|
db = self.db_factory()
|
|
result = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_default == True)
|
|
.where(ModelInstance.is_active == True)
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if not row:
|
|
result2 = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_active == True)
|
|
.limit(1)
|
|
)
|
|
row = result2.first()
|
|
if row:
|
|
instance, provider = row
|
|
config = {
|
|
"api_base": (provider.base_url or settings.LLM_API_BASE).rstrip("/"),
|
|
"model": instance.model_name,
|
|
"api_key": provider.api_key or settings.LLM_API_KEY,
|
|
}
|
|
else:
|
|
config = {
|
|
"api_base": settings.LLM_API_BASE.rstrip("/"),
|
|
"model": settings.LLM_MODEL,
|
|
"api_key": settings.LLM_API_KEY,
|
|
}
|
|
self._llm_config_cache = config
|
|
# 5分钟后自动清除缓存
|
|
asyncio.get_event_loop().call_later(300, lambda: setattr(self, '_llm_config_cache', None))
|
|
return config
|
|
except Exception:
|
|
return {
|
|
"api_base": settings.LLM_API_BASE.rstrip("/"),
|
|
"model": settings.LLM_MODEL,
|
|
"api_key": settings.LLM_API_KEY,
|
|
}
|
|
|
|
async def inject_memory(
|
|
self,
|
|
user_id: str,
|
|
flow_id: str,
|
|
session_id: str,
|
|
context: dict,
|
|
):
|
|
"""向对话上下文中注入三层记忆信息。
|
|
|
|
从 PG/Redis 中获取近期消息、摘要、原子记忆和画像,
|
|
合并后注入到 context["_memory_context"] 中供 LLM 使用。
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
flow_id: 流程 ID
|
|
session_id: 会话 ID
|
|
context: 对话上下文字典(会在原地被修改)
|
|
"""
|
|
uid = uuid.UUID(user_id)
|
|
fid = uuid.UUID(flow_id)
|
|
sid = uuid.UUID(session_id)
|
|
|
|
cached = await self._redis_get_recent(uid, sid)
|
|
if cached:
|
|
recent_messages = cached
|
|
else:
|
|
recent_messages = await self._pg_get_recent(uid, fid, sid, self.MAX_HISTORY)
|
|
if recent_messages:
|
|
await self._redis_set_recent(uid, sid, recent_messages)
|
|
|
|
summary = await self._get_summary(uid, fid, sid)
|
|
atoms = await self._get_relevant_atoms(uid, fid)
|
|
persona = await self._get_persona(uid)
|
|
|
|
context["_memory_context"] = {
|
|
"recent_messages": recent_messages,
|
|
"summary": summary,
|
|
"atoms": atoms,
|
|
"persona": persona,
|
|
"session_id": session_id,
|
|
}
|
|
|
|
async def record_exchange(
|
|
self,
|
|
user_id: str,
|
|
flow_id: str,
|
|
session_id: str,
|
|
user_msg: str,
|
|
assistant_msg: str,
|
|
flow_name: str = "",
|
|
):
|
|
"""记录一次用户-助手对话交换。
|
|
|
|
将用户消息和助手消息写入 PG,同时更新 Redis 缓存,
|
|
并异步触发 L1/L1/L2/L3 记忆提取任务。
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
flow_id: 流程 ID
|
|
session_id: 会话 ID
|
|
user_msg: 用户消息内容
|
|
assistant_msg: 助手回复内容
|
|
flow_name: 流程名称(可选,用于会话记录)
|
|
"""
|
|
uid = uuid.UUID(user_id)
|
|
fid = uuid.UUID(flow_id)
|
|
sid = uuid.UUID(session_id)
|
|
ts = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
async with self.db_factory() as db:
|
|
user_row = {
|
|
"id": uuid.uuid4(),
|
|
"user_id": uid,
|
|
"flow_id": fid,
|
|
"session_id": sid,
|
|
"role": "user",
|
|
"content": user_msg,
|
|
"created_at": ts,
|
|
}
|
|
asst_row = {
|
|
"id": uuid.uuid4(),
|
|
"user_id": uid,
|
|
"flow_id": fid,
|
|
"session_id": sid,
|
|
"role": "assistant",
|
|
"content": assistant_msg,
|
|
"created_at": ts,
|
|
}
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
|
|
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
|
|
"""),
|
|
user_row,
|
|
)
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_messages (id, user_id, flow_id, session_id, role, content, created_at)
|
|
VALUES (:id, :user_id, :flow_id, :session_id, :role, :content, :created_at)
|
|
"""),
|
|
asst_row,
|
|
)
|
|
|
|
session_row = {
|
|
"id": uuid.uuid4(),
|
|
"user_id": uid,
|
|
"flow_id": fid,
|
|
"session_id": sid,
|
|
"flow_name": flow_name,
|
|
"last_active_at": ts,
|
|
"created_at": ts,
|
|
}
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_sessions (id, user_id, flow_id, session_id, flow_name, message_count, last_active_at, created_at)
|
|
VALUES (:id, :user_id, :flow_id, :session_id, :flow_name, 2, :last_active_at, :created_at)
|
|
ON CONFLICT (user_id, flow_id, session_id) DO UPDATE SET
|
|
message_count = memory_sessions.message_count + 2,
|
|
last_active_at = EXCLUDED.last_active_at,
|
|
flow_name = COALESCE(NULLIF(EXCLUDED.flow_name, ''), memory_sessions.flow_name)
|
|
"""),
|
|
session_row,
|
|
)
|
|
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"记录记忆失败(PG): {e}")
|
|
return
|
|
|
|
try:
|
|
new_msgs = [
|
|
{"role": "user", "content": user_msg, "ts": ts.isoformat()},
|
|
{"role": "assistant", "content": assistant_msg, "ts": ts.isoformat()},
|
|
]
|
|
await self._redis_append_recent(uid, sid, new_msgs)
|
|
except Exception as e:
|
|
logger.debug(f"Redis缓存更新失败: {e}")
|
|
|
|
asyncio.create_task(self._maybe_summarize(uid, fid, sid))
|
|
asyncio.create_task(self._maybe_extract_atoms(uid, fid, sid))
|
|
asyncio.create_task(self._maybe_extract_scenes(uid, fid))
|
|
asyncio.create_task(self._maybe_update_persona(uid, fid))
|
|
|
|
async def get_conversation_history(
|
|
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
|
|
) -> list[dict]:
|
|
"""获取指定会话的对话历史。
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
flow_id: 流程 ID
|
|
session_id: 会话 ID
|
|
limit: 返回的最大消息数
|
|
|
|
Returns:
|
|
消息列表,每项含 role/content/ts 字段
|
|
"""
|
|
uid = uuid.UUID(user_id)
|
|
sid = uuid.UUID(session_id)
|
|
fid = uuid.UUID(flow_id) if flow_id else None
|
|
return await self._pg_get_recent(uid, fid, sid, limit)
|
|
|
|
async def delete_session(self, user_id: str, session_id: str):
|
|
"""删除指定会话的所有记忆数据(PG + Redis)。"""
|
|
uid = uuid.UUID(user_id)
|
|
sid = uuid.UUID(session_id)
|
|
|
|
try:
|
|
async with self.db_factory() as db:
|
|
await db.execute(
|
|
text("DELETE FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
|
|
{"uid": uid, "sid": sid},
|
|
)
|
|
await db.execute(
|
|
text("DELETE FROM memory_sessions WHERE user_id = :uid AND session_id = :sid"),
|
|
{"uid": uid, "sid": sid},
|
|
)
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"清除记忆失败(PG): {e}")
|
|
|
|
try:
|
|
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
|
|
await self.redis.delete(cache_key)
|
|
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
|
|
await self.redis.delete(summary_key)
|
|
except Exception:
|
|
pass
|
|
|
|
async def list_user_sessions(self, user_id: str) -> list[dict]:
|
|
"""列出用户的所有记忆会话。
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
|
|
Returns:
|
|
会话列表,每项含 session_id/flow_id/flow_name/last_active_at
|
|
"""
|
|
uid = uuid.UUID(user_id)
|
|
try:
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT session_id, flow_id, flow_name, last_active_at
|
|
FROM memory_sessions
|
|
WHERE user_id = :uid
|
|
ORDER BY last_active_at DESC
|
|
LIMIT 100
|
|
"""),
|
|
{"uid": uid},
|
|
)
|
|
rows = result.fetchall()
|
|
return [
|
|
{
|
|
"session_id": str(row[0]),
|
|
"flow_id": str(row[1]),
|
|
"flow_name": row[2] or "",
|
|
"last_active_at": row[3].isoformat() if row[3] else "",
|
|
}
|
|
for row in rows
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]:
|
|
"""从 PG 查询最近的对话消息。"""
|
|
try:
|
|
async with self.db_factory() as db:
|
|
if fid:
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT role, content, created_at
|
|
FROM memory_messages
|
|
WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
"""),
|
|
{"uid": uid, "fid": fid, "sid": sid, "limit": limit},
|
|
)
|
|
else:
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT role, content, created_at
|
|
FROM memory_messages
|
|
WHERE user_id = :uid AND session_id = :sid
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
"""),
|
|
{"uid": uid, "sid": sid, "limit": limit},
|
|
)
|
|
rows = result.fetchall()
|
|
return [
|
|
{"role": row[0], "content": row[1], "ts": row[2].isoformat() if row[2] else ""}
|
|
for row in reversed(rows)
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None:
|
|
"""从 Redis 读取缓存的消息列表。"""
|
|
try:
|
|
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
|
|
raw = await self.redis.get(cache_key)
|
|
if raw:
|
|
return json.loads(raw)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]):
|
|
"""将消息列表写入 Redis 缓存。"""
|
|
try:
|
|
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
|
|
top = messages[-self.REDIS_CACHE_SIZE:]
|
|
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(top, ensure_ascii=False))
|
|
except Exception:
|
|
pass
|
|
|
|
async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]):
|
|
"""追加新消息到 Redis 缓存。"""
|
|
try:
|
|
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
|
|
existing = await self._redis_get_recent(uid, sid) or []
|
|
combined = existing + new_msgs
|
|
recent = combined[-self.REDIS_CACHE_SIZE:]
|
|
await self.redis.setex(cache_key, self.REDIS_CACHE_TTL, json.dumps(recent, ensure_ascii=False))
|
|
except Exception:
|
|
pass
|
|
|
|
async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str:
|
|
"""从 Redis 读取对话摘要。"""
|
|
try:
|
|
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
|
|
val = await self.redis.get(summary_key)
|
|
return val or ""
|
|
except Exception:
|
|
return ""
|
|
|
|
async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
|
|
"""L1 条件触发:对话摘要生成。
|
|
|
|
当消息数 >= 30 且尚无摘要时,调用 LLM 生成摘要并缓存到 Redis。
|
|
"""
|
|
try:
|
|
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
|
|
existing = await self.redis.get(summary_key)
|
|
if existing:
|
|
return
|
|
|
|
count = 0
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("SELECT COUNT(*) FROM memory_messages WHERE user_id = :uid AND session_id = :sid"),
|
|
{"uid": uid, "sid": sid},
|
|
)
|
|
row = result.fetchone()
|
|
count = row[0] if row else 0
|
|
|
|
if count < 30:
|
|
return
|
|
|
|
recent = await self._pg_get_recent(uid, fid, sid, 20)
|
|
dialogue = "\n".join(
|
|
f"{m['role']}: {m['content'][:500]}" for m in recent[-10:]
|
|
)
|
|
|
|
import httpx
|
|
llm = await self._get_llm_config()
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{llm['api_base']}/chat/completions",
|
|
json={
|
|
"model": llm["model"],
|
|
"messages": [{
|
|
"role": "user",
|
|
"content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}"
|
|
}],
|
|
"max_tokens": 200,
|
|
},
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"},
|
|
)
|
|
data = resp.json()
|
|
summary = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
|
|
if summary:
|
|
await self.redis.setex(summary_key, 2592000, summary)
|
|
except Exception:
|
|
pass
|
|
|
|
async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]:
|
|
"""从 PG 查询与用户/流程相关的高优先级原子记忆。"""
|
|
try:
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT id, atom_type, content, priority, metadata
|
|
FROM memory_atoms
|
|
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
|
|
ORDER BY priority DESC, updated_at DESC
|
|
LIMIT 20
|
|
"""),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
rows = result.fetchall()
|
|
return [
|
|
{"id": str(row[0]), "type": row[1], "content": row[2], "priority": row[3]}
|
|
for row in rows
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
async def _get_persona(self, uid: uuid.UUID) -> dict:
|
|
"""从 PG 查询用户画像。"""
|
|
try:
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("SELECT content, raw_text, version FROM memory_personas WHERE user_id = :uid"),
|
|
{"uid": uid},
|
|
)
|
|
row = result.fetchone()
|
|
if row:
|
|
return {"content": row[0] or {}, "raw_text": row[1] or "", "version": row[2] or 1}
|
|
except Exception:
|
|
pass
|
|
return {}
|
|
|
|
async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
|
|
"""L1 条件触发:原子记忆提取。
|
|
|
|
当消息数达到 ATOM_EXTRACT_EVERY 的整数倍时,调用 LLM 从对话中提取信息原子。
|
|
"""
|
|
try:
|
|
task_key = f"extract_atoms:{uid}:{fid}"
|
|
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
|
|
return
|
|
|
|
count = 0
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid AND flow_id = :fid AND session_id = :sid"),
|
|
{"uid": uid, "fid": fid, "sid": sid},
|
|
)
|
|
row = result.fetchone()
|
|
count = row[0] if row else 0
|
|
|
|
last_extract = await db.execute(
|
|
text("""
|
|
SELECT COUNT(*) FROM memory_atoms
|
|
WHERE user_id = :uid AND flow_id = :fid AND metadata->>'source_session_id' = :sid_str
|
|
"""),
|
|
{"uid": uid, "fid": fid, "sid_str": str(sid)},
|
|
)
|
|
last_count = last_extract.fetchone()[0]
|
|
if count < self.ATOM_EXTRACT_EVERY or count <= last_count * self.ATOM_EXTRACT_EVERY:
|
|
return
|
|
|
|
recent = await self._pg_get_recent(uid, fid, sid, 20)
|
|
dialogue = "\n".join(
|
|
f"{m['role']}: {m['content'][:400]}" for m in recent[-15:]
|
|
)
|
|
|
|
task = asyncio.create_task(self._do_extract_atoms(uid, fid, sid, dialogue))
|
|
self._extract_tasks[task_key] = task
|
|
except Exception:
|
|
pass
|
|
|
|
async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str):
|
|
"""执行 L1 原子记忆提取:调用 LLM 从对话中提取结构化记忆原子。"""
|
|
try:
|
|
prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。
|
|
|
|
对话内容:
|
|
{dialogue}
|
|
|
|
请以JSON数组格式返回,每个原子包含以下字段:
|
|
- atom_type: "persona"(用户个人信息/偏好) / "episodic"(事件/任务) / "instruction"(用户给的指令/要求)
|
|
- content: 原子的具体内容(一句话描述)
|
|
- priority: 0-100的整数,表示重要性(越高越核心)
|
|
|
|
格式示例:
|
|
[
|
|
{{"atom_type": "persona", "content": "用户名为张三,是产品经理", "priority": 80}},
|
|
{{"atom_type": "episodic", "content": "用户要求本周五前完成UI评审", "priority": 70}}
|
|
]
|
|
|
|
只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。"""
|
|
|
|
import httpx
|
|
llm = await self._get_llm_config()
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(
|
|
f"{llm['api_base']}/chat/completions",
|
|
json={
|
|
"model": llm["model"],
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 800,
|
|
"temperature": 0.3,
|
|
},
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"},
|
|
)
|
|
data = resp.json()
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
|
|
|
|
try:
|
|
atoms = json.loads(result_text)
|
|
if not isinstance(atoms, list):
|
|
return
|
|
except (json.JSONDecodeError, ValueError):
|
|
return
|
|
|
|
await self._dedup_and_store_atoms(uid, fid, sid, atoms)
|
|
except Exception as e:
|
|
logger.warning(f"L1原子提取失败: {e}")
|
|
|
|
async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]):
|
|
"""对提取的原子进行去重并存储到 PG。
|
|
|
|
使用文本相似度判断是否与已有原子重复:
|
|
- 相似度 > 75% 时更新优先级和元数据
|
|
- 否则插入新原子记录
|
|
"""
|
|
try:
|
|
async with self.db_factory() as db:
|
|
existing = await db.execute(
|
|
text("""
|
|
SELECT id, content FROM memory_atoms
|
|
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
|
|
"""),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
existing_atoms = [{"id": row[0], "content": row[1]} for row in existing.fetchall()]
|
|
|
|
for atom in atoms:
|
|
content = atom.get("content", "").strip()
|
|
if not content:
|
|
continue
|
|
|
|
is_dup = False
|
|
for ex in existing_atoms:
|
|
if self._text_similarity(content, ex["content"]) > 0.75:
|
|
existing_atoms.remove(ex)
|
|
await db.execute(
|
|
text("""
|
|
UPDATE memory_atoms
|
|
SET priority = GREATER(priority, :priority),
|
|
updated_at = NOW(),
|
|
metadata = metadata || :meta
|
|
WHERE id = :id
|
|
"""),
|
|
{
|
|
"id": ex["id"],
|
|
"priority": atom.get("priority", 50),
|
|
"meta": json.dumps({"last_source_session_id": str(sid)}),
|
|
},
|
|
)
|
|
is_dup = True
|
|
break
|
|
|
|
if not is_dup:
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_atoms (user_id, flow_id, atom_type, content, priority, source_session_id, metadata, created_at, updated_at)
|
|
VALUES (:uid, :fid, :atype, :content, :priority, :sid, :meta, NOW(), NOW())
|
|
"""),
|
|
{
|
|
"uid": uid,
|
|
"fid": fid,
|
|
"atype": atom.get("atom_type", "episodic"),
|
|
"content": content,
|
|
"priority": atom.get("priority", 50),
|
|
"sid": sid,
|
|
"meta": json.dumps({"source_session_id": str(sid)}),
|
|
},
|
|
)
|
|
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"L1原子存储失败: {e}")
|
|
|
|
@staticmethod
|
|
def _text_similarity(a: str, b: str) -> float:
|
|
"""计算两段文本的 Jaccard 相似度(基于单词集合)。"""
|
|
a_words = set(a.lower().split())
|
|
b_words = set(b.lower().split())
|
|
if not a_words or not b_words:
|
|
return 0.0
|
|
intersection = a_words & b_words
|
|
union = a_words | b_words
|
|
return len(intersection) / len(union)
|
|
|
|
async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID):
|
|
"""L2 条件触发:场景提取。
|
|
|
|
当原子数达到 SCENE_EXTRACT_EVERY 且距上次提取超过 12 小时时,
|
|
调用 LLM 对已有原子进行场景归纳。
|
|
"""
|
|
try:
|
|
task_key = f"extract_scenes:{uid}:{fid}"
|
|
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
|
|
return
|
|
|
|
count = 0
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("SELECT COUNT(*) FROM memory_atoms WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
row = result.fetchone()
|
|
count = row[0] if row else 0
|
|
|
|
scene_count_result = await db.execute(
|
|
text("SELECT COUNT(*) FROM memory_scenes WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)"),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
scene_count = scene_count_result.fetchone()[0]
|
|
|
|
if count < self.SCENE_EXTRACT_EVERY:
|
|
return
|
|
|
|
if scene_count > 0:
|
|
atoms_result = await db.execute(
|
|
text("""
|
|
SELECT updated_at FROM memory_scenes
|
|
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
|
|
ORDER BY updated_at DESC LIMIT 1
|
|
"""),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
latest_scene = atoms_result.fetchone()
|
|
if latest_scene:
|
|
from datetime import timezone, timedelta
|
|
ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc)
|
|
if ago < timedelta(hours=12):
|
|
return
|
|
|
|
atoms_result = await db.execute(
|
|
text("""
|
|
SELECT atom_type, content, priority FROM memory_atoms
|
|
WHERE user_id = :uid AND (flow_id = :fid OR flow_id IS NULL)
|
|
ORDER BY priority DESC LIMIT 30
|
|
"""),
|
|
{"uid": uid, "fid": fid},
|
|
)
|
|
atoms = [{"type": r[0], "content": r[1], "priority": r[2]} for r in atoms_result.fetchall()]
|
|
|
|
if len(atoms) < 10:
|
|
return
|
|
|
|
task = asyncio.create_task(self._do_extract_scenes(uid, fid, atoms))
|
|
self._extract_tasks[task_key] = task
|
|
except Exception:
|
|
pass
|
|
|
|
async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]):
|
|
"""执行 L2 场景提取:调用 LLM 将原子记忆归纳为场景块。"""
|
|
try:
|
|
atoms_text = "\n".join(
|
|
f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms
|
|
)
|
|
|
|
prompt = f"""以下是从用户对话中提取的结构化记忆原子。请将它们归纳为1-3个场景块,每个场景块描述一个主题领域。
|
|
|
|
记忆原子:
|
|
{atoms_text}
|
|
|
|
请以JSON数组格式返回,每个场景块包含:
|
|
- scene_name: 场景名称(简短标签)
|
|
- summary: 场景摘要(2-3句话总结该场景的关键信息)
|
|
|
|
格式示例:
|
|
[
|
|
{{"scene_name": "项目管理", "summary": "用户负责产品评审和UI改版项目,截止日期为本周五..."}}
|
|
]
|
|
|
|
只返回JSON数组,不要其他内容。"""
|
|
|
|
import httpx
|
|
llm = await self._get_llm_config()
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(
|
|
f"{llm['api_base']}/chat/completions",
|
|
json={
|
|
"model": llm["model"],
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 500,
|
|
"temperature": 0.3,
|
|
},
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"},
|
|
)
|
|
data = resp.json()
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]")
|
|
|
|
try:
|
|
scenes = json.loads(result_text)
|
|
if not isinstance(scenes, list):
|
|
return
|
|
except (json.JSONDecodeError, ValueError):
|
|
return
|
|
|
|
async with self.db_factory() as db:
|
|
for scene in scenes:
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_scenes (user_id, flow_id, scene_name, summary, heat, content, created_at, updated_at)
|
|
VALUES (:uid, :fid, :name, :summary, 1, :content, NOW(), NOW())
|
|
"""),
|
|
{
|
|
"uid": uid,
|
|
"fid": fid,
|
|
"name": scene.get("scene_name", ""),
|
|
"summary": scene.get("summary", ""),
|
|
"content": json.dumps(scene, ensure_ascii=False),
|
|
},
|
|
)
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"L2场景提取失败: {e}")
|
|
|
|
async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID):
|
|
"""L3 条件触发:用户画像更新。
|
|
|
|
当消息数达到 PERSONA_UPDATE_EVERY 且距上次更新超过 6 小时时,
|
|
基于已有 persona 类型原子重新生成画像。
|
|
"""
|
|
try:
|
|
task_key = f"update_persona:{uid}"
|
|
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
|
|
return
|
|
|
|
msg_count = 0
|
|
async with self.db_factory() as db:
|
|
result = await db.execute(
|
|
text("SELECT message_count FROM memory_sessions WHERE user_id = :uid ORDER BY last_active_at DESC LIMIT 1"),
|
|
{"uid": uid},
|
|
)
|
|
row = result.fetchone()
|
|
msg_count = row[0] if row else 0
|
|
|
|
persona_result = await db.execute(
|
|
text("SELECT version, updated_at FROM memory_personas WHERE user_id = :uid"),
|
|
{"uid": uid},
|
|
)
|
|
persona_row = persona_result.fetchone()
|
|
if persona_row:
|
|
from datetime import timezone, timedelta
|
|
ago = datetime.now(timezone.utc) - persona_row[1].replace(tzinfo=timezone.utc)
|
|
if ago < timedelta(hours=6):
|
|
return
|
|
|
|
if msg_count < self.PERSONA_UPDATE_EVERY:
|
|
return
|
|
|
|
persona_atoms = []
|
|
async with self.db_factory() as db:
|
|
atoms_result = await db.execute(
|
|
text("""
|
|
SELECT content FROM memory_atoms
|
|
WHERE user_id = :uid AND atom_type = 'persona'
|
|
ORDER BY priority DESC, updated_at DESC LIMIT 30
|
|
"""),
|
|
{"uid": uid},
|
|
)
|
|
persona_atoms = [r[0] for r in atoms_result.fetchall()]
|
|
|
|
persona_text = "\n".join(persona_atoms[:15]) if persona_atoms else "暂无用户信息"
|
|
|
|
task = asyncio.create_task(self._do_update_persona(uid, persona_text, persona_row.version + 1 if persona_row else 1))
|
|
self._extract_tasks[task_key] = task
|
|
except Exception:
|
|
pass
|
|
|
|
async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int):
|
|
"""执行 L3 画像更新:调用 LLM 生成结构化用户画像并持久化。"""
|
|
try:
|
|
prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。
|
|
|
|
用户信息:
|
|
{persona_text}
|
|
|
|
请返回一个JSON对象,包含以下字段:
|
|
- name: 用户姓名(未知则"未知")
|
|
- role: 角色/职位
|
|
- preferences: 偏好/习惯列表
|
|
- skills: 技能/专长列表
|
|
- personality: 性格描述
|
|
- raw_text: 综合画像的文本描述(2-3句话)
|
|
|
|
只返回JSON对象,不要其他内容。"""
|
|
|
|
import httpx
|
|
llm = await self._get_llm_config()
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(
|
|
f"{llm['api_base']}/chat/completions",
|
|
json={
|
|
"model": llm["model"],
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 500,
|
|
"temperature": 0.3,
|
|
},
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"},
|
|
)
|
|
data = resp.json()
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}")
|
|
|
|
try:
|
|
persona = json.loads(result_text)
|
|
except (json.JSONDecodeError, ValueError):
|
|
persona = {"raw_text": result_text}
|
|
|
|
async with self.db_factory() as db:
|
|
existing = await db.execute(
|
|
text("SELECT id FROM memory_personas WHERE user_id = :uid"),
|
|
{"uid": uid},
|
|
)
|
|
if existing.fetchone():
|
|
await db.execute(
|
|
text("""
|
|
UPDATE memory_personas
|
|
SET content = :content, raw_text = :raw_text, version = :version, updated_at = NOW()
|
|
WHERE user_id = :uid
|
|
"""),
|
|
{
|
|
"uid": uid,
|
|
"content": json.dumps(persona, ensure_ascii=False),
|
|
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
|
|
"version": version,
|
|
},
|
|
)
|
|
else:
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO memory_personas (user_id, content, raw_text, version, updated_at)
|
|
VALUES (:uid, :content, :raw_text, :version, NOW())
|
|
"""),
|
|
{
|
|
"uid": uid,
|
|
"content": json.dumps(persona, ensure_ascii=False),
|
|
"raw_text": persona.get("raw_text", json.dumps(persona, ensure_ascii=False)),
|
|
"version": version,
|
|
},
|
|
)
|
|
await db.commit()
|
|
except Exception as e:
|
|
logger.warning(f"L3画像更新失败: {e}")
|
|
|
|
async def search_memory(
|
|
self,
|
|
uid: uuid.UUID,
|
|
query: str,
|
|
fid: uuid.UUID = None,
|
|
top_k: int = 10,
|
|
embedding: list[float] = None,
|
|
) -> list[dict]:
|
|
"""混合检索记忆:向量相似度 + 全文检索,使用 RRF 算法融合排序。
|
|
|
|
Args:
|
|
uid: 用户 ID
|
|
query: 搜索查询文本
|
|
fid: 流程 ID(可选,过滤范围)
|
|
top_k: 返回结果数
|
|
embedding: 查询向量(可选,启用向量检索)
|
|
|
|
Returns:
|
|
按 RRF 分数降序排列的记忆原子列表
|
|
"""
|
|
results = []
|
|
try:
|
|
async with self.db_factory() as db:
|
|
flow_filter = "AND (flow_id = :fid OR flow_id IS NULL)" if fid else ""
|
|
if embedding and len(embedding) > 0:
|
|
emb_str = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
vector_results = await db.execute(
|
|
text(f"""
|
|
SELECT id, atom_type, content, priority,
|
|
(1.0 - (embedding <=> :emb::vector)) AS similarity
|
|
FROM memory_atoms
|
|
WHERE user_id = :uid
|
|
{flow_filter}
|
|
AND embedding IS NOT NULL
|
|
ORDER BY embedding <=> :emb::vector
|
|
LIMIT :limit
|
|
"""),
|
|
{"uid": uid, "fid": fid, "emb": emb_str, "limit": top_k * 2} if fid
|
|
else {"uid": uid, "emb": emb_str, "limit": top_k * 2},
|
|
)
|
|
for row in vector_results.fetchall():
|
|
results.append({
|
|
"id": str(row[0]), "type": row[1], "content": row[2],
|
|
"priority": row[3], "vector_score": float(row[4]),
|
|
})
|
|
|
|
text_results = await db.execute(
|
|
text(f"""
|
|
SELECT id, atom_type, content, priority,
|
|
ts_rank(content_tsv, plainto_tsquery('simple', :query)) AS text_score
|
|
FROM memory_atoms
|
|
WHERE user_id = :uid
|
|
{flow_filter}
|
|
AND content_tsv @@ plainto_tsquery('simple', :query)
|
|
ORDER BY text_score DESC
|
|
LIMIT :limit
|
|
"""),
|
|
{"uid": uid, "fid": fid, "query": query, "limit": top_k * 2} if fid
|
|
else {"uid": uid, "query": query, "limit": top_k * 2},
|
|
)
|
|
for row in text_results.fetchall():
|
|
results.append({
|
|
"id": str(row[0]), "type": row[1], "content": row[2],
|
|
"priority": row[3], "text_score": float(row[4]),
|
|
})
|
|
|
|
rrf_scores: dict[str, float] = {}
|
|
k = 60
|
|
for rank, r in enumerate(sorted(results, key=lambda x: x.get("vector_score", 0), reverse=True)):
|
|
if "vector_score" in r:
|
|
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
|
|
for rank, r in enumerate(sorted(results, key=lambda x: x.get("text_score", 0), reverse=True)):
|
|
if "text_score" in r:
|
|
rrf_scores[r["id"]] = rrf_scores.get(r["id"], 0) + 1.0 / (k + rank + 1)
|
|
|
|
seen = set()
|
|
merged = []
|
|
for rid, score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True):
|
|
for r in results:
|
|
if r["id"] == rid and rid not in seen:
|
|
seen.add(rid)
|
|
r["rrf_score"] = score
|
|
merged.append(r)
|
|
break
|
|
|
|
merged.sort(key=lambda x: x.get("rrf_score", 0), reverse=True)
|
|
return merged[:top_k]
|
|
except Exception as e:
|
|
logger.warning(f"混合检索失败: {e}")
|
|
return results[:top_k] if results else []
|