You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
300 lines
10 KiB
300 lines
10 KiB
"""Redis → PostgreSQL 记忆数据迁移脚本
|
|
|
|
将 Redis 中现有记忆数据(消息、摘要)迁移到 PostgreSQL memory_messages 表。
|
|
|
|
使用方式:
|
|
python scripts/migrate_memory_redis_to_pg.py [--dry-run] [--user-id UUID]
|
|
|
|
选项:
|
|
--dry-run 仅扫描不实际写入
|
|
--user-id 只迁移指定用户的数据
|
|
--batch-size 每批写入条数 (默认100)
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
import sys
|
|
import os
|
|
from datetime import datetime
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
|
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from config import settings
|
|
|
|
|
|
async def scan_redis_keys(redis: Redis, pattern: str) -> list[str]:
|
|
keys = []
|
|
cursor = 0
|
|
while True:
|
|
cursor, batch = await redis.scan(cursor, match=pattern, count=200)
|
|
keys.extend(batch)
|
|
if cursor == 0:
|
|
break
|
|
return keys
|
|
|
|
|
|
def parse_key_info(key: str) -> dict | None:
|
|
"""从 Redis key 中解析 user_id, flow_id, session_id"""
|
|
parts = key.split(":")
|
|
info = {}
|
|
|
|
uid_idx = None
|
|
fid_idx = None
|
|
sid_idx = None
|
|
|
|
for i, p in enumerate(parts):
|
|
try:
|
|
val = uuid.UUID(p)
|
|
if uid_idx is None:
|
|
uid_idx = i
|
|
info["user_id"] = val
|
|
elif sid_idx is None:
|
|
sid_idx = i
|
|
info["session_id"] = val
|
|
elif fid_idx is None:
|
|
fid_idx = i
|
|
info["flow_id"] = val
|
|
except ValueError:
|
|
if p == "messages" and uid_idx is not None and sid_idx is not None and fid_idx is None:
|
|
fid_idx = -1
|
|
continue
|
|
|
|
if not info.get("user_id"):
|
|
return None
|
|
|
|
return info
|
|
|
|
|
|
async def migrate_old_format_keys(redis: Redis, engine, dry_run: bool):
|
|
"""迁移旧格式: mem:{uid}:{fid}:{sid}:messages (List) 和 mem:{uid}:{fid}:{sid}:meta (Hash)"""
|
|
print(">>> 扫描旧格式 Redis 记忆键 ...")
|
|
keys = await scan_redis_keys(redis, "mem:*:messages")
|
|
keys += await scan_redis_keys(redis, "mem:*:meta")
|
|
print(f" 找到 {len(keys)} 个旧格式键")
|
|
|
|
migrated = 0
|
|
session_info = {}
|
|
|
|
for key in keys:
|
|
key_type = await redis.type(key)
|
|
info = parse_key_info(key)
|
|
if not info:
|
|
print(f" [跳过] 无法解析键: {key}")
|
|
continue
|
|
|
|
uid = info.get("user_id")
|
|
sid = info.get("session_id")
|
|
|
|
if key_type == "hash":
|
|
meta = await redis.hgetall(key)
|
|
session_info[str(sid)] = meta
|
|
elif key_type == "list":
|
|
messages = await redis.lrange(key, 0, -1)
|
|
for raw in messages:
|
|
try:
|
|
msg = json.loads(raw)
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
ts_str = msg.get("ts", msg.get("timestamp", ""))
|
|
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
|
|
|
|
if not dry_run:
|
|
async with AsyncSession(engine) as session:
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO memory_messages (user_id, flow_id, session_id, role, content, created_at)
|
|
VALUES (:uid, :fid, :sid, :role, :content, :ts)
|
|
ON CONFLICT DO NOTHING
|
|
"""),
|
|
{
|
|
"uid": uid,
|
|
"fid": info.get("flow_id"),
|
|
"sid": sid,
|
|
"role": role,
|
|
"content": content,
|
|
"ts": created_at,
|
|
},
|
|
)
|
|
await session.commit()
|
|
migrated += 1
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
print(f" [错误] 解析消息失败: {raw[:80]}... -> {e}")
|
|
|
|
return migrated
|
|
|
|
|
|
async def migrate_new_format_cache(redis: Redis, engine, dry_run: bool):
|
|
"""迁移新格式缓存: mem:cache:msgs:{uid}:{sid} (String JSON array)"""
|
|
print(">>> 扫描新格式 Redis 消息缓存 ...")
|
|
keys = await scan_redis_keys(redis, "mem:cache:msgs:*")
|
|
print(f" 找到 {len(keys)} 个缓存键")
|
|
|
|
migrated = 0
|
|
for key in keys:
|
|
parts = key.split(":")
|
|
if len(parts) < 5:
|
|
continue
|
|
try:
|
|
uid = uuid.UUID(parts[3])
|
|
sid = uuid.UUID(parts[4])
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
raw = await redis.get(key)
|
|
if not raw:
|
|
continue
|
|
try:
|
|
messages = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
ts_str = msg.get("ts", "")
|
|
created_at = datetime.fromisoformat(ts_str) if ts_str else datetime.utcnow()
|
|
|
|
if not dry_run:
|
|
async with AsyncSession(engine) as session:
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO memory_messages (user_id, session_id, role, content, created_at)
|
|
VALUES (:uid, :sid, :role, :content, :ts)
|
|
ON CONFLICT DO NOTHING
|
|
"""),
|
|
{"uid": uid, "sid": sid, "role": role, "content": content, "ts": created_at},
|
|
)
|
|
await session.commit()
|
|
migrated += 1
|
|
|
|
return migrated
|
|
|
|
|
|
async def migrate_summaries(redis: Redis, engine, dry_run: bool):
|
|
"""迁移 Redis 摘要到 PostgreSQL memory_atoms"""
|
|
print(">>> 扫描 Redis 摘要缓存 ...")
|
|
keys = await scan_redis_keys(redis, "mem:summary:*")
|
|
print(f" 找到 {len(keys)} 个摘要键")
|
|
|
|
migrated = 0
|
|
for key in keys:
|
|
parts = key.split(":")
|
|
if len(parts) < 4:
|
|
continue
|
|
try:
|
|
uid = uuid.UUID(parts[2])
|
|
sid = uuid.UUID(parts[3])
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
summary = await redis.get(key)
|
|
if not summary or len(summary.strip()) < 10:
|
|
continue
|
|
|
|
if not dry_run:
|
|
async with AsyncSession(engine) as session:
|
|
result = await session.execute(
|
|
text("SELECT id FROM memory_atoms WHERE user_id = :uid AND source_session_id = :sid AND atom_type = 'summary'"),
|
|
{"uid": uid, "sid": sid},
|
|
)
|
|
existing = result.fetchone()
|
|
if existing:
|
|
await session.execute(
|
|
text("UPDATE memory_atoms SET content = :content, updated_at = NOW() WHERE id = :id"),
|
|
{"content": summary, "id": existing[0]},
|
|
)
|
|
else:
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO memory_atoms (user_id, atom_type, content, priority, source_session_id, created_at, updated_at)
|
|
VALUES (:uid, 'summary', :content, 60, :sid, NOW(), NOW())
|
|
"""),
|
|
{"uid": uid, "content": summary, "sid": sid},
|
|
)
|
|
await session.commit()
|
|
migrated += 1
|
|
|
|
return migrated
|
|
|
|
|
|
async def migrate_session_list(redis: Redis, engine, dry_run: bool):
|
|
"""迁移 mem:{uid}:sessions Set 中的会话列表"""
|
|
print(">>> 扫描会话列表 (mem:*:sessions)...")
|
|
keys = await scan_redis_keys(redis, "mem:*:sessions")
|
|
print(f" 找到 {len(keys)} 个会话列表键")
|
|
|
|
migrated = 0
|
|
for key in keys:
|
|
parts = key.split(":")
|
|
if len(parts) < 3:
|
|
continue
|
|
try:
|
|
uid = uuid.UUID(parts[1])
|
|
except ValueError:
|
|
continue
|
|
|
|
sessions = await redis.smembers(key)
|
|
for sid_str in sessions:
|
|
try:
|
|
sid = uuid.UUID(sid_str)
|
|
except ValueError:
|
|
continue
|
|
|
|
if not dry_run:
|
|
async with AsyncSession(engine) as session:
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO memory_sessions (user_id, session_id, last_active_at, created_at)
|
|
VALUES (:uid, :sid, NOW(), NOW())
|
|
ON CONFLICT (user_id, session_id) DO NOTHING
|
|
"""),
|
|
{"uid": uid, "sid": sid},
|
|
)
|
|
await session.commit()
|
|
migrated += 1
|
|
|
|
return migrated
|
|
|
|
|
|
async def main():
|
|
dry_run = "--dry-run" in sys.argv
|
|
batch_size = 100
|
|
|
|
for i, arg in enumerate(sys.argv):
|
|
if arg == "--batch-size" and i + 1 < len(sys.argv):
|
|
batch_size = int(sys.argv[i + 1])
|
|
|
|
print("=" * 60)
|
|
print("Redis → PostgreSQL 记忆数据迁移")
|
|
print(f"模式: {'试运行(不写入)' if dry_run else '正式迁移'}")
|
|
print(f"数据库: {settings.DATABASE_URL}")
|
|
print(f"Redis: {settings.REDIS_URL}")
|
|
print("=" * 60)
|
|
|
|
engine = create_async_engine(settings.DATABASE_URL, echo=False)
|
|
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
|
|
try:
|
|
await redis.ping()
|
|
print("Redis 连接成功")
|
|
|
|
total = 0
|
|
total += await migrate_old_format_keys(redis, engine, dry_run)
|
|
total += await migrate_new_format_cache(redis, engine, dry_run)
|
|
total += await migrate_summaries(redis, engine, dry_run)
|
|
total += await migrate_session_list(redis, engine, dry_run)
|
|
|
|
print(f"\n迁移完成!共处理 {total} 条记录")
|
|
if dry_run:
|
|
print("提示: 使用不带 --dry-run 参数运行以实际写入数据")
|
|
finally:
|
|
await redis.aclose()
|
|
await engine.dispose()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|