You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

353 lines
11 KiB

"""自定义工具模块路由。
提供自定义工具的创建、管理、导入和执行功能。
支持从 OpenAPI 规范自动导入工具定义,以及手动创建自定义 HTTP 工具。
"""
import uuid
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import CustomTool
from schemas import CustomToolCreate, CustomToolUpdate, CustomToolOut, OpenAPIImportRequest
from modules.custom_tool.parser import OpenAPIParser
from modules.custom_tool.executor import CustomToolExecutor
from modules.flow_engine.engine import ToolNodeAgent
from dependencies import get_current_user
import logging
logger = logging.getLogger(__name__) # 当前模块的日志记录器
router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"])
@router.post("/import-openapi")
async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)):
"""从 OpenAPI 规范 URL 导入工具定义。
自动下载 OpenAPI 文档,解析其中的 API 端点并创建对应的自定义工具。
Args:
req: OpenAPI 导入请求体,包含 openapi_url 和可选的 base_url_override。
request: HTTP 请求对象。
db: 异步数据库会话。
Returns:
dict: 包含导入成功工具列表的响应数据。
Raises:
HTTPException: 获取 OpenAPI 文档失败或解析不到工具时抛出异常。
"""
user_ctx = request.state.user
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(req.openapi_url)
resp.raise_for_status()
spec = resp.json() # 解析 OpenAPI 规范 JSON
except httpx.HTTPError as e:
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
except ValueError:
raise HTTPException(400, "OpenAPI 文档不是有效的 JSON 格式")
parser = OpenAPIParser(spec)
tools = parser.parse_tools()
if not tools:
raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具")
base_url = req.base_url_override or parser.base_url # 优先使用用户指定的基础 URL
if not base_url:
raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override")
created = []
for t in tools:
existing = await db.execute(
select(CustomTool).where(CustomTool.name == t["name"])
)
if existing.scalar_one_or_none():
continue # 跳过已存在的同名工具
tool = CustomTool(
name=t["name"],
description=t["description"],
schema_json=t["parameters"],
endpoint_url=base_url,
method=t["method"],
path=t["path"],
created_by=uuid.UUID(user_ctx["id"]),
)
db.add(tool)
created.append(t["name"])
ToolNodeAgent.register_custom_tool(
t["name"],
t["parameters"],
{
"endpoint_url": base_url,
"method": t["method"],
"path": t["path"],
"headers_json": {},
"auth_type": "none",
"auth_config": {},
"timeout": 30,
},
)
await db.flush()
return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}}
@router.post("/", response_model=CustomToolOut)
async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)):
"""创建新的自定义工具,支持手动创建或从 OpenAPI 导入。
Args:
req: 自定义工具创建请求体。
request: HTTP 请求对象。
db: 异步数据库会话。
Returns:
CustomToolOut: 创建后的自定义工具响应。
Raises:
HTTPException: 获取 OpenAPI 文档失败或创建失败时抛出异常。
"""
user_ctx = request.state.user
user_id = uuid.UUID(user_ctx["id"])
if req.openapi_url:
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(req.openapi_url)
resp.raise_for_status()
spec = resp.json()
except Exception as e:
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
parser = OpenAPIParser(spec)
tools = parser.parse_tools()
created_tool = None
for t in tools:
if t["name"] == req.name or (not req.name and tools):
existing = await db.execute(
select(CustomTool).where(CustomTool.name == t["name"])
)
if existing.scalar_one_or_none():
continue
tool = CustomTool(
name=t["name"],
description=t["description"],
schema_json=t["parameters"],
endpoint_url=parser.base_url,
method=t["method"],
path=t["path"],
created_by=user_id,
)
db.add(tool)
created_tool = tool
break
if not created_tool:
raise HTTPException(400, "未找到匹配的工具")
await db.flush()
return created_tool
# 手动创建模式
schema_json = req.tool_schema or {}
if not schema_json and req.endpoint_url:
schema_json = {
"type": "object",
"properties": {},
"description": req.description or "",
}
tool = CustomTool(
name=req.name,
description=req.description,
schema_json=schema_json,
endpoint_url=req.endpoint_url or "",
method=req.method,
path=req.path,
headers_json=req.headers,
auth_type=req.auth_type,
auth_config=req.auth_config,
created_by=user_id,
)
db.add(tool)
ToolNodeAgent.register_custom_tool(
req.name,
schema_json,
{
"endpoint_url": req.endpoint_url or "",
"method": req.method,
"path": req.path,
"headers_json": req.headers,
"auth_type": req.auth_type,
"auth_config": req.auth_config,
"timeout": 30,
},
)
await db.flush()
return tool
@router.get("/", response_model=list[CustomToolOut])
async def list_custom_tools(db: AsyncSession = Depends(get_db)):
"""列出所有处于活跃状态的自定义工具。
Args:
db: 异步数据库会话。
Returns:
list[CustomToolOut]: 活跃自定义工具列表。
"""
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc())
)
return result.scalars().all()
@router.get("/{tool_id}", response_model=CustomToolOut)
async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
"""获取指定自定义工具的详细信息。
Args:
tool_id: 自定义工具唯一标识 ID。
db: 异步数据库会话。
Returns:
CustomToolOut: 自定义工具详细信息。
Raises:
HTTPException: 工具不存在时抛出异常。
"""
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
return tool
@router.put("/{tool_id}", response_model=CustomToolOut)
async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)):
"""更新自定义工具的配置信息。
Args:
tool_id: 自定义工具唯一标识 ID。
req: 自定义工具更新请求体。
db: 异步数据库会话。
Returns:
CustomToolOut: 更新后的自定义工具响应。
Raises:
HTTPException: 工具不存在时抛出异常。
"""
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
if req.name is not None:
tool.name = req.name
if req.description is not None:
tool.description = req.description
if req.endpoint_url is not None:
tool.endpoint_url = req.endpoint_url
if req.method is not None:
tool.method = req.method
if req.path is not None:
tool.path = req.path
if req.headers is not None:
tool.headers_json = req.headers
if req.auth_type is not None:
tool.auth_type = req.auth_type
if req.auth_config is not None:
tool.auth_config = req.auth_config
if req.tool_schema is not None:
tool.schema_json = req.tool_schema
if req.is_active is not None:
tool.is_active = req.is_active
await db.flush()
return tool
@router.delete("/{tool_id}")
async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
"""删除(停用)自定义工具。
采用软删除方式,将工具标记为非活跃状态而非真正删除。
Args:
tool_id: 自定义工具唯一标识 ID。
db: 异步数据库会话。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 工具不存在时抛出异常。
"""
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
tool.is_active = False # 软删除:标记为非活跃
await db.flush()
return {"code": 200, "message": "工具已停用"}
@router.post("/{tool_id}/test")
async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)):
"""测试执行自定义工具。
使用自定义工具的配置信息创建执行器并执行,返回执行结果。
Args:
tool_id: 自定义工具唯一标识 ID。
params: 测试参数,传递给工具执行的参数体。
db: 异步数据库会话。
Returns:
dict: 包含执行结果的响应数据。
Raises:
HTTPException: 工具不存在或执行失败时抛出异常。
"""
tool = await db.get(CustomTool, tool_id)
if not tool:
raise HTTPException(404, "工具不存在")
if params is None:
params = {}
executor = CustomToolExecutor({
"endpoint_url": tool.endpoint_url,
"method": tool.method,
"path": tool.path,
"headers_json": tool.headers_json,
"auth_type": tool.auth_type,
"auth_config": tool.auth_config,
})
try:
result = await executor.execute(params)
return {"code": 200, "data": {"result": result}}
except Exception as e:
raise HTTPException(500, f"工具执行失败: {str(e)}")
@router.get("/schemas/all")
async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)):
"""获取所有活跃自定义工具的参数 Schema。
Args:
db: 异步数据库会话。
Returns:
dict: 包含所有工具 Schema 的响应数据,格式为 {工具名: schema}。
"""
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True)
)
tools = result.scalars().all()
schemas = {}
for t in tools:
schemas[t.name] = t.schema_json
return {"code": 200, "data": schemas}