You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

181 lines
6.6 KiB

import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import ModelProvider, ModelInstance
from dependencies import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
@router.get("")
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelProvider).order_by(ModelProvider.created_at.desc())
)
return {
"code": 200,
"data": [
{
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"base_url": p.base_url,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() if p.created_at else "",
}
for p in result.scalars().all()
],
}
@router.post("")
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
existing = await db.execute(
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
)
if existing.scalars().first():
raise HTTPException(400, "相同 base_url 的供应商已存在")
p = ModelProvider(
name=payload["name"],
provider_type=payload["provider_type"],
base_url=payload.get("base_url", ""),
api_key=payload.get("api_key", ""),
extra_config=payload.get("extra_config", {}),
)
db.add(p)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.put("/{provider_id}")
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
p.name = payload.get("name", p.name)
p.base_url = payload.get("base_url", p.base_url)
p.api_key = payload.get("api_key", p.api_key)
p.provider_type = payload.get("provider_type", p.provider_type)
p.extra_config = payload.get("extra_config", p.extra_config)
p.is_active = payload.get("is_active", p.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
await db.delete(p)
await db.commit()
return {"code": 200, "message": "已删除"}
@router.get("/models/all")
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
}
for m in result.scalars().all()
],
}
@router.get("/{provider_id}/models")
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"provider_id": str(m.provider_id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
"capabilities": m.capabilities,
"default_params": m.default_params,
"is_default": m.is_default,
"is_active": m.is_active,
}
for m in result.scalars().all()
],
}
@router.post("/{provider_id}/models")
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
existing = await db.execute(
select(ModelInstance).where(
ModelInstance.provider_id == uuid.UUID(provider_id),
ModelInstance.model_name == payload["model_name"],
)
)
if existing.scalars().first():
raise HTTPException(400, "相同名称的模型已存在")
m = ModelInstance(
provider_id=uuid.UUID(provider_id),
model_name=payload["model_name"],
model_type=payload["model_type"],
display_name=payload.get("display_name", payload["model_name"]),
capabilities=payload.get("capabilities", {}),
default_params=payload.get("default_params", {}),
is_default=payload.get("is_default", False),
)
db.add(m)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.put("/{provider_id}/models/{model_id}")
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
m.model_name = payload.get("model_name", m.model_name)
m.model_type = payload.get("model_type", m.model_type)
m.display_name = payload.get("display_name", m.display_name)
m.capabilities = payload.get("capabilities", m.capabilities)
m.default_params = payload.get("default_params", m.default_params)
m.is_default = payload.get("is_default", m.is_default)
m.is_active = payload.get("is_active", m.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.delete("/{provider_id}/models/{model_id}")
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
await db.delete(m)
await db.commit()
return {"code": 200, "message": "已删除"}