"""模型供应商模块路由。 提供模型供应商和模型实例的 CRUD 管理功能。 支持多供应商接入和模型能力的统一管理。 """ import uuid import logging from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import ModelProvider, ModelInstance from dependencies import get_current_user logger = logging.getLogger(__name__) # 当前模块的日志记录器 router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"]) @router.get("") async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """列出所有已注册的模型供应商。 Args: db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含模型供应商列表的响应数据。 """ result = await db.execute( select(ModelProvider).order_by(ModelProvider.created_at.desc()) ) return { "code": 200, "data": [ { "id": str(p.id), "name": p.name, "provider_type": p.provider_type, "base_url": p.base_url, "is_active": p.is_active, "created_at": p.created_at.isoformat() if p.created_at else "", } for p in result.scalars().all() ], } @router.post("") async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """注册新的模型供应商。 Args: payload: 请求体,包含 name、provider_type、base_url、api_key、extra_config 字段。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含新供应商 ID 的响应数据。 Raises: HTTPException: 相同 base_url 的供应商已存在时抛出异常。 """ existing = await db.execute( select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", "")) ) if existing.scalars().first(): raise HTTPException(400, "相同 base_url 的供应商已存在") p = ModelProvider( name=payload["name"], provider_type=payload["provider_type"], base_url=payload.get("base_url", ""), api_key=payload.get("api_key", ""), extra_config=payload.get("extra_config", {}), ) db.add(p) await db.commit() return {"code": 200, "data": {"id": str(p.id)}} @router.put("/{provider_id}") async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """更新模型供应商的配置信息。 Args: provider_id: 模型供应商唯一标识 ID。 payload: 请求体,包含要更新的供应商字段。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含更新后供应商 ID 的响应数据。 Raises: HTTPException: 供应商不存在时抛出异常。 """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") p.name = payload.get("name", p.name) p.base_url = payload.get("base_url", p.base_url) p.api_key = payload.get("api_key", p.api_key) p.provider_type = payload.get("provider_type", p.provider_type) p.extra_config = payload.get("extra_config", p.extra_config) p.is_active = payload.get("is_active", p.is_active) await db.commit() return {"code": 200, "data": {"id": str(p.id)}} @router.delete("/{provider_id}") async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """删除指定的模型供应商。 Args: provider_id: 模型供应商唯一标识 ID。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 操作结果响应。 Raises: HTTPException: 供应商不存在时抛出异常。 """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") await db.delete(p) await db.commit() return {"code": 200, "message": "已删除"} @router.get("/models/all") async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """列出所有处于活跃状态的模型实例。 Args: db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含所有活跃模型实例列表的响应数据。 """ result = await db.execute( select(ModelInstance) .where(ModelInstance.is_active == True) .order_by(ModelInstance.model_type, ModelInstance.model_name) ) return { "code": 200, "data": [ { "id": str(m.id), "model_name": m.model_name, "model_type": m.model_type, "display_name": m.display_name, } for m in result.scalars().all() ], } @router.get("/{provider_id}/models") async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """列出指定供应商下的所有模型实例。 Args: provider_id: 模型供应商唯一标识 ID。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含模型实例列表的响应数据。 """ result = await db.execute( select(ModelInstance) .where(ModelInstance.provider_id == uuid.UUID(provider_id)) .order_by(ModelInstance.model_type, ModelInstance.model_name) ) return { "code": 200, "data": [ { "id": str(m.id), "provider_id": str(m.provider_id), "model_name": m.model_name, "model_type": m.model_type, "display_name": m.display_name, "capabilities": m.capabilities, "default_params": m.default_params, "is_default": m.is_default, "is_active": m.is_active, } for m in result.scalars().all() ], } @router.post("/{provider_id}/models") async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """在指定供应商下添加新的模型实例。 Args: provider_id: 模型供应商唯一标识 ID。 payload: 请求体,包含 model_name、model_type、display_name、capabilities、default_params、is_default 字段。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含新模型 ID 的响应数据。 Raises: HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常。 """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") existing = await db.execute( select(ModelInstance).where( ModelInstance.provider_id == uuid.UUID(provider_id), ModelInstance.model_name == payload["model_name"], ) ) if existing.scalars().first(): raise HTTPException(400, "相同名称的模型已存在") m = ModelInstance( provider_id=uuid.UUID(provider_id), model_name=payload["model_name"], model_type=payload["model_type"], display_name=payload.get("display_name", payload["model_name"]), capabilities=payload.get("capabilities", {}), default_params=payload.get("default_params", {}), is_default=payload.get("is_default", False), ) db.add(m) await db.commit() return {"code": 200, "data": {"id": str(m.id)}} @router.put("/{provider_id}/models/{model_id}") async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """更新模型实例的配置信息。 Args: provider_id: 模型供应商唯一标识 ID。 model_id: 模型实例唯一标识 ID。 payload: 请求体,包含要更新的模型字段。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 包含更新后模型 ID 的响应数据。 Raises: HTTPException: 模型不存在时抛出异常。 """ m = await db.get(ModelInstance, uuid.UUID(model_id)) if not m or str(m.provider_id) != provider_id: raise HTTPException(404, "模型不存在") m.model_name = payload.get("model_name", m.model_name) m.model_type = payload.get("model_type", m.model_type) m.display_name = payload.get("display_name", m.display_name) m.capabilities = payload.get("capabilities", m.capabilities) m.default_params = payload.get("default_params", m.default_params) m.is_default = payload.get("is_default", m.is_default) m.is_active = payload.get("is_active", m.is_active) await db.commit() return {"code": 200, "data": {"id": str(m.id)}} @router.delete("/{provider_id}/models/{model_id}") async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): """删除指定的模型实例。 Args: provider_id: 模型供应商唯一标识 ID。 model_id: 模型实例唯一标识 ID。 db: 异步数据库会话。 user: 当前登录用户信息。 Returns: dict: 操作结果响应。 Raises: HTTPException: 模型不存在时抛出异常。 """ m = await db.get(ModelInstance, uuid.UUID(model_id)) if not m or str(m.provider_id) != provider_id: raise HTTPException(404, "模型不存在") await db.delete(m) await db.commit() return {"code": 200, "message": "已删除"} @router.get("/default-llm") async def get_default_llm(db: AsyncSession = Depends(get_db)): """获取系统默认 LLM 模型配置。 从 model_instances 表中查找 is_default=True 且 model_type='llm' 的模型实例, 返回包含供应商连接信息的完整配置,供系统中各模块统一使用。 Returns: dict: 默认 LLM 配置(model_name, api_key, base_url, default_params 等), 若未配置则返回 null。 """ result = await db.execute( select(ModelInstance, ModelProvider) .join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) .where(ModelInstance.model_type == 'llm') .where(ModelInstance.is_default == True) .where(ModelInstance.is_active == True) .limit(1) ) row = result.first() if not row: # 回退:取第一个启用的 LLM result2 = await db.execute( select(ModelInstance, ModelProvider) .join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) .where(ModelInstance.model_type == 'llm') .where(ModelInstance.is_active == True) .order_by(ModelInstance.created_at) .limit(1) ) row = result2.first() if not row: return {"code": 200, "data": None} instance, provider = row return { "code": 200, "data": { "id": str(instance.id), "model_name": instance.model_name, "display_name": instance.display_name or instance.model_name, "model_type": instance.model_type, "api_key": provider.api_key or "", "base_url": provider.base_url or settings.LLM_API_BASE, "default_params": instance.default_params or {}, "capabilities": instance.capabilities or {}, "is_default": instance.is_default, "is_active": instance.is_active, "provider_id": str(instance.provider_id), "provider_name": provider.name, }, } async def resolve_model_config(db: AsyncSession, model_instance_id: str = None) -> dict: """根据模型实例 ID 解析完整的模型调用配置。 这是系统的核心模型解析函数。各模块(Agent工厂、流程引擎、记忆管理器等) 应通过此函数获取模型配置,实现统一的模型管理。 解析优先级: 1. 如果提供了 model_instance_id → 从数据库读取该模型的完整配置 2. 否则 → 获取默认 LLM 配置(is_default 的 LLM 实例) 3. 都没有 → 回退到环境变量 settings.LLM_* Args: db: 异步数据库会话。 model_instance_id: 可选的模型实例 UUID 字符串。 Returns: dict: 包含 model_name, api_key, base_url, default_params 等的配置字典。 """ if model_instance_id: try: uid = uuid.UUID(model_instance_id) result = await db.execute( select(ModelInstance, ModelProvider) .join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) .where(ModelInstance.id == uid) .limit(1) ) row = result.first() if row: instance, provider = row return { "model_name": instance.model_name, "api_key": provider.api_key or settings.LLM_API_KEY, "base_url": provider.base_url or settings.LLM_API_BASE, "default_params": instance.default_params or {}, "capabilities": instance.capabilities or {}, } except (ValueError, Exception): pass # 无指定实例或未找到:使用默认 LLM result = await db.execute( select(ModelInstance, ModelProvider) .join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) .where(ModelInstance.model_type == 'llm') .where(ModelInstance.is_default == True) .where(ModelInstance.is_active == True) .limit(1) ) row = result.first() if not row: result = await db.execute( select(ModelInstance, ModelProvider) .join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) .where(ModelInstance.model_type == 'llm') .where(ModelInstance.is_active == True) .order_by(ModelInstance.created_at) .limit(1) ) row = result.first() if row: instance, provider = row return { "model_name": instance.model_name, "api_key": provider.api_key or settings.LLM_API_KEY, "base_url": provider.base_url or settings.LLM_API_BASE, "default_params": instance.default_params or {}, "capabilities": instance.capabilities or {}, } # 最终回退到环境变量 return { "model_name": settings.LLM_MODEL, "api_key": settings.LLM_API_KEY, "base_url": settings.LLM_API_BASE, "default_params": {}, "capabilities": {}, }