You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
6.6 KiB
181 lines
6.6 KiB
import uuid
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import ModelProvider, ModelInstance
|
|
from dependencies import get_current_user
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
|
|
|
|
|
|
@router.get("")
|
|
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
result = await db.execute(
|
|
select(ModelProvider).order_by(ModelProvider.created_at.desc())
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(p.id),
|
|
"name": p.name,
|
|
"provider_type": p.provider_type,
|
|
"base_url": p.base_url,
|
|
"is_active": p.is_active,
|
|
"created_at": p.created_at.isoformat() if p.created_at else "",
|
|
}
|
|
for p in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.post("")
|
|
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
existing = await db.execute(
|
|
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
|
|
)
|
|
if existing.scalars().first():
|
|
raise HTTPException(400, "相同 base_url 的供应商已存在")
|
|
|
|
p = ModelProvider(
|
|
name=payload["name"],
|
|
provider_type=payload["provider_type"],
|
|
base_url=payload.get("base_url", ""),
|
|
api_key=payload.get("api_key", ""),
|
|
extra_config=payload.get("extra_config", {}),
|
|
)
|
|
db.add(p)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(p.id)}}
|
|
|
|
|
|
@router.put("/{provider_id}")
|
|
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
|
|
p.name = payload.get("name", p.name)
|
|
p.base_url = payload.get("base_url", p.base_url)
|
|
p.api_key = payload.get("api_key", p.api_key)
|
|
p.provider_type = payload.get("provider_type", p.provider_type)
|
|
p.extra_config = payload.get("extra_config", p.extra_config)
|
|
p.is_active = payload.get("is_active", p.is_active)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(p.id)}}
|
|
|
|
|
|
@router.delete("/{provider_id}")
|
|
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
await db.delete(p)
|
|
await db.commit()
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/models/all")
|
|
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
result = await db.execute(
|
|
select(ModelInstance)
|
|
.where(ModelInstance.is_active == True)
|
|
.order_by(ModelInstance.model_type, ModelInstance.model_name)
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(m.id),
|
|
"model_name": m.model_name,
|
|
"model_type": m.model_type,
|
|
"display_name": m.display_name,
|
|
}
|
|
for m in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.get("/{provider_id}/models")
|
|
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
result = await db.execute(
|
|
select(ModelInstance)
|
|
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
|
|
.order_by(ModelInstance.model_type, ModelInstance.model_name)
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(m.id),
|
|
"provider_id": str(m.provider_id),
|
|
"model_name": m.model_name,
|
|
"model_type": m.model_type,
|
|
"display_name": m.display_name,
|
|
"capabilities": m.capabilities,
|
|
"default_params": m.default_params,
|
|
"is_default": m.is_default,
|
|
"is_active": m.is_active,
|
|
}
|
|
for m in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.post("/{provider_id}/models")
|
|
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
|
|
existing = await db.execute(
|
|
select(ModelInstance).where(
|
|
ModelInstance.provider_id == uuid.UUID(provider_id),
|
|
ModelInstance.model_name == payload["model_name"],
|
|
)
|
|
)
|
|
if existing.scalars().first():
|
|
raise HTTPException(400, "相同名称的模型已存在")
|
|
|
|
m = ModelInstance(
|
|
provider_id=uuid.UUID(provider_id),
|
|
model_name=payload["model_name"],
|
|
model_type=payload["model_type"],
|
|
display_name=payload.get("display_name", payload["model_name"]),
|
|
capabilities=payload.get("capabilities", {}),
|
|
default_params=payload.get("default_params", {}),
|
|
is_default=payload.get("is_default", False),
|
|
)
|
|
db.add(m)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(m.id)}}
|
|
|
|
|
|
@router.put("/{provider_id}/models/{model_id}")
|
|
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
m = await db.get(ModelInstance, uuid.UUID(model_id))
|
|
if not m or str(m.provider_id) != provider_id:
|
|
raise HTTPException(404, "模型不存在")
|
|
|
|
m.model_name = payload.get("model_name", m.model_name)
|
|
m.model_type = payload.get("model_type", m.model_type)
|
|
m.display_name = payload.get("display_name", m.display_name)
|
|
m.capabilities = payload.get("capabilities", m.capabilities)
|
|
m.default_params = payload.get("default_params", m.default_params)
|
|
m.is_default = payload.get("is_default", m.is_default)
|
|
m.is_active = payload.get("is_active", m.is_active)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(m.id)}}
|
|
|
|
|
|
@router.delete("/{provider_id}/models/{model_id}")
|
|
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
m = await db.get(ModelInstance, uuid.UUID(model_id))
|
|
if not m or str(m.provider_id) != provider_id:
|
|
raise HTTPException(404, "模型不存在")
|
|
await db.delete(m)
|
|
await db.commit()
|
|
return {"code": 200, "message": "已删除"}
|