You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.7 KiB
189 lines
6.7 KiB
from fastapi import APIRouter, Depends, UploadFile, File, Request
|
|
from database import get_db
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from dependencies import get_current_user
|
|
import os
|
|
import uuid
|
|
from config import settings
|
|
|
|
from .knowledge import add_document, add_text, search, retrieve_for_agent, get_knowledge_base
|
|
|
|
router = APIRouter(prefix="/api/rag", tags=["rag"])
|
|
|
|
|
|
@router.post("/upload")
|
|
async def rag_upload(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
|
filename = f"{uuid.uuid4().hex}_{file.filename}"
|
|
file_path = os.path.join(settings.UPLOAD_DIR, filename)
|
|
|
|
content = await file.read()
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
result = await add_document(file_path)
|
|
return {"code": 200, "message": result, "file_id": filename, "file_name": file.filename}
|
|
|
|
|
|
@router.post("/index-text")
|
|
async def rag_index_text(
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
text = payload.get("text", "")
|
|
source = payload.get("source", "manual")
|
|
if not text:
|
|
return {"code": 400, "message": "文本内容不能为空"}
|
|
result = await add_text(text, source)
|
|
return {"code": 200, "message": result}
|
|
|
|
|
|
@router.get("/search")
|
|
async def rag_search(
|
|
request: Request,
|
|
q: str = "",
|
|
limit: int = 5,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
if not q:
|
|
return {"code": 400, "message": "查询内容不能为空"}
|
|
results = await search(q, limit=limit)
|
|
return {"code": 200, "data": results, "query": q}
|
|
|
|
|
|
@router.get("/retrieve")
|
|
async def rag_retrieve(
|
|
request: Request,
|
|
q: str = "",
|
|
limit: int = 5,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
if not q:
|
|
return {"code": 400, "message": "查询内容不能为空"}
|
|
result = await retrieve_for_agent(q, limit=limit)
|
|
return {"code": 200, "data": result}
|
|
|
|
|
|
@router.get("/documents")
|
|
async def list_documents(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
try:
|
|
kb = get_knowledge_base()
|
|
if not kb or not hasattr(kb, '_embedding_store'):
|
|
return {"code": 200, "data": [], "total": 0}
|
|
|
|
store = kb._embedding_store
|
|
all_docs = []
|
|
try:
|
|
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
|
|
collection = store.client.get_collection(collection_name=store.collection_name)
|
|
total = collection.count()
|
|
offset = 0
|
|
batch_size = 100
|
|
while offset < total:
|
|
result = collection.get(offset=offset, limit=batch_size)
|
|
for doc_id, payload, vector in zip(result.ids, result.payloads, result.vectors):
|
|
source = (payload or {}).get("source", "") or (payload or {}).get("file_path", "")
|
|
content_preview = ""
|
|
if isinstance(payload, dict) and "content" in payload:
|
|
c = payload["content"]
|
|
content_preview = c[:200] if len(c) > 200 else c
|
|
all_docs.append({
|
|
"id": doc_id,
|
|
"source": source,
|
|
"content_preview": content_preview,
|
|
"metadata": payload or {},
|
|
})
|
|
offset += batch_size
|
|
|
|
seen_sources = {}
|
|
for d in all_docs:
|
|
src = d["source"] or "unknown"
|
|
if src not in seen_sources:
|
|
seen_sources[src] = {"source": src, "chunk_count": 0, "first_id": d["id"], "preview": d["content_preview"]}
|
|
seen_sources[src]["chunk_count"] += 1
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": list(seen_sources.values()),
|
|
"total_chunks": total,
|
|
"total_files": len(seen_sources),
|
|
}
|
|
except Exception as e:
|
|
return {"code": 200, "data": [], "error": str(e), "total": 0}
|
|
except Exception as e:
|
|
return {"code": 500, "message": f"获取文档列表失败: {e}"}
|
|
|
|
|
|
@router.delete("/documents/{source}")
|
|
async def delete_document(
|
|
source: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
try:
|
|
kb = get_knowledge_base()
|
|
store = kb._embedding_store
|
|
|
|
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
|
|
collection = store.client.get_collection(collection_name=store.collection_name)
|
|
result = collection.get(where={"source": source})
|
|
if result and result.ids:
|
|
collection.delete(ids=result.ids)
|
|
return {"code": 200, "message": f"已删除 {len(result.ids)} 个文档块", "deleted_count": len(result.ids)}
|
|
else:
|
|
return {"code": 404, "message": "未找到该来源的文档"}
|
|
return {"code": 500, "message": "无法连接向量存储"}
|
|
except Exception as e:
|
|
return {"code": 500, "message": f"删除文档失败: {e}"}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def knowledge_stats(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
try:
|
|
kb = get_knowledge_base()
|
|
store = kb._embedding_store
|
|
stats_data = {
|
|
"status": "initialized",
|
|
"collection_name": getattr(store, 'collection_name', 'unknown'),
|
|
"dimensions": getattr(store, 'dimensions', 0),
|
|
"total_chunks": 0,
|
|
"total_files": 0,
|
|
}
|
|
|
|
if hasattr(store, 'client') and hasattr(store.client, 'get_collection'):
|
|
try:
|
|
collection = store.client.get_collection(collection_name=store.collection_name)
|
|
stats_data["total_chunks"] = collection.count()
|
|
|
|
result = collection.get(limit=1000)
|
|
sources = set()
|
|
for p in (result.payloads or []):
|
|
if isinstance(p, dict):
|
|
s = p.get("source", "") or p.get("file_path", "")
|
|
if s:
|
|
sources.add(s)
|
|
stats_data["total_files"] = len(sources)
|
|
except Exception:
|
|
stats_data["status"] = "collection_error"
|
|
|
|
return {"code": 200, "data": stats_data}
|
|
except Exception as e:
|
|
return {"code": 500, "message": f"获取统计信息失败: {e}"}
|