"""知识库(RAG)模块路由。 提供知识库的文档上传、文本索引、语义检索和文档管理功能。 支持文件上传后自动解析和向量化,以及基于相似度的知识检索。 """ from fastapi import APIRouter, Depends, UploadFile, File, Request from database import get_db from sqlalchemy.ext.asyncio import AsyncSession from dependencies import get_current_user import os import uuid from config import settings from .knowledge import add_document, add_text, search, retrieve_for_agent, get_knowledge_base router = APIRouter(prefix="/api/rag", tags=["rag"]) @router.post("/upload") async def rag_upload( request: Request, file: UploadFile = File(...), db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """上传文档文件并自动索引到知识库。 将上传的文件保存到服务器后,调用文档解析器将其切分为文本块 并进行向量化存储到知识库中。 Args: request: HTTP 请求对象。 file: 上传的文件对象。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含索引结果和文件信息的响应数据。 """ os.makedirs(settings.UPLOAD_DIR, exist_ok=True) filename = f"{uuid.uuid4().hex}_{file.filename}" # 生成唯一文件名,避免冲突 file_path = os.path.join(settings.UPLOAD_DIR, filename) content = await file.read() with open(file_path, "wb") as f: f.write(content) result = await add_document(file_path) return {"code": 200, "message": result, "file_id": filename, "file_name": file.filename} @router.post("/index-text") async def rag_index_text( request: Request, payload: dict, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """将纯文本内容索引到知识库。 Args: request: HTTP 请求对象。 payload: 请求体,包含 text 和可选的 source 字段。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含索引结果的响应数据。 """ text = payload.get("text", "") # 待索引的文本内容 source = payload.get("source", "manual") # 文本来源标识 if not text: return {"code": 400, "message": "文本内容不能为空"} result = await add_text(text, source) return {"code": 200, "message": result} @router.get("/search") async def rag_search( request: Request, q: str = "", limit: int = 5, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """在知识库中执行语义检索。 Args: request: HTTP 请求对象。 q: 查询文本。 limit: 返回结果的最大数量。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含检索结果列表的响应数据。 """ if not q: return {"code": 400, "message": "查询内容不能为空"} results = await search(q, limit=limit) return {"code": 200, "data": results, "query": q} @router.get("/retrieve") async def rag_retrieve( request: Request, q: str = "", limit: int = 5, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """为 AI 智能体执行知识库检索,返回格式化的结果文本。 Args: request: HTTP 请求对象。 q: 查询文本。 limit: 返回结果的最大数量。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含格式化检索结果的响应数据。 """ if not q: return {"code": 400, "message": "查询内容不能为空"} result = await retrieve_for_agent(q, limit=limit) return {"code": 200, "data": result} @router.get("/documents") async def list_documents( request: Request, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """列出知识库中已索引的所有文档来源及其统计信息。 从向量存储中获取所有文档,按来源分组并统计每个来源的文档块数量。 Args: request: HTTP 请求对象。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含文档列表和统计信息的响应数据。 """ try: kb = get_knowledge_base() if not kb or not hasattr(kb, '_embedding_store'): return {"code": 200, "data": [], "total": 0} store = kb._embedding_store all_docs = [] try: if hasattr(store, 'client') and hasattr(store.client, 'get_collection'): collection = store.client.get_collection(collection_name=store.collection_name) total = collection.count() offset = 0 batch_size = 100 while offset < total: result = collection.get(offset=offset, limit=batch_size) for doc_id, payload, vector in zip(result.ids, result.payloads, result.vectors): source = (payload or {}).get("source", "") or (payload or {}).get("file_path", "") content_preview = "" if isinstance(payload, dict) and "content" in payload: c = payload["content"] content_preview = c[:200] if len(c) > 200 else c all_docs.append({ "id": doc_id, "source": source, "content_preview": content_preview, "metadata": payload or {}, }) offset += batch_size # 按来源分组统计 seen_sources = {} for d in all_docs: src = d["source"] or "unknown" if src not in seen_sources: seen_sources[src] = {"source": src, "chunk_count": 0, "first_id": d["id"], "preview": d["content_preview"]} seen_sources[src]["chunk_count"] += 1 return { "code": 200, "data": list(seen_sources.values()), "total_chunks": total, "total_files": len(seen_sources), } except Exception as e: return {"code": 200, "data": [], "error": str(e), "total": 0} except Exception as e: return {"code": 500, "message": f"获取文档列表失败: {e}"} @router.delete("/documents/{source}") async def delete_document( source: str, request: Request, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """按来源删除知识库中的文档块。 Args: source: 文档来源标识。 request: HTTP 请求对象。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含删除结果的响应数据。 """ try: kb = get_knowledge_base() store = kb._embedding_store if hasattr(store, 'client') and hasattr(store.client, 'get_collection'): collection = store.client.get_collection(collection_name=store.collection_name) result = collection.get(where={"source": source}) if result and result.ids: collection.delete(ids=result.ids) return {"code": 200, "message": f"已删除 {len(result.ids)} 个文档块", "deleted_count": len(result.ids)} else: return {"code": 404, "message": "未找到该来源的文档"} return {"code": 500, "message": "无法连接向量存储"} except Exception as e: return {"code": 500, "message": f"删除文档失败: {e}"} @router.get("/stats") async def knowledge_stats( request: Request, db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): """获取知识库的统计信息,包括文档块数量和来源文件数量。 Args: request: HTTP 请求对象。 db: 异步数据库会话。 current_user: 当前登录用户信息。 Returns: dict: 包含知识库统计信息的响应数据。 """ try: kb = get_knowledge_base() store = kb._embedding_store stats_data = { "status": "initialized", "collection_name": getattr(store, 'collection_name', 'unknown'), "dimensions": getattr(store, 'dimensions', 0), "total_chunks": 0, "total_files": 0, } if hasattr(store, 'client') and hasattr(store.client, 'get_collection'): try: collection = store.client.get_collection(collection_name=store.collection_name) stats_data["total_chunks"] = collection.count() result = collection.get(limit=1000) sources = set() for p in (result.payloads or []): if isinstance(p, dict): s = p.get("source", "") or p.get("file_path", "") if s: sources.add(s) stats_data["total_files"] = len(sources) except Exception: stats_data["status"] = "collection_error" return {"code": 200, "data": stats_data} except Exception as e: return {"code": 500, "message": f"获取统计信息失败: {e}"}