You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
7.2 KiB
201 lines
7.2 KiB
import json
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from redis.asyncio import Redis
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_memory_manager: "MemoryManager | None" = None
|
|
|
|
|
|
def get_memory_manager() -> "MemoryManager":
|
|
global _memory_manager
|
|
if _memory_manager is None:
|
|
raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()")
|
|
return _memory_manager
|
|
|
|
|
|
async def init_memory_manager():
|
|
global _memory_manager
|
|
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
await redis.ping()
|
|
_memory_manager = MemoryManager(redis)
|
|
|
|
|
|
class MemoryManager:
|
|
KEY_PREFIX = "mem"
|
|
DEFAULT_TTL = 604800
|
|
SESSION_INDEX_TTL = 2592000
|
|
MAX_HISTORY = 40
|
|
|
|
def __init__(self, redis: Redis):
|
|
self.redis = redis
|
|
|
|
async def inject_memory(
|
|
self,
|
|
user_id: str,
|
|
flow_id: str,
|
|
session_id: str,
|
|
context: dict,
|
|
):
|
|
messages = await self._get_recent_messages(user_id, flow_id, session_id)
|
|
summary = await self._get_summary(user_id, flow_id, session_id)
|
|
|
|
context["_memory_context"] = {
|
|
"recent_messages": list(reversed(messages)),
|
|
"summary": summary,
|
|
"session_id": session_id,
|
|
}
|
|
|
|
async def record_exchange(
|
|
self,
|
|
user_id: str,
|
|
flow_id: str,
|
|
session_id: str,
|
|
user_msg: str,
|
|
assistant_msg: str,
|
|
flow_name: str = "",
|
|
):
|
|
key = self._msg_key(user_id, flow_id, session_id)
|
|
ts = datetime.utcnow().isoformat()
|
|
|
|
try:
|
|
async with self.redis.pipeline() as pipe:
|
|
pipe.lpush(key,
|
|
json.dumps({"role": "assistant", "content": assistant_msg, "ts": ts}, ensure_ascii=False),
|
|
json.dumps({"role": "user", "content": user_msg, "ts": ts}, ensure_ascii=False),
|
|
)
|
|
pipe.ltrim(key, 0, self.MAX_HISTORY - 1)
|
|
pipe.expire(key, self.DEFAULT_TTL)
|
|
|
|
pipe.hset(self._meta_key(user_id, flow_id, session_id), mapping={
|
|
"flow_name": flow_name,
|
|
"last_active_at": ts,
|
|
})
|
|
pipe.expire(self._meta_key(user_id, flow_id, session_id), self.DEFAULT_TTL)
|
|
|
|
pipe.sadd(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
|
|
pipe.expire(f"{self.KEY_PREFIX}:{user_id}:sessions", self.SESSION_INDEX_TTL)
|
|
|
|
await pipe.execute()
|
|
except Exception as e:
|
|
logger.warning(f"记录记忆失败: {e}")
|
|
|
|
asyncio.create_task(self._maybe_summarize(user_id, flow_id, session_id))
|
|
|
|
async def get_conversation_history(
|
|
self, user_id: str, flow_id: str, session_id: str, limit: int = 20
|
|
) -> list[dict]:
|
|
messages = await self._get_recent_messages(user_id, flow_id, session_id, limit)
|
|
return list(reversed(messages))
|
|
|
|
async def delete_session(self, user_id: str, session_id: str):
|
|
try:
|
|
patterns = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{session_id}:*")
|
|
async with self.redis.pipeline() as pipe:
|
|
if patterns:
|
|
pipe.delete(*patterns)
|
|
pipe.srem(f"{self.KEY_PREFIX}:{user_id}:sessions", session_id)
|
|
await pipe.execute()
|
|
except Exception as e:
|
|
logger.warning(f"清除记忆失败: {e}")
|
|
|
|
async def list_user_sessions(self, user_id: str) -> list[dict]:
|
|
try:
|
|
session_ids = await self.redis.smembers(f"{self.KEY_PREFIX}:{user_id}:sessions")
|
|
except Exception:
|
|
return []
|
|
|
|
sessions = []
|
|
for sid in session_ids:
|
|
try:
|
|
keys = await self.redis.keys(f"{self.KEY_PREFIX}:{user_id}:*:{sid}:meta")
|
|
for k in keys:
|
|
meta = await self.redis.hgetall(k)
|
|
parts = k.split(":")
|
|
flow_id = parts[2] if len(parts) > 2 else ""
|
|
sessions.append({
|
|
"session_id": sid,
|
|
"flow_id": flow_id,
|
|
"flow_name": meta.get("flow_name", ""),
|
|
"last_active_at": meta.get("last_active_at", ""),
|
|
})
|
|
except Exception:
|
|
continue
|
|
|
|
return sorted(sessions, key=lambda s: s.get("last_active_at", ""), reverse=True)
|
|
|
|
async def _get_recent_messages(
|
|
self, user_id: str, flow_id: str, session_id: str, limit: int = None
|
|
) -> list[dict]:
|
|
limit = limit or self.MAX_HISTORY
|
|
try:
|
|
key = self._msg_key(user_id, flow_id, session_id)
|
|
raw = await self.redis.lrange(key, 0, limit - 1)
|
|
result = []
|
|
for m in raw:
|
|
try:
|
|
result.append(json.loads(m))
|
|
except json.JSONDecodeError:
|
|
continue
|
|
return result
|
|
except Exception:
|
|
return []
|
|
|
|
async def _get_summary(self, user_id: str, flow_id: str, session_id: str) -> str:
|
|
try:
|
|
key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
|
|
val = await self.redis.get(key)
|
|
return val or ""
|
|
except Exception:
|
|
return ""
|
|
|
|
async def _maybe_summarize(self, user_id: str, flow_id: str, session_id: str):
|
|
try:
|
|
key = self._msg_key(user_id, flow_id, session_id)
|
|
count = await self.redis.llen(key)
|
|
if count < 30:
|
|
return
|
|
|
|
summary_key = f"{self.KEY_PREFIX}:{user_id}:{flow_id}:{session_id}:summary"
|
|
existing = await self.redis.get(summary_key)
|
|
if existing:
|
|
return
|
|
|
|
recent = await self._get_recent_messages(user_id, flow_id, session_id, 20)
|
|
dialogue = "\n".join(
|
|
f"{m['role']}: {m['content'][:500]}" for m in reversed(recent[:10])
|
|
)
|
|
|
|
import httpx
|
|
api_base = settings.LLM_API_BASE.rstrip("/")
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{api_base}/chat/completions",
|
|
json={
|
|
"model": settings.LLM_MODEL,
|
|
"messages": [{
|
|
"role": "user",
|
|
"content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}"
|
|
}],
|
|
"max_tokens": 200,
|
|
},
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
|
)
|
|
data = resp.json()
|
|
summary = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
|
|
if summary:
|
|
await self.redis.setex(summary_key, 2592000, summary)
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _msg_key(user_id: str, flow_id: str, session_id: str) -> str:
|
|
return f"mem:{user_id}:{flow_id}:{session_id}:messages"
|
|
|
|
@staticmethod
|
|
def _meta_key(user_id: str, flow_id: str, session_id: str) -> str:
|
|
return f"mem:{user_id}:{flow_id}:{session_id}:meta"
|