You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.9 KiB
134 lines
4.9 KiB
import os
|
|
import asyncio
|
|
import logging
|
|
from agentscope.embedding import OpenAITextEmbedding
|
|
from agentscope.rag import SimpleKnowledge, QdrantStore, TextReader, PDFReader, WordReader, ExcelReader
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_knowledge_base: SimpleKnowledge | None = None
|
|
_STORE_PATH = os.path.join(settings.UPLOAD_DIR, "..", "data", "qdrant")
|
|
_COLLECTION_NAME = "enterprise_knowledge"
|
|
_VECTOR_DIM = 1536
|
|
|
|
|
|
def _get_embedding_model():
|
|
return OpenAITextEmbedding(
|
|
api_key=settings.LLM_API_KEY,
|
|
model_name="text-embedding-3-small",
|
|
dimensions=_VECTOR_DIM,
|
|
)
|
|
|
|
|
|
def get_knowledge_base() -> SimpleKnowledge:
|
|
global _knowledge_base
|
|
if _knowledge_base is None:
|
|
os.makedirs(_STORE_PATH, exist_ok=True)
|
|
store = QdrantStore(
|
|
location=_STORE_PATH,
|
|
collection_name=_COLLECTION_NAME,
|
|
dimensions=_VECTOR_DIM,
|
|
)
|
|
_knowledge_base = SimpleKnowledge(
|
|
embedding_store=store,
|
|
embedding_model=_get_embedding_model(),
|
|
)
|
|
logger.info(f"知识库已初始化: {_STORE_PATH}")
|
|
return _knowledge_base
|
|
|
|
|
|
async def add_document(file_path: str, file_type: str = "auto") -> str:
|
|
try:
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
kb = get_knowledge_base()
|
|
|
|
if file_type == "auto":
|
|
if ext == ".pdf":
|
|
reader = PDFReader(chunk_size=1024, split_by="sentence")
|
|
documents = await reader(pdf_path=file_path)
|
|
elif ext in (".docx", ".doc"):
|
|
reader = WordReader(chunk_size=1024)
|
|
documents = await reader(file_path=file_path)
|
|
elif ext in (".xlsx", ".xls"):
|
|
reader = ExcelReader(chunk_size=1024)
|
|
documents = await reader(file_path=file_path)
|
|
else:
|
|
reader = TextReader(chunk_size=1024, split_by="sentence")
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
documents = await reader(text=content)
|
|
else:
|
|
if file_type == "pdf":
|
|
reader = PDFReader(chunk_size=1024, split_by="sentence")
|
|
documents = await reader(pdf_path=file_path)
|
|
elif file_type == "word":
|
|
reader = WordReader(chunk_size=1024)
|
|
documents = await reader(file_path=file_path)
|
|
elif file_type == "excel":
|
|
reader = ExcelReader(chunk_size=1024)
|
|
documents = await reader(file_path=file_path)
|
|
else:
|
|
reader = TextReader(chunk_size=1024)
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
documents = await reader(text=content)
|
|
|
|
await kb.add_documents(documents)
|
|
filenames = set(d.metadata.file_path for d in documents)
|
|
return f"成功索引 {len(documents)} 个文档块 (来自 {len(filenames)} 个文件)"
|
|
except Exception as e:
|
|
logger.error(f"文档索引失败: {e}")
|
|
return f"文档索引失败: {e}"
|
|
|
|
|
|
async def add_text(text: str, source: str = "manual") -> str:
|
|
try:
|
|
kb = get_knowledge_base()
|
|
reader = TextReader(chunk_size=1024, split_by="sentence")
|
|
documents = await reader(text=text)
|
|
for doc in documents:
|
|
doc.metadata.source = source
|
|
await kb.add_documents(documents)
|
|
return f"成功索引 {len(documents)} 个文档块"
|
|
except Exception as e:
|
|
logger.error(f"文本索引失败: {e}")
|
|
return f"文本索引失败: {e}"
|
|
|
|
|
|
async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> list[dict]:
|
|
try:
|
|
kb = get_knowledge_base()
|
|
if not kb or not hasattr(kb, 'retrieve'):
|
|
logger.warning("知识库未初始化或不可用")
|
|
return []
|
|
docs = await asyncio.wait_for(
|
|
kb.retrieve(query=query, limit=limit, score_threshold=score_threshold),
|
|
timeout=10.0
|
|
)
|
|
return [
|
|
{
|
|
"id": doc.id,
|
|
"content": doc.metadata.content.get("text", "")[:500],
|
|
"score": round(doc.score, 4) if doc.score else 0,
|
|
"source": doc.metadata.source or doc.metadata.file_path or "",
|
|
}
|
|
for doc in docs
|
|
]
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"知识检索超时 (query={query[:50]})")
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"知识检索失败: {e}")
|
|
return []
|
|
|
|
|
|
async def retrieve_for_agent(query: str, limit: int = 5) -> str:
|
|
results = await search(query, limit=limit)
|
|
if not results:
|
|
return "未找到相关文档。"
|
|
|
|
parts = ["根据知识库检索到以下相关内容:"]
|
|
for i, r in enumerate(results, 1):
|
|
parts.append(f"\n[{i}] (相关度: {r['score']})\n{r['content']}")
|
|
return "\n".join(parts)
|