You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.1 KiB
73 lines
2.1 KiB
from fastapi import APIRouter, Depends, UploadFile, File, Request
|
|
from database import get_db
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from dependencies import get_current_user
|
|
import os
|
|
import uuid
|
|
from config import settings
|
|
|
|
from .knowledge import add_document, add_text, search, retrieve_for_agent
|
|
|
|
router = APIRouter(prefix="/api/rag", tags=["rag"])
|
|
|
|
|
|
@router.post("/upload")
|
|
async def rag_upload(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
|
filename = f"{uuid.uuid4().hex}_{file.filename}"
|
|
file_path = os.path.join(settings.UPLOAD_DIR, filename)
|
|
|
|
content = await file.read()
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
result = await add_document(file_path)
|
|
return {"code": 200, "message": result, "file_id": filename, "file_name": file.filename}
|
|
|
|
|
|
@router.post("/index-text")
|
|
async def rag_index_text(
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
text = payload.get("text", "")
|
|
source = payload.get("source", "manual")
|
|
if not text:
|
|
return {"code": 400, "message": "文本内容不能为空"}
|
|
result = await add_text(text, source)
|
|
return {"code": 200, "message": result}
|
|
|
|
|
|
@router.get("/search")
|
|
async def rag_search(
|
|
request: Request,
|
|
q: str = "",
|
|
limit: int = 5,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
if not q:
|
|
return {"code": 400, "message": "查询内容不能为空"}
|
|
results = await search(q, limit=limit)
|
|
return {"code": 200, "data": results, "query": q}
|
|
|
|
|
|
@router.get("/retrieve")
|
|
async def rag_retrieve(
|
|
request: Request,
|
|
q: str = "",
|
|
limit: int = 5,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user=Depends(get_current_user),
|
|
):
|
|
if not q:
|
|
return {"code": 400, "message": "查询内容不能为空"}
|
|
result = await retrieve_for_agent(q, limit=limit)
|
|
return {"code": 200, "data": result}
|