You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
3.7 KiB
121 lines
3.7 KiB
import uuid
|
|
from fastapi import APIRouter, Depends, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import User, ChatSession, ChatMessage
|
|
from agentscope_integration.factory import AgentFactory
|
|
|
|
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
|
|
|
|
|
@router.post("/chat/{agent_type}")
|
|
async def agent_chat(
|
|
agent_type: str,
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""
|
|
与智能体对话。
|
|
agent_type: employee | manager | task | document
|
|
"""
|
|
user_ctx = request.state.user
|
|
user_id = uuid.UUID(user_ctx["id"])
|
|
msg_content = payload.get("message", "")
|
|
session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}")
|
|
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(404, "用户不存在")
|
|
|
|
session_result = await db.execute(
|
|
select(ChatSession).where(ChatSession.session_id == session_id)
|
|
)
|
|
session = session_result.scalar_one_or_none()
|
|
if not session:
|
|
session = ChatSession(
|
|
user_id=user.id, agent_type=agent_type,
|
|
session_id=session_id,
|
|
)
|
|
db.add(session)
|
|
await db.flush()
|
|
|
|
user_msg = ChatMessage(
|
|
session_id=session.id, user_id=user.id,
|
|
role="user", content=msg_content,
|
|
)
|
|
db.add(user_msg)
|
|
await db.flush()
|
|
|
|
agent = await AgentFactory.create_agent(
|
|
agent_type=agent_type,
|
|
user_id=str(user.id),
|
|
user_name=user.display_name,
|
|
department_id=str(user.department_id) if user.department_id else None,
|
|
)
|
|
|
|
from agentscope.message import Msg
|
|
input_msg = Msg(name="user", content=msg_content, role="user")
|
|
response = await agent.reply(input_msg)
|
|
|
|
reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response)
|
|
|
|
ai_msg = ChatMessage(
|
|
session_id=session.id, user_id=user.id,
|
|
role="assistant", content=reply_text,
|
|
)
|
|
db.add(ai_msg)
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"session_id": session_id,
|
|
"reply": reply_text,
|
|
"role": "assistant",
|
|
},
|
|
}
|
|
|
|
|
|
@router.get("/list")
|
|
async def get_agent_list(request: Request):
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{"type": "employee", "name": "员工AI助手", "description": "日常问答、文档处理、知识查询"},
|
|
{"type": "manager", "name": "管理分析助手", "description": "下属工作分析、效能评估"},
|
|
{"type": "task", "name": "任务管理助手", "description": "任务创建、分派、追踪"},
|
|
{"type": "document", "name": "文档处理助手", "description": "格式修正、内容提取、导入导出"},
|
|
],
|
|
}
|
|
|
|
|
|
@router.get("/history/{session_id}")
|
|
async def get_chat_history(
|
|
session_id: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
session_result = await db.execute(
|
|
select(ChatSession).where(ChatSession.session_id == session_id)
|
|
)
|
|
session = session_result.scalar_one_or_none()
|
|
if not session:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(404, "会话不存在")
|
|
|
|
msg_result = await db.execute(
|
|
select(ChatMessage).where(ChatMessage.session_id == session.id).order_by(ChatMessage.created_at)
|
|
)
|
|
messages = msg_result.scalars().all()
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": [{
|
|
"role": m.role,
|
|
"content": m.content,
|
|
"created_at": str(m.created_at),
|
|
} for m in messages],
|
|
}
|