You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
427 lines
17 KiB
427 lines
17 KiB
"""AgentScope 智能体工厂模块。
|
|
|
|
提供统一的智能体创建接口,根据用户类型(员工/管理者/任务/文档)创建对应的 AI 智能体实例。
|
|
支持智能体缓存以减少重复创建的开销。
|
|
"""
|
|
from sqlalchemy import select
|
|
from agentscope.agent import AgentBase
|
|
from agentscope.agent._react_agent import ReActAgent
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
from agentscope.tool import Toolkit
|
|
from agentscope.message import Msg
|
|
from config import settings
|
|
from models import ModelInstance, ModelProvider
|
|
from database import get_db
|
|
from .memory.user_memory import UserIsolatedMemory
|
|
from .hooks.rbac_hook import register_rbac_hooks_for_user
|
|
|
|
|
|
class AgentFactory:
|
|
"""智能体工厂类,负责创建和管理不同类型的 AI 智能体实例。
|
|
|
|
采用类级别的单例模式缓存模型和格式化器实例,
|
|
同时为每个用户缓存已创建的智能体,避免重复初始化。
|
|
"""
|
|
_model: OpenAIChatModel | None = None # 缓存的大语言模型实例
|
|
_formatter: OpenAIChatFormatter | None = None # 缓存的消息格式化器实例
|
|
_agent_cache: dict[str, AgentBase] = {} # 智能体缓存:{agent_type_user_id: AgentBase}
|
|
_MAX_CACHE_SIZE = 50 # 智能体缓存上限
|
|
|
|
@classmethod
|
|
async def _get_llm_config(cls, db, model_instance_id=None):
|
|
"""从数据库获取 LLM 配置信息。
|
|
|
|
根据提供的 model_instance_id 查询对应的模型实例和供应商配置,
|
|
如果未提供或未找到,则回退到默认模型(is_default=True)。
|
|
最终如果都未找到,则使用 settings 中的默认配置。
|
|
|
|
Args:
|
|
db: 数据库会话对象。
|
|
model_instance_id: 可选的模型实例 ID,用于指定特定模型。
|
|
|
|
Returns:
|
|
dict: 包含 LLM 配置的字典:
|
|
- model_name (str): 模型名称
|
|
- api_key (str): API 密钥
|
|
- base_url (str): API 基础地址
|
|
- default_params (dict): 默认参数(temperature, max_tokens 等)
|
|
"""
|
|
try:
|
|
if model_instance_id:
|
|
stmt = select(ModelInstance).where(
|
|
ModelInstance.id == model_instance_id,
|
|
ModelInstance.is_active == True
|
|
)
|
|
result = await db.execute(stmt)
|
|
model_instance = result.scalar_one_or_none()
|
|
|
|
if model_instance:
|
|
stmt_provider = select(ModelProvider).where(
|
|
ModelProvider.id == model_instance.provider_id,
|
|
ModelProvider.is_active == True
|
|
)
|
|
result_provider = await db.execute(stmt_provider)
|
|
provider = result_provider.scalar_one_or_none()
|
|
|
|
if provider:
|
|
return {
|
|
"model_name": model_instance.model_name,
|
|
"api_key": provider.api_key or settings.LLM_API_KEY,
|
|
"base_url": provider.base_url or settings.LLM_API_BASE,
|
|
"default_params": model_instance.default_params or {},
|
|
}
|
|
|
|
stmt_default = select(ModelInstance).where(
|
|
ModelInstance.is_default == True,
|
|
ModelInstance.is_active == True,
|
|
ModelInstance.model_type == "llm"
|
|
).limit(1)
|
|
result_default = await db.execute(stmt_default)
|
|
default_instance = result_default.scalar_one_or_none()
|
|
|
|
if default_instance:
|
|
stmt_provider = select(ModelProvider).where(
|
|
ModelProvider.id == default_instance.provider_id,
|
|
ModelProvider.is_active == True
|
|
)
|
|
result_provider = await db.execute(stmt_provider)
|
|
provider = result_provider.scalar_one_or_none()
|
|
|
|
if provider:
|
|
return {
|
|
"model_name": default_instance.model_name,
|
|
"api_key": provider.api_key or settings.LLM_API_KEY,
|
|
"base_url": provider.base_url or settings.LLM_API_BASE,
|
|
"default_params": default_instance.default_params or {},
|
|
}
|
|
except Exception as e:
|
|
pass
|
|
|
|
return {
|
|
"model_name": settings.LLM_MODEL,
|
|
"api_key": settings.LLM_API_KEY,
|
|
"base_url": settings.LLM_API_BASE,
|
|
"default_params": {},
|
|
}
|
|
|
|
@classmethod
|
|
def _get_model(cls, llm_config: dict = None) -> OpenAIChatModel:
|
|
"""获取或创建全局共享的大语言模型实例。
|
|
|
|
支持通过 llm_config 参数动态指定模型配置,如果未提供则使用默认配置。
|
|
|
|
Args:
|
|
llm_config: 可选的 LLM 配置字典,包含:
|
|
- model_name (str): 模型名称
|
|
- api_key (str): API 密钥
|
|
- base_url (str): API 基础地址
|
|
- default_params (dict): 默认参数(temperature, max_tokens 等)
|
|
|
|
Returns:
|
|
OpenAIChatModel: 配置好的大语言模型实例。
|
|
"""
|
|
if cls._model is None:
|
|
if llm_config:
|
|
model_name = llm_config.get("model_name", settings.LLM_MODEL)
|
|
api_key = llm_config.get("api_key", settings.LLM_API_KEY)
|
|
base_url = llm_config.get("base_url", settings.LLM_API_BASE)
|
|
default_params = llm_config.get("default_params", {})
|
|
|
|
model_kwargs = {
|
|
"config_name": "enterprise_model",
|
|
"model_name": model_name,
|
|
"api_key": api_key,
|
|
"api_base": base_url,
|
|
}
|
|
|
|
if default_params:
|
|
if "temperature" in default_params:
|
|
model_kwargs["temperature"] = default_params["temperature"]
|
|
if "max_tokens" in default_params:
|
|
model_kwargs["max_tokens"] = default_params["max_tokens"]
|
|
|
|
cls._model = OpenAIChatModel(**model_kwargs)
|
|
else:
|
|
cls._model = OpenAIChatModel(
|
|
config_name="enterprise_model",
|
|
model_name=settings.LLM_MODEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
return cls._model
|
|
|
|
@classmethod
|
|
def _get_formatter(cls) -> OpenAIChatFormatter:
|
|
"""获取或创建全局共享的消息格式化器实例。
|
|
|
|
Returns:
|
|
OpenAIChatFormatter: OpenAI 聊天格式化器实例。
|
|
"""
|
|
if cls._formatter is None:
|
|
cls._formatter = OpenAIChatFormatter()
|
|
return cls._formatter
|
|
|
|
@classmethod
|
|
async def create_agent(
|
|
cls,
|
|
agent_type: str,
|
|
user_id: str,
|
|
user_name: str,
|
|
department_id: str | None = None,
|
|
) -> AgentBase:
|
|
"""根据智能体类型和用户信息创建对应的 AI 智能体。
|
|
|
|
优先从缓存中获取已存在的智能体实例,如果缓存中不存在则创建新实例。
|
|
缓存满时会自动淘汰最旧的智能体实例。
|
|
|
|
Args:
|
|
agent_type: 智能体类型,支持 employee/manager/task/document。
|
|
user_id: 用户唯一标识。
|
|
user_name: 用户显示名称。
|
|
department_id: 所属部门 ID(可选)。
|
|
|
|
Returns:
|
|
AgentBase: 创建或缓存的 AI 智能体实例。
|
|
"""
|
|
cache_key = f"{agent_type}_{user_id}" # 缓存键:智能体类型_用户ID
|
|
if cache_key in cls._agent_cache:
|
|
return cls._agent_cache[cache_key]
|
|
|
|
model = cls._get_model()
|
|
formatter = cls._get_formatter()
|
|
|
|
if agent_type == "employee":
|
|
agent = await cls._create_employee_agent(user_id, user_name, department_id, model, formatter)
|
|
elif agent_type == "manager":
|
|
agent = await cls._create_manager_agent(user_id, user_name, model, formatter)
|
|
elif agent_type == "task":
|
|
agent = await cls._create_task_agent(user_id, user_name, model, formatter)
|
|
elif agent_type == "document":
|
|
agent = await cls._create_document_agent(user_id, user_name, model, formatter)
|
|
else:
|
|
agent = await cls._create_employee_agent(user_id, user_name, department_id, model, formatter)
|
|
|
|
if len(cls._agent_cache) >= cls._MAX_CACHE_SIZE:
|
|
oldest_key = next(iter(cls._agent_cache))
|
|
del cls._agent_cache[oldest_key]
|
|
cls._agent_cache[cache_key] = agent
|
|
return agent
|
|
|
|
@classmethod
|
|
async def _create_employee_agent(cls, user_id, user_name, department_id, model, formatter):
|
|
"""创建员工专属 AI 助手智能体。
|
|
|
|
该智能体具备文档处理、通知发送、知识库查询等功能,
|
|
数据权限范围限定为仅能访问员工自己的数据。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
user_name: 用户显示名称。
|
|
department_id: 所属部门 ID。
|
|
model: 大语言模型实例。
|
|
formatter: 消息格式化器实例。
|
|
|
|
Returns:
|
|
ReActAgent: 配置好的员工 AI 智能体。
|
|
"""
|
|
from .tools.wecom_tools import send_notification
|
|
from .tools.document_tools import parse_document, format_correction
|
|
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
|
|
toolkit.register_tool_function(parse_document) # 注册文档解析工具
|
|
toolkit.register_tool_function(format_correction) # 注册格式修正工具
|
|
|
|
knowledge = None
|
|
try:
|
|
from modules.rag.knowledge import get_knowledge_base
|
|
knowledge = get_knowledge_base() # 尝试获取知识库
|
|
except Exception:
|
|
pass
|
|
|
|
agent = ReActAgent(
|
|
name=f"EmployeeAI_{user_name}",
|
|
sys_prompt=f"""你是 {user_name} 的专属AI工作助手。
|
|
|
|
你可以:
|
|
1. 回答工作中的问题,提供专业建议
|
|
2. 帮助处理文档,修正格式
|
|
3. 查询知识库获取信息
|
|
4. 发送通知给相关人员
|
|
|
|
重要约束:
|
|
- 只能访问该员工权限范围内的数据和工具
|
|
- 涉及敏感操作需要二次确认
|
|
- 始终保持专业和友好的态度""",
|
|
model=model,
|
|
formatter=formatter,
|
|
toolkit=toolkit,
|
|
knowledge=knowledge,
|
|
memory=UserIsolatedMemory(user_id=user_id),
|
|
max_iters=8,
|
|
)
|
|
|
|
register_rbac_hooks_for_user(agent, {
|
|
"user_id": user_id,
|
|
"user_name": user_name,
|
|
"role": "employee",
|
|
"department_id": department_id or "",
|
|
"data_scope": "self_only", # 数据权限:仅限本人
|
|
})
|
|
|
|
return agent
|
|
|
|
@classmethod
|
|
async def _create_manager_agent(cls, user_id, user_name, model, formatter):
|
|
"""创建管理者专属 AI 分析助手智能体。
|
|
|
|
该智能体具备下属管理、团队效率分析、任务统计等管理功能,
|
|
数据权限范围限定为仅能访问其下属员工的数据。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
user_name: 用户显示名称。
|
|
model: 大语言模型实例。
|
|
formatter: 消息格式化器实例。
|
|
|
|
Returns:
|
|
ReActAgent: 配置好的管理者 AI 智能体。
|
|
"""
|
|
from .tools.manager_tools import list_subordinates, get_employee_dashboard, generate_efficiency_report, get_task_statistics
|
|
from .tools.wecom_tools import send_notification
|
|
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(list_subordinates) # 注册下属列表查询工具
|
|
toolkit.register_tool_function(get_employee_dashboard) # 注册员工看板查询工具
|
|
toolkit.register_tool_function(generate_efficiency_report) # 注册效率报告生成工具
|
|
toolkit.register_tool_function(get_task_statistics) # 注册任务统计查询工具
|
|
toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
|
|
|
|
agent = ReActAgent(
|
|
name=f"ManagerAI_{user_name}",
|
|
sys_prompt=f"""你是 {user_name} 的管理分析助手。
|
|
|
|
你可以:
|
|
1. 查看下属员工列表和工作数据 (list_subordinates, get_employee_dashboard)
|
|
2. 生成团队效率报告 (generate_efficiency_report)
|
|
3. 统计分析任务完成情况 (get_task_statistics)
|
|
4. 向下属发送企业微信通知提醒 (send_notification)
|
|
|
|
重要约束:
|
|
- 只能查看你的直接和间接下属的数据
|
|
- 不能查看非下属或跨部门员工的数据
|
|
- 生成报告时注意数据隐私""",
|
|
model=model,
|
|
formatter=formatter,
|
|
toolkit=toolkit,
|
|
memory=UserIsolatedMemory(user_id=user_id),
|
|
max_iters=8,
|
|
)
|
|
|
|
register_rbac_hooks_for_user(agent, {
|
|
"user_id": user_id,
|
|
"user_name": user_name,
|
|
"role": "dept_manager",
|
|
"data_scope": "subordinate_only", # 数据权限:仅限下属
|
|
})
|
|
|
|
return agent
|
|
|
|
@classmethod
|
|
async def _create_task_agent(cls, user_id, user_name, model, formatter):
|
|
"""创建任务管理专属 AI 助手智能体。
|
|
|
|
该智能体专注于任务的创建、查询、更新和通知推送,
|
|
帮助用户高效管理日常工作事务。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
user_name: 用户显示名称。
|
|
model: 大语言模型实例。
|
|
formatter: 消息格式化器实例。
|
|
|
|
Returns:
|
|
ReActAgent: 配置好的任务管理 AI 智能体。
|
|
"""
|
|
from .tools.task_tools import list_tasks, create_task, get_task, update_task
|
|
from .tools.wecom_tools import send_notification
|
|
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(list_tasks) # 注册任务列表查询工具
|
|
toolkit.register_tool_function(create_task) # 注册任务创建工具
|
|
toolkit.register_tool_function(get_task) # 注册任务详情查询工具
|
|
toolkit.register_tool_function(update_task) # 注册任务更新工具
|
|
toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
|
|
|
|
agent = ReActAgent(
|
|
name=f"TaskAI_{user_name}",
|
|
sys_prompt=f"""你是任务管理助手。帮助用户创建、跟踪和管理工作任务。
|
|
|
|
你可以:
|
|
1. 创建新任务并分配给指定人员 (create_task)
|
|
2. 查询任务状态和进度 (list_tasks, get_task)
|
|
3. 更新任务信息 (update_task)
|
|
4. 推送任务通知到企业微信 (send_notification)
|
|
|
|
重要约束:
|
|
- 创建任务前确保标题和负责人信息完整
|
|
- 修改任务状态前告知用户变更
|
|
- 优先级: low/medium/high/urgent""",
|
|
model=model,
|
|
formatter=formatter,
|
|
toolkit=toolkit,
|
|
memory=UserIsolatedMemory(user_id=user_id),
|
|
max_iters=8,
|
|
)
|
|
|
|
return agent
|
|
|
|
@classmethod
|
|
async def _create_document_agent(cls, user_id, user_name, model, formatter):
|
|
"""创建文档处理专属 AI 助手智能体。
|
|
|
|
该智能体专注于各类办公文档的解析、格式修正和内容提取,
|
|
支持 PDF、Word、Excel 等常见格式。
|
|
|
|
Args:
|
|
user_id: 用户唯一标识。
|
|
user_name: 用户显示名称。
|
|
model: 大语言模型实例。
|
|
formatter: 消息格式化器实例。
|
|
|
|
Returns:
|
|
ReActAgent: 配置好的文档处理 AI 智能体。
|
|
"""
|
|
from .tools.document_tools import parse_document, format_correction
|
|
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(parse_document) # 注册文档解析工具
|
|
toolkit.register_tool_function(format_correction) # 注册格式修正工具
|
|
|
|
knowledge = None
|
|
try:
|
|
from modules.rag.knowledge import get_knowledge_base
|
|
knowledge = get_knowledge_base() # 尝试获取知识库
|
|
except Exception:
|
|
pass
|
|
|
|
agent = ReActAgent(
|
|
name=f"DocAI_{user_name}",
|
|
sys_prompt=f"""你是文档处理专家。帮助用户处理各类文档。
|
|
|
|
你可以:
|
|
1. 解析PDF/Word/Excel/PPT等格式
|
|
2. 修正文档格式
|
|
3. 提取文档关键信息
|
|
4. 从知识库中检索文档内容
|
|
5. 格式转换""",
|
|
model=model,
|
|
formatter=formatter,
|
|
toolkit=toolkit,
|
|
knowledge=knowledge,
|
|
memory=UserIsolatedMemory(user_id=user_id),
|
|
max_iters=8,
|
|
)
|
|
|
|
return agent
|
|
|