You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

206 lines
7.1 KiB

import uuid
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User, ChatSession, ChatMessage, AgentConfig
from schemas import AgentConfigCreate, AgentConfigUpdate, AgentConfigOut
from agentscope_integration.factory import AgentFactory
router = APIRouter(prefix="/api/agent", tags=["agent"])
@router.post("/chat/{agent_type}")
async def agent_chat(
agent_type: str,
request: Request,
payload: dict,
db: AsyncSession = Depends(get_db),
):
user_ctx = request.state.user
user_id = uuid.UUID(user_ctx["id"])
msg_content = payload.get("message", "")
session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(404, "用户不存在")
session_result = await db.execute(
select(ChatSession).where(ChatSession.session_id == session_id)
)
session = session_result.scalar_one_or_none()
if not session:
session = ChatSession(
user_id=user.id, agent_type=agent_type,
session_id=session_id,
)
db.add(session)
await db.flush()
user_msg = ChatMessage(
session_id=session.id, user_id=user.id,
role="user", content=msg_content,
)
db.add(user_msg)
await db.flush()
agent = await AgentFactory.create_agent(
agent_type=agent_type,
user_id=str(user.id),
user_name=user.display_name,
department_id=str(user.department_id) if user.department_id else None,
)
from agentscope.message import Msg
input_msg = Msg(name="user", content=msg_content, role="user")
response = await agent.reply(input_msg)
reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response)
ai_msg = ChatMessage(
session_id=session.id, user_id=user.id,
role="assistant", content=reply_text,
)
db.add(ai_msg)
await db.flush()
return {
"code": 200,
"data": {
"session_id": session_id,
"reply": reply_text,
"role": "assistant",
},
}
@router.get("/list")
async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(AgentConfig).where(AgentConfig.status == "active").order_by(AgentConfig.updated_at.desc())
)
agents = result.scalars().all()
return {
"code": 200,
"data": [{
"id": str(a.id),
"name": a.name,
"description": a.description,
"system_prompt": a.system_prompt,
"model": a.model,
"temperature": a.temperature if isinstance(a.temperature, float) else float(a.temperature) / 10.0,
"tools": a.tools or [],
"status": a.status,
} for a in agents],
}
@router.post("/", response_model=AgentConfigOut)
async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSession = Depends(get_db)):
user_ctx = request.state.user
agent = AgentConfig(
name=req.name,
description=req.description,
system_prompt=req.system_prompt,
model=req.model,
temperature=req.temperature,
tools=req.tools,
creator_id=uuid.UUID(user_ctx["id"]),
)
db.add(agent)
await db.flush()
return AgentConfigOut(
id=agent.id, name=agent.name, description=agent.description,
system_prompt=agent.system_prompt, model=agent.model,
temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0,
tools=agent.tools or [], status=agent.status,
creator_id=agent.creator_id,
created_at=agent.created_at, updated_at=agent.updated_at,
)
@router.get("/{agent_id}", response_model=AgentConfigOut)
async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(404, "Agent不存在")
return AgentConfigOut(
id=agent.id, name=agent.name, description=agent.description,
system_prompt=agent.system_prompt, model=agent.model,
temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0,
tools=agent.tools or [], status=agent.status,
creator_id=agent.creator_id,
created_at=agent.created_at, updated_at=agent.updated_at,
)
@router.put("/{agent_id}", response_model=AgentConfigOut)
async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(404, "Agent不存在")
if req.name is not None:
agent.name = req.name
if req.description is not None:
agent.description = req.description
if req.system_prompt is not None:
agent.system_prompt = req.system_prompt
if req.model is not None:
agent.model = req.model
if req.temperature is not None:
agent.temperature = req.temperature
if req.tools is not None:
agent.tools = req.tools
if req.status is not None:
agent.status = req.status
await db.flush()
return AgentConfigOut(
id=agent.id, name=agent.name, description=agent.description,
system_prompt=agent.system_prompt, model=agent.model,
temperature=agent.temperature if isinstance(agent.temperature, float) else float(agent.temperature) / 10.0,
tools=agent.tools or [], status=agent.status,
creator_id=agent.creator_id,
created_at=agent.created_at, updated_at=agent.updated_at,
)
@router.delete("/{agent_id}")
async def delete_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(404, "Agent不存在")
await db.delete(agent)
return {"code": 200, "message": "已删除"}
@router.get("/history/{session_id}")
async def get_chat_history(
session_id: str,
request: Request,
db: AsyncSession = Depends(get_db),
):
session_result = await db.execute(
select(ChatSession).where(ChatSession.session_id == session_id)
)
session = session_result.scalar_one_or_none()
if not session:
raise HTTPException(404, "会话不存在")
msg_result = await db.execute(
select(ChatMessage).where(ChatMessage.session_id == session.id).order_by(ChatMessage.created_at)
)
messages = msg_result.scalars().all()
return {
"code": 200,
"data": [{
"role": m.role,
"content": m.content,
"created_at": str(m.created_at),
} for m in messages],
}