import json import asyncio import logging from datetime import datetime from redis.asyncio import Redis from config import settings logger = logging.getLogger(__name__) _memory_manager: "MemoryManager | None" = None def get_memory_manager() -> "MemoryManager": global _memory_manager if _memory_manager is None: raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()") return _memory_manager async def init_memory_manager(): global _memory_manager redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await redis.ping() _memory_manager = MemoryManager(redis) class MemoryManager: KEY_PREFIX = "mem" DEFAULT_TTL = 604800 SESSION_INDEX_TTL = 2592000 MAX_HISTORY = 40 def __init__(self, redis: Redis): self.redis = redis async def inject_memory( self, user_id: str, flow_id: str, session_id: str, context: dict, ): messages = await self._get_recent_messages(user_id, flow_id, session_id) summary = await self._get_summary(user_id, flow_id, session_id) context["_memory_context"] = { "recent_messages": list(reversed(messages)), "summary": summary, "session_id": session_id, } async def record_exchange( self, user_id: str, flow_id: str, session_id: str, user_msg: str, assistant_msg: str, flow_name: str = "", ): key = self._msg_key(user_id, flow_id, session_id) ts = datetime.utcnow().isoformat() try: async with self.redis.pipeline() as pipe: pipe.lpush(key, json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False), json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False), ) pipe.ltrim(key, 0, self.MAX_HISTORY - 1) pipe.expire(key, self.DEFAULT_TTL) pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={ "flow_name": flow_name, "last_active_at": ts, }) pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL) pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id) pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL) await pipe.execute() except Exception as e: logger.warning(f"记录记忆失败: {e}") asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id)) async def get_conversation_history( self, user_id: str, flow_id: str, session_id: str, limit: int = 20 ) -> list[dict]: messages = await self._get_recent_messages(user_id, flow_id, session_id, limit) return list(reversed(messages)) async def delete_session(self, user_id: str, session_id: str): try: patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*") async with self.redis.pipeline() as pipe: if patterns: pipe.delete(*patterns) pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id) await pipe.execute() except Exception as e: logger.warning(f"清除记忆失败: {e}") async def list_user_sessions(self, user_id: str) -> list[dict]: try: session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions") except Exception: return [] sessions = [] for sid in session_ids: try: keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta") for k in keys: meta = await self.redis.hgetall(k) parts = k.split(":") flow_id = parts[2] if len(parts) > 2 else "" sessions.append({ "session_id": sid, "flow_id": flow_id, "flow_name": meta.get("flow_name", ""), "last_active_at": meta.get("last_active_at", ""), }) except Exception: continue return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True) async def _get_recent_messages( self, user_id: str, flow_id: str, session_id: str, limit: int = None ) -> list[dict]: limit = limit or self.MAX_HISTORY try: key = self._msg_key(user_id, flow_id, session_id) raw = await self.redis.lrange(key, 0, limit - 1) result = [] for m in raw: try: result.append(json.loads(m)) except json.JSONDecodeError: continue return result except Exception: return [] async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str: try: key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary" val = await self.redis.get(key) return val or "" except Exception: return "" async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str): try: key = self._msg_key(user_id, flow_id, session_id) count = await self.redis.llen(key) if count < 30: return summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary" existing = await self.redis.get(summary_key) if existing: return recent = await self._get_recent_messages(user_id, flow_id, session_id, 20) dialogue = "\n".join( f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10]) ) import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": settings.LLM_MODEL, "messages": [{ "role": "user", "content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}" }], "max_tokens": 200, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() summary = data.get("choices", [{}])[0].get("message", {}).get("content", "") if summary: await self.redis.setex(summary_key, 2592000, summary) except Exception: pass @staticmethod def _msg_key(user_id: str, flow_id: str, session_id: str) -> str: return f"mem:{user_id}:{flow_id}:{session_id}:messages" @staticmethod def _meta_key(user_id: str, flow_id: str, session_id: str) -> str: return f"mem:{user_id}:{flow_id}:{session_id}:meta"