You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

82 lines
2.9 KiB

"""用户隔离记忆模块。
提供基于用户 ID 隔离的记忆存储机制,确保每个用户只能访问自己的对话历史。
通过包装 AgentScope 的 MemoryBase 实现用户级别的记忆隔离。
"""
from agentscope.memory import MemoryBase, InMemoryMemory
from agentscope.message import Msg
class UserIsolatedMemory(MemoryBase):
"""用户隔离记忆类,确保每个用户只能访问自己对话历史的记忆管理器。
通过在消息元数据中标记用户 ID,在获取记忆时过滤出当前用户的消息,
实现多用户环境下的对话历史隔离。
Attributes:
user_id: 当前记忆实例绑定的用户唯一标识。
_backend: 底层记忆存储实例,默认为 InMemoryMemory。
"""
def __init__(self, user_id: str, backend_memory: MemoryBase | None = None):
"""初始化用户隔离记忆实例。
Args:
user_id: 用户唯一标识。
backend_memory: 可选的底层记忆存储实例,不提供则使用内存存储。
"""
self.user_id = user_id
self._backend = backend_memory or InMemoryMemory()
async def add(self, msg: Msg | list[Msg] | None) -> None:
"""添加消息到记忆中,自动标记当前用户 ID。
Args:
msg: 要添加的消息,可以是单条消息、消息列表或 None。
"""
if msg is None:
return
msgs = msg if isinstance(msg, list) else [msg]
for m in msgs:
m.metadata = m.metadata or {}
m.metadata["_user_id"] = self.user_id
await self._backend.add(msg)
async def get_memory(self, **kwargs) -> list[Msg]:
"""获取当前用户的记忆历史。
从底层存储中获取所有消息后,过滤出属于当前用户的消息。
Args:
**kwargs: 传递给底层存储的额外参数。
Returns:
list[Msg]: 属于当前用户的消息列表。
"""
all_msgs = await self._backend.get_memory(**kwargs)
return [m for m in all_msgs if m.metadata.get("_user_id") == self.user_id]
async def delete_by_mark(self, mark: str) -> None:
"""根据标记删除消息。
Args:
mark: 要删除的消息标记。
"""
await self._backend.delete_by_mark(mark)
async def update_messages_mark(self, msg_ids: list[str], new_mark: str) -> None:
"""更新消息的标记。
Args:
msg_ids: 要更新标记的消息 ID 列表。
new_mark: 新的标记字符串。
"""
await self._backend.update_messages_mark(msg_ids, new_mark)
async def update_compressed_summary(self, summary: str) -> None:
"""更新压缩后的记忆摘要。
Args:
summary: 新的记忆摘要字符串。
"""
await self._backend.update_compressed_summary(summary)