You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
1.2 KiB
30 lines
1.2 KiB
from agentscope.memory import MemoryBase, InMemoryMemory
|
|
from agentscope.message import Msg
|
|
|
|
|
|
class UserIsolatedMemory(MemoryBase):
|
|
def __init__(self, user_id: str, backend_memory: MemoryBase | None = None):
|
|
self.user_id = user_id
|
|
self._backend = backend_memory or InMemoryMemory()
|
|
|
|
async def add(self, msg: Msg | list[Msg] | None) -> None:
|
|
if msg is None:
|
|
return
|
|
msgs = msg if isinstance(msg, list) else [msg]
|
|
for m in msgs:
|
|
m.metadata = m.metadata or {}
|
|
m.metadata["_user_id"] = self.user_id
|
|
await self._backend.add(msg)
|
|
|
|
async def get_memory(self, **kwargs) -> list[Msg]:
|
|
all_msgs = await self._backend.get_memory(**kwargs)
|
|
return [m for m in all_msgs if m.metadata.get("_user_id") == self.user_id]
|
|
|
|
async def delete_by_mark(self, mark: str) -> None:
|
|
await self._backend.delete_by_mark(mark)
|
|
|
|
async def update_messages_mark(self, msg_ids: list[str], new_mark: str) -> None:
|
|
await self._backend.update_messages_mark(msg_ids, new_mark)
|
|
|
|
async def update_compressed_summary(self, summary: str) -> None:
|
|
await self._backend.update_compressed_summary(summary)
|