You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

201 lines
7.2 KiB

import json
import asyncio
import logging
from datetime import datetime
from redis.asyncio import Redis
from config import settings
logger = logging.getLogger(__name__)
_memory_manager: "MemoryManager | None" = None
def get_memory_manager() -> "MemoryManager":
global _memory_manager
if _memory_manager is None:
raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()")
return _memory_manager
async def init_memory_manager():
global _memory_manager
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await redis.ping()
_memory_manager = MemoryManager(redis)
class MemoryManager:
KEY_PREFIX = "mem"
DEFAULT_TTL = 604800
SESSION_INDEX_TTL = 2592000
MAX_HISTORY = 40
def __init__(self, redis: Redis):
self.redis = redis
async def inject_memory(
self,
user_id: str,
flow_id: str,
session_id: str,
context: dict,
):
messages = await self._get_recent_messages(user_id, flow_id, session_id)
summary = await self._get_summary(user_id, flow_id, session_id)
context["_memory_context"] = {
"recent_messages": list(reversed(messages)),
"summary": summary,
"session_id": session_id,
}
async def record_exchange(
self,
user_id: str,
flow_id: str,
session_id: str,
user_msg: str,
assistant_msg: str,
flow_name: str = "",
):
key = self._msg_key(user_id, flow_id, session_id)
ts = datetime.utcnow().isoformat()
try:
async with self.redis.pipeline() as pipe:
pipe.lpush(key,
json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False),
json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False),
)
pipe.ltrim(key, 0, self.MAX_HISTORY - 1)
pipe.expire(key, self.DEFAULT_TTL)
pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={
"flow_name": flow_name,
"last_active_at": ts,
})
pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL)
pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL)
await pipe.execute()
except Exception as e:
logger.warning(f"记录记忆失败: {e}")
asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id))
async def get_conversation_history(
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
) -> list[dict]:
messages = await self._get_recent_messages(user_id, flow_id, session_id, limit)
return list(reversed(messages))
async def delete_session(self, user_id: str, session_id: str):
try:
patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*")
async with self.redis.pipeline() as pipe:
if patterns:
pipe.delete(*patterns)
pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
await pipe.execute()
except Exception as e:
logger.warning(f"清除记忆失败: {e}")
async def list_user_sessions(self, user_id: str) -> list[dict]:
try:
session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions")
except Exception:
return []
sessions = []
for sid in session_ids:
try:
keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta")
for k in keys:
meta = await self.redis.hgetall(k)
parts = k.split(":")
flow_id = parts[2] if len(parts) > 2 else ""
sessions.append({
"session_id": sid,
"flow_id": flow_id,
"flow_name": meta.get("flow_name", ""),
"last_active_at": meta.get("last_active_at", ""),
})
except Exception:
continue
return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True)
async def _get_recent_messages(
self, user_id: str, flow_id: str, session_id: str, limit: int = None
) -> list[dict]:
limit = limit or self.MAX_HISTORY
try:
key = self._msg_key(user_id, flow_id, session_id)
raw = await self.redis.lrange(key, 0, limit - 1)
result = []
for m in raw:
try:
result.append(json.loads(m))
except json.JSONDecodeError:
continue
return result
except Exception:
return []
async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str:
try:
key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
val = await self.redis.get(key)
return val or ""
except Exception:
return ""
async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str):
try:
key = self._msg_key(user_id, flow_id, session_id)
count = await self.redis.llen(key)
if count < 30:
return
summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
existing = await self.redis.get(summary_key)
if existing:
return
recent = await self._get_recent_messages(user_id, flow_id, session_id, 20)
dialogue = "\n".join(
f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10])
)
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": settings.LLM_MODEL,
"messages": [{
"role": "user",
"content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}"
}],
"max_tokens": 200,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
summary = data.get("choices", [{}])[0].get("message", {}).get("content", "")
if summary:
await self.redis.setex(summary_key, 2592000, summary)
except Exception:
pass
@staticmethod
def _msg_key(user_id: str, flow_id: str, session_id: str) -> str:
return f"mem:{user_id}:{flow_id}:{session_id}:messages"
@staticmethod
def _meta_key(user_id: str, flow_id: str, session_id: str) -> str:
return f"mem:{user_id}:{flow_id}:{session_id}:meta"