You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

432 lines
15 KiB

"""模型供应商模块路由。
提供模型供应商和模型实例的 CRUD 管理功能。
支持多供应商接入和模型能力的统一管理。
"""
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import ModelProvider, ModelInstance
from dependencies import get_current_user
logger = logging.getLogger(__name__) # 当前模块的日志记录器
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
@router.get("")
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有已注册的模型供应商。
Args:
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含模型供应商列表的响应数据。
"""
result = await db.execute(
select(ModelProvider).order_by(ModelProvider.created_at.desc())
)
return {
"code": 200,
"data": [
{
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"base_url": p.base_url,
"is_active": p.is_active,
"created_at": p.created_at.isoformat() if p.created_at else "",
}
for p in result.scalars().all()
],
}
@router.post("")
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""注册新的模型供应商。
Args:
payload: 请求体,包含 name、provider_type、base_url、api_key、extra_config 字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含新供应商 ID 的响应数据。
Raises:
HTTPException: 相同 base_url 的供应商已存在时抛出异常。
"""
existing = await db.execute(
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
)
if existing.scalars().first():
raise HTTPException(400, "相同 base_url 的供应商已存在")
p = ModelProvider(
name=payload["name"],
provider_type=payload["provider_type"],
base_url=payload.get("base_url", ""),
api_key=payload.get("api_key", ""),
extra_config=payload.get("extra_config", {}),
)
db.add(p)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.put("/{provider_id}")
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型供应商的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID。
payload: 请求体,包含要更新的供应商字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含更新后供应商 ID 的响应数据。
Raises:
HTTPException: 供应商不存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
p.name = payload.get("name", p.name)
p.base_url = payload.get("base_url", p.base_url)
p.api_key = payload.get("api_key", p.api_key)
p.provider_type = payload.get("provider_type", p.provider_type)
p.extra_config = payload.get("extra_config", p.extra_config)
p.is_active = payload.get("is_active", p.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(p.id)}}
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型供应商。
Args:
provider_id: 模型供应商唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 供应商不存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
await db.delete(p)
await db.commit()
return {"code": 200, "message": "已删除"}
@router.get("/models/all")
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有处于活跃状态的模型实例。
Args:
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含所有活跃模型实例列表的响应数据。
"""
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
}
for m in result.scalars().all()
],
}
@router.get("/{provider_id}/models")
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出指定供应商下的所有模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含模型实例列表的响应数据。
"""
result = await db.execute(
select(ModelInstance)
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
.order_by(ModelInstance.model_type, ModelInstance.model_name)
)
return {
"code": 200,
"data": [
{
"id": str(m.id),
"provider_id": str(m.provider_id),
"model_name": m.model_name,
"model_type": m.model_type,
"display_name": m.display_name,
"capabilities": m.capabilities,
"default_params": m.default_params,
"is_default": m.is_default,
"is_active": m.is_active,
}
for m in result.scalars().all()
],
}
@router.post("/{provider_id}/models")
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""在指定供应商下添加新的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
payload: 请求体,包含 model_name、model_type、display_name、capabilities、default_params、is_default 字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含新模型 ID 的响应数据。
Raises:
HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常。
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p:
raise HTTPException(404, "供应商不存在")
existing = await db.execute(
select(ModelInstance).where(
ModelInstance.provider_id == uuid.UUID(provider_id),
ModelInstance.model_name == payload["model_name"],
)
)
if existing.scalars().first():
raise HTTPException(400, "相同名称的模型已存在")
m = ModelInstance(
provider_id=uuid.UUID(provider_id),
model_name=payload["model_name"],
model_type=payload["model_type"],
display_name=payload.get("display_name", payload["model_name"]),
capabilities=payload.get("capabilities", {}),
default_params=payload.get("default_params", {}),
is_default=payload.get("is_default", False),
)
db.add(m)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.put("/{provider_id}/models/{model_id}")
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型实例的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID。
model_id: 模型实例唯一标识 ID。
payload: 请求体,包含要更新的模型字段。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 包含更新后模型 ID 的响应数据。
Raises:
HTTPException: 模型不存在时抛出异常。
"""
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
m.model_name = payload.get("model_name", m.model_name)
m.model_type = payload.get("model_type", m.model_type)
m.display_name = payload.get("display_name", m.display_name)
m.capabilities = payload.get("capabilities", m.capabilities)
m.default_params = payload.get("default_params", m.default_params)
m.is_default = payload.get("is_default", m.is_default)
m.is_active = payload.get("is_active", m.is_active)
await db.commit()
return {"code": 200, "data": {"id": str(m.id)}}
@router.delete("/{provider_id}/models/{model_id}")
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID。
model_id: 模型实例唯一标识 ID。
db: 异步数据库会话。
user: 当前登录用户信息。
Returns:
dict: 操作结果响应。
Raises:
HTTPException: 模型不存在时抛出异常。
"""
m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在")
await db.delete(m)
await db.commit()
return {"code": 200, "message": "已删除"}
@router.get("/default-llm")
async def get_default_llm(db: AsyncSession = Depends(get_db)):
"""获取系统默认 LLM 模型配置。
从 model_instances 表中查找 is_default=True 且 model_type='llm' 的模型实例,
返回包含供应商连接信息的完整配置,供系统中各模块统一使用。
Returns:
dict: 默认 LLM 配置(model_name, api_key, base_url, default_params 等),
若未配置则返回 null。
"""
result = await db.execute(
select(ModelInstance, ModelProvider)
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
.where(ModelInstance.model_type == 'llm')
.where(ModelInstance.is_default == True)
.where(ModelInstance.is_active == True)
.limit(1)
)
row = result.first()
if not row:
# 回退:取第一个启用的 LLM
result2 = await db.execute(
select(ModelInstance, ModelProvider)
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
.where(ModelInstance.model_type == 'llm')
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.created_at)
.limit(1)
)
row = result2.first()
if not row:
return {"code": 200, "data": None}
instance, provider = row
return {
"code": 200,
"data": {
"id": str(instance.id),
"model_name": instance.model_name,
"display_name": instance.display_name or instance.model_name,
"model_type": instance.model_type,
"api_key": provider.api_key or "",
"base_url": provider.base_url or settings.LLM_API_BASE,
"default_params": instance.default_params or {},
"capabilities": instance.capabilities or {},
"is_default": instance.is_default,
"is_active": instance.is_active,
"provider_id": str(instance.provider_id),
"provider_name": provider.name,
},
}
async def resolve_model_config(db: AsyncSession, model_instance_id: str = None) -> dict:
"""根据模型实例 ID 解析完整的模型调用配置。
这是系统的核心模型解析函数。各模块(Agent工厂、流程引擎、记忆管理器等)
应通过此函数获取模型配置,实现统一的模型管理。
解析优先级:
1. 如果提供了 model_instance_id → 从数据库读取该模型的完整配置
2. 否则 → 获取默认 LLM 配置(is_default 的 LLM 实例)
3. 都没有 → 回退到环境变量 settings.LLM_*
Args:
db: 异步数据库会话。
model_instance_id: 可选的模型实例 UUID 字符串。
Returns:
dict: 包含 model_name, api_key, base_url, default_params 等的配置字典。
"""
if model_instance_id:
try:
uid = uuid.UUID(model_instance_id)
result = await db.execute(
select(ModelInstance, ModelProvider)
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
.where(ModelInstance.id == uid)
.limit(1)
)
row = result.first()
if row:
instance, provider = row
return {
"model_name": instance.model_name,
"api_key": provider.api_key or settings.LLM_API_KEY,
"base_url": provider.base_url or settings.LLM_API_BASE,
"default_params": instance.default_params or {},
"capabilities": instance.capabilities or {},
}
except (ValueError, Exception):
pass
# 无指定实例或未找到:使用默认 LLM
result = await db.execute(
select(ModelInstance, ModelProvider)
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
.where(ModelInstance.model_type == 'llm')
.where(ModelInstance.is_default == True)
.where(ModelInstance.is_active == True)
.limit(1)
)
row = result.first()
if not row:
result = await db.execute(
select(ModelInstance, ModelProvider)
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
.where(ModelInstance.model_type == 'llm')
.where(ModelInstance.is_active == True)
.order_by(ModelInstance.created_at)
.limit(1)
)
row = result.first()
if row:
instance, provider = row
return {
"model_name": instance.model_name,
"api_key": provider.api_key or settings.LLM_API_KEY,
"base_url": provider.base_url or settings.LLM_API_BASE,
"default_params": instance.default_params or {},
"capabilities": instance.capabilities or {},
}
# 最终回退到环境变量
return {
"model_name": settings.LLM_MODEL,
"api_key": settings.LLM_API_KEY,
"base_url": settings.LLM_API_BASE,
"default_params": {},
"capabilities": {},
}