You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

249 lines
8.1 KiB

import uuid
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import CustomTool
from schemas import CustomToolCreate, CustomToolUpdate, CustomToolOut, OpenAPIImportRequest
from modules.custom_tool.parser import OpenAPIParser
from modules.custom_tool.executor import CustomToolExecutor
from modules.flow_engine.engine import ToolNodeAgent
from dependencies import get_current_user
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"])
@router.post("/import-openapi")
async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)):
user_ctx = request.state.user
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(req.openapi_url)
resp.raise_for_status()
spec = resp.json()
except httpx.HTTPError as e:
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
except ValueError:
raise HTTPException(400, "OpenAPI 文档不是有效的 JSON 格式")
parser = OpenAPIParser(spec)
tools = parser.parse_tools()
if not tools:
raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具")
base_url = req.base_url_override or parser.base_url
if not base_url:
raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override")
created = []
for t in tools:
existing = await db.execute(
select(CustomTool).where(CustomTool.name == t["name"])
)
if existing.scalar_one_or_none():
continue
tool = CustomTool(
name=t["name"],
description=t["description"],
schema_json=t["parameters"],
endpoint_url=base_url,
method=t["method"],
path=t["path"],
created_by=uuid.UUID(user_ctx["id"]),
)
db.add(tool)
created.append(t["name"])
ToolNodeAgent.register_custom_tool(
t["name"],
t["parameters"],
{
"endpoint_url": base_url,
"method": t["method"],
"path": t["path"],
"headers_json": {},
"auth_type": "none",
"auth_config": {},
"timeout": 30,
},
)
await db.flush()
return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}}
@router.post("/", response_model=CustomToolOut)
async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)):
user_ctx = request.state.user
user_id = uuid.UUID(user_ctx["id"])
if req.openapi_url:
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(req.openapi_url)
resp.raise_for_status()
spec = resp.json()
except Exception as e:
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
parser = OpenAPIParser(spec)
tools = parser.parse_tools()
created_tool = None
for t in tools:
if t["name"] == req.name or (not req.name and tools):
existing = await db.execute(
select(CustomTool).where(CustomTool.name == t["name"])
)
if existing.scalar_one_or_none():
continue
tool = CustomTool(
name=t["name"],
description=t["description"],
schema_json=t["parameters"],
endpoint_url=parser.base_url,
method=t["method"],
path=t["path"],
created_by=user_id,
)
db.add(tool)
created_tool = tool
break
if not created_tool:
raise HTTPException(400, "未找到匹配的工具")
await db.flush()
return created_tool
schema_json = req.tool_schema or {}
if not schema_json and req.endpoint_url:
schema_json = {
"type": "object",
"properties": {},
"description": req.description or "",
}
tool = CustomTool(
name=req.name,
description=req.description,
schema_json=schema_json,
endpoint_url=req.endpoint_url or "",
method=req.method,
path=req.path,
headers_json=req.headers,
auth_type=req.auth_type,
auth_config=req.auth_config,
created_by=user_id,
)
db.add(tool)
ToolNodeAgent.register_custom_tool(
req.name,
schema_json,
{
"endpoint_url": req.endpoint_url or "",
"method": req.method,
"path": req.path,
"headers_json": req.headers,
"auth_type": req.auth_type,
"auth_config": req.auth_config,
"timeout": 30,
},
)
await db.flush()
return tool
@router.get("/", response_model=list[CustomToolOut])
async def list_custom_tools(db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc())
)
return result.scalars().all()
@router.get("/{tool_id}", response_model=CustomToolOut)
async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
return tool
@router.put("/{tool_id}", response_model=CustomToolOut)
async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)):
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
if req.name is not None:
tool.name = req.name
if req.description is not None:
tool.description = req.description
if req.endpoint_url is not None:
tool.endpoint_url = req.endpoint_url
if req.method is not None:
tool.method = req.method
if req.path is not None:
tool.path = req.path
if req.headers is not None:
tool.headers_json = req.headers
if req.auth_type is not None:
tool.auth_type = req.auth_type
if req.auth_config is not None:
tool.auth_config = req.auth_config
if req.tool_schema is not None:
tool.schema_json = req.tool_schema
if req.is_active is not None:
tool.is_active = req.is_active
await db.flush()
return tool
@router.delete("/{tool_id}")
async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
tool.is_active = False
await db.flush()
return {"code": 200, "message": "工具已停用"}
@router.post("/{tool_id}/test")
async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)):
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
if params is None:
params = {}
executor = CustomToolExecutor({
"endpoint_url": tool.endpoint_url,
"method": tool.method,
"path": tool.path,
"headers_json": tool.headers_json,
"auth_type": tool.auth_type,
"auth_config": tool.auth_config,
})
try:
result = await executor.execute(params)
return {"code": 200, "data": {"result": result}}
except Exception as e:
raise HTTPException(500, f"工具执行失败: {str(e)}")
@router.get("/schemas/all")
async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True)
)
tools = result.scalars().all()
schemas = {}
for t in tools:
schemas[t.name] = t.schema_json
return {"code": 200, "data": schemas}