You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
353 lines
11 KiB
353 lines
11 KiB
"""自定义工具模块路由。
|
|
|
|
提供自定义工具的创建、管理、导入和执行功能。
|
|
支持从 OpenAPI 规范自动导入工具定义,以及手动创建自定义 HTTP 工具。
|
|
"""
|
|
import uuid
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import CustomTool
|
|
from schemas import CustomToolCreate, CustomToolUpdate, CustomToolOut, OpenAPIImportRequest
|
|
from modules.custom_tool.parser import OpenAPIParser
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
from modules.flow_engine.engine import ToolNodeAgent
|
|
from dependencies import get_current_user
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__) # 当前模块的日志记录器
|
|
|
|
router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"])
|
|
|
|
|
|
@router.post("/import-openapi")
|
|
async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
|
"""从 OpenAPI 规范 URL 导入工具定义。
|
|
|
|
自动下载 OpenAPI 文档,解析其中的 API 端点并创建对应的自定义工具。
|
|
|
|
Args:
|
|
req: OpenAPI 导入请求体,包含 openapi_url 和可选的 base_url_override。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 包含导入成功工具列表的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 获取 OpenAPI 文档失败或解析不到工具时抛出异常。
|
|
"""
|
|
user_ctx = request.state.user
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(req.openapi_url)
|
|
resp.raise_for_status()
|
|
spec = resp.json() # 解析 OpenAPI 规范 JSON
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
|
|
except ValueError:
|
|
raise HTTPException(400, "OpenAPI 文档不是有效的 JSON 格式")
|
|
|
|
parser = OpenAPIParser(spec)
|
|
tools = parser.parse_tools()
|
|
if not tools:
|
|
raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具")
|
|
|
|
base_url = req.base_url_override or parser.base_url # 优先使用用户指定的基础 URL
|
|
if not base_url:
|
|
raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override")
|
|
|
|
created = []
|
|
for t in tools:
|
|
existing = await db.execute(
|
|
select(CustomTool).where(CustomTool.name == t["name"])
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
continue # 跳过已存在的同名工具
|
|
|
|
tool = CustomTool(
|
|
name=t["name"],
|
|
description=t["description"],
|
|
schema_json=t["parameters"],
|
|
endpoint_url=base_url,
|
|
method=t["method"],
|
|
path=t["path"],
|
|
created_by=uuid.UUID(user_ctx["id"]),
|
|
)
|
|
db.add(tool)
|
|
created.append(t["name"])
|
|
|
|
ToolNodeAgent.register_custom_tool(
|
|
t["name"],
|
|
t["parameters"],
|
|
{
|
|
"endpoint_url": base_url,
|
|
"method": t["method"],
|
|
"path": t["path"],
|
|
"headers_json": {},
|
|
"auth_type": "none",
|
|
"auth_config": {},
|
|
"timeout": 30,
|
|
},
|
|
)
|
|
|
|
await db.flush()
|
|
return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}}
|
|
|
|
|
|
@router.post("/", response_model=CustomToolOut)
|
|
async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)):
|
|
"""创建新的自定义工具,支持手动创建或从 OpenAPI 导入。
|
|
|
|
Args:
|
|
req: 自定义工具创建请求体。
|
|
request: HTTP 请求对象。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
CustomToolOut: 创建后的自定义工具响应。
|
|
|
|
Raises:
|
|
HTTPException: 获取 OpenAPI 文档失败或创建失败时抛出异常。
|
|
"""
|
|
user_ctx = request.state.user
|
|
user_id = uuid.UUID(user_ctx["id"])
|
|
|
|
if req.openapi_url:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(req.openapi_url)
|
|
resp.raise_for_status()
|
|
spec = resp.json()
|
|
except Exception as e:
|
|
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
|
|
|
|
parser = OpenAPIParser(spec)
|
|
tools = parser.parse_tools()
|
|
|
|
created_tool = None
|
|
for t in tools:
|
|
if t["name"] == req.name or (not req.name and tools):
|
|
existing = await db.execute(
|
|
select(CustomTool).where(CustomTool.name == t["name"])
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
continue
|
|
tool = CustomTool(
|
|
name=t["name"],
|
|
description=t["description"],
|
|
schema_json=t["parameters"],
|
|
endpoint_url=parser.base_url,
|
|
method=t["method"],
|
|
path=t["path"],
|
|
created_by=user_id,
|
|
)
|
|
db.add(tool)
|
|
created_tool = tool
|
|
break
|
|
|
|
if not created_tool:
|
|
raise HTTPException(400, "未找到匹配的工具")
|
|
|
|
await db.flush()
|
|
return created_tool
|
|
|
|
# 手动创建模式
|
|
schema_json = req.tool_schema or {}
|
|
if not schema_json and req.endpoint_url:
|
|
schema_json = {
|
|
"type": "object",
|
|
"properties": {},
|
|
"description": req.description or "",
|
|
}
|
|
|
|
tool = CustomTool(
|
|
name=req.name,
|
|
description=req.description,
|
|
schema_json=schema_json,
|
|
endpoint_url=req.endpoint_url or "",
|
|
method=req.method,
|
|
path=req.path,
|
|
headers_json=req.headers,
|
|
auth_type=req.auth_type,
|
|
auth_config=req.auth_config,
|
|
created_by=user_id,
|
|
)
|
|
db.add(tool)
|
|
ToolNodeAgent.register_custom_tool(
|
|
req.name,
|
|
schema_json,
|
|
{
|
|
"endpoint_url": req.endpoint_url or "",
|
|
"method": req.method,
|
|
"path": req.path,
|
|
"headers_json": req.headers,
|
|
"auth_type": req.auth_type,
|
|
"auth_config": req.auth_config,
|
|
"timeout": 30,
|
|
},
|
|
)
|
|
await db.flush()
|
|
return tool
|
|
|
|
|
|
@router.get("/", response_model=list[CustomToolOut])
|
|
async def list_custom_tools(db: AsyncSession = Depends(get_db)):
|
|
"""列出所有处于活跃状态的自定义工具。
|
|
|
|
Args:
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
list[CustomToolOut]: 活跃自定义工具列表。
|
|
"""
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/{tool_id}", response_model=CustomToolOut)
|
|
async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
|
|
"""获取指定自定义工具的详细信息。
|
|
|
|
Args:
|
|
tool_id: 自定义工具唯一标识 ID。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
CustomToolOut: 自定义工具详细信息。
|
|
|
|
Raises:
|
|
HTTPException: 工具不存在时抛出异常。
|
|
"""
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
return tool
|
|
|
|
|
|
@router.put("/{tool_id}", response_model=CustomToolOut)
|
|
async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)):
|
|
"""更新自定义工具的配置信息。
|
|
|
|
Args:
|
|
tool_id: 自定义工具唯一标识 ID。
|
|
req: 自定义工具更新请求体。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
CustomToolOut: 更新后的自定义工具响应。
|
|
|
|
Raises:
|
|
HTTPException: 工具不存在时抛出异常。
|
|
"""
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
if req.name is not None:
|
|
tool.name = req.name
|
|
if req.description is not None:
|
|
tool.description = req.description
|
|
if req.endpoint_url is not None:
|
|
tool.endpoint_url = req.endpoint_url
|
|
if req.method is not None:
|
|
tool.method = req.method
|
|
if req.path is not None:
|
|
tool.path = req.path
|
|
if req.headers is not None:
|
|
tool.headers_json = req.headers
|
|
if req.auth_type is not None:
|
|
tool.auth_type = req.auth_type
|
|
if req.auth_config is not None:
|
|
tool.auth_config = req.auth_config
|
|
if req.tool_schema is not None:
|
|
tool.schema_json = req.tool_schema
|
|
if req.is_active is not None:
|
|
tool.is_active = req.is_active
|
|
await db.flush()
|
|
return tool
|
|
|
|
|
|
@router.delete("/{tool_id}")
|
|
async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
|
|
"""删除(停用)自定义工具。
|
|
|
|
采用软删除方式,将工具标记为非活跃状态而非真正删除。
|
|
|
|
Args:
|
|
tool_id: 自定义工具唯一标识 ID。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 操作结果响应。
|
|
|
|
Raises:
|
|
HTTPException: 工具不存在时抛出异常。
|
|
"""
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
tool.is_active = False # 软删除:标记为非活跃
|
|
await db.flush()
|
|
return {"code": 200, "message": "工具已停用"}
|
|
|
|
|
|
@router.post("/{tool_id}/test")
|
|
async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)):
|
|
"""测试执行自定义工具。
|
|
|
|
使用自定义工具的配置信息创建执行器并执行,返回执行结果。
|
|
|
|
Args:
|
|
tool_id: 自定义工具唯一标识 ID。
|
|
params: 测试参数,传递给工具执行的参数体。
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 包含执行结果的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 工具不存在或执行失败时抛出异常。
|
|
"""
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
if params is None:
|
|
params = {}
|
|
|
|
executor = CustomToolExecutor({
|
|
"endpoint_url": tool.endpoint_url,
|
|
"method": tool.method,
|
|
"path": tool.path,
|
|
"headers_json": tool.headers_json,
|
|
"auth_type": tool.auth_type,
|
|
"auth_config": tool.auth_config,
|
|
})
|
|
try:
|
|
result = await executor.execute(params)
|
|
return {"code": 200, "data": {"result": result}}
|
|
except Exception as e:
|
|
raise HTTPException(500, f"工具执行失败: {str(e)}")
|
|
|
|
|
|
@router.get("/schemas/all")
|
|
async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)):
|
|
"""获取所有活跃自定义工具的参数 Schema。
|
|
|
|
Args:
|
|
db: 异步数据库会话。
|
|
|
|
Returns:
|
|
dict: 包含所有工具 Schema 的响应数据,格式为 {工具名: schema}。
|
|
"""
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True)
|
|
)
|
|
tools = result.scalars().all()
|
|
schemas = {}
|
|
for t in tools:
|
|
schemas[t.name] = t.schema_json
|
|
return {"code": 200, "data": schemas}
|
|
|