You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
6.9 KiB
205 lines
6.9 KiB
import uuid
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import User, ChatSession, ChatMessage, AgentConfig
|
|
from schemas import AgentConfigCreate, AgentConfigUpdate, AgentConfigOut
|
|
from agentscope_integration.factory import AgentFactory
|
|
|
|
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
|
|
|
|
|
@router.post("/chat/{agent_type}")
|
|
async def agent_chat(
|
|
agent_type: str,
|
|
request: Request,
|
|
payload: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
user_ctx = request.state.user
|
|
user_id = uuid.UUID(user_ctx["id"])
|
|
msg_content = payload.get("message", "")
|
|
session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}")
|
|
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(404, "用户不存在")
|
|
|
|
session_result = await db.execute(
|
|
select(ChatSession).where(ChatSession.session_id == session_id)
|
|
)
|
|
session = session_result.scalar_one_or_none()
|
|
if not session:
|
|
session = ChatSession(
|
|
user_id=user.id, agent_type=agent_type,
|
|
session_id=session_id,
|
|
)
|
|
db.add(session)
|
|
await db.flush()
|
|
|
|
user_msg = ChatMessage(
|
|
session_id=session.id, user_id=user.id,
|
|
role="user", content=msg_content,
|
|
)
|
|
db.add(user_msg)
|
|
await db.flush()
|
|
|
|
agent = await AgentFactory.create_agent(
|
|
agent_type=agent_type,
|
|
user_id=str(user.id),
|
|
user_name=user.display_name,
|
|
department_id=str(user.department_id) if user.department_id else None,
|
|
)
|
|
|
|
from agentscope.message import Msg
|
|
input_msg = Msg(name="user", content=msg_content, role="user")
|
|
response = await agent.reply(input_msg)
|
|
|
|
reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response)
|
|
|
|
ai_msg = ChatMessage(
|
|
session_id=session.id, user_id=user.id,
|
|
role="assistant", content=reply_text,
|
|
)
|
|
db.add(ai_msg)
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"session_id": session_id,
|
|
"reply": reply_text,
|
|
"role": "assistant",
|
|
},
|
|
}
|
|
|
|
|
|
@router.get("/list")
|
|
async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(AgentConfig).where(AgentConfig.status == "active").order_by(AgentConfig.updated_at.desc())
|
|
)
|
|
agents = result.scalars().all()
|
|
return {
|
|
"code": 200,
|
|
"data": [{
|
|
"id": str(a.id),
|
|
"name": a.name,
|
|
"description": a.description,
|
|
"system_prompt": a.system_prompt,
|
|
"model": a.model,
|
|
"temperature": float(a.temperature) / 10.0,
|
|
"tools": a.tools or [],
|
|
"status": a.status,
|
|
} for a in agents],
|
|
}
|
|
|
|
|
|
@router.post("/", response_model=AgentConfigOut)
|
|
async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSession = Depends(get_db)):
|
|
user_ctx = request.state.user
|
|
agent = AgentConfig(
|
|
name=req.name,
|
|
description=req.description,
|
|
system_prompt=req.system_prompt,
|
|
model=req.model,
|
|
temperature=int(req.temperature * 10),
|
|
tools=req.tools,
|
|
creator_id=uuid.UUID(user_ctx["id"]),
|
|
)
|
|
db.add(agent)
|
|
await db.flush()
|
|
return AgentConfigOut(
|
|
id=agent.id, name=agent.name, description=agent.description,
|
|
system_prompt=agent.system_prompt, model=agent.model,
|
|
temperature=float(agent.temperature) / 10.0,
|
|
tools=agent.tools or [], status=agent.status,
|
|
creator_id=agent.creator_id,
|
|
created_at=agent.created_at, updated_at=agent.updated_at,
|
|
)
|
|
|
|
|
|
@router.get("/{agent_id}", response_model=AgentConfigOut)
|
|
async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
if not agent:
|
|
raise HTTPException(404, "Agent不存在")
|
|
return AgentConfigOut(
|
|
id=agent.id, name=agent.name, description=agent.description,
|
|
system_prompt=agent.system_prompt, model=agent.model,
|
|
temperature=float(agent.temperature) / 10.0,
|
|
tools=agent.tools or [], status=agent.status,
|
|
creator_id=agent.creator_id,
|
|
created_at=agent.created_at, updated_at=agent.updated_at,
|
|
)
|
|
|
|
|
|
@router.put("/{agent_id}", response_model=AgentConfigOut)
|
|
async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
if not agent:
|
|
raise HTTPException(404, "Agent不存在")
|
|
if req.name is not None:
|
|
agent.name = req.name
|
|
if req.description is not None:
|
|
agent.description = req.description
|
|
if req.system_prompt is not None:
|
|
agent.system_prompt = req.system_prompt
|
|
if req.model is not None:
|
|
agent.model = req.model
|
|
if req.temperature is not None:
|
|
agent.temperature = int(req.temperature * 10)
|
|
if req.tools is not None:
|
|
agent.tools = req.tools
|
|
if req.status is not None:
|
|
agent.status = req.status
|
|
await db.flush()
|
|
return AgentConfigOut(
|
|
id=agent.id, name=agent.name, description=agent.description,
|
|
system_prompt=agent.system_prompt, model=agent.model,
|
|
temperature=float(agent.temperature) / 10.0,
|
|
tools=agent.tools or [], status=agent.status,
|
|
creator_id=agent.creator_id,
|
|
created_at=agent.created_at, updated_at=agent.updated_at,
|
|
)
|
|
|
|
|
|
@router.delete("/{agent_id}")
|
|
async def delete_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
if not agent:
|
|
raise HTTPException(404, "Agent不存在")
|
|
await db.delete(agent)
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/history/{session_id}")
|
|
async def get_chat_history(
|
|
session_id: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
session_result = await db.execute(
|
|
select(ChatSession).where(ChatSession.session_id == session_id)
|
|
)
|
|
session = session_result.scalar_one_or_none()
|
|
if not session:
|
|
raise HTTPException(404, "会话不存在")
|
|
|
|
msg_result = await db.execute(
|
|
select(ChatMessage).where(ChatMessage.session_id == session.id).order_by(ChatMessage.created_at)
|
|
)
|
|
messages = msg_result.scalars().all()
|
|
|
|
return {
|
|
"code": 200,
|
|
"data": [{
|
|
"role": m.role,
|
|
"content": m.content,
|
|
"created_at": str(m.created_at),
|
|
} for m in messages],
|
|
}
|