"""Redis → PostgreSQL 记忆数据迁移脚本 将 Redis 中现有记忆数据(消息、摘要)迁移到 PostgreSQL memory_messages 表。 使用方式: python scripts/migrate_memory_redis_to_pg.py [--dry-run] [--user-id UUID] 选项: --dry-run 仅扫描不实际写入 --user-id 只迁移指定用户的数据 --batch-size 每批写入条数 (默认100) """ import asyncio import json import uuid import sys import os from datetime import datetime sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend")) from redis.asyncio import Redis from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from config import settings async def scan_redis_keys(redis: Redis, pattern: str) -> list[str]: keys = [] cursor = 0 while True: cursor, batch = await redis.scan(cursor, match=pattern, count=200) keys.extend(batch) if cursor == 0: break return keys def parse_key_info(key: str) -> dict | None: """从 Redis key 中解析 user_id, flow_id, session_id""" parts = key.split(":") info = {} uid_idx = None fid_idx = None sid_idx = None for i, p in enumerate(parts): try: val = uuid.UUID(p) if uid_idx is None: uid_idx = i info["user_id"] = val elif sid_idx is None: sid_idx = i info["session_id"] = val elif fid_idx is None: fid_idx = i info["flow_id"] = val except ValueError: if p == "messages" and uid_idx is not None and sid_idx is not None and fid_idx is None: fid_idx = -1 continue if not info.get("user_id"): return None return info async def migrate_old_format_keys(redis: Redis, engine, dry_run: bool): """迁移旧格式: mem:{uid}:{fid}:{sid}:messages (List) 和 mem:{uid}:{fid}:{sid}:meta (Hash)""" print(">>> 扫描旧格式 Redis 记忆键 ...") keys = await scan_redis_keys(redis, "mem:*:messages") keys += await scan_redis_keys(redis, "mem:*:meta") print(f" 找到 {len(keys)} 个旧格式键") migrated = 0 session_info = {} for key in keys: key_type = await redis.type(key) info = parse_key_info(key) if not info: print(f" [跳过] 无法解析键: {key}") continue uid = info.get("user_id") sid = info.get("session_id") if key_type == "hash": meta = await redis.hgetall(key) session_info[str(sid)] = meta elif key_type == "list": messages = await redis.lrange(key, 0, -1) for raw in messages: try: msg = json.loads(raw) role = msg.get("role", "user") content = msg.get("content", "") ts_str = msg.get("ts", msg.get("timestamp", "")) created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow() if not dry_run: async with AsyncSession(engine) as session: await session.execute( text(""" INSERT INTO memory_messages (user_id, flow_id, session_id, role, content, created_at) VALUES (:uid, :fid, :sid, :role, :content, :ts) ON CONFLICT DO NOTHING """), { "uid": uid, "fid": info.get("flow_id"), "sid": sid, "role": role, "content": content, "ts": created_at, }, ) await session.commit() migrated += 1 except (json.JSONDecodeError, Exception) as e: print(f" [错误] 解析消息失败: {raw[:80]}... -> {e}") return migrated async def migrate_new_format_cache(redis: Redis, engine, dry_run: bool): """迁移新格式缓存: mem:cache:msgs:{uid}:{sid} (String JSON array)""" print(">>> 扫描新格式 Redis 消息缓存 ...") keys = await scan_redis_keys(redis, "mem:cache:msgs:*") print(f" 找到 {len(keys)} 个缓存键") migrated = 0 for key in keys: parts = key.split(":") if len(parts) < 5: continue try: uid = uuid.UUID(parts[3]) sid = uuid.UUID(parts[4]) except (ValueError, IndexError): continue raw = await redis.get(key) if not raw: continue try: messages = json.loads(raw) except json.JSONDecodeError: continue for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") ts_str = msg.get("ts", "") created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow() if not dry_run: async with AsyncSession(engine) as session: await session.execute( text(""" INSERT INTO memory_messages (user_id, session_id, role, content, created_at) VALUES (:uid, :sid, :role, :content, :ts) ON CONFLICT DO NOTHING """), {"uid": uid, "sid": sid, "role": role, "content": content, "ts": created_at}, ) await session.commit() migrated += 1 return migrated async def migrate_summaries(redis: Redis, engine, dry_run: bool): """迁移 Redis 摘要到 PostgreSQL memory_atoms""" print(">>> 扫描 Redis 摘要缓存 ...") keys = await scan_redis_keys(redis, "mem:summary:*") print(f" 找到 {len(keys)} 个摘要键") migrated = 0 for key in keys: parts = key.split(":") if len(parts) < 4: continue try: uid = uuid.UUID(parts[2]) sid = uuid.UUID(parts[3]) except (ValueError, IndexError): continue summary = await redis.get(key) if not summary or len(summary.strip()) < 10: continue if not dry_run: async with AsyncSession(engine) as session: result = await session.execute( text("SELECT id FROM memory_atoms WHERE user_id = :uid AND source_session_id = :sid AND atom_type = 'summary'"), {"uid": uid, "sid": sid}, ) existing = result.fetchone() if existing: await session.execute( text("UPDATE memory_atoms SET content = :content, updated_at = NOW() WHERE id = :id"), {"content": summary, "id": existing[0]}, ) else: await session.execute( text(""" INSERT INTO memory_atoms (user_id, atom_type, content, priority, source_session_id, created_at, updated_at) VALUES (:uid, 'summary', :content, 60, :sid, NOW(), NOW()) """), {"uid": uid, "content": summary, "sid": sid}, ) await session.commit() migrated += 1 return migrated async def migrate_session_list(redis: Redis, engine, dry_run: bool): """迁移 mem:{uid}:sessions Set 中的会话列表""" print(">>> 扫描会话列表 (mem:*:sessions)...") keys = await scan_redis_keys(redis, "mem:*:sessions") print(f" 找到 {len(keys)} 个会话列表键") migrated = 0 for key in keys: parts = key.split(":") if len(parts) < 3: continue try: uid = uuid.UUID(parts[1]) except ValueError: continue sessions = await redis.smembers(key) for sid_str in sessions: try: sid = uuid.UUID(sid_str) except ValueError: continue if not dry_run: async with AsyncSession(engine) as session: await session.execute( text(""" INSERT INTO memory_sessions (user_id, session_id, last_active_at, created_at) VALUES (:uid, :sid, NOW(), NOW()) ON CONFLICT (user_id, session_id) DO NOTHING """), {"uid": uid, "sid": sid}, ) await session.commit() migrated += 1 return migrated async def main(): dry_run = "--dry-run" in sys.argv batch_size = 100 for i, arg in enumerate(sys.argv): if arg == "--batch-size" and i + 1 < len(sys.argv): batch_size = int(sys.argv[i + 1]) print("=" * 60) print("Redis → PostgreSQL 记忆数据迁移") print(f"模式: {'试运行(不写入)' if dry_run else '正式迁移'}") print(f"数据库: {settings.DATABASE_URL}") print(f"Redis: {settings.REDIS_URL}") print("=" * 60) engine = create_async_engine(settings.DATABASE_URL, echo=False) redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) try: await redis.ping() print("Redis 连接成功") total = 0 total += await migrate_old_format_keys(redis, engine, dry_run) total += await migrate_new_format_cache(redis, engine, dry_run) total += await migrate_summaries(redis, engine, dry_run) total += await migrate_session_list(redis, engine, dry_run) print(f"\n迁移完成!共处理 {total} 条记录") if dry_run: print("提示: 使用不带 --dry-run 参数运行以实际写入数据") finally: await redis.aclose() await engine.dispose() if __name__ == "__main__": asyncio.run(main())