You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

297 lines
9.9 KiB

"""模型供应商模块路由。
提供模型供应商和模型实例的 CRUD 管理功能。
支持多供应商接入和模型能力的统一管理。
"""
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import ModelProvider, ModelInstance
from dependencies import get_current_user
logger = logging.getLogger(__name__) # 当前模块的日志记录器
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
@router.get("")
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有已注册的模型供应商。
Args:
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含模型供应商列表的响应数据。
"""
result = await db.execute(
select(ModelProvider).order_by(ModelProvider.created_at.desc())
)
return {
"code": 200,
"data": [
{
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"base_url": p.base_url,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() if p.created_at else "",
}
for p in result.scalars().all()
],
}
@router.post("")
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""注册新的模型供应商。
Args:
payload: 请求体,包含 name、provider_type、base_url、api_key、extra_config 字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含新供应商 ID 的响应数据。
Raises:
HTTPException: 相同 base_url 的供应商已存在时抛出异常。
"""
existing = await db.execute(
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
)
if existing.scalars().first():
raise HTTPException(400, "相同 base_url 的供应商已存在")
p = ModelProvider(
name=payload["name"],
provider_type=payload["provider_type"],
base_url=payload.get("base_url", ""),
api_key=payload.get("api_key", ""),
extra_config=payload.get("extra_config", {}),
)
db.add(p)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.put("/{provider_id}")
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型供应商的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID。
payload: 请求体,包含要更新的供应商字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含更新后供应商 ID 的响应数据。
Raises:
HTTPException: 供应商不存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
p.name = payload.get("name", p.name)
p.base_url = payload.get("base_url", p.base_url)
p.api_key = payload.get("api_key", p.api_key)
p.provider_type = payload.get("provider_type", p.provider_type)
p.extra_config = payload.get("extra_config", p.extra_config)
p.is_active = payload.get("is_active", p.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型供应商。
Args:
provider_id: 模型供应商唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 供应商不存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
await db.delete(p)
await db.commit()
return {"code": 200, "message": "已删除"}
@router.get("/models/all")
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有处于活跃状态的模型实例。
Args:
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含所有活跃模型实例列表的响应数据。
"""
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
}
for m in result.scalars().all()
],
}
@router.get("/{provider_id}/models")
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出指定供应商下的所有模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含模型实例列表的响应数据。
"""
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"provider_id": str(m.provider_id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
"capabilities": m.capabilities,
"default_params": m.default_params,
"is_default": m.is_default,
"is_active": m.is_active,
}
for m in result.scalars().all()
],
}
@router.post("/{provider_id}/models")
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""在指定供应商下添加新的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
payload: 请求体,包含 model_name、model_type、display_name、capabilities、default_params、is_default 字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含新模型 ID 的响应数据。
Raises:
HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
existing = await db.execute(
select(ModelInstance).where(
ModelInstance.provider_id == uuid.UUID(provider_id),
ModelInstance.model_name == payload["model_name"],
)
)
if existing.scalars().first():
raise HTTPException(400, "相同名称的模型已存在")
m = ModelInstance(
provider_id=uuid.UUID(provider_id),
model_name=payload["model_name"],
model_type=payload["model_type"],
display_name=payload.get("display_name", payload["model_name"]),
capabilities=payload.get("capabilities", {}),
default_params=payload.get("default_params", {}),
is_default=payload.get("is_default", False),
)
db.add(m)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.put("/{provider_id}/models/{model_id}")
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型实例的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID。
model_id: 模型实例唯一标识 ID。
payload: 请求体,包含要更新的模型字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含更新后模型 ID 的响应数据。
Raises:
HTTPException: 模型不存在时抛出异常。
"""
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
m.model_name = payload.get("model_name", m.model_name)
m.model_type = payload.get("model_type", m.model_type)
m.display_name = payload.get("display_name", m.display_name)
m.capabilities = payload.get("capabilities", m.capabilities)
m.default_params = payload.get("default_params", m.default_params)
m.is_default = payload.get("is_default", m.is_default)
m.is_active = payload.get("is_active", m.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.delete("/{provider_id}/models/{model_id}")
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
model_id: 模型实例唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 模型不存在时抛出异常。
"""
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
await db.delete(m)
await db.commit()
return {"code": 200, "message": "已删除"}