You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
432 lines
15 KiB
432 lines
15 KiB
"""模型供应商模块路由。
|
|
|
|
提供模型供应商和模型实例的 CRUD 管理功能。
|
|
支持多供应商接入和模型能力的统一管理。
|
|
"""
|
|
import uuid
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from database import get_db
|
|
from models import ModelProvider, ModelInstance
|
|
from dependencies import get_current_user
|
|
|
|
logger = logging.getLogger(__name__) # 当前模块的日志记录器
|
|
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
|
|
|
|
|
|
@router.get("")
|
|
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""列出所有已注册的模型供应商。
|
|
|
|
Args:
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含模型供应商列表的响应数据。
|
|
"""
|
|
result = await db.execute(
|
|
select(ModelProvider).order_by(ModelProvider.created_at.desc())
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(p.id),
|
|
"name": p.name,
|
|
"provider_type": p.provider_type,
|
|
"base_url": p.base_url,
|
|
"is_active": p.is_active,
|
|
"created_at": p.created_at.isoformat() if p.created_at else "",
|
|
}
|
|
for p in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.post("")
|
|
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""注册新的模型供应商。
|
|
|
|
Args:
|
|
payload: 请求体,包含 name、provider_type、base_url、api_key、extra_config 字段。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含新供应商 ID 的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 相同 base_url 的供应商已存在时抛出异常。
|
|
"""
|
|
existing = await db.execute(
|
|
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
|
|
)
|
|
if existing.scalars().first():
|
|
raise HTTPException(400, "相同 base_url 的供应商已存在")
|
|
|
|
p = ModelProvider(
|
|
name=payload["name"],
|
|
provider_type=payload["provider_type"],
|
|
base_url=payload.get("base_url", ""),
|
|
api_key=payload.get("api_key", ""),
|
|
extra_config=payload.get("extra_config", {}),
|
|
)
|
|
db.add(p)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(p.id)}}
|
|
|
|
|
|
@router.put("/{provider_id}")
|
|
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""更新模型供应商的配置信息。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
payload: 请求体,包含要更新的供应商字段。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含更新后供应商 ID 的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 供应商不存在时抛出异常。
|
|
"""
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
|
|
p.name = payload.get("name", p.name)
|
|
p.base_url = payload.get("base_url", p.base_url)
|
|
p.api_key = payload.get("api_key", p.api_key)
|
|
p.provider_type = payload.get("provider_type", p.provider_type)
|
|
p.extra_config = payload.get("extra_config", p.extra_config)
|
|
p.is_active = payload.get("is_active", p.is_active)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(p.id)}}
|
|
|
|
|
|
@router.delete("/{provider_id}")
|
|
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""删除指定的模型供应商。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 操作结果响应。
|
|
|
|
Raises:
|
|
HTTPException: 供应商不存在时抛出异常。
|
|
"""
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
await db.delete(p)
|
|
await db.commit()
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/models/all")
|
|
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""列出所有处于活跃状态的模型实例。
|
|
|
|
Args:
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含所有活跃模型实例列表的响应数据。
|
|
"""
|
|
result = await db.execute(
|
|
select(ModelInstance)
|
|
.where(ModelInstance.is_active == True)
|
|
.order_by(ModelInstance.model_type, ModelInstance.model_name)
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(m.id),
|
|
"model_name": m.model_name,
|
|
"model_type": m.model_type,
|
|
"display_name": m.display_name,
|
|
}
|
|
for m in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.get("/{provider_id}/models")
|
|
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""列出指定供应商下的所有模型实例。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含模型实例列表的响应数据。
|
|
"""
|
|
result = await db.execute(
|
|
select(ModelInstance)
|
|
.where(ModelInstance.provider_id == uuid.UUID(provider_id))
|
|
.order_by(ModelInstance.model_type, ModelInstance.model_name)
|
|
)
|
|
return {
|
|
"code": 200,
|
|
"data": [
|
|
{
|
|
"id": str(m.id),
|
|
"provider_id": str(m.provider_id),
|
|
"model_name": m.model_name,
|
|
"model_type": m.model_type,
|
|
"display_name": m.display_name,
|
|
"capabilities": m.capabilities,
|
|
"default_params": m.default_params,
|
|
"is_default": m.is_default,
|
|
"is_active": m.is_active,
|
|
}
|
|
for m in result.scalars().all()
|
|
],
|
|
}
|
|
|
|
|
|
@router.post("/{provider_id}/models")
|
|
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""在指定供应商下添加新的模型实例。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
payload: 请求体,包含 model_name、model_type、display_name、capabilities、default_params、is_default 字段。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含新模型 ID 的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常。
|
|
"""
|
|
p = await db.get(ModelProvider, uuid.UUID(provider_id))
|
|
if not p:
|
|
raise HTTPException(404, "供应商不存在")
|
|
|
|
existing = await db.execute(
|
|
select(ModelInstance).where(
|
|
ModelInstance.provider_id == uuid.UUID(provider_id),
|
|
ModelInstance.model_name == payload["model_name"],
|
|
)
|
|
)
|
|
if existing.scalars().first():
|
|
raise HTTPException(400, "相同名称的模型已存在")
|
|
|
|
m = ModelInstance(
|
|
provider_id=uuid.UUID(provider_id),
|
|
model_name=payload["model_name"],
|
|
model_type=payload["model_type"],
|
|
display_name=payload.get("display_name", payload["model_name"]),
|
|
capabilities=payload.get("capabilities", {}),
|
|
default_params=payload.get("default_params", {}),
|
|
is_default=payload.get("is_default", False),
|
|
)
|
|
db.add(m)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(m.id)}}
|
|
|
|
|
|
@router.put("/{provider_id}/models/{model_id}")
|
|
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""更新模型实例的配置信息。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
model_id: 模型实例唯一标识 ID。
|
|
payload: 请求体,包含要更新的模型字段。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 包含更新后模型 ID 的响应数据。
|
|
|
|
Raises:
|
|
HTTPException: 模型不存在时抛出异常。
|
|
"""
|
|
m = await db.get(ModelInstance, uuid.UUID(model_id))
|
|
if not m or str(m.provider_id) != provider_id:
|
|
raise HTTPException(404, "模型不存在")
|
|
|
|
m.model_name = payload.get("model_name", m.model_name)
|
|
m.model_type = payload.get("model_type", m.model_type)
|
|
m.display_name = payload.get("display_name", m.display_name)
|
|
m.capabilities = payload.get("capabilities", m.capabilities)
|
|
m.default_params = payload.get("default_params", m.default_params)
|
|
m.is_default = payload.get("is_default", m.is_default)
|
|
m.is_active = payload.get("is_active", m.is_active)
|
|
await db.commit()
|
|
return {"code": 200, "data": {"id": str(m.id)}}
|
|
|
|
|
|
@router.delete("/{provider_id}/models/{model_id}")
|
|
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
|
|
"""删除指定的模型实例。
|
|
|
|
Args:
|
|
provider_id: 模型供应商唯一标识 ID。
|
|
model_id: 模型实例唯一标识 ID。
|
|
db: 异步数据库会话。
|
|
user: 当前登录用户信息。
|
|
|
|
Returns:
|
|
dict: 操作结果响应。
|
|
|
|
Raises:
|
|
HTTPException: 模型不存在时抛出异常。
|
|
"""
|
|
m = await db.get(ModelInstance, uuid.UUID(model_id))
|
|
if not m or str(m.provider_id) != provider_id:
|
|
raise HTTPException(404, "模型不存在")
|
|
await db.delete(m)
|
|
await db.commit()
|
|
return {"code": 200, "message": "已删除"}
|
|
|
|
|
|
@router.get("/default-llm")
|
|
async def get_default_llm(db: AsyncSession = Depends(get_db)):
|
|
"""获取系统默认 LLM 模型配置。
|
|
|
|
从 model_instances 表中查找 is_default=True 且 model_type='llm' 的模型实例,
|
|
返回包含供应商连接信息的完整配置,供系统中各模块统一使用。
|
|
|
|
Returns:
|
|
dict: 默认 LLM 配置(model_name, api_key, base_url, default_params 等),
|
|
若未配置则返回 null。
|
|
"""
|
|
result = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_default == True)
|
|
.where(ModelInstance.is_active == True)
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if not row:
|
|
# 回退:取第一个启用的 LLM
|
|
result2 = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_active == True)
|
|
.order_by(ModelInstance.created_at)
|
|
.limit(1)
|
|
)
|
|
row = result2.first()
|
|
if not row:
|
|
return {"code": 200, "data": None}
|
|
|
|
instance, provider = row
|
|
return {
|
|
"code": 200,
|
|
"data": {
|
|
"id": str(instance.id),
|
|
"model_name": instance.model_name,
|
|
"display_name": instance.display_name or instance.model_name,
|
|
"model_type": instance.model_type,
|
|
"api_key": provider.api_key or "",
|
|
"base_url": provider.base_url or settings.LLM_API_BASE,
|
|
"default_params": instance.default_params or {},
|
|
"capabilities": instance.capabilities or {},
|
|
"is_default": instance.is_default,
|
|
"is_active": instance.is_active,
|
|
"provider_id": str(instance.provider_id),
|
|
"provider_name": provider.name,
|
|
},
|
|
}
|
|
|
|
|
|
async def resolve_model_config(db: AsyncSession, model_instance_id: str = None) -> dict:
|
|
"""根据模型实例 ID 解析完整的模型调用配置。
|
|
|
|
这是系统的核心模型解析函数。各模块(Agent工厂、流程引擎、记忆管理器等)
|
|
应通过此函数获取模型配置,实现统一的模型管理。
|
|
|
|
解析优先级:
|
|
1. 如果提供了 model_instance_id → 从数据库读取该模型的完整配置
|
|
2. 否则 → 获取默认 LLM 配置(is_default 的 LLM 实例)
|
|
3. 都没有 → 回退到环境变量 settings.LLM_*
|
|
|
|
Args:
|
|
db: 异步数据库会话。
|
|
model_instance_id: 可选的模型实例 UUID 字符串。
|
|
|
|
Returns:
|
|
dict: 包含 model_name, api_key, base_url, default_params 等的配置字典。
|
|
"""
|
|
if model_instance_id:
|
|
try:
|
|
uid = uuid.UUID(model_instance_id)
|
|
result = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.id == uid)
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if row:
|
|
instance, provider = row
|
|
return {
|
|
"model_name": instance.model_name,
|
|
"api_key": provider.api_key or settings.LLM_API_KEY,
|
|
"base_url": provider.base_url or settings.LLM_API_BASE,
|
|
"default_params": instance.default_params or {},
|
|
"capabilities": instance.capabilities or {},
|
|
}
|
|
except (ValueError, Exception):
|
|
pass
|
|
|
|
# 无指定实例或未找到:使用默认 LLM
|
|
result = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_default == True)
|
|
.where(ModelInstance.is_active == True)
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
if not row:
|
|
result = await db.execute(
|
|
select(ModelInstance, ModelProvider)
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id)
|
|
.where(ModelInstance.model_type == 'llm')
|
|
.where(ModelInstance.is_active == True)
|
|
.order_by(ModelInstance.created_at)
|
|
.limit(1)
|
|
)
|
|
row = result.first()
|
|
|
|
if row:
|
|
instance, provider = row
|
|
return {
|
|
"model_name": instance.model_name,
|
|
"api_key": provider.api_key or settings.LLM_API_KEY,
|
|
"base_url": provider.base_url or settings.LLM_API_BASE,
|
|
"default_params": instance.default_params or {},
|
|
"capabilities": instance.capabilities or {},
|
|
}
|
|
|
|
# 最终回退到环境变量
|
|
return {
|
|
"model_name": settings.LLM_MODEL,
|
|
"api_key": settings.LLM_API_KEY,
|
|
"base_url": settings.LLM_API_BASE,
|
|
"default_params": {},
|
|
"capabilities": {},
|
|
}
|
|
|