You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

277 lines
9.5 KiB

"""知识库(RAG)模块路由。
提供知识库的文档上传、文本索引、语义检索和文档管理功能。
支持文件上传后自动解析和向量化,以及基于相似度的知识检索。
"""
from fastapi import APIRouter, Depends, UploadFile, File, Request
from database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
from dependencies import get_current_user
import os
import uuid
from config import settings
from .knowledge import add_document, add_text, search, retrieve_for_agent, get_knowledge_base
router = APIRouter(prefix="/api/rag", tags=["rag"])
@router.post("/upload")
async def rag_upload(
request: Request,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""上传文档文件并自动索引到知识库。
将上传的文件保存到服务器后,调用文档解析器将其切分为文本块
并进行向量化存储到知识库中。
Args:
request: HTTP 请求对象。
file: 上传的文件对象。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含索引结果和文件信息的响应数据。
"""
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
filename = f"{uuid.uuid4().hex}_{file.filename}" # 生成唯一文件名,避免冲突
file_path = os.path.join(settings.UPLOAD_DIR, filename)
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
result = await add_document(file_path)
return {"code": 200, "message": result, "file_id": filename, "file_name": file.filename}
@router.post("/index-text")
async def rag_index_text(
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""将纯文本内容索引到知识库。
Args:
request: HTTP 请求对象。
payload: 请求体,包含 text 和可选的 source 字段。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含索引结果的响应数据。
"""
text = payload.get("text", "") # 待索引的文本内容
source = payload.get("source", "manual") # 文本来源标识
if not text:
return {"code": 400, "message": "文本内容不能为空"}
result = await add_text(text, source)
return {"code": 200, "message": result}
@router.get("/search")
async def rag_search(
request: Request,
q: str = "",
limit: int = 5,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""在知识库中执行语义检索。
Args:
request: HTTP 请求对象。
q: 查询文本。
limit: 返回结果的最大数量。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含检索结果列表的响应数据。
"""
if not q:
return {"code": 400, "message": "查询内容不能为空"}
results = await search(q, limit=limit)
return {"code": 200, "data": results, "query": q}
@router.get("/retrieve")
async def rag_retrieve(
request: Request,
q: str = "",
limit: int = 5,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""为 AI 智能体执行知识库检索,返回格式化的结果文本。
Args:
request: HTTP 请求对象。
q: 查询文本。
limit: 返回结果的最大数量。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含格式化检索结果的响应数据。
"""
if not q:
return {"code": 400, "message": "查询内容不能为空"}
result = await retrieve_for_agent(q, limit=limit)
return {"code": 200, "data": result}
@router.get("/documents")
async def list_documents(
request: Request,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出知识库中已索引的所有文档来源及其统计信息。
从向量存储中获取所有文档,按来源分组并统计每个来源的文档块数量。
Args:
request: HTTP 请求对象。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含文档列表和统计信息的响应数据。
"""
try:
kb = get_knowledge_base()
if not kb or not hasattr(kb, '_embedding_store'):
return {"code": 200, "data": [], "total": 0}
store = kb._embedding_store
all_docs = []
try:
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
collection = store.client.get_collection(collection_name=store.collection_name)
total = collection.count()
offset = 0
batch_size = 100
while offset < total:
result = collection.get(offset=offset, limit=batch_size)
for doc_id, payload, vector in zip(result.ids, result.payloads, result.vectors):
source = (payload or {}).get("source", "") or (payload or {}).get("file_path", "")
content_preview = ""
if isinstance(payload, dict) and "content" in payload:
c = payload["content"]
content_preview = c[:200] if len(c) > 200 else c
all_docs.append({
"id": doc_id,
"source": source,
"content_preview": content_preview,
"metadata": payload or {},
})
offset += batch_size
# 按来源分组统计
seen_sources = {}
for d in all_docs:
src = d["source"] or "unknown"
if src not in seen_sources:
seen_sources[src] = {"source": src, "chunk_count": 0, "first_id": d["id"], "preview": d["content_preview"]}
seen_sources[src]["chunk_count"] += 1
return {
"code": 200,
"data": list(seen_sources.values()),
"total_chunks": total,
"total_files": len(seen_sources),
}
except Exception as e:
return {"code": 200, "data": [], "error": str(e), "total": 0}
except Exception as e:
return {"code": 500, "message": f"获取文档列表失败: {e}"}
@router.delete("/documents/{source}")
async def delete_document(
source: str,
request: Request,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""按来源删除知识库中的文档块。
Args:
source: 文档来源标识。
request: HTTP 请求对象。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含删除结果的响应数据。
"""
try:
kb = get_knowledge_base()
store = kb._embedding_store
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
collection = store.client.get_collection(collection_name=store.collection_name)
result = collection.get(where={"source": source})
if result and result.ids:
collection.delete(ids=result.ids)
return {"code": 200, "message": f"已删除 {len(result.ids)} 个文档块", "deleted_count": len(result.ids)}
else:
return {"code": 404, "message": "未找到该来源的文档"}
return {"code": 500, "message": "无法连接向量存储"}
except Exception as e:
return {"code": 500, "message": f"删除文档失败: {e}"}
@router.get("/stats")
async def knowledge_stats(
request: Request,
db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取知识库的统计信息,包括文档块数量和来源文件数量。
Args:
request: HTTP 请求对象。
db: 异步数据库会话。
current_user: 当前登录用户信息。
Returns:
dict: 包含知识库统计信息的响应数据。
"""
try:
kb = get_knowledge_base()
store = kb._embedding_store
stats_data = {
"status": "initialized",
"collection_name": getattr(store, 'collection_name', 'unknown'),
"dimensions": getattr(store, 'dimensions', 0),
"total_chunks": 0,
"total_files": 0,
}
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
try:
collection = store.client.get_collection(collection_name=store.collection_name)
stats_data["total_chunks"] = collection.count()
result = collection.get(limit=1000)
sources = set()
for p in (result.payloads or []):
if isinstance(p, dict):
s = p.get("source", "") or p.get("file_path", "")
if s:
sources.add(s)
stats_data["total_files"] = len(sources)
except Exception:
stats_data["status"] = "collection_error"
return {"code": 200, "data": stats_data}
except Exception as e:
return {"code": 500, "message": f"获取统计信息失败: {e}"}