You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
6.8 KiB
196 lines
6.8 KiB
import uuid
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import MCPService, AuditLog
|
|
from schemas import MCPServiceCreate, MCPServiceUpdate, MCPServiceOut
|
|
from dependencies import get_current_user
|
|
|
|
router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
|
|
|
|
|
@router.get("/servers", response_model=list[MCPServiceOut])
|
|
async def list_servers(request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(MCPService).order_by(MCPService.updated_at.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/servers/{server_id}", response_model=MCPServiceOut)
|
|
async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(MCPService).where(MCPService.id == server_id))
|
|
server = result.scalar_one_or_none()
|
|
if not server:
|
|
raise HTTPException(404, "MCP服务不存在")
|
|
return server
|
|
|
|
|
|
@router.post("/servers", response_model=MCPServiceOut)
|
|
async def register_server(
|
|
req: MCPServiceCreate,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
existing = await db.execute(select(MCPService).where(MCPService.name == req.name))
|
|
if existing.scalar_one_or_none():
|
|
raise HTTPException(400, "服务名称已存在")
|
|
|
|
server = MCPService(
|
|
name=req.name,
|
|
transport=req.transport,
|
|
url=req.url,
|
|
command=req.command,
|
|
args=req.args,
|
|
env=req.env,
|
|
creator_id=uuid.UUID(user["id"]),
|
|
)
|
|
db.add(server)
|
|
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(user["id"]),
|
|
action="mcp.register",
|
|
resource="mcp_service",
|
|
resource_id=req.name,
|
|
detail={"name": req.name, "transport": req.transport},
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
|
|
await db.flush()
|
|
return server
|
|
|
|
|
|
@router.put("/servers/{server_id}", response_model=MCPServiceOut)
|
|
async def update_server(
|
|
server_id: uuid.UUID, req: MCPServiceUpdate,
|
|
request: Request, db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(select(MCPService).where(MCPService.id == server_id))
|
|
server = result.scalar_one_or_none()
|
|
if not server:
|
|
raise HTTPException(404, "MCP服务不存在")
|
|
|
|
if req.transport is not None:
|
|
server.transport = req.transport
|
|
if req.url is not None:
|
|
server.url = req.url
|
|
if req.command is not None:
|
|
server.command = req.command
|
|
if req.args is not None:
|
|
server.args = req.args
|
|
if req.env is not None:
|
|
server.env = req.env
|
|
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(user["id"]),
|
|
action="mcp.update",
|
|
resource="mcp_service",
|
|
resource_id=str(server_id),
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
await db.flush()
|
|
return server
|
|
|
|
|
|
@router.delete("/servers/{server_id}")
|
|
async def delete_server(
|
|
server_id: uuid.UUID, request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(select(MCPService).where(MCPService.id == server_id))
|
|
server = result.scalar_one_or_none()
|
|
if not server:
|
|
raise HTTPException(404, "MCP服务不存在")
|
|
|
|
await db.delete(server)
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(user["id"]),
|
|
action="mcp.delete",
|
|
resource="mcp_service",
|
|
resource_id=str(server_id),
|
|
detail={"name": server.name},
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
await db.flush()
|
|
return {"code": 200, "message": "已注销"}
|
|
|
|
|
|
@router.post("/servers/{server_id}/test")
|
|
async def test_connection(
|
|
server_id: uuid.UUID, request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(select(MCPService).where(MCPService.id == server_id))
|
|
server = result.scalar_one_or_none()
|
|
if not server:
|
|
raise HTTPException(404, "MCP服务不存在")
|
|
|
|
test_results = {"connectivity": False, "tools_discovered": 0, "tools": [], "error": None}
|
|
|
|
if server.transport == "http" and server.url:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(server.url.rstrip("/") + "/.well-known/mcp")
|
|
if resp.status_code == 200:
|
|
test_results["connectivity"] = True
|
|
data = resp.json()
|
|
tools = data.get("tools", [])
|
|
test_results["tools_discovered"] = len(tools)
|
|
test_results["tools"] = [{"name": t.get("name", ""), "description": t.get("description", "")} for t in tools]
|
|
server.tools = test_results["tools"]
|
|
server.status = "connected"
|
|
else:
|
|
test_results["error"] = f"HTTP {resp.status_code}"
|
|
server.status = "error"
|
|
except Exception as e:
|
|
test_results["error"] = str(e)
|
|
server.status = "error"
|
|
|
|
audit = AuditLog(
|
|
operator_id=uuid.UUID(user["id"]),
|
|
action="mcp.test",
|
|
resource="mcp_service",
|
|
resource_id=str(server_id),
|
|
detail={"name": server.name, "result": "connected" if test_results["connectivity"] else "failed"},
|
|
ip_address=request.client.host if request.client else None,
|
|
)
|
|
db.add(audit)
|
|
await db.flush()
|
|
|
|
return {"code": 200, "data": test_results}
|
|
|
|
|
|
@router.post("/servers/{server_id}/discover-tools")
|
|
async def discover_tools(
|
|
server_id: uuid.UUID, request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(select(MCPService).where(MCPService.id == server_id))
|
|
server = result.scalar_one_or_none()
|
|
if not server:
|
|
raise HTTPException(404, "MCP服务不存在")
|
|
|
|
tools = []
|
|
if server.transport == "http" and server.url:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.post(server.url.rstrip("/") + "/tools/list", json={})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
tools = data.get("tools", [])
|
|
server.tools = [{"name": t.get("name", ""), "description": t.get("description", ""), "inputSchema": t.get("inputSchema", {})} for t in tools]
|
|
server.status = "connected"
|
|
except Exception as e:
|
|
raise HTTPException(500, f"工具发现失败: {str(e)}")
|
|
|
|
await db.flush()
|
|
return {"code": 200, "data": {"tools": server.tools, "count": len(server.tools)}}
|