You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
8.1 KiB
249 lines
8.1 KiB
import uuid
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import CustomTool
|
|
from schemas import CustomToolCreate, CustomToolUpdate, CustomToolOut, OpenAPIImportRequest
|
|
from modules.custom_tool.parser import OpenAPIParser
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
from modules.flow_engine.engine import ToolNodeAgent
|
|
from dependencies import get_current_user
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"])
|
|
|
|
|
|
@router.post("/import-openapi")
|
|
async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
|
user_ctx = request.state.user
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(req.openapi_url)
|
|
resp.raise_for_status()
|
|
spec = resp.json()
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
|
|
except ValueError:
|
|
raise HTTPException(400, "OpenAPI 文档不是有效的 JSON 格式")
|
|
|
|
parser = OpenAPIParser(spec)
|
|
tools = parser.parse_tools()
|
|
if not tools:
|
|
raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具")
|
|
|
|
base_url = req.base_url_override or parser.base_url
|
|
if not base_url:
|
|
raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override")
|
|
|
|
created = []
|
|
for t in tools:
|
|
existing = await db.execute(
|
|
select(CustomTool).where(CustomTool.name == t["name"])
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
continue
|
|
|
|
tool = CustomTool(
|
|
name=t["name"],
|
|
description=t["description"],
|
|
schema_json=t["parameters"],
|
|
endpoint_url=base_url,
|
|
method=t["method"],
|
|
path=t["path"],
|
|
created_by=uuid.UUID(user_ctx["id"]),
|
|
)
|
|
db.add(tool)
|
|
created.append(t["name"])
|
|
|
|
ToolNodeAgent.register_custom_tool(
|
|
t["name"],
|
|
t["parameters"],
|
|
{
|
|
"endpoint_url": base_url,
|
|
"method": t["method"],
|
|
"path": t["path"],
|
|
"headers_json": {},
|
|
"auth_type": "none",
|
|
"auth_config": {},
|
|
"timeout": 30,
|
|
},
|
|
)
|
|
|
|
await db.flush()
|
|
return {"code": 200, "message": f"成功导入 {len(created)} 个工具", "data": {"tools": created}}
|
|
|
|
|
|
@router.post("/", response_model=CustomToolOut)
|
|
async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)):
|
|
user_ctx = request.state.user
|
|
user_id = uuid.UUID(user_ctx["id"])
|
|
|
|
if req.openapi_url:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.get(req.openapi_url)
|
|
resp.raise_for_status()
|
|
spec = resp.json()
|
|
except Exception as e:
|
|
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
|
|
|
|
parser = OpenAPIParser(spec)
|
|
tools = parser.parse_tools()
|
|
|
|
created_tool = None
|
|
for t in tools:
|
|
if t["name"] == req.name or (not req.name and tools):
|
|
existing = await db.execute(
|
|
select(CustomTool).where(CustomTool.name == t["name"])
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
continue
|
|
tool = CustomTool(
|
|
name=t["name"],
|
|
description=t["description"],
|
|
schema_json=t["parameters"],
|
|
endpoint_url=parser.base_url,
|
|
method=t["method"],
|
|
path=t["path"],
|
|
created_by=user_id,
|
|
)
|
|
db.add(tool)
|
|
created_tool = tool
|
|
break
|
|
|
|
if not created_tool:
|
|
raise HTTPException(400, "未找到匹配的工具")
|
|
|
|
await db.flush()
|
|
return created_tool
|
|
|
|
schema_json = req.schema_json or {}
|
|
if not schema_json and req.endpoint_url:
|
|
schema_json = {
|
|
"type": "object",
|
|
"properties": {},
|
|
"description": req.description or "",
|
|
}
|
|
|
|
tool = CustomTool(
|
|
name=req.name,
|
|
description=req.description,
|
|
schema_json=schema_json,
|
|
endpoint_url=req.endpoint_url or "",
|
|
method=req.method,
|
|
path=req.path,
|
|
headers_json=req.headers,
|
|
auth_type=req.auth_type,
|
|
auth_config=req.auth_config,
|
|
created_by=user_id,
|
|
)
|
|
db.add(tool)
|
|
ToolNodeAgent.register_custom_tool(
|
|
req.name,
|
|
schema_json,
|
|
{
|
|
"endpoint_url": req.endpoint_url or "",
|
|
"method": req.method,
|
|
"path": req.path,
|
|
"headers_json": req.headers,
|
|
"auth_type": req.auth_type,
|
|
"auth_config": req.auth_config,
|
|
"timeout": 30,
|
|
},
|
|
)
|
|
await db.flush()
|
|
return tool
|
|
|
|
|
|
@router.get("/", response_model=list[CustomToolOut])
|
|
async def list_custom_tools(db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/{tool_id}", response_model=CustomToolOut)
|
|
async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
return tool
|
|
|
|
|
|
@router.put("/{tool_id}", response_model=CustomToolOut)
|
|
async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)):
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
if req.name is not None:
|
|
tool.name = req.name
|
|
if req.description is not None:
|
|
tool.description = req.description
|
|
if req.endpoint_url is not None:
|
|
tool.endpoint_url = req.endpoint_url
|
|
if req.method is not None:
|
|
tool.method = req.method
|
|
if req.path is not None:
|
|
tool.path = req.path
|
|
if req.headers is not None:
|
|
tool.headers_json = req.headers
|
|
if req.auth_type is not None:
|
|
tool.auth_type = req.auth_type
|
|
if req.auth_config is not None:
|
|
tool.auth_config = req.auth_config
|
|
if req.schema_json is not None:
|
|
tool.schema_json = req.schema_json
|
|
if req.is_active is not None:
|
|
tool.is_active = req.is_active
|
|
await db.flush()
|
|
return tool
|
|
|
|
|
|
@router.delete("/{tool_id}")
|
|
async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
tool.is_active = False
|
|
await db.flush()
|
|
return {"code": 200, "message": "工具已停用"}
|
|
|
|
|
|
@router.post("/{tool_id}/test")
|
|
async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)):
|
|
tool = await db.get(CustomTool, tool_id)
|
|
if not tool:
|
|
raise HTTPException(404, "工具不存在")
|
|
if params is None:
|
|
params = {}
|
|
|
|
executor = CustomToolExecutor({
|
|
"endpoint_url": tool.endpoint_url,
|
|
"method": tool.method,
|
|
"path": tool.path,
|
|
"headers_json": tool.headers_json,
|
|
"auth_type": tool.auth_type,
|
|
"auth_config": tool.auth_config,
|
|
})
|
|
try:
|
|
result = await executor.execute(params)
|
|
return {"code": 200, "data": {"result": result}}
|
|
except Exception as e:
|
|
raise HTTPException(500, f"工具执行失败: {str(e)}")
|
|
|
|
|
|
@router.get("/schemas/all")
|
|
async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True)
|
|
)
|
|
tools = result.scalars().all()
|
|
schemas = {}
|
|
for t in tools:
|
|
schemas[t.name] = t.schema_json
|
|
return {"code": 200, "data": schemas}
|