You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

300 lines
10 KiB

"""Redis → PostgreSQL 记忆数据迁移脚本
将 Redis 中现有记忆数据(消息、摘要)迁移到 PostgreSQL memory_messages 表。
使用方式:
python scripts/migrate_memory_redis_to_pg.py [--dry-run] [--user-id UUID]
选项:
--dry-run 仅扫描不实际写入
--user-id 只迁移指定用户的数据
--batch-size 每批写入条数 (默认100)
"""
import asyncio
import json
import uuid
import sys
import os
from datetime import datetime
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from config import settings
async def scan_redis_keys(redis: Redis, pattern: str) -> list[str]:
keys = []
cursor = 0
while True:
cursor, batch = await redis.scan(cursor, match=pattern, count=200)
keys.extend(batch)
if cursor == 0:
break
return keys
def parse_key_info(key: str) -> dict | None:
"""从 Redis key 中解析 user_id, flow_id, session_id"""
parts = key.split(":")
info = {}
uid_idx = None
fid_idx = None
sid_idx = None
for i, p in enumerate(parts):
try:
val = uuid.UUID(p)
if uid_idx is None:
uid_idx = i
info["user_id"] = val
elif sid_idx is None:
sid_idx = i
info["session_id"] = val
elif fid_idx is None:
fid_idx = i
info["flow_id"] = val
except ValueError:
if p == "messages" and uid_idx is not None and sid_idx is not None and fid_idx is None:
fid_idx = -1
continue
if not info.get("user_id"):
return None
return info
async def migrate_old_format_keys(redis: Redis, engine, dry_run: bool):
"""迁移旧格式: mem:{uid}:{fid}:{sid}:messages (List) 和 mem:{uid}:{fid}:{sid}:meta (Hash)"""
print(">>> 扫描旧格式 Redis 记忆键 ...")
keys = await scan_redis_keys(redis, "mem:*:messages")
keys += await scan_redis_keys(redis, "mem:*:meta")
print(f" 找到 {len(keys)} 个旧格式键")
migrated = 0
session_info = {}
for key in keys:
key_type = await redis.type(key)
info = parse_key_info(key)
if not info:
print(f" [跳过] 无法解析键: {key}")
continue
uid = info.get("user_id")
sid = info.get("session_id")
if key_type == "hash":
meta = await redis.hgetall(key)
session_info[str(sid)] = meta
elif key_type == "list":
messages = await redis.lrange(key, 0, -1)
for raw in messages:
try:
msg = json.loads(raw)
role = msg.get("role", "user")
content = msg.get("content", "")
ts_str = msg.get("ts", msg.get("timestamp", ""))
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_messages (user_id, flow_id, session_id, role, content, created_at)
VALUES (:uid, :fid, :sid, :role, :content, :ts)
ON CONFLICT DO NOTHING
"""),
{
"uid": uid,
"fid": info.get("flow_id"),
"sid": sid,
"role": role,
"content": content,
"ts": created_at,
},
)
await session.commit()
migrated += 1
except (json.JSONDecodeError, Exception) as e:
print(f" [错误] 解析消息失败: {raw[:80]}... -> {e}")
return migrated
async def migrate_new_format_cache(redis: Redis, engine, dry_run: bool):
"""迁移新格式缓存: mem:cache:msgs:{uid}:{sid} (String JSON array)"""
print(">>> 扫描新格式 Redis 消息缓存 ...")
keys = await scan_redis_keys(redis, "mem:cache:msgs:*")
print(f" 找到 {len(keys)} 个缓存键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 5:
continue
try:
uid = uuid.UUID(parts[3])
sid = uuid.UUID(parts[4])
except (ValueError, IndexError):
continue
raw = await redis.get(key)
if not raw:
continue
try:
messages = json.loads(raw)
except json.JSONDecodeError:
continue
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
ts_str = msg.get("ts", "")
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_messages (user_id, session_id, role, content, created_at)
VALUES (:uid, :sid, :role, :content, :ts)
ON CONFLICT DO NOTHING
"""),
{"uid": uid, "sid": sid, "role": role, "content": content, "ts": created_at},
)
await session.commit()
migrated += 1
return migrated
async def migrate_summaries(redis: Redis, engine, dry_run: bool):
"""迁移 Redis 摘要到 PostgreSQL memory_atoms"""
print(">>> 扫描 Redis 摘要缓存 ...")
keys = await scan_redis_keys(redis, "mem:summary:*")
print(f" 找到 {len(keys)} 个摘要键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 4:
continue
try:
uid = uuid.UUID(parts[2])
sid = uuid.UUID(parts[3])
except (ValueError, IndexError):
continue
summary = await redis.get(key)
if not summary or len(summary.strip()) < 10:
continue
if not dry_run:
async with AsyncSession(engine) as session:
result = await session.execute(
text("SELECT id FROM memory_atoms WHERE user_id = :uid AND source_session_id = :sid AND atom_type = 'summary'"),
{"uid": uid, "sid": sid},
)
existing = result.fetchone()
if existing:
await session.execute(
text("UPDATE memory_atoms SET content = :content, updated_at = NOW() WHERE id = :id"),
{"content": summary, "id": existing[0]},
)
else:
await session.execute(
text("""
INSERT INTO memory_atoms (user_id, atom_type, content, priority, source_session_id, created_at, updated_at)
VALUES (:uid, 'summary', :content, 60, :sid, NOW(), NOW())
"""),
{"uid": uid, "content": summary, "sid": sid},
)
await session.commit()
migrated += 1
return migrated
async def migrate_session_list(redis: Redis, engine, dry_run: bool):
"""迁移 mem:{uid}:sessions Set 中的会话列表"""
print(">>> 扫描会话列表 (mem:*:sessions)...")
keys = await scan_redis_keys(redis, "mem:*:sessions")
print(f" 找到 {len(keys)} 个会话列表键")
migrated = 0
for key in keys:
parts = key.split(":")
if len(parts) < 3:
continue
try:
uid = uuid.UUID(parts[1])
except ValueError:
continue
sessions = await redis.smembers(key)
for sid_str in sessions:
try:
sid = uuid.UUID(sid_str)
except ValueError:
continue
if not dry_run:
async with AsyncSession(engine) as session:
await session.execute(
text("""
INSERT INTO memory_sessions (user_id, session_id, last_active_at, created_at)
VALUES (:uid, :sid, NOW(), NOW())
ON CONFLICT (user_id, session_id) DO NOTHING
"""),
{"uid": uid, "sid": sid},
)
await session.commit()
migrated += 1
return migrated
async def main():
dry_run = "--dry-run" in sys.argv
batch_size = 100
for i, arg in enumerate(sys.argv):
if arg == "--batch-size" and i + 1 < len(sys.argv):
batch_size = int(sys.argv[i + 1])
print("=" * 60)
print("Redis → PostgreSQL 记忆数据迁移")
print(f"模式: {'试运行(不写入)' if dry_run else '正式迁移'}")
print(f"数据库: {settings.DATABASE_URL}")
print(f"Redis: {settings.REDIS_URL}")
print("=" * 60)
engine = create_async_engine(settings.DATABASE_URL, echo=False)
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
try:
await redis.ping()
print("Redis 连接成功")
total = 0
total += await migrate_old_format_keys(redis, engine, dry_run)
total += await migrate_new_format_cache(redis, engine, dry_run)
total += await migrate_summaries(redis, engine, dry_run)
total += await migrate_session_list(redis, engine, dry_run)
print(f"\n迁移完成!共处理 {total} 条记录")
if dry_run:
print("提示: 使用不带 --dry-run 参数运行以实际写入数据")
finally:
await redis.aclose()
await engine.dispose()
if __name__ == "__main__":
asyncio.run(main())