diff --git a/backend/__init__.py b/backend/__init__.py index e69de29..d2ecbea 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -0,0 +1 @@ +"""后端应用包。""" \ No newline at end of file diff --git a/backend/agentscope_integration/__init__.py b/backend/agentscope_integration/__init__.py index e69de29..90179a7 100644 --- a/backend/agentscope_integration/__init__.py +++ b/backend/agentscope_integration/__init__.py @@ -0,0 +1 @@ +"""AgentScope 集成模块初始化。""" \ No newline at end of file diff --git a/backend/agentscope_integration/factory.py b/backend/agentscope_integration/factory.py index aa64594..bb96c79 100644 --- a/backend/agentscope_integration/factory.py +++ b/backend/agentscope_integration/factory.py @@ -1,3 +1,8 @@ +"""AgentScope 智能体工厂模块。 + +提供统一的智能体创建接口,根据用户类型(员工/管理者/任务/文档)创建对应的 AI 智能体实例。 +支持智能体缓存以减少重复创建的开销。 +""" from agentscope.agent import AgentBase from agentscope.agent._react_agent import ReActAgent from agentscope.model import OpenAIChatModel @@ -10,13 +15,23 @@ from .hooks.rbac_hook import register_rbac_hooks_for_user class AgentFactory: - _model: OpenAIChatModel | None = None - _formatter: OpenAIChatFormatter | None = None - _agent_cache: dict[str, AgentBase] = {} - _MAX_CACHE_SIZE = 50 + """智能体工厂类,负责创建和管理不同类型的 AI 智能体实例。 + + 采用类级别的单例模式缓存模型和格式化器实例, + 同时为每个用户缓存已创建的智能体,避免重复初始化。 + """ + _model: OpenAIChatModel | None = None # 缓存的大语言模型实例 + _formatter: OpenAIChatFormatter | None = None # 缓存的消息格式化器实例 + _agent_cache: dict[str, AgentBase] = {} # 智能体缓存:{agent_type_user_id: AgentBase} + _MAX_CACHE_SIZE = 50 # 智能体缓存上限 @classmethod def _get_model(cls) -> OpenAIChatModel: + """获取或创建全局共享的大语言模型实例。 + + Returns: + OpenAIChatModel: 配置好的大语言模型实例。 + """ if cls._model is None: cls._model = OpenAIChatModel( config_name="enterprise_model", @@ -28,6 +43,11 @@ class AgentFactory: @classmethod def _get_formatter(cls) -> OpenAIChatFormatter: + """获取或创建全局共享的消息格式化器实例。 + + Returns: + OpenAIChatFormatter: OpenAI 聊天格式化器实例。 + """ if cls._formatter is None: cls._formatter = OpenAIChatFormatter() return cls._formatter @@ -40,7 +60,21 @@ class AgentFactory: user_name: str, department_id: str | None = None, ) -> AgentBase: - cache_key = f"{agent_type}_{user_id}" + """根据智能体类型和用户信息创建对应的 AI 智能体。 + + 优先从缓存中获取已存在的智能体实例,如果缓存中不存在则创建新实例。 + 缓存满时会自动淘汰最旧的智能体实例。 + + Args: + agent_type: 智能体类型,支持 employee/manager/task/document。 + user_id: 用户唯一标识。 + user_name: 用户显示名称。 + department_id: 所属部门 ID(可选)。 + + Returns: + AgentBase: 创建或缓存的 AI 智能体实例。 + """ + cache_key = f"{agent_type}_{user_id}" # 缓存键:智能体类型_用户ID if cache_key in cls._agent_cache: return cls._agent_cache[cache_key] @@ -66,18 +100,33 @@ class AgentFactory: @classmethod async def _create_employee_agent(cls, user_id, user_name, department_id, model, formatter): + """创建员工专属 AI 助手智能体。 + + 该智能体具备文档处理、通知发送、知识库查询等功能, + 数据权限范围限定为仅能访问员工自己的数据。 + + Args: + user_id: 用户唯一标识。 + user_name: 用户显示名称。 + department_id: 所属部门 ID。 + model: 大语言模型实例。 + formatter: 消息格式化器实例。 + + Returns: + ReActAgent: 配置好的员工 AI 智能体。 + """ from .tools.wecom_tools import send_notification from .tools.document_tools import parse_document, format_correction toolkit = Toolkit() - toolkit.register_tool_function(send_notification) - toolkit.register_tool_function(parse_document) - toolkit.register_tool_function(format_correction) + toolkit.register_tool_function(send_notification) # 注册企业微信通知工具 + toolkit.register_tool_function(parse_document) # 注册文档解析工具 + toolkit.register_tool_function(format_correction) # 注册格式修正工具 knowledge = None try: from modules.rag.knowledge import get_knowledge_base - knowledge = get_knowledge_base() + knowledge = get_knowledge_base() # 尝试获取知识库 except Exception: pass @@ -108,22 +157,36 @@ class AgentFactory: "user_name": user_name, "role": "employee", "department_id": department_id or "", - "data_scope": "self_only", + "data_scope": "self_only", # 数据权限:仅限本人 }) return agent @classmethod async def _create_manager_agent(cls, user_id, user_name, model, formatter): + """创建管理者专属 AI 分析助手智能体。 + + 该智能体具备下属管理、团队效率分析、任务统计等管理功能, + 数据权限范围限定为仅能访问其下属员工的数据。 + + Args: + user_id: 用户唯一标识。 + user_name: 用户显示名称。 + model: 大语言模型实例。 + formatter: 消息格式化器实例。 + + Returns: + ReActAgent: 配置好的管理者 AI 智能体。 + """ from .tools.manager_tools import list_subordinates, get_employee_dashboard, generate_efficiency_report, get_task_statistics from .tools.wecom_tools import send_notification toolkit = Toolkit() - toolkit.register_tool_function(list_subordinates) - toolkit.register_tool_function(get_employee_dashboard) - toolkit.register_tool_function(generate_efficiency_report) - toolkit.register_tool_function(get_task_statistics) - toolkit.register_tool_function(send_notification) + toolkit.register_tool_function(list_subordinates) # 注册下属列表查询工具 + toolkit.register_tool_function(get_employee_dashboard) # 注册员工看板查询工具 + toolkit.register_tool_function(generate_efficiency_report) # 注册效率报告生成工具 + toolkit.register_tool_function(get_task_statistics) # 注册任务统计查询工具 + toolkit.register_tool_function(send_notification) # 注册企业微信通知工具 agent = ReActAgent( name=f"ManagerAI_{user_name}", @@ -150,22 +213,36 @@ class AgentFactory: "user_id": user_id, "user_name": user_name, "role": "dept_manager", - "data_scope": "subordinate_only", + "data_scope": "subordinate_only", # 数据权限:仅限下属 }) return agent @classmethod async def _create_task_agent(cls, user_id, user_name, model, formatter): + """创建任务管理专属 AI 助手智能体。 + + 该智能体专注于任务的创建、查询、更新和通知推送, + 帮助用户高效管理日常工作事务。 + + Args: + user_id: 用户唯一标识。 + user_name: 用户显示名称。 + model: 大语言模型实例。 + formatter: 消息格式化器实例。 + + Returns: + ReActAgent: 配置好的任务管理 AI 智能体。 + """ from .tools.task_tools import list_tasks, create_task, get_task, update_task from .tools.wecom_tools import send_notification toolkit = Toolkit() - toolkit.register_tool_function(list_tasks) - toolkit.register_tool_function(create_task) - toolkit.register_tool_function(get_task) - toolkit.register_tool_function(update_task) - toolkit.register_tool_function(send_notification) + toolkit.register_tool_function(list_tasks) # 注册任务列表查询工具 + toolkit.register_tool_function(create_task) # 注册任务创建工具 + toolkit.register_tool_function(get_task) # 注册任务详情查询工具 + toolkit.register_tool_function(update_task) # 注册任务更新工具 + toolkit.register_tool_function(send_notification) # 注册企业微信通知工具 agent = ReActAgent( name=f"TaskAI_{user_name}", @@ -192,16 +269,30 @@ class AgentFactory: @classmethod async def _create_document_agent(cls, user_id, user_name, model, formatter): + """创建文档处理专属 AI 助手智能体。 + + 该智能体专注于各类办公文档的解析、格式修正和内容提取, + 支持 PDF、Word、Excel 等常见格式。 + + Args: + user_id: 用户唯一标识。 + user_name: 用户显示名称。 + model: 大语言模型实例。 + formatter: 消息格式化器实例。 + + Returns: + ReActAgent: 配置好的文档处理 AI 智能体。 + """ from .tools.document_tools import parse_document, format_correction toolkit = Toolkit() - toolkit.register_tool_function(parse_document) - toolkit.register_tool_function(format_correction) + toolkit.register_tool_function(parse_document) # 注册文档解析工具 + toolkit.register_tool_function(format_correction) # 注册格式修正工具 knowledge = None try: from modules.rag.knowledge import get_knowledge_base - knowledge = get_knowledge_base() + knowledge = get_knowledge_base() # 尝试获取知识库 except Exception: pass @@ -223,4 +314,4 @@ class AgentFactory: max_iters=8, ) - return agent \ No newline at end of file + return agent diff --git a/backend/agentscope_integration/hooks/rbac_hook.py b/backend/agentscope_integration/hooks/rbac_hook.py index 6cef32a..b40ee39 100644 --- a/backend/agentscope_integration/hooks/rbac_hook.py +++ b/backend/agentscope_integration/hooks/rbac_hook.py @@ -1,9 +1,34 @@ +"""RBAC 权限钩子模块。 + +提供 AgentScope 智能体的 RBAC(基于角色的访问控制)权限钩子, +在智能体回复前自动注入用户上下文信息(用户ID、角色、部门、数据权限范围)到消息元数据中。 +""" from agentscope.agent import AgentBase from agentscope.message import Msg def create_rbac_pre_reply_hook(user_context: dict): + """创建 RBAC 预回复钩子函数。 + + 该钩子会在智能体每次回复前执行,将用户的身份信息注入到消息元数据中, + 以便后续的工具调用和权限校验能够获取正确的用户上下文。 + + Args: + user_context: 用户上下文信息字典,包含 user_id、role、department_id、data_scope 等字段。 + + Returns: + callable: 异步钩子函数,用于注册到智能体的 pre_reply 钩子点。 + """ async def rbac_pre_reply_hook(self: AgentBase, kwargs: dict) -> dict: + """RBAC 预回复钩子内部实现。 + + Args: + self: 智能体实例。 + kwargs: 传递给智能体 reply 方法的参数字典。 + + Returns: + dict: 修改后的参数字典,消息元数据中已注入用户上下文信息。 + """ msg = kwargs.get("msg") if msg and isinstance(msg, Msg): msg.metadata = msg.metadata or {} @@ -18,6 +43,15 @@ def create_rbac_pre_reply_hook(user_context: dict): def register_rbac_hooks_for_user(agent: AgentBase, user_context: dict): + """为指定智能体注册 RBAC 权限钩子。 + + 将用户上下文信息绑定到智能体的 pre_reply 钩子点, + 确保智能体在处理每条消息时都能携带正确的用户身份信息。 + + Args: + agent: 目标智能体实例。 + user_context: 用户上下文信息字典。 + """ hook = create_rbac_pre_reply_hook(user_context) hook_name = f"rbac_{user_context['user_id']}" - agent.register_instance_hook("pre_reply", hook_name, hook) \ No newline at end of file + agent.register_instance_hook("pre_reply", hook_name, hook) diff --git a/backend/agentscope_integration/memory/user_memory.py b/backend/agentscope_integration/memory/user_memory.py index ccb573e..40a0fba 100644 --- a/backend/agentscope_integration/memory/user_memory.py +++ b/backend/agentscope_integration/memory/user_memory.py @@ -1,13 +1,39 @@ +"""用户隔离记忆模块。 + +提供基于用户 ID 隔离的记忆存储机制,确保每个用户只能访问自己的对话历史。 +通过包装 AgentScope 的 MemoryBase 实现用户级别的记忆隔离。 +""" from agentscope.memory import MemoryBase, InMemoryMemory from agentscope.message import Msg class UserIsolatedMemory(MemoryBase): + """用户隔离记忆类,确保每个用户只能访问自己对话历史的记忆管理器。 + + 通过在消息元数据中标记用户 ID,在获取记忆时过滤出当前用户的消息, + 实现多用户环境下的对话历史隔离。 + + Attributes: + user_id: 当前记忆实例绑定的用户唯一标识。 + _backend: 底层记忆存储实例,默认为 InMemoryMemory。 + """ + def __init__(self, user_id: str, backend_memory: MemoryBase | None = None): + """初始化用户隔离记忆实例。 + + Args: + user_id: 用户唯一标识。 + backend_memory: 可选的底层记忆存储实例,不提供则使用内存存储。 + """ self.user_id = user_id self._backend = backend_memory or InMemoryMemory() async def add(self, msg: Msg | list[Msg] | None) -> None: + """添加消息到记忆中,自动标记当前用户 ID。 + + Args: + msg: 要添加的消息,可以是单条消息、消息列表或 None。 + """ if msg is None: return msgs = msg if isinstance(msg, list) else [msg] @@ -17,14 +43,40 @@ class UserIsolatedMemory(MemoryBase): await self._backend.add(msg) async def get_memory(self, **kwargs) -> list[Msg]: + """获取当前用户的记忆历史。 + + 从底层存储中获取所有消息后,过滤出属于当前用户的消息。 + + Args: + **kwargs: 传递给底层存储的额外参数。 + + Returns: + list[Msg]: 属于当前用户的消息列表。 + """ all_msgs = await self._backend.get_memory(**kwargs) return [m for m in all_msgs if m.metadata.get("_user_id") == self.user_id] async def delete_by_mark(self, mark: str) -> None: + """根据标记删除消息。 + + Args: + mark: 要删除的消息标记。 + """ await self._backend.delete_by_mark(mark) async def update_messages_mark(self, msg_ids: list[str], new_mark: str) -> None: + """更新消息的标记。 + + Args: + msg_ids: 要更新标记的消息 ID 列表。 + new_mark: 新的标记字符串。 + """ await self._backend.update_messages_mark(msg_ids, new_mark) async def update_compressed_summary(self, summary: str) -> None: - await self._backend.update_compressed_summary(summary) \ No newline at end of file + """更新压缩后的记忆摘要。 + + Args: + summary: 新的记忆摘要字符串。 + """ + await self._backend.update_compressed_summary(summary) diff --git a/backend/agentscope_integration/tools/document_tools.py b/backend/agentscope_integration/tools/document_tools.py index 47c4f08..4a0f5bf 100644 --- a/backend/agentscope_integration/tools/document_tools.py +++ b/backend/agentscope_integration/tools/document_tools.py @@ -1,12 +1,22 @@ +"""文档处理工具模块。 + +提供多种办公文档格式的解析和格式修正功能,支持 PDF、Word、Excel 等格式。 +采用延迟导入策略,仅在需要时才尝试加载相应的依赖库。 +""" import os import logging -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 -_IMPORT_ERRORS: dict[str, str] = {} +_IMPORT_ERRORS: dict[str, str] = {} # 记录各库的导入错误信息,避免重复尝试 def _try_import_pdf() -> bool: + """尝试导入 PDF 解析库 PyPDF2。 + + Returns: + bool: 导入成功返回 True,失败返回 False。 + """ global _IMPORT_ERRORS if "pdf" in _IMPORT_ERRORS: return False @@ -19,6 +29,11 @@ def _try_import_pdf() -> bool: def _try_import_docx() -> bool: + """尝试导入 Word 文档解析库 python-docx。 + + Returns: + bool: 导入成功返回 True,失败返回 False。 + """ global _IMPORT_ERRORS if "docx" in _IMPORT_ERRORS: return False @@ -31,6 +46,11 @@ def _try_import_docx() -> bool: def _try_import_excel() -> bool: + """尝试导入 Excel 表格解析库 openpyxl。 + + Returns: + bool: 导入成功返回 True,失败返回 False。 + """ global _IMPORT_ERRORS if "excel" in _IMPORT_ERRORS: return False @@ -43,8 +63,20 @@ def _try_import_excel() -> bool: def parse_document(file_path: str, file_type: str = "auto") -> str: - ext = os.path.splitext(file_path)[1].lower() + """解析各类办公文档,提取文本内容。 + + 自动根据文件扩展名识别文档类型,支持 PDF、Word、Excel、PPT 和纯文本。 + Args: + file_path: 文档文件的完整路径。 + file_type: 文档类型,auto 表示自动识别。 + + Returns: + str: 提取的文档文本内容或错误信息。 + """ + ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名 + + # 根据扩展名自动识别文件类型 if file_type == "auto": if ext in (".pdf",): file_type = "pdf" @@ -101,7 +133,7 @@ def parse_document(file_path: str, file_type: str = "auto") -> str: import openpyxl try: - wb = openpyxl.load_workbook(file_path, data_only=True) + wb = openpyxl.load_workbook(file_path, data_only=True) # data_only 获取计算后的值 result_parts = [] for sheet_name in wb.sheetnames: ws = wb[sheet_name] @@ -118,6 +150,7 @@ def parse_document(file_path: str, file_type: str = "auto") -> str: if file_type in ("ppt", "pptx"): return "PPT 解析暂不支持,请将内容复制到 Word 或 PDF 后重试。" + # 尝试以纯文本方式读取文件 try: with open(file_path, "r", encoding="utf-8") as f: return f.read() @@ -135,6 +168,17 @@ def parse_document(file_path: str, file_type: str = "auto") -> str: def format_correction(content: str, format_rules: str = "standard") -> str: + """对文档内容进行格式修正。 + + 根据指定的格式规则对文本进行标准化处理,支持标准和企业公文两种模式。 + + Args: + content: 待修正的原始文本内容。 + format_rules: 格式规则,standard 为标准模式,enterprise 为企业公文模式。 + + Returns: + str: 格式修正后的文本内容。 + """ parts = [] parts.append(f"[格式规则: {format_rules}]\n") @@ -151,4 +195,4 @@ def format_correction(content: str, format_rules: str = "standard") -> str: return "\n".join(parts) -__all__ = ["parse_document", "format_correction"] \ No newline at end of file +__all__ = ["parse_document", "format_correction"] diff --git a/backend/agentscope_integration/tools/manager_tools.py b/backend/agentscope_integration/tools/manager_tools.py index 5166d7e..aade94a 100644 --- a/backend/agentscope_integration/tools/manager_tools.py +++ b/backend/agentscope_integration/tools/manager_tools.py @@ -1,3 +1,8 @@ +"""管理者工具模块。 + +提供管理者专属的工具函数,包括下属员工查询、员工看板数据获取、团队效率报告生成和任务统计等功能。 +通过内部 HTTP API 与后端组织监控服务通信,使用 JWT 服务令牌进行认证。 +""" import httpx import logging import os @@ -5,13 +10,18 @@ import jwt import time from config import settings -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 -_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") -_client: httpx.Client | None = None +_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") # 内部 API 基础地址 +_client: httpx.Client | None = None # 全局复用的 HTTP 客户端实例 def _get_client() -> httpx.Client: + """获取或创建全局复用的 HTTP 客户端实例。 + + Returns: + httpx.Client: 配置了超时时间的 HTTP 客户端。 + """ global _client if _client is None: _client = httpx.Client(timeout=30) @@ -19,6 +29,11 @@ def _get_client() -> httpx.Client: def _get_service_token() -> str | None: + """生成用于内部服务间调用的 JWT 令牌。 + + Returns: + str | None: 编码后的 JWT 令牌,生成失败返回 None。 + """ try: payload = {"sub": "system_tool", "exp": int(time.time()) + 3600, "type": "service"} token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") @@ -28,10 +43,19 @@ def _get_service_token() -> str | None: def _headers(token: str | None = None) -> dict: + """构建 HTTP 请求头,包含认证令牌。 + + Args: + token: 可选的自定义令牌,不提供则自动生成服务令牌。 + + Returns: + dict: 包含 Authorization 头的字典。 + """ t = token or _get_service_token() return {"Authorization": f"Bearer {t}"} if t else {} +# 工具函数描述 Schema,用于 AgentScope 工具注册 SCHEMAS = { "list_subordinates": { "name": "list_subordinates", @@ -67,6 +91,11 @@ SCHEMAS = { def list_subordinates() -> str: + """查询当前管理者名下的下属员工列表。 + + Returns: + str: 格式化的下属员工列表文本或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/org/subordinates", headers=_headers()) users = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) @@ -81,6 +110,14 @@ def list_subordinates() -> str: def get_employee_dashboard(employee_id: str) -> str: + """查询指定员工的工作看板数据,包括任务完成率、响应时间等指标。 + + Args: + employee_id: 员工唯一标识 ID。 + + Returns: + str: 格式化的员工看板数据或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employee/{employee_id}/dashboard", headers=_headers()) data = resp.json() @@ -90,6 +127,14 @@ def get_employee_dashboard(employee_id: str) -> str: def generate_efficiency_report(department_id: str | None = None) -> str: + """生成团队效率分析报告,包含各员工的任务数和完成率统计。 + + Args: + department_id: 可选的部门 ID,用于限定报告范围。 + + Returns: + str: 格式化的团队效率报告或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employees", headers=_headers()) employees = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) @@ -109,6 +154,14 @@ def generate_efficiency_report(department_id: str | None = None) -> str: def get_task_statistics(employee_id: str | None = None) -> str: + """查询任务统计数据,支持按员工筛选。 + + Args: + employee_id: 可选的员工 ID,用于筛选特定员工的任务。 + + Returns: + str: 格式化的任务统计信息或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) @@ -122,4 +175,4 @@ def get_task_statistics(employee_id: str | None = None) -> str: return f"查询任务统计失败: {e}" -__all__ = ["list_subordinates", "get_employee_dashboard", "generate_efficiency_report", "get_task_statistics", "SCHEMAS"] \ No newline at end of file +__all__ = ["list_subordinates", "get_employee_dashboard", "generate_efficiency_report", "get_task_statistics", "SCHEMAS"] diff --git a/backend/agentscope_integration/tools/task_tools.py b/backend/agentscope_integration/tools/task_tools.py index 1e75735..bf31db6 100644 --- a/backend/agentscope_integration/tools/task_tools.py +++ b/backend/agentscope_integration/tools/task_tools.py @@ -1,3 +1,8 @@ +"""任务管理工具模块。 + +提供任务相关操作的封装,包括任务列表查询、创建、获取详情、更新状态等功能。 +通过内部 HTTP API 与后端任务服务通信,使用 JWT 服务令牌进行认证。 +""" import httpx import logging import os @@ -5,13 +10,18 @@ import jwt import time from config import settings -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 -_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") -_client: httpx.Client | None = None +_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") # 内部 API 基础地址 +_client: httpx.Client | None = None # 全局复用的 HTTP 客户端实例 def _get_client() -> httpx.Client: + """获取或创建全局复用的 HTTP 客户端实例。 + + Returns: + httpx.Client: 配置了超时时间的 HTTP 客户端。 + """ global _client if _client is None: _client = httpx.Client(timeout=30) @@ -19,11 +29,16 @@ def _get_client() -> httpx.Client: def _get_service_token() -> str | None: + """生成用于内部服务间调用的 JWT 令牌。 + + Returns: + str | None: 编码后的 JWT 令牌,生成失败返回 None。 + """ try: payload = { - "sub": "system_tool", - "exp": int(time.time()) + 3600, - "type": "service", + "sub": "system_tool", # 令牌主体标识为系统工具 + "exp": int(time.time()) + 3600, # 1 小时后过期 + "type": "service", # 令牌类型为服务令牌 } token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") return token @@ -32,10 +47,19 @@ def _get_service_token() -> str | None: def _headers(token: str | None = None) -> dict: + """构建 HTTP 请求头,包含认证令牌。 + + Args: + token: 可选的自定义令牌,不提供则自动生成服务令牌。 + + Returns: + dict: 包含 Authorization 头的字典。 + """ t = token or _get_service_token() return {"Authorization": f"Bearer {t}"} if t else {} +# 工具函数描述 Schema,用于 AgentScope 工具注册 SCHEMAS = { "list_tasks": { "name": "list_tasks", @@ -97,6 +121,14 @@ SCHEMAS = { def list_tasks(status: str | None = None) -> str: + """查询任务列表,支持按状态筛选。 + + Args: + status: 可选的任务状态筛选条件(todo/in_progress/done)。 + + Returns: + str: 格式化的任务列表文本或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) @@ -118,6 +150,18 @@ def list_tasks(status: str | None = None) -> str: def create_task(title: str, description: str = "", assignee_id: str = "", priority: str = "medium", deadline: str | None = None) -> str: + """创建新任务。 + + Args: + title: 任务标题(必填)。 + description: 任务描述。 + assignee_id: 负责人用户 ID。 + priority: 任务优先级,默认 medium。 + deadline: 截止日期。 + + Returns: + str: 创建结果描述或错误信息。 + """ try: body = {"title": title, "description": description, "assignee_id": assignee_id, "priority": priority, "deadline": deadline} resp = _get_client().post(f"{_INTERNAL_BASE}/tasks", json=body, headers=_headers()) @@ -128,6 +172,14 @@ def create_task(title: str, description: str = "", assignee_id: str = "", priori def get_task(task_id: str) -> str: + """查询指定任务的详细信息。 + + Args: + task_id: 任务唯一标识 ID。 + + Returns: + str: 格式化的任务详情文本或错误信息。 + """ try: resp = _get_client().get(f"{_INTERNAL_BASE}/tasks/{task_id}", headers=_headers()) t = resp.json() @@ -137,6 +189,16 @@ def get_task(task_id: str) -> str: def update_task(task_id: str, status: str | None = None, description: str | None = None) -> str: + """更新任务状态或描述。 + + Args: + task_id: 任务唯一标识 ID。 + status: 新的任务状态(todo/in_progress/done)。 + description: 新的任务描述。 + + Returns: + str: 更新结果描述或错误信息。 + """ try: body = {} if status: @@ -150,6 +212,14 @@ def update_task(task_id: str, status: str | None = None, description: str | None def push_task_to_wecom(task_id: str) -> str: + """将任务通知推送到企业微信。 + + Args: + task_id: 任务唯一标识 ID。 + + Returns: + str: 推送结果描述或错误信息。 + """ try: resp = _get_client().post(f"{_INTERNAL_BASE}/tasks/{task_id}/push", headers=_headers()) return f"任务 {task_id[:8]} 已推送至企业微信" @@ -157,4 +227,4 @@ def push_task_to_wecom(task_id: str) -> str: return f"推送任务失败: {e}" -__all__ = ["list_tasks", "create_task", "get_task", "update_task", "push_task_to_wecom", "SCHEMAS"] \ No newline at end of file +__all__ = ["list_tasks", "create_task", "get_task", "update_task", "push_task_to_wecom", "SCHEMAS"] diff --git a/backend/agentscope_integration/tools/wecom_tools.py b/backend/agentscope_integration/tools/wecom_tools.py index 17838eb..15045fc 100644 --- a/backend/agentscope_integration/tools/wecom_tools.py +++ b/backend/agentscope_integration/tools/wecom_tools.py @@ -1,12 +1,29 @@ +"""企业微信工具模块。 + +提供企业微信 API 的封装,支持发送消息、查询用户信息、群消息发送等功能。 +包含 access_token 的自动获取和缓存机制。 +""" import httpx import logging -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 -_WECOM_ACCESS_TOKEN: dict = {"token": None, "expires_at": 0} +_WECOM_ACCESS_TOKEN: dict = {"token": None, "expires_at": 0} # 企业微信 access_token 缓存 def _get_access_token(corp_id: str, app_secret: str) -> str | None: + """获取或刷新企业微信 access_token。 + + 优先使用缓存中的 token,如果已过期或不存在则重新请求。 + Token 过期前 5 分钟会自动刷新。 + + Args: + corp_id: 企业微信 CorpID。 + app_secret: 企业微信应用 Secret。 + + Returns: + str | None: 有效的 access_token,获取失败返回 None。 + """ if not corp_id or not app_secret: logger.warning("WECOM_CORP_ID 或 WECOM_APP_SECRET 未配置,无法发送企微通知") return None @@ -22,7 +39,7 @@ def _get_access_token(corp_id: str, app_secret: str) -> str | None: data = resp.json() if data.get("errcode") == 0: _WECOM_ACCESS_TOKEN["token"] = data["access_token"] - _WECOM_ACCESS_TOKEN["expires_at"] = now + data.get("expires_in", 7200) - 300 + _WECOM_ACCESS_TOKEN["expires_at"] = now + data.get("expires_in", 7200) - 300 # 提前 5 分钟过期 return _WECOM_ACCESS_TOKEN["token"] else: logger.error(f"获取企微 token 失败: {data}") @@ -33,11 +50,28 @@ def _get_access_token(corp_id: str, app_secret: str) -> str | None: def _get_config(): + """从全局配置中获取企业微信 CorpID 和 AppSecret。 + + Returns: + tuple: (corp_id, app_secret) 元组。 + """ from config import settings return settings.WECOM_CORP_ID, settings.WECOM_APP_SECRET def send_notification(to_user: str, message: str, msg_type: str = "text") -> str: + """向指定企业微信用户发送通知消息。 + + 支持文本和文本卡片两种消息类型。 + + Args: + to_user: 接收消息的企业微信用户 ID。 + message: 消息内容。 + msg_type: 消息类型,支持 text/textcard。 + + Returns: + str: 发送结果描述信息。 + """ corp_id, app_secret = _get_config() token = _get_access_token(corp_id, app_secret) if not token: @@ -78,6 +112,14 @@ def send_notification(to_user: str, message: str, msg_type: str = "text") -> str def query_wecom_user(user_id: str) -> str: + """查询企业微信用户的详细信息。 + + Args: + user_id: 企业微信用户 ID。 + + Returns: + str: 用户信息描述或错误信息。 + """ corp_id, app_secret = _get_config() token = _get_access_token(corp_id, app_secret) if not token: @@ -97,6 +139,18 @@ def query_wecom_user(user_id: str) -> str: def send_wecom_group_message(message: str, group_id: str | None = None, msg_type: str = "text") -> str: + """向企业微信群发送消息。 + + 支持文本和 Markdown 两种消息格式。 + + Args: + message: 消息内容。 + group_id: 企业微信群聊 ID。 + msg_type: 消息类型,支持 text/markdown。 + + Returns: + str: 发送结果描述信息。 + """ corp_id, app_secret = _get_config() token = _get_access_token(corp_id, app_secret) if not token: @@ -124,4 +178,4 @@ def send_wecom_group_message(message: str, group_id: str | None = None, msg_type return f"企业微信群消息发送失败: {e}" -__all__ = ["send_notification", "query_wecom_user", "send_wecom_group_message"] \ No newline at end of file +__all__ = ["send_notification", "query_wecom_user", "send_wecom_group_message"] diff --git a/backend/config.py b/backend/config.py index 039a622..34e4816 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,27 +3,29 @@ from pydantic_settings import BaseSettings class Settings(BaseSettings): + """全局配置类,从环境变量加载所有应用配置项,支持通过 .env 文件覆盖。""" + DATABASE_URL: str = os.getenv( "DATABASE_URL", "postgresql+asyncpg://enterprise:enterprise123@localhost:5432/enterprise_ai", - ) - REDIS_URL: str = os.getenv("REDIS_URL", "redis://:redis123@localhost:6379/0") - JWT_SECRET: str = os.getenv("JWT_SECRET", "dev-secret-change-me") - JWT_ALGORITHM: str = "HS256" - JWT_EXPIRE_MINUTES: int = 1440 - LLM_API_KEY: str = os.getenv("LLM_API_KEY", "sk-placeholder") - LLM_API_BASE: str = os.getenv("LLM_API_BASE", "https://api.openai.com/v1") - LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o-mini") + ) # PostgreSQL 数据库连接 URL(asyncpg 异步驱动) + REDIS_URL: str = os.getenv("REDIS_URL", "redis://:redis123@localhost:6379/0") # Redis 连接 URL + JWT_SECRET: str = os.getenv("JWT_SECRET", "dev-secret-change-me") # JWT 令牌签名密钥 + JWT_ALGORITHM: str = "HS256" # JWT 签名算法 + JWT_EXPIRE_MINUTES: int = 1440 # JWT 令牌过期时间(分钟),默认 24 小时 + LLM_API_KEY: str = os.getenv("LLM_API_KEY", "sk-placeholder") # 大语言模型 API 密钥 + LLM_API_BASE: str = os.getenv("LLM_API_BASE", "https://api.openai.com/v1") # 大语言模型 API 基础地址 + LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o-mini") # 默认使用的大语言模型名称 - RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) - RATE_LIMIT_BURST: int = int(os.getenv("RATE_LIMIT_BURST", "10")) - UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads") - MAX_UPLOAD_SIZE_MB: int = int(os.getenv("MAX_UPLOAD_SIZE_MB", "50")) - WECOM_CORP_ID: str = os.getenv("WECOM_CORP_ID", "") - WECOM_APP_SECRET: str = os.getenv("WECOM_APP_SECRET", "") - WECOM_TOKEN: str = os.getenv("WECOM_TOKEN", "") - WECOM_AES_KEY: str = os.getenv("WECOM_AES_KEY", "") - METRICS_COLLECTION_INTERVAL: int = 60 + RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) # 每分钟请求速率限制 + RATE_LIMIT_BURST: int = int(os.getenv("RATE_LIMIT_BURST", "10")) # 速率限制突发上限 + UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads") # 文件上传存储目录 + MAX_UPLOAD_SIZE_MB: int = int(os.getenv("MAX_UPLOAD_SIZE_MB", "50")) # 最大上传文件大小(MB) + WECOM_CORP_ID: str = os.getenv("WECOM_CORP_ID", "") # 企业微信 CorpID + WECOM_APP_SECRET: str = os.getenv("WECOM_APP_SECRET", "") # 企业微信应用 Secret + WECOM_TOKEN: str = os.getenv("WECOM_TOKEN", "") # 企业微信 Token(用于回调验证) + WECOM_AES_KEY: str = os.getenv("WECOM_AES_KEY", "") # 企业微信 AES 密钥(用于回调消息解密) + METRICS_COLLECTION_INTERVAL: int = 60 # 系统指标采集间隔(秒) -settings = Settings() \ No newline at end of file +settings = Settings() # 全局配置单例实例 \ No newline at end of file diff --git a/backend/database.py b/backend/database.py index 93675b1..ee205a1 100644 --- a/backend/database.py +++ b/backend/database.py @@ -5,9 +5,11 @@ from config import settings class Base(DeclarativeBase): + """SQLAlchemy ORM 基类,所有数据库模型均继承此类。""" pass +# 异步数据库引擎,连接池大小 20,最大溢出 40,启用连接健康检查 async_engine = create_async_engine( settings.DATABASE_URL, pool_size=20, @@ -17,6 +19,7 @@ async_engine = create_async_engine( echo=False, ) +# 异步数据库会话工厂,用于创建数据库会话实例 AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, @@ -25,6 +28,7 @@ AsyncSessionLocal = async_sessionmaker( async def init_db(): + """初始化数据库:创建所有表并执行增量迁移。""" async with async_engine.begin() as conn: from models import Base as MBase await conn.run_sync(MBase.metadata.create_all) @@ -33,6 +37,7 @@ async def init_db(): async def _run_migrations(): + """执行数据库增量迁移,在已有表上安全添加新字段。""" async with async_engine.begin() as conn: await conn.execute(text( "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_version_id UUID REFERENCES flow_versions(id)" @@ -43,6 +48,7 @@ async def _run_migrations(): async def get_db(): + """FastAPI 依赖注入函数,提供数据库会话,自动提交或回滚事务。""" async with AsyncSessionLocal() as session: try: yield session diff --git a/backend/dependencies.py b/backend/dependencies.py index 27a4610..aeb2f9c 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -6,20 +6,26 @@ from database import AsyncSessionLocal from models import User, UserRole, Role, RolePermission, Permission from config import settings -security = HTTPBearer(auto_error=False) +security = HTTPBearer(auto_error=False) # HTTP Bearer 令牌认证方案 async def get_current_user( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(security), ) -> dict: + """获取当前登录用户及其角色、权限信息。 + + 优先从 request.state.user 读取(由 RBAC 中间件预填充), + 否则通过 JWT 令牌解析用户身份并从数据库查询角色权限。 + 返回包含 id、username、display_name、role、permissions、data_scope 的字典。 + """ if hasattr(request.state, "user") and request.state.user: return request.state.user if credentials: try: payload = jwt.decode( - credentials.credentials, + credentials=credentials.credentials, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM], ) @@ -33,11 +39,13 @@ async def get_current_user( if not user: raise HTTPException(401, "用户不存在") + # 查询用户关联的所有角色 ur_result = await db.execute( select(Role).join(UserRole).where(UserRole.user_id == user.id) ) roles = ur_result.scalars().all() + # 收集所有权限编码和数据权限范围 permissions = [] data_scopes = [] for role in roles: @@ -45,7 +53,7 @@ async def get_current_user( rp_result = await db.execute( select(Permission.code) .join(RolePermission) - .where(RolePermission.role_id == role.id) + .where(RolePermission.role_id == role.id) ) perms = rp_result.scalars().all() permissions.extend(perms) @@ -68,6 +76,7 @@ async def get_current_user( def require_permission(perm_code: str): + """权限检查依赖注入工厂,根据权限编码校验当前用户是否拥有该权限。""" async def checker(user: dict = Depends(get_current_user)) -> dict: if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []): raise HTTPException(403, f"缺少权限: {perm_code}") @@ -76,6 +85,7 @@ def require_permission(perm_code: str): async def get_db(): + """FastAPI 依赖注入函数,提供异步数据库会话,自动提交或回滚。""" async with AsyncSessionLocal() as session: try: yield session diff --git a/backend/main.py b/backend/main.py index 1edb916..c173981 100644 --- a/backend/main.py +++ b/backend/main.py @@ -30,6 +30,7 @@ from database import AsyncSessionLocal @asynccontextmanager async def lifespan(app: AgentApp): + """应用生命周期管理器,启动时初始化数据库和缓存,关闭时清理资源。""" await init_db() await cache_manager.connect() await init_memory_manager(AsyncSessionLocal) @@ -44,32 +45,34 @@ async def lifespan(app: AgentApp): app = AgentApp( - app_name="Enterprise AI Platform", - app_description="企业级 AI Agent 平台 - 双RBAC/企微集成/无代码流编排", + app_name="Enterprise AI Platform", # 应用名称 + app_description="企业级 AI Agent 平台 - 双RBAC/企微集成/无代码流编排", # 应用描述 lifespan=lifespan, docs_url="/docs", redoc_url=None, ) -app.middleware("http")(rate_limit_middleware) -app.middleware("http")(rbac_middleware) +# 注册全局 HTTP 中间件 +app.middleware("http")(rate_limit_middleware) # 速率限制中间件 +app.middleware("http")(rbac_middleware) # RBAC 权限中间件 -app.include_router(auth_router) -app.include_router(org_router) -app.include_router(rbac_router) -app.include_router(wecom_router) -app.include_router(agent_manager_router) -app.include_router(task_router) -app.include_router(monitor_router) -app.include_router(mcp_router) -app.include_router(flow_router) -app.include_router(gateway_router) -app.include_router(audit_router) -app.include_router(document_router) -app.include_router(notification_router) -app.include_router(system_router) -app.include_router(rag_router) -app.include_router(chat_router) -app.include_router(custom_tool_router) -app.include_router(memory_router) -app.include_router(model_provider_router) \ No newline at end of file +# 注册所有业务模块的路由 +app.include_router(auth_router) # 认证模块 +app.include_router(org_router) # 组织架构模块 +app.include_router(rbac_router) # 权限管理模块 +app.include_router(wecom_router) # 企业微信模块 +app.include_router(agent_manager_router) # 智能体管理模块 +app.include_router(task_router) # 任务管理模块 +app.include_router(monitor_router) # 监控模块 +app.include_router(mcp_router) # MCP 服务注册模块 +app.include_router(flow_router) # 流程定义管理模块 +app.include_router(gateway_router) # 流程 API 网关模块 +app.include_router(audit_router) # 审计日志模块 +app.include_router(document_router) # 文档管理模块 +app.include_router(notification_router) # 通知模块 +app.include_router(system_router) # 系统设置模块 +app.include_router(rag_router) # 知识库模块 +app.include_router(chat_router) # 对话模块 +app.include_router(custom_tool_router) # 自定义工具模块 +app.include_router(memory_router) # 记忆管理模块 +app.include_router(model_provider_router) # 模型供应商管理模块 \ No newline at end of file diff --git a/backend/middleware/apikey_auth.py b/backend/middleware/apikey_auth.py index 911e545..f95c535 100644 --- a/backend/middleware/apikey_auth.py +++ b/backend/middleware/apikey_auth.py @@ -1,3 +1,8 @@ +"""API 密钥认证中间件模块。 + +提供基于 API 密钥的流程访问认证功能。 +主要用于外部系统通过 API Key 调用已发布的 AI 流程。 +""" import hashlib from datetime import datetime from fastapi import Request, HTTPException @@ -8,15 +13,29 @@ from database import get_db async def authenticate_api_key(request: Request) -> dict: + """验证请求中的 API 密钥并返回关联的流程信息。 + + 从 Authorization 请求头中提取 API Key,验证其有效性并更新最后使用时间。 + API Key 必须以 "flow-" 开头,验证时对其 SHA-256 哈希值进行数据库匹配。 + + Args: + request: 当前 HTTP 请求对象。 + + Returns: + dict: 包含 flow_id、api_key_id 和 auth_type 的认证信息字典。 + + Raises: + HTTPException: 当缺少认证信息、API Key 格式无效或密钥不存在时抛出 401 异常。 + """ auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): raise HTTPException(401, "缺少认证信息") - raw_key = auth_header[7:] + raw_key = auth_header[7:] # 提取 Bearer 后的密钥部分 if not raw_key.startswith("flow-"): raise HTTPException(401, "无效的API Key格式") - key_hash = hashlib.sha256(raw_key.encode()).hexdigest() + key_hash = hashlib.sha256(raw_key.encode()).hexdigest() # 计算 SHA-256 哈希值 db_gen = get_db() db: AsyncSession = await db_gen.__anext__() @@ -28,13 +47,13 @@ async def authenticate_api_key(request: Request) -> dict: if not api_key: raise HTTPException(401, "API Key无效或已删除") - api_key.last_used_at = datetime.utcnow() + api_key.last_used_at = datetime.utcnow() # 更新最后使用时间 await db.flush() return { - "flow_id": str(api_key.flow_id), - "api_key_id": str(api_key.id), - "auth_type": "api_key", + "flow_id": str(api_key.flow_id), # 关联的流程 ID + "api_key_id": str(api_key.id), # API Key 记录 ID + "auth_type": "api_key", # 认证类型标识 } finally: try: diff --git a/backend/middleware/cache_manager.py b/backend/middleware/cache_manager.py index 6fdf69f..20595a0 100644 --- a/backend/middleware/cache_manager.py +++ b/backend/middleware/cache_manager.py @@ -1,3 +1,8 @@ +"""缓存管理器模块。 + +提供二级缓存机制(Redis + 内存),用于缓存 API 响应和计算结果。 +当 Redis 不可用时自动降级为纯内存缓存,保证系统的高可用性。 +""" import json import time import asyncio @@ -7,13 +12,30 @@ from config import settings class CacheManager: + """二级缓存管理器类,优先使用 Redis 缓存,降级时使用内存缓存。 + + 提供 get/set/delete/delete_pattern 四种基本操作, + 支持 TTL 过期时间和模式匹配批量删除。 + + Attributes: + _local: 内存缓存存储,结构为 {key: (expire_timestamp, value)}。 + _redis: Redis 异步客户端实例。 + _redis_available: Redis 是否可用标志。 + _lock: 异步锁,保证内存缓存操作的并发安全。 + """ def __init__(self): - self._local: dict[str, tuple[float, Any]] = {} - self._redis: Redis | None = None - self._redis_available = False - self._lock = asyncio.Lock() + """初始化缓存管理器实例。""" + self._local: dict[str, tuple[float, Any]] = {} # 内存缓存:{key: (过期时间戳, 值)} + self._redis: Redis | None = None # Redis 异步客户端 + self._redis_available = False # Redis 可用性标志 + self._lock = asyncio.Lock() # 异步锁 async def connect(self): + """连接到 Redis 服务器。 + + 尝试从配置中的 REDIS_URL 建立连接,如果连接失败则标记 Redis 不可用, + 后续操作将自动降级为纯内存缓存。 + """ try: self._redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await self._redis.ping() @@ -22,14 +44,31 @@ class CacheManager: self._redis_available = False async def disconnect(self): + """断开 Redis 连接,释放资源。""" if self._redis: await self._redis.close() @property def available(self) -> bool: + """检查缓存是否可用(Redis 或内存至少一个可用)。 + + Returns: + bool: Redis 可用时返回 True,否则返回 False。 + """ return self._redis_available async def get(self, key: str) -> Any | None: + """从缓存中获取指定键的值。 + + 优先从 Redis 获取,如果 Redis 不可用或未找到则从内存缓存获取。 + 内存缓存中的过期条目会被自动清理。 + + Args: + key: 缓存键。 + + Returns: + Any | None: 缓存值,未找到或已过期返回 None。 + """ if self._redis_available and self._redis: try: val = await self._redis.get(key) @@ -48,6 +87,15 @@ class CacheManager: return None async def set(self, key: str, value: Any, ttl: int = 300): + """将值设置到缓存中,指定过期时间。 + + 同时写入 Redis 和内存缓存,Redis 写入失败不影响内存缓存。 + + Args: + key: 缓存键。 + value: 要缓存的值,支持任意可 JSON 序列化的类型。 + ttl: 过期时间(秒),默认 300 秒(5 分钟)。 + """ if self._redis_available and self._redis: try: await self._redis.setex(key, ttl, json.dumps(value, default=str)) @@ -56,6 +104,7 @@ class CacheManager: async with self._lock: self._local[key] = (time.time() + ttl, value) + # 当内存缓存条目超过上限时清理过期条目 if len(self._local) > 10000: now = time.time() expired = [k for k, (t, v) in self._local.items() if now >= t] @@ -63,6 +112,13 @@ class CacheManager: del self._local[k] async def delete(self, key: str): + """从缓存中删除指定键。 + + 同时从 Redis 和内存缓存中删除,任一删除失败不影响另一个。 + + Args: + key: 要删除的缓存键。 + """ if self._redis_available and self._redis: try: await self._redis.delete(key) @@ -72,6 +128,13 @@ class CacheManager: self._local.pop(key, None) async def delete_pattern(self, pattern: str): + """按模式匹配批量删除缓存键。 + + Redis 中使用 keys 命令匹配,内存缓存中使用字符串包含匹配。 + + Args: + pattern: 匹配模式,支持通配符 *。 + """ if self._redis_available and self._redis: try: keys = await self._redis.keys(pattern) @@ -85,4 +148,4 @@ class CacheManager: del self._local[k] -cache_manager = CacheManager() \ No newline at end of file +cache_manager = CacheManager() # 全局缓存管理器单例实例 diff --git a/backend/middleware/rate_limiter.py b/backend/middleware/rate_limiter.py index fe81df5..d50ef20 100644 --- a/backend/middleware/rate_limiter.py +++ b/backend/middleware/rate_limiter.py @@ -1,3 +1,8 @@ +"""速率限制中间件模块。 + +提供基于令牌桶算法的 HTTP 请求速率限制功能。 +采用内存中的滑动窗口机制,限制每个 IP 地址在指定时间窗口内的请求数量。 +""" import time import asyncio from collections import defaultdict @@ -6,14 +11,31 @@ from config import settings class RateLimiter: + """内存速率限制器类,使用滑动窗口算法限制请求频率。 + + 为每个唯一键(通常是 IP 地址)维护一个时间戳列表, + 在每次请求时清理过期时间戳并检查是否超过限制。 + + Attributes: + MAX_KEYS: 最大缓存的键数量,防止内存无限增长。 + _buckets: 滑动窗口桶,存储每个键的请求时间戳列表。 + _lock: 异步锁,保证并发安全。 + _last_cleanup: 上次清理缓存的时间戳。 + """ MAX_KEYS = 10000 def __init__(self): - self._buckets: dict[str, list[float]] = defaultdict(list) - self._lock = asyncio.Lock() - self._last_cleanup = time.time() + """初始化速率限制器实例。""" + self._buckets: dict[str, list[float]] = defaultdict(list) # 滑动窗口桶:{key: [timestamp, ...]} + self._lock = asyncio.Lock() # 异步锁,保证并发安全 + self._last_cleanup = time.time() # 上次清理缓存的时间戳 async def _cleanup(self): + """清理过期和空闲的键,释放内存空间。 + + 仅在距上次清理超过 60 秒时执行实际清理操作。 + 删除空桶或最后一个请求超过 120 秒的桶。 + """ now = time.time() if now - self._last_cleanup < 60: return @@ -23,21 +45,33 @@ class RateLimiter: del self._buckets[k] async def check(self, key: str) -> bool: + """检查指定键是否允许新的请求。 + + 使用滑动窗口算法,清理窗口外的时间戳后检查是否超过限制。 + 如果超过限制则拒绝请求,否则记录当前时间戳并允许通过。 + + Args: + key: 速率限制键(通常为 IP 地址)。 + + Returns: + bool: 允许请求返回 True,拒绝请求返回 False。 + """ now = time.time() - limit = settings.RATE_LIMIT_PER_MINUTE - window = 60.0 + limit = settings.RATE_LIMIT_PER_MINUTE # 每分钟请求限制数 + window = 60.0 # 时间窗口(秒) async with self._lock: await self._cleanup() bucket = self._buckets[key] - bucket = [t for t in bucket if now - t < window] + bucket = [t for t in bucket if now - t < window] # 过滤窗口内的时间戳 self._buckets[key] = bucket if len(bucket) >= limit: - return False + return False # 超过限制,拒绝请求 - bucket.append(now) + bucket.append(now) # 记录当前请求时间戳 + # 当缓存键数量超过上限时,淘汰最旧的键 if len(self._buckets) > self.MAX_KEYS: oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2] for k in oldest_keys: @@ -46,27 +80,49 @@ class RateLimiter: return True async def remaining(self, key: str) -> int: + """获取指定键剩余的请求次数。 + + Args: + key: 速率限制键。 + + Returns: + int: 当前时间窗口内剩余的请求次数。 + """ now = time.time() async with self._lock: bucket = [t for t in self._buckets.get(key, []) if now - t < 60] return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket)) -rate_limiter = RateLimiter() +rate_limiter = RateLimiter() # 全局速率限制器单例实例 async def rate_limit_middleware(request: Request, call_next): + """速率限制中间件。 + + 对每个 HTTP 请求进行速率限制检查: + 1. 跳过公开路径(健康检查、登录等) + 2. 基于客户端 IP 地址进行速率限制 + 3. 在响应头中添加剩余请求次数信息 + + Args: + request: 当前 HTTP 请求对象。 + call_next: 下一个中间件或路由处理函数。 + + Returns: + Response: 如果未超限则返回后续处理结果,否则返回 429 错误响应。 + """ path = request.url.path if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]: return await call_next(request) - client_ip = request.client.host if request.client else "unknown" - key = f"ratelimit:{client_ip}" + client_ip = request.client.host if request.client else "unknown" # 客户端 IP 地址 + key = f"ratelimit:{client_ip}" # 速率限制键 if not await rate_limiter.check(key): raise HTTPException(429, "请求过于频繁,请稍后再试") response = await call_next(request) remaining = await rate_limiter.remaining(key) - response.headers["X-RateLimit-Remaining"] = str(remaining) - return response \ No newline at end of file + response.headers["X-RateLimit-Remaining"] = str(remaining) # 响应头中添加剩余请求次数 + return response diff --git a/backend/middleware/rbac_middleware.py b/backend/middleware/rbac_middleware.py index 69e69ad..fd09e1c 100644 --- a/backend/middleware/rbac_middleware.py +++ b/backend/middleware/rbac_middleware.py @@ -1,3 +1,9 @@ +"""RBAC 权限中间件模块。 + +提供全局 HTTP 请求的 RBAC(基于角色的访问控制)权限校验中间件。 +每个请求都会经过此中间件,解析 JWT 令牌并查询用户的角色和权限信息, +将用户上下文存储到 request.state.user 中供后续路由使用。 +""" import jwt from fastapi import Request, HTTPException from fastapi.responses import JSONResponse @@ -8,34 +14,55 @@ from sqlalchemy import select async def rbac_middleware(request: Request, call_next): + """RBAC 权限校验中间件。 + + 对每个 HTTP 请求进行权限校验: + 1. 跳过公开路径(登录、健康检查等) + 2. 解析 JWT 令牌获取用户身份 + 3. 从数据库查询用户的角色、权限和数据权限范围 + 4. 将用户上下文存储到 request.state.user 中 + + Args: + request: 当前 HTTP 请求对象。 + call_next: 下一个中间件或路由处理函数。 + + Returns: + Response: 如果权限校验通过则返回后续处理结果,否则返回 401 错误响应。 + """ + # 公开路径列表,无需认证即可访问 public_paths = ["/api/auth/login", "/api/auth/wecom", "/health", "/docs", "/openapi.json", "/wecom/callback"] if any(request.url.path.startswith(p) for p in public_paths): return await call_next(request) + # 从 Authorization 头中提取 JWT 令牌 token = request.headers.get("Authorization", "").replace("Bearer ", "") if not token: return JSONResponse({"code": 401, "message": "未提供认证令牌"}, 401) + # 解析 JWT 令牌获取用户 ID try: payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) user_id = payload.get("sub") except jwt.PyJWTError: return JSONResponse({"code": 401, "message": "令牌无效或已过期"}, 401) + # 从数据库查询用户信息和权限 async with AsyncSessionLocal() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user or user.status != "active": return JSONResponse({"code": 401, "message": "用户不存在或已禁用"}, 401) + # 查询用户关联的所有角色 ur_result = await db.execute( select(Role).join(UserRole).where(UserRole.user_id == user.id) ) roles = ur_result.scalars().all() - role_codes = [r.code for r in roles] - is_root = "root" in role_codes + role_codes = [r.code for r in roles] # 角色编码列表 + is_root = "root" in role_codes # 是否为超级管理员 + # 收集所有权限编码和数据权限范围 permissions = [] data_scopes = [] for role in roles: @@ -46,11 +73,13 @@ async def rbac_middleware(request: Request, call_next): perms = rp_result.scalars().all() permissions.extend([p.code for p in perms]) - unique_perms = list(set(permissions)) + unique_perms = list(set(permissions)) # 去重后的权限列表 + # 超级管理员自动拥有所有权限 if is_root and "*:*" not in unique_perms: unique_perms.insert(0, "*:*") + # 将用户上下文存储到 request.state 中 request.state.user = { "id": str(user.id), "username": user.username, @@ -66,4 +95,4 @@ async def rbac_middleware(request: Request, call_next): ), } - return await call_next(request) \ No newline at end of file + return await call_next(request) diff --git a/backend/models/__init__.py b/backend/models/__init__.py index c665c9d..15500c9 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -1,3 +1,9 @@ +""" +数据库 ORM 模型模块 + +本模块定义了所有数据库表对应的 SQLAlchemy ORM 模型。 +每个类映射到一张数据库表,类属性映射到表字段。 +""" import uuid from datetime import datetime from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Boolean, JSON, Text, Float @@ -7,389 +13,201 @@ from database import Base class Department(Base): + """部门表 (departments),存储企业部门层级结构,支持多级树形组织架构。""" __tablename__ = "departments" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), nullable=False) - parent_id = Column(UUID(as_uuid=True), ForeignKey("departments.id"), nullable=True) - path = Column(String(500), default="/") - level = Column(Integer, default=0) - sort_order = Column(Integer, default=0) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 部门唯一标识 UUID + name = Column(String(100), nullable=False) # 部门名称 + parent_id = Column(UUID(as_uuid=True), ForeignKey("departments.id"), nullable=True) # 上级部门 ID,用于构建树形结构 + path = Column(String(500), default="/") # 部门路径,从根节点到当前节点的路径字符串 + level = Column(Integer, default=0) # 部门层级深度(根部门为 0) + sort_order = Column(Integer, default=0) # 排序序号,同级部门按此字段排列顺序 + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 - children = relationship("Department", backref="parent", remote_side=[id]) - users = relationship("User", back_populates="department") + children = relationship("Department", backref="parent", remote_side=[id]) # 子部门列表(一对多自引用) + users = relationship("User", back_populates="department") # 部门下的用户列表 class User(Base): + """用户表 (users),存储系统用户信息,包括账号、身份、组织归属。""" __tablename__ = "users" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - username = Column(String(50), unique=True, nullable=False) - password_hash = Column(String(255), nullable=False) - display_name = Column(String(100), nullable=False) - email = Column(String(100)) - phone = Column(String(20)) - wecom_user_id = Column(String(100), unique=True) - department_id = Column(UUID(as_uuid=True), ForeignKey("departments.id")) - position = Column(String(100)) - manager_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - status = Column(String(20), default="active") - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - department = relationship("Department", back_populates="users") - roles = relationship("UserRole", back_populates="user") - manager = relationship("User", remote_side=[id], backref="subordinates") + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 用户唯一标识 UUID + username = Column(String(50), unique=True, nullable=False) # 登录用户名(唯一) + password_hash = Column(String(255), nullable=False) # 密码哈希值(bcrypt 加密) + display_name = Column(String(100), nullable=False) # 用户显示名称 + email = Column(String(100)) # 电子邮箱 + phone = Column(String(20)) # 手机号码 + wecom_user_id = Column(String(100), unique=True) # 企业微信用户 ID(唯一) + department_id = Column(UUID(as_uuid=True), ForeignKey("departments.id")) # 所属部门 ID + position = Column(String(100)) # 职位/岗位名称 + manager_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 直接上级用户 ID + status = Column(String(20), default="active") # 用户状态:active/inactive/disabled + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 + + department = relationship("Department", back_populates="users") # 所属部门(多对一) + roles = relationship("UserRole", back_populates="user") # 用户角色列表(通过中间表关联) + manager = relationship("User", remote_side=[id], backref="subordinates") # 直接上级(自引用) class Role(Base): + """角色表 (roles),存储系统角色定义,用于 RBAC 权限管理。""" __tablename__ = "roles" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(50), unique=True, nullable=False) - code = Column(String(50), unique=True, nullable=False, default="") - description = Column(String(200)) - is_system = Column(Boolean, default=False) - data_scope = Column(String(50), default="self_only") - created_at = Column(DateTime, default=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 角色唯一标识 UUID + name = Column(String(50), unique=True, nullable=False) # 角色名称(唯一) + code = Column(String(50), unique=True, nullable=False, default="") # 角色编码(唯一,如 admin/user) + description = Column(String(200)) # 角色描述 + is_system = Column(Boolean, default=False) # 是否为系统内置角色(不可删除) + data_scope = Column(String(50), default="self_only") # 数据权限范围:self_only/department/all + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 - permissions = relationship("RolePermission", back_populates="role") + permissions = relationship("RolePermission", back_populates="role") # 角色权限关联列表 class Permission(Base): + """权限表 (permissions),存储系统中每个可操作的权限点。""" __tablename__ = "permissions" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - code = Column(String(100), unique=True, nullable=False) - name = Column(String(100), nullable=False) - resource = Column(String(100), nullable=False) - action = Column(String(50), nullable=False) - description = Column(String(200)) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 权限唯一标识 UUID + code = Column(String(100), unique=True, nullable=False) # 权限编码(唯一) + name = Column(String(100), nullable=False) # 权限名称 + resource = Column(String(100), nullable=False) # 所属资源名称(如 user/role) + action = Column(String(50), nullable=False) # 操作类型(create/read/update/delete) + description = Column(String(200)) # 权限描述 class RolePermission(Base): + """角色-权限关联表 (role_permissions),多对多关联中间表。""" __tablename__ = "role_permissions" - role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) - permission_id = Column(UUID(as_uuid=True), ForeignKey("permissions.id", ondelete="CASCADE"), primary_key=True) + role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) # 角色 ID(级联删除) + permission_id = Column(UUID(as_uuid=True), ForeignKey("permissions.id", ondelete="CASCADE"), primary_key=True) # 权限 ID(级联删除) role = relationship("Role", back_populates="permissions") permission = relationship("Permission") class UserRole(Base): + """用户-角色关联表 (user_roles),多对多关联中间表。""" __tablename__ = "user_roles" - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) - role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) # 用户 ID(级联删除) + role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) # 角色 ID(级联删除) user = relationship("User", back_populates="roles") role = relationship("Role") class ChatSession(Base): + """聊天会话表 (chat_sessions),存储用户与 AI 智能体的对话会话记录。""" __tablename__ = "chat_sessions" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) - agent_type = Column(String(50), nullable=False) - session_id = Column(String(100), unique=True, nullable=False) - status = Column(String(20), default="active") - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 会话唯一标识 UUID + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) # 所属用户 ID + agent_type = Column(String(50), nullable=False) # 智能体类型(chat/flow/rag) + session_id = Column(String(100), unique=True, nullable=False) # 外部会话 ID(对客户端暴露) + status = Column(String(20), default="active") # 会话状态:active/closed + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 class ChatMessage(Base): + """聊天消息表 (chat_messages),存储聊天会话中的每条消息内容。""" __tablename__ = "chat_messages" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - session_id = Column(UUID(as_uuid=True), ForeignKey("chat_sessions.id", ondelete="CASCADE")) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) - role = Column(String(20), nullable=False) - content = Column(Text, nullable=False) - metadata_ = Column("metadata", JSON, default=dict) - created_at = Column(DateTime, default=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 消息唯一标识 UUID + session_id = Column(UUID(as_uuid=True), ForeignKey("chat_sessions.id", ondelete="CASCADE")) # 所属会话 ID + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) # 发送者用户 ID + role = Column(String(20), nullable=False) # 消息角色:user/assistant/system + content = Column(Text, nullable=False) # 消息内容文本 + metadata_ = Column("metadata", JSON, default=dict) # 元数据(额外信息 JSON) + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 class Task(Base): + """任务表 (tasks),存储分配给用户的待办任务信息。""" __tablename__ = "tasks" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - title = Column(String(200), nullable=False) - content = Column(Text) - assigner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - assignee_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - status = Column(String(20), default="pending") - priority = Column(String(20), default="normal") - deadline = Column(DateTime) - wecom_message_id = Column(String(100)) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 任务唯一标识 UUID + title = Column(String(200), nullable=False) # 任务标题 + content = Column(Text) # 任务内容描述 + assigner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 任务分配者(发起人)ID + assignee_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) # 任务执行人 ID + status = Column(String(20), default="pending") # 任务状态:pending/in_progress/completed/cancelled + priority = Column(String(20), default="normal") # 优先级:low/normal/high/urgent + deadline = Column(DateTime) # 截止日期时间 + wecom_message_id = Column(String(100)) # 企业微信消息 ID + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 class FlowDefinition(Base): + """流程定义表 (flow_definitions),存储可执行 AI 工作流的节点和连线配置。""" __tablename__ = "flow_definitions" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(200), nullable=False) - description = Column(Text) - version = Column(Integer, default=1) - status = Column(String(20), default="draft") - definition_json = Column(JSON, nullable=False, default=dict) - published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) - draft_definition_json = Column(JSON, nullable=True, default=None) - creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - flow_mode = Column(String(20), default="chatflow") - published_to_wecom = Column(Boolean, default=False) - published_to_web = Column(Boolean, default=False) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 流程唯一标识 UUID + name = Column(String(200), nullable=False) # 流程名称 + description = Column(Text) # 流程描述 + version = Column(Integer, default=1) # 当前版本号 + status = Column(String(20), default="draft") # 流程状态:draft/published/archived + definition_json = Column(JSON, nullable=False, default=dict) # 已发布的节点和连线配置 JSON + published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) # 已发布版本 ID + draft_definition_json = Column(JSON, nullable=True, default=None) # 草稿编辑中的配置 JSON + creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID + flow_mode = Column(String(20), default="chatflow") # 流程模式:chatflow/workflow + published_to_wecom = Column(Boolean, default=False) # 是否已发布到企业微信 + published_to_web = Column(Boolean, default=False) # 是否已发布到 Web 端 + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 published_version = relationship("FlowVersion", foreign_keys=[published_version_id], post_update=True) class FlowVersion(Base): + """流程版本表 (flow_versions),存储流程定义的历史版本快照。""" __tablename__ = "flow_versions" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) - version = Column(Integer, nullable=False) - definition_json = Column(JSON, nullable=False, default=dict) - changelog = Column(Text, default="") - published_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) - published_to_wecom = Column(Boolean, default=False) - published_to_web = Column(Boolean, default=False) - created_at = Column(DateTime, default=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 版本唯一标识 UUID + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) # 所属流程定义 ID + version = Column(Integer, nullable=False) # 版本号(同一流程内递增) + definition_json = Column(JSON, nullable=False, default=dict) # 该版本的流程定义 JSON 快照 + changelog = Column(Text, default="") # 版本变更日志 + published_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 发布者用户 ID + published_to_wecom = Column(Boolean, default=False) # 是否发布到企业微信 + published_to_web = Column(Boolean, default=False) # 是否发布到 Web 端 + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 class FlowApiKey(Base): + """流程 API 密钥表 (flow_api_keys),存储用于外部调用流程的 API 密钥。""" __tablename__ = "flow_api_keys" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) - name = Column(String(100), nullable=False) - key_hash = Column(String(64), nullable=False) - key_prefix = Column(String(10), nullable=False) - created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) - last_used_at = Column(DateTime, nullable=True) - created_at = Column(DateTime, default=datetime.utcnow) + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 密钥唯一标识 UUID + flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) # 所属流程 ID + name = Column(String(100), nullable=False) # 密钥名称 + key_hash = Column(String(64), nullable=False) # 密钥哈希值(SHA-256 加密存储) + key_prefix = Column(String(10), nullable=False) # 密钥前缀(用于显示识别) + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID + last_used_at = Column(DateTime, nullable=True) # 最后使用时间 + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 class FlowTemplate(Base): + """流程模板表 (flow_templates),存储预定义的流程模板,可供用户快速创建流程。""" __tablename__ = "flow_templates" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(200), nullable=False) - description = Column(Text, default="") - category = Column(String(50), default="") - definition_json = Column(JSON, nullable=False, default=dict) - icon = Column(String(50), default="") - sort_order = Column(Integer, default=0) - is_builtin = Column(Boolean, default=False) - usage_count = Column(Integer, default=0) - created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class CustomTool(Base): - __tablename__ = "custom_tools" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), nullable=False) - description = Column(Text) - schema_json = Column(JSON, nullable=False, default=dict) - endpoint_url = Column(String(500), nullable=False) - method = Column(String(10), default="GET") - path = Column(String(500), default="") - headers_json = Column(JSON, default=dict) - auth_type = Column(String(20), default="none") - auth_config = Column(JSON, default=dict) - created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class FlowExecution(Base): - __tablename__ = "flow_executions" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE")) - version = Column(Integer, nullable=True) - trigger_type = Column(String(50)) - trigger_user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - input_data = Column(JSON) - output_data = Column(JSON) - status = Column(String(20), default="running") - token_usage = Column(JSON, default=dict) - latency_ms = Column(Integer, nullable=True) - error_message = Column(Text, nullable=True) - started_at = Column(DateTime, default=datetime.utcnow) - finished_at = Column(DateTime) - - -class MCPService(Base): - __tablename__ = "mcp_services" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), unique=True, nullable=False) - transport = Column(String(20), default="http") - url = Column(String(500)) - command = Column(String(500)) - args = Column(JSON, default=list) - env = Column(JSON, default=dict) - status = Column(String(20), default="disconnected") - tools = Column(JSON, default=list) - creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class NotificationTemplate(Base): - __tablename__ = "notification_templates" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), nullable=False) - code = Column(String(100), unique=True, nullable=False) - channel = Column(String(20), default="wecom") - title_template = Column(String(500)) - body_template = Column(Text, nullable=False) - variables = Column(JSON, default=list) - is_system = Column(Boolean, default=False) - created_at = Column(DateTime, default=datetime.utcnow) - - -class SystemMetric(Base): - __tablename__ = "system_metrics" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - metric_type = Column(String(50), nullable=False) - value = Column(JSON, nullable=False) - collected_at = Column(DateTime, default=datetime.utcnow) - - -class AgentConfig(Base): - __tablename__ = "agent_configs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), nullable=False) - description = Column(String(500)) - system_prompt = Column(Text, default="") - model = Column(String(50), default="gpt-4o-mini") - model_instance_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True) - embedding_model_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True) - temperature = Column(Float, default=0.7) - tools = Column(JSON, default=list) - status = Column(String(20), default="active") - creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class AuditLog(Base): - __tablename__ = "audit_logs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - operator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) - action = Column(String(100), nullable=False) - resource = Column(String(100)) - resource_id = Column(String(100)) - detail = Column(JSON, default=dict) - ip_address = Column(String(50)) - created_at = Column(DateTime, default=datetime.utcnow) - - -class MemoryMessage(Base): - __tablename__ = "memory_messages" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) - session_id = Column(UUID(as_uuid=True), nullable=False) - role = Column(String(20), nullable=False) - content = Column(Text, nullable=False) - created_at = Column(DateTime, default=datetime.utcnow) - - -class MemoryAtom(Base): - __tablename__ = "memory_atoms" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True) - atom_type = Column(String(20), nullable=False) - content = Column(Text, nullable=False) - priority = Column(Integer, default=50) - source_session_id = Column(UUID(as_uuid=True), nullable=True) - metadata_ = Column("metadata", JSON, default=dict) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class MemoryScene(Base): - __tablename__ = "memory_scenes" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True) - scene_name = Column(String(200), nullable=False) - summary = Column(Text, nullable=False) - heat = Column(Integer, default=0) - content = Column(JSON, default=dict) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class MemoryPersona(Base): - __tablename__ = "memory_personas" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True) - content = Column(JSON, default=dict, nullable=False) - raw_text = Column(Text, default="") - version = Column(Integer, default=1) - updated_at = Column(DateTime, default=datetime.utcnow) - - -class MemorySession(Base): - __tablename__ = "memory_sessions" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) - session_id = Column(UUID(as_uuid=True), nullable=False) - flow_name = Column(String(200), default="") - message_count = Column(Integer, default=0) - last_active_at = Column(DateTime, default=datetime.utcnow) - created_at = Column(DateTime, default=datetime.utcnow) - - -class ModelProvider(Base): - __tablename__ = "model_providers" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(String(100), nullable=False) - provider_type = Column(String(50), nullable=False) - base_url = Column(String(500)) - api_key = Column(Text) - extra_config = Column(JSON, default=dict) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.utcnow) - - -class ModelInstance(Base): - __tablename__ = "model_instances" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_id = Column(UUID(as_uuid=True), ForeignKey("model_providers.id", ondelete="CASCADE")) - model_name = Column(String(100), nullable=False) - model_type = Column(String(30), nullable=False) - display_name = Column(String(200)) - capabilities = Column(JSON, default=dict) - default_params = Column(JSON, default=dict) - is_default = Column(Boolean, default=False) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.utcnow) \ No newline at end of file + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 模板唯一标识 UUID + name = Column(String(200), nullable=False) # 模板名称 + description = Column(Text, default="") # 模板描述 + category = Column(String(50), default="") # 模板分类 + definition_json = Column(JSON, nullable=False, default=dict) # 模板的流程定义 JSON + icon = Column(String(50), default="") # 模板图标名称 + sort_order = Column(Integer, default=0) # 排序序号 + is_builtin = Column(Boolean, default=False) # 是否为系统内置模板 + usage_count = Column(Integer, default=0) # 使用次数统计 + created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID + created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间 + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间 diff --git a/backend/modules/__init__.py b/backend/modules/__init__.py index e69de29..4e92b2d 100644 --- a/backend/modules/__init__.py +++ b/backend/modules/__init__.py @@ -0,0 +1 @@ +"""业务模块包。""" \ No newline at end of file diff --git a/backend/modules/agent_manager/__init__.py b/backend/modules/agent_manager/__init__.py index e69de29..9b95ba3 100644 --- a/backend/modules/agent_manager/__init__.py +++ b/backend/modules/agent_manager/__init__.py @@ -0,0 +1 @@ +"""Agent 管理模块。""" \ No newline at end of file diff --git a/backend/modules/agent_manager/router.py b/backend/modules/agent_manager/router.py index 8764a75..d6e5676 100644 --- a/backend/modules/agent_manager/router.py +++ b/backend/modules/agent_manager/router.py @@ -1,3 +1,8 @@ +"""智能体管理模块路由。 + +提供 AI 智能体的配置管理、对话交互和聊天历史记录查询功能。 +支持多种智能体类型(员工/管理者/任务/文档)的动态创建和对话。 +""" import uuid from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select @@ -17,16 +22,32 @@ async def agent_chat( payload: dict, db: AsyncSession = Depends(get_db), ): + """与指定类型的 AI 智能体进行对话。 + + 根据用户身份创建或获取对应的智能体实例,处理用户消息并返回 AI 回复。 + 同时会自动创建或复用聊天会话,并保存对话消息到数据库。 + + Args: + agent_type: 智能体类型(employee/manager/task/document)。 + request: HTTP 请求对象,用于获取当前用户信息。 + payload: 请求体,包含 message 和可选的 session_id。 + db: 异步数据库会话。 + + Returns: + dict: 包含 session_id、reply 和 role 的响应数据。 + """ user_ctx = request.state.user user_id = uuid.UUID(user_ctx["id"]) - msg_content = payload.get("message", "") - session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}") + msg_content = payload.get("message", "") # 用户消息内容 + session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}") # 会话 ID,未提供则自动生成 + # 查询当前用户信息 result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException(404, "用户不存在") + # 查询或创建聊天会话 session_result = await db.execute( select(ChatSession).where(ChatSession.session_id == session_id) ) @@ -39,6 +60,7 @@ async def agent_chat( db.add(session) await db.flush() + # 保存用户消息到数据库 user_msg = ChatMessage( session_id=session.id, user_id=user.id, role="user", content=msg_content, @@ -46,6 +68,7 @@ async def agent_chat( db.add(user_msg) await db.flush() + # 创建对应类型的 AI 智能体 agent = await AgentFactory.create_agent( agent_type=agent_type, user_id=str(user.id), @@ -53,12 +76,14 @@ async def agent_chat( department_id=str(user.department_id) if user.department_id else None, ) + # 构造消息并调用智能体回复 from agentscope.message import Msg input_msg = Msg(name="user", content=msg_content, role="user") response = await agent.reply(input_msg) reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response) + # 保存 AI 回复消息到数据库 ai_msg = ChatMessage( session_id=session.id, user_id=user.id, role="assistant", content=reply_text, @@ -78,6 +103,11 @@ async def agent_chat( @router.get("/list") async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)): + """获取所有处于活跃状态的智能体配置列表。 + + Returns: + dict: 包含智能体配置列表的响应数据。 + """ result = await db.execute( select(AgentConfig).where(AgentConfig.status == "active").order_by(AgentConfig.updated_at.desc()) ) @@ -99,6 +129,16 @@ async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)): @router.post("/", response_model=AgentConfigOut) async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSession = Depends(get_db)): + """创建新的智能体配置。 + + Args: + req: 智能体配置创建请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + AgentConfigOut: 创建后的智能体配置响应。 + """ user_ctx = request.state.user agent = AgentConfig( name=req.name, @@ -123,6 +163,16 @@ async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSessio @router.get("/{agent_id}", response_model=AgentConfigOut) async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + """获取指定智能体的配置详情。 + + Args: + agent_id: 智能体唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + AgentConfigOut: 智能体配置响应。 + """ result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) agent = result.scalar_one_or_none() if not agent: @@ -139,6 +189,17 @@ async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = De @router.put("/{agent_id}", response_model=AgentConfigOut) async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Request, db: AsyncSession = Depends(get_db)): + """更新指定智能体的配置。 + + Args: + agent_id: 智能体唯一标识 ID。 + req: 智能体配置更新请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + AgentConfigOut: 更新后的智能体配置响应。 + """ result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) agent = result.scalar_one_or_none() if not agent: @@ -170,6 +231,16 @@ async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Req @router.delete("/{agent_id}") async def delete_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + """删除指定的智能体配置。 + + Args: + agent_id: 智能体唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 操作结果响应。 + """ result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) agent = result.scalar_one_or_none() if not agent: @@ -184,6 +255,16 @@ async def get_chat_history( request: Request, db: AsyncSession = Depends(get_db), ): + """获取指定会话的完整聊天历史记录。 + + Args: + session_id: 会话唯一标识。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含消息列表的响应数据,按时间顺序排列。 + """ session_result = await db.execute( select(ChatSession).where(ChatSession.session_id == session_id) ) @@ -203,4 +284,4 @@ async def get_chat_history( "content": m.content, "created_at": str(m.created_at), } for m in messages], - } \ No newline at end of file + } diff --git a/backend/modules/audit/__init__.py b/backend/modules/audit/__init__.py index e69de29..11b78cc 100644 --- a/backend/modules/audit/__init__.py +++ b/backend/modules/audit/__init__.py @@ -0,0 +1 @@ +"""审计日志模块。""" \ No newline at end of file diff --git a/backend/modules/audit/router.py b/backend/modules/audit/router.py index f862340..10d5065 100644 --- a/backend/modules/audit/router.py +++ b/backend/modules/audit/router.py @@ -1,3 +1,8 @@ +"""审计日志模块路由。 + +提供审计日志的查询、统计和导出功能。 +记录系统中所有重要操作的详细信息,支持按操作类型、资源、操作人和时间范围筛选。 +""" import uuid import csv import io @@ -26,6 +31,22 @@ async def list_logs( date_to: datetime | None = Query(None), db: AsyncSession = Depends(get_db), ): + """分页查询审计日志列表,支持多条件筛选。 + + Args: + request: HTTP 请求对象。 + page: 页码,从 1 开始。 + page_size: 每页数量,最大 100。 + action: 可选的操作类型筛选条件。 + resource: 可选的资源类型筛选条件。 + operator_id: 可选的操作人 ID 筛选条件。 + date_from: 可选的起始时间筛选条件。 + date_to: 可选的结束时间筛选条件。 + db: 异步数据库会话。 + + Returns: + AuditLogPage: 分页的审计日志响应数据。 + """ conditions = [] if action: conditions.append(AuditLog.action == action) @@ -38,14 +59,16 @@ async def list_logs( if date_to: conditions.append(AuditLog.created_at <= date_to) - where = and_(*conditions) if conditions else None + where = and_(*conditions) if conditions else None # 组合所有筛选条件 + # 查询总数 count_q = select(func.count(AuditLog.id)) if where is not None: count_q = count_q.where(where) total_result = await db.execute(count_q) total = total_result.scalar() or 0 + # 分页查询 q = select(AuditLog).order_by(AuditLog.created_at.desc()) if where is not None: q = q.where(where) @@ -63,6 +86,15 @@ async def list_logs( @router.get("/actions") async def list_action_types(request: Request, db: AsyncSession = Depends(get_db)): + """获取所有操作类型及其出现次数统计。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含操作类型统计列表的响应数据。 + """ result = await db.execute( select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action) ) @@ -74,15 +106,25 @@ async def list_action_types(request: Request, db: AsyncSession = Depends(get_db) @router.get("/stats") async def audit_stats(request: Request, db: AsyncSession = Depends(get_db)): + """获取审计日志的统计摘要,包括总数、今日数量和 TOP 排行。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含审计统计摘要的响应数据。 + """ total_result = await db.execute(select(func.count(AuditLog.id))) total = total_result.scalar() or 0 - today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) + today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) # 今日零点 today_result = await db.execute( select(func.count(AuditLog.id)).where(AuditLog.created_at >= today_start) ) today = today_result.scalar() or 0 + # 最常见的操作类型 TOP 10 top_result = await db.execute( select(AuditLog.action, func.count(AuditLog.id)) .group_by(AuditLog.action) @@ -91,6 +133,7 @@ async def audit_stats(request: Request, db: AsyncSession = Depends(get_db)): ) top_actions = [{"action": r[0], "count": r[1]} for r in top_result.all()] + # 最常见的资源类型 TOP 10 top_resources = await db.execute( select(AuditLog.resource, func.count(AuditLog.id)) .group_by(AuditLog.resource) @@ -117,6 +160,17 @@ async def export_logs( date_to: datetime | None = Query(None), db: AsyncSession = Depends(get_db), ): + """导出审计日志为 CSV 文件。 + + Args: + request: HTTP 请求对象。 + date_from: 可选的起始时间筛选条件。 + date_to: 可选的结束时间筛选条件。 + db: 异步数据库会话。 + + Returns: + StreamingResponse: CSV 格式的文件流响应。 + """ conditions = [] if date_from: conditions.append(AuditLog.created_at >= date_from) @@ -126,13 +180,13 @@ async def export_logs( q = select(AuditLog).order_by(AuditLog.created_at.desc()) if conditions: q = q.where(and_(*conditions)) - q = q.limit(10000) + q = q.limit(10000) # 最多导出 10000 条 result = await db.execute(q) logs = result.scalars().all() output = io.StringIO() writer = csv.writer(output) - writer.writerow(["ID", "操作时间", "操作人ID", "操作", "资源", "资源ID", "详情", "IP地址"]) + writer.writerow(["ID", "操作时间", "操作人ID", "操作", "资源", "资源ID", "详情", "IP地址"]) # CSV 表头 for log in logs: writer.writerow([ str(log.id), @@ -150,4 +204,4 @@ async def export_logs( iter([output.getvalue()]), media_type="text/csv", headers={"Content-Disposition": f"attachment; filename=audit_logs_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv"}, - ) \ No newline at end of file + ) diff --git a/backend/modules/auth/__init__.py b/backend/modules/auth/__init__.py index e69de29..a45ad97 100644 --- a/backend/modules/auth/__init__.py +++ b/backend/modules/auth/__init__.py @@ -0,0 +1,4 @@ +"""认证模块。 + +提供用户登录、JWT 令牌生成、企业微信 OAuth 授权、个人信息修改和密码修改等认证功能。 +""" \ No newline at end of file diff --git a/backend/modules/auth/router.py b/backend/modules/auth/router.py index af9518c..751ad73 100644 --- a/backend/modules/auth/router.py +++ b/backend/modules/auth/router.py @@ -1,3 +1,8 @@ +"""认证模块路由。 + +提供用户登录、JWT 令牌生成、企业微信 OAuth 授权、个人信息查询/修改和密码修改等功能。 +支持基于用户名密码和企业微信 OAuth 两种认证方式。 +""" import uuid import secrets from datetime import datetime, timedelta @@ -12,17 +17,36 @@ from models import User, UserRole, Role, RolePermission, Permission from schemas import LoginRequest, TokenResponse, UserOut, RoleOut from config import settings +# OAuth 状态存储,用于防止 CSRF 攻击 _oauth_states: dict[str, float] = {} -_OAUTH_STATE_TTL = 600 +_OAUTH_STATE_TTL = 600 # OAuth 状态有效期(秒) def hash_password(password: str) -> str: + """对密码进行 bcrypt 哈希加密。 + + Args: + password: 明文密码字符串。 + + Returns: + str: bcrypt 加密后的哈希字符串。 + """ return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') -router = APIRouter(prefix="/api/auth", tags=["auth"]) +router = APIRouter(prefix="/api/auth", tags=["auth"]) # 认证路由前缀 + async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> list[str]: + """根据角色 ID 列表获取所有关联的权限代码。 + + Args: + db: 异步数据库会话。 + role_ids: 角色 ID 列表。 + + Returns: + list[str]: 去重后的权限代码列表。 + """ result = await db.execute( select(Permission.code) .join(RolePermission) @@ -32,12 +56,24 @@ async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> l async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]: + """获取用户的所有角色及其权限信息。 + + 查询用户关联的所有角色,并为每个角色查询其关联的权限代码。 + + Args: + db: 异步数据库会话。 + user_id: 用户唯一标识 ID。 + + Returns: + list[RoleOut]: 角色信息列表,每个角色包含名称、代码、描述和权限代码列表。 + """ result = await db.execute( select(Role).join(UserRole).where(UserRole.user_id == user_id) ) roles = result.scalars().all() out = [] for role in roles: + # 查询该角色关联的所有权限代码 rp_result = await db.execute( select(Permission.code) .join(RolePermission) @@ -58,6 +94,20 @@ async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]: @router.post("/login", response_model=TokenResponse) async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): + """用户登录接口。 + + 验证用户名和密码,验证通过后生成 JWT 令牌并返回用户信息。 + + Args: + req: 登录请求体,包含用户名和密码。 + db: 异步数据库会话。 + + Returns: + TokenResponse: 包含访问令牌和用户信息的响应。 + + Raises: + HTTPException: 用户名或密码错误、账户被禁用时抛出异常。 + """ result = await db.execute(select(User).where(User.username == req.username)) user = result.scalar_one_or_none() if not user or not bcrypt.checkpw(req.password.encode('utf-8'), user.password_hash.encode('utf-8')): @@ -65,8 +115,9 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): if user.status != "active": raise HTTPException(403, "账户已被禁用") - roles = await get_user_roles(db, user.id) + roles = await get_user_roles(db, user.id) # 获取用户角色信息 + # 生成 JWT 令牌 expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES) token = jwt.encode( {"sub": str(user.id), "username": user.username, "exp": expire}, @@ -88,12 +139,24 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): @router.get("/me", response_model=UserOut) async def get_me(request: Request, db: AsyncSession = Depends(get_db)): - user_ctx = request.state.user + """获取当前登录用户的详细信息。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + UserOut: 当前用户信息,包含角色列表。 + + Raises: + HTTPException: 用户不存在时抛出异常。 + """ + user_ctx = request.state.user # 从请求状态中获取当前用户上下文 result = await db.execute(select(User).where(User.id == user_ctx["id"])) user = result.scalar_one_or_none() if not user: raise HTTPException(404, "用户不存在") - roles = await get_user_roles(db, user.id) + roles = await get_user_roles(db, user.id) # 获取用户角色信息 return UserOut( id=user.id, username=user.username, display_name=user.display_name, email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id, @@ -105,17 +168,29 @@ async def get_me(request: Request, db: AsyncSession = Depends(get_db)): @router.get("/wecom/oauth-url") async def get_wecom_oauth_url(request: Request): + """获取企业微信 OAuth 授权 URL。 + + 生成用于企业微信网页授权登录的 URL,包含随机 state 参数用于防 CSRF 攻击。 + + Args: + request: HTTP 请求对象,用于获取基础 URL 构建回调地址。 + + Returns: + dict: 包含 OAuth 授权 URL 和 state 参数的响应数据。 + """ corp_id = settings.WECOM_CORP_ID or "" if not corp_id: return {"code": 400, "message": "请先配置 WECOM_CORP_ID"} base_url = str(request.base_url).rstrip("/") - redirect_uri = f"{base_url}/api/auth/wecom/callback" - state = secrets.token_urlsafe(32) + redirect_uri = f"{base_url}/api/auth/wecom/callback" # OAuth 回调地址 + state = secrets.token_urlsafe(32) # 生成随机 state 用于防 CSRF import time - _oauth_states[state] = time.time() + _oauth_states[state] = time.time() # 存储 state 及其创建时间 + # 清理过期的 state expired = [k for k, v in _oauth_states.items() if time.time() - v > _OAUTH_STATE_TTL] for k in expired: del _oauth_states[k] + # 拼接企业微信 OAuth 授权 URL url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state={state}#wechat_redirect" return {"code": 200, "data": {"url": url, "state": state}} @@ -126,6 +201,21 @@ async def update_me( payload: dict, db: AsyncSession = Depends(get_db), ): + """更新当前用户的个人信息。 + + 支持修改显示名称、邮箱和手机号。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + payload: 更新字段字典,可包含 display_name、email、phone。 + db: 异步数据库会话。 + + Returns: + UserOut: 更新后的用户信息。 + + Raises: + HTTPException: 用户不存在时抛出异常。 + """ user_ctx = request.state.user result = await db.execute(select(User).where(User.id == user_ctx["id"])) user = result.scalar_one_or_none() @@ -156,6 +246,21 @@ async def change_password( payload: dict, db: AsyncSession = Depends(get_db), ): + """修改当前用户的登录密码。 + + 需要验证旧密码正确性,新密码至少 6 位。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + payload: 包含 old_password 和 new_password 的字典。 + db: 异步数据库会话。 + + Returns: + dict: 修改成功的响应数据。 + + Raises: + HTTPException: 用户不存在、旧密码错误或新密码长度不足时抛出异常。 + """ user_ctx = request.state.user result = await db.execute(select(User).where(User.id == user_ctx["id"])) user = result.scalar_one_or_none() @@ -169,6 +274,6 @@ async def change_password( if len(new_pw) < 6: raise HTTPException(400, "新密码至少6位") - user.password_hash = hash_password(new_pw) + user.password_hash = hash_password(new_pw) # 更新为新密码哈希 await db.commit() - return {"code": 200, "message": "密码已修改"} \ No newline at end of file + return {"code": 200, "message": "密码已修改"} diff --git a/backend/modules/chat/__init__.py b/backend/modules/chat/__init__.py index e69de29..6ac504a 100644 --- a/backend/modules/chat/__init__.py +++ b/backend/modules/chat/__init__.py @@ -0,0 +1 @@ +"""聊天模块。""" \ No newline at end of file diff --git a/backend/modules/chat/router.py b/backend/modules/chat/router.py index da71d2a..199c06d 100644 --- a/backend/modules/chat/router.py +++ b/backend/modules/chat/router.py @@ -1,3 +1,8 @@ +"""对话模块路由。 + +提供基于流程的聊天功能,支持 WebSocket 实时通信和 HTTP 消息发送。 +可以执行已发布的 AI 流程并将结果返回给客户端。 +""" from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -12,7 +17,15 @@ router = APIRouter(prefix="/api/chat", tags=["chat"]) @router.websocket("/ws") async def chat_websocket(websocket: WebSocket): - user_id = websocket.query_params.get("user_id", "anonymous") + """WebSocket 聊天连接处理器。 + + 接受客户端的 WebSocket 连接并将其注册到 WebSocket 管理器中。 + 持续接收消息并回显给发送者,断开时自动清理连接。 + + Args: + websocket: WebSocket 连接对象。 + """ + user_id = websocket.query_params.get("user_id", "anonymous") # 从查询参数获取用户 ID await ws_manager.connect(websocket, user_id) try: while True: @@ -29,6 +42,23 @@ async def chat_message( payload: dict, db: AsyncSession = Depends(get_db), ): + """向指定的已发布流程发送消息并获取 AI 回复。 + + 加载流程定义后使用 FlowEngine 执行,将用户消息作为输入, + 返回流程执行结果。 + + Args: + flow_id: 流程定义的唯一标识 ID。 + request: HTTP 请求对象,用于获取当前用户信息。 + payload: 请求体,包含 message 或 query 字段作为输入文本。 + db: 异步数据库会话。 + + Returns: + dict: 包含 AI 回复和节点执行结果的响应数据。 + + Raises: + HTTPException: 流程不存在、未发布或执行失败时抛出异常。 + """ try: import uuid as _uuid fid = _uuid.UUID(flow_id) @@ -40,7 +70,8 @@ async def chat_message( if not flow or flow.status != "published": raise HTTPException(404, "流不存在或未发布") - definition = flow.definition_json + definition = flow.definition_json # 流程定义 JSON + # 如果有已发布版本,优先使用版本的定义 published_version_id = getattr(flow, 'published_version_id', None) if published_version_id: ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id)) @@ -50,7 +81,7 @@ async def chat_message( definition = json.loads(json.dumps(published.definition_json)) user_ctx = request.state.user - input_text = payload.get("message", payload.get("query", "")) + input_text = payload.get("message", payload.get("query", "")) # 用户输入文本 if not input_text: raise HTTPException(400, "请输入消息内容") @@ -59,9 +90,9 @@ async def chat_message( context = { "user_id": user_ctx.get("id", "web_user"), "username": user_ctx.get("username", "网页访客"), - "trigger_data": {"channel": "web_chat"}, - "_node_results": {}, - "session_id": payload.get("session_id", str(uuid.uuid4())), + "trigger_data": {"channel": "web_chat"}, # 触发渠道为网页聊天 + "_node_results": {}, # 存储各节点的执行结果 + "session_id": payload.get("session_id", str(_uuid.uuid4())), } try: @@ -81,6 +112,15 @@ async def chat_message( @router.get("/flows") async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): + """列出所有已发布的流程,供聊天界面选择使用。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含已发布流程列表的响应数据。 + """ result = await db.execute( select(FlowDefinition).where(FlowDefinition.status == "published") ) @@ -97,4 +137,4 @@ async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): } for f in flows ], - } \ No newline at end of file + } diff --git a/backend/modules/custom_tool/__init__.py b/backend/modules/custom_tool/__init__.py index 43a7103..7f5cc26 100644 --- a/backend/modules/custom_tool/__init__.py +++ b/backend/modules/custom_tool/__init__.py @@ -1,3 +1,8 @@ +"""自定义工具模块。 + +提供自定义工具的创建、管理、导入和执行功能。 +支持从 OpenAPI 规范自动导入工具定义,以及手动创建自定义 HTTP 工具。 +""" from .router import router __all__ = ["router"] \ No newline at end of file diff --git a/backend/modules/custom_tool/executor.py b/backend/modules/custom_tool/executor.py index 26c036a..49b5977 100644 --- a/backend/modules/custom_tool/executor.py +++ b/backend/modules/custom_tool/executor.py @@ -1,33 +1,69 @@ +"""自定义工具执行器。 + +提供执行自定义 HTTP 工具的功能,支持多种认证方式(API Key、Bearer Token)和 HTTP 方法。 +""" import httpx import json + class CustomToolExecutor: + """自定义工具执行器类。 + + 根据工具定义(端点 URL、HTTP 方法、认证配置等)执行 HTTP 请求。 + + Attributes: + endpoint_url: API 端点的基础 URL。 + method: HTTP 请求方法。 + path: API 路径。 + headers: 请求头字典。 + auth_type: 认证类型(none/api_key/bearer)。 + auth_config: 认证配置信息。 + timeout: 请求超时时间(秒)。 + """ + def __init__(self, tool_def: dict): - self.endpoint_url = tool_def.get("endpoint_url", "") - self.method = tool_def.get("method", "GET") - self.path = tool_def.get("path", "") - self.headers = dict(tool_def.get("headers_json", {})) - self.auth_type = tool_def.get("auth_type", "none") - self.auth_config = dict(tool_def.get("auth_config", {})) - self.timeout = int(tool_def.get("timeout", 30)) + """初始化工具执行器。 + + Args: + tool_def: 工具定义字典,包含 endpoint_url、method、path、headers_json、auth_type、auth_config、timeout 等字段。 + """ + self.endpoint_url = tool_def.get("endpoint_url", "") # API 基础 URL + self.method = tool_def.get("method", "GET") # HTTP 请求方法,默认为 GET + self.path = tool_def.get("path", "") # API 路径 + self.headers = dict(tool_def.get("headers_json", {})) # 请求头字典 + self.auth_type = tool_def.get("auth_type", "none") # 认证类型,默认为无认证 + self.auth_config = dict(tool_def.get("auth_config", {})) # 认证配置信息 + self.timeout = int(tool_def.get("timeout", 30)) # 请求超时时间(秒),默认 30 秒 async def execute(self, params: dict) -> str: + """执行自定义工具请求。 + + 根据工具定义构造完整的 URL,应用认证信息,发送 HTTP 请求并返回响应结果。 + + Args: + params: 请求参数,GET 请求作为查询参数,其他方法作为 JSON 请求体。 + + Returns: + str: 响应内容的字符串表示,最大长度 4000 字符。优先返回 JSON 格式,否则返回纯文本。 + """ + # 构造完整 URL(确保基础 URL 和路径之间只有一个斜杠) url = f"{self.endpoint_url.rstrip('/')}/{self.path.lstrip('/')}" - headers = dict(self.headers) - req_params = dict(params) + headers = dict(self.headers) # 复制请求头 + req_params = dict(params) # 复制请求参数 + # 根据认证类型添加认证信息 if self.auth_type == "api_key": - key = self.auth_config.get("key", "") - loc = self.auth_config.get("location", "header") - name = self.auth_config.get("name", "X-API-Key") + key = self.auth_config.get("key", "") # API Key + loc = self.auth_config.get("location", "header") # 认证位置(header/query) + name = self.auth_config.get("name", "X-API-Key") # 认证参数名 if loc == "header": - headers[name] = key + headers[name] = key # 添加到请求头 else: - req_params[name] = key + req_params[name] = key # 添加到查询参数 elif self.auth_type == "bearer": - headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" + headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" # Bearer Token 认证 - timeout = httpx.Timeout(self.timeout) + timeout = httpx.Timeout(self.timeout) # 创建超时配置 async with httpx.AsyncClient(timeout=timeout) as client: if self.method == "GET": resp = await client.get(url, params=req_params, headers=headers) @@ -36,8 +72,9 @@ class CustomToolExecutor: self.method, url, json=req_params, headers=headers ) + # 尝试解析 JSON 响应,否则返回纯文本 try: data = resp.json() - return json.dumps(data, ensure_ascii=False, indent=2)[:4000] + return json.dumps(data, ensure_ascii=False, indent=2)[:4000] # 格式化 JSON,限制最大长度 except Exception: - return resp.text[:4000] \ No newline at end of file + return resp.text[:4000] # 返回纯文本响应,限制最大长度 \ No newline at end of file diff --git a/backend/modules/custom_tool/parser.py b/backend/modules/custom_tool/parser.py index 4c99c9c..f0750fc 100644 --- a/backend/modules/custom_tool/parser.py +++ b/backend/modules/custom_tool/parser.py @@ -1,21 +1,48 @@ +"""OpenAPI 规范解析器。 + +提供从 OpenAPI/Swagger 规范文档中自动解析 API 端点并转换为自定义工具定义的功能。 +""" import json from typing import Any + class OpenAPIParser: + """OpenAPI 规范解析器类。 + + 解析 OpenAPI 3.0 规范文档,提取其中的 API 端点信息并转换为自定义工具定义。 + + Attributes: + spec: OpenAPI 规范文档的字典表示。 + base_url: API 服务的基础 URL。 + """ + def __init__(self, spec: dict): - self.spec = spec - self.base_url = "" + """初始化 OpenAPI 解析器。 + + Args: + spec: OpenAPI 规范文档的字典表示,包含 servers、paths 等字段。 + """ + self.spec = spec # OpenAPI 规范文档内容 + self.base_url = "" # API 基础 URL servers = spec.get("servers", [{}]) if servers and isinstance(servers, list): - self.base_url = servers[0].get("url", "") + self.base_url = servers[0].get("url", "") # 获取第一个服务器 URL 作为基础地址 def parse_tools(self) -> list[dict]: + """解析 OpenAPI 规范中的所有 API 端点。 + + 遍历 paths 中的所有 HTTP 方法,将每个端点转换为工具定义。 + + Returns: + list[dict]: 工具定义列表,每个工具包含 name、description、parameters、path、method 等信息。 + """ tools = [] - paths = self.spec.get("paths", {}) + paths = self.spec.get("paths", {}) # 获取所有 API 路径 for path, methods in paths.items(): if not isinstance(methods, dict): continue for method, operation in methods.items(): + # 只处理标准的 HTTP 方法 if method in ("get", "post", "put", "delete", "patch") and isinstance(operation, dict): tool = self._parse_endpoint(path, method, operation) if tool: @@ -23,47 +50,71 @@ class OpenAPIParser: return tools def _parse_endpoint(self, path: str, method: str, operation: dict) -> dict | None: + """解析单个 API 端点的详细信息。 + + Args: + path: API 路径,如 "/users/{id}"。 + method: HTTP 方法,如 "get"、"post" 等。 + operation: 端点的操作定义,包含 operationId、summary、parameters 等。 + + Returns: + dict | None: 工具定义字典,包含名称、描述、参数等信息;如果解析失败返回 None。 + """ + # 生成工具名称:优先使用 operationId,否则从路径生成 op_id = operation.get("operationId", "") if not op_id: op_id = f"{method}_{path.replace('/', '_').strip('_')}" + # 生成工具描述:优先使用 summary,其次 description,最后使用方法和路径 description = operation.get("summary") or operation.get("description") or f"{method.upper()} {path}" - properties = self._parse_parameters(operation) + properties = self._parse_parameters(operation) # 解析参数 required = [] for param in operation.get("parameters", []): if isinstance(param, dict) and param.get("required"): - required.append(param["name"]) + required.append(param["name"]) # 收集必填参数名 return { - "name": op_id, - "description": description, - "parameters": { + "name": op_id, # 工具名称 + "description": description, # 工具描述 + "parameters": { # 参数 Schema "type": "object", "properties": properties, "required": required, }, - "path": path, - "method": method.upper(), + "path": path, # API 路径 + "method": method.upper(), # HTTP 方法(大写) } def _parse_parameters(self, operation: dict) -> dict[str, Any]: + """解析 API 端点的参数定义。 + + 包括查询参数、路径参数、请求头参数和请求体参数。 + + Args: + operation: 端点的操作定义。 + + Returns: + dict[str, Any]: 参数属性字典,键为参数名,值为参数类型和描述。 + """ props = {} + # 解析 query/path/header 参数 for param in operation.get("parameters", []): if not isinstance(param, dict): continue pname = param.get("name", "") if not pname: continue - schema = param.get("schema", {}) + schema = param.get("schema", {}) # 参数的 Schema 定义 if not isinstance(schema, dict): schema = {} props[pname] = { - "type": schema.get("type", "string"), - "description": param.get("description", ""), + "type": schema.get("type", "string"), # 参数类型,默认为 string + "description": param.get("description", ""), # 参数描述 } - if "enum" in schema: + if "enum" in schema: # 如果有限定值列表 props[pname]["enum"] = schema["enum"] + # 解析请求体(requestBody)中的 JSON Schema 属性 body = ( operation.get("requestBody", {}) .get("content", {}) diff --git a/backend/modules/custom_tool/router.py b/backend/modules/custom_tool/router.py index 2873a80..d298e18 100644 --- a/backend/modules/custom_tool/router.py +++ b/backend/modules/custom_tool/router.py @@ -1,3 +1,8 @@ +"""自定义工具模块路由。 + +提供自定义工具的创建、管理、导入和执行功能。 +支持从 OpenAPI 规范自动导入工具定义,以及手动创建自定义 HTTP 工具。 +""" import uuid import httpx from fastapi import APIRouter, Depends, HTTPException, Request @@ -12,19 +17,34 @@ from modules.flow_engine.engine import ToolNodeAgent from dependencies import get_current_user import logging -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"]) @router.post("/import-openapi") async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)): + """从 OpenAPI 规范 URL 导入工具定义。 + + 自动下载 OpenAPI 文档,解析其中的 API 端点并创建对应的自定义工具。 + + Args: + req: OpenAPI 导入请求体,包含 openapi_url 和可选的 base_url_override。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含导入成功工具列表的响应数据。 + + Raises: + HTTPException: 获取 OpenAPI 文档失败或解析不到工具时抛出异常。 + """ user_ctx = request.state.user try: async with httpx.AsyncClient(timeout=30) as client: resp = await client.get(req.openapi_url) resp.raise_for_status() - spec = resp.json() + spec = resp.json() # 解析 OpenAPI 规范 JSON except httpx.HTTPError as e: raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}") except ValueError: @@ -35,7 +55,7 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS if not tools: raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具") - base_url = req.base_url_override or parser.base_url + base_url = req.base_url_override or parser.base_url # 优先使用用户指定的基础 URL if not base_url: raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override") @@ -45,7 +65,7 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS select(CustomTool).where(CustomTool.name == t["name"]) ) if existing.scalar_one_or_none(): - continue + continue # 跳过已存在的同名工具 tool = CustomTool( name=t["name"], @@ -79,6 +99,19 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS @router.post("/", response_model=CustomToolOut) async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)): + """创建新的自定义工具,支持手动创建或从 OpenAPI 导入。 + + Args: + req: 自定义工具创建请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + CustomToolOut: 创建后的自定义工具响应。 + + Raises: + HTTPException: 获取 OpenAPI 文档失败或创建失败时抛出异常。 + """ user_ctx = request.state.user user_id = uuid.UUID(user_ctx["id"]) @@ -121,6 +154,7 @@ async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncS await db.flush() return created_tool + # 手动创建模式 schema_json = req.tool_schema or {} if not schema_json and req.endpoint_url: schema_json = { @@ -161,6 +195,14 @@ async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncS @router.get("/", response_model=list[CustomToolOut]) async def list_custom_tools(db: AsyncSession = Depends(get_db)): + """列出所有处于活跃状态的自定义工具。 + + Args: + db: 异步数据库会话。 + + Returns: + list[CustomToolOut]: 活跃自定义工具列表。 + """ result = await db.execute( select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc()) ) @@ -169,6 +211,18 @@ async def list_custom_tools(db: AsyncSession = Depends(get_db)): @router.get("/{tool_id}", response_model=CustomToolOut) async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + """获取指定自定义工具的详细信息。 + + Args: + tool_id: 自定义工具唯一标识 ID。 + db: 异步数据库会话。 + + Returns: + CustomToolOut: 自定义工具详细信息。 + + Raises: + HTTPException: 工具不存在时抛出异常。 + """ tool = await db.get(CustomTool, tool_id) if not tool: raise HTTPException(404, "工具不存在") @@ -177,6 +231,19 @@ async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db) @router.put("/{tool_id}", response_model=CustomToolOut) async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)): + """更新自定义工具的配置信息。 + + Args: + tool_id: 自定义工具唯一标识 ID。 + req: 自定义工具更新请求体。 + db: 异步数据库会话。 + + Returns: + CustomToolOut: 更新后的自定义工具响应。 + + Raises: + HTTPException: 工具不存在时抛出异常。 + """ tool = await db.get(CustomTool, tool_id) if not tool: raise HTTPException(404, "工具不存在") @@ -206,16 +273,45 @@ async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: Asyn @router.delete("/{tool_id}") async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): + """删除(停用)自定义工具。 + + 采用软删除方式,将工具标记为非活跃状态而非真正删除。 + + Args: + tool_id: 自定义工具唯一标识 ID。 + db: 异步数据库会话。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 工具不存在时抛出异常。 + """ tool = await db.get(CustomTool, tool_id) if not tool: raise HTTPException(404, "工具不存在") - tool.is_active = False + tool.is_active = False # 软删除:标记为非活跃 await db.flush() return {"code": 200, "message": "工具已停用"} @router.post("/{tool_id}/test") async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)): + """测试执行自定义工具。 + + 使用自定义工具的配置信息创建执行器并执行,返回执行结果。 + + Args: + tool_id: 自定义工具唯一标识 ID。 + params: 测试参数,传递给工具执行的参数体。 + db: 异步数据库会话。 + + Returns: + dict: 包含执行结果的响应数据。 + + Raises: + HTTPException: 工具不存在或执行失败时抛出异常。 + """ tool = await db.get(CustomTool, tool_id) if not tool: raise HTTPException(404, "工具不存在") @@ -239,6 +335,14 @@ async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSes @router.get("/schemas/all") async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)): + """获取所有活跃自定义工具的参数 Schema。 + + Args: + db: 异步数据库会话。 + + Returns: + dict: 包含所有工具 Schema 的响应数据,格式为 {工具名: schema}。 + """ result = await db.execute( select(CustomTool).where(CustomTool.is_active == True) ) @@ -246,4 +350,4 @@ async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)): schemas = {} for t in tools: schemas[t.name] = t.schema_json - return {"code": 200, "data": schemas} \ No newline at end of file + return {"code": 200, "data": schemas} diff --git a/backend/modules/document/router.py b/backend/modules/document/router.py index 14137d5..24867ea 100644 --- a/backend/modules/document/router.py +++ b/backend/modules/document/router.py @@ -1,3 +1,8 @@ +"""文档处理模块路由。 + +提供文档上传、解析、格式修正和删除功能。 +支持多种文档格式(文本、PDF、Word、Excel 等)的处理。 +""" import os import uuid import shutil @@ -20,7 +25,22 @@ async def upload_document( request: Request = None, user: dict = Depends(get_current_user), ): - max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024 + """上传文档文件到服务器。 + + 检查文件大小限制后保存到上传目录。 + + Args: + file: 上传的文件对象。 + request: HTTP 请求对象。 + user: 当前登录用户信息。 + + Returns: + DocumentUploadOut: 包含文件 ID、文件名、大小等信息的响应。 + + Raises: + HTTPException: 文件大小超过限制时抛出异常。 + """ + max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024 # 最大允许上传大小(字节) content = await file.read() if len(content) > max_size: raise HTTPException(400, f"文件大小超过限制 ({settings.MAX_UPLOAD_SIZE_MB}MB)") @@ -28,8 +48,8 @@ async def upload_document( file_id = uuid.uuid4() os.makedirs(settings.UPLOAD_DIR, exist_ok=True) - ext = os.path.splitext(file.filename or "unknown")[1] - stored_name = f"{file_id}{ext}" + ext = os.path.splitext(file.filename or "unknown")[1] # 获取文件扩展名 + stored_name = f"{file_id}{ext}" # 使用 UUID 作为存储文件名 file_path = os.path.join(settings.UPLOAD_DIR, stored_name) with open(file_path, "wb") as f: @@ -51,6 +71,22 @@ async def parse_document( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """解析已上传的文档文件,提取文本内容。 + + 根据文件扩展名选择合适的解析方式,支持纯文本、PDF、Word、Excel 等格式。 + + Args: + file_id: 文件唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + DocumentParseResult: 包含文件内容和元数据的解析结果。 + + Raises: + HTTPException: 文件不存在时抛出异常。 + """ ext_map = {".txt", ".md", ".py", ".js", ".ts", ".json", ".xml", ".yaml", ".yml", ".csv", ".html", ".css", ".java", ".go", ".rs"} os.makedirs(settings.UPLOAD_DIR, exist_ok=True) @@ -65,7 +101,7 @@ async def parse_document( if not found_file: raise HTTPException(404, "文件不存在") - ext = os.path.splitext(found_filename)[1].lower() + ext = os.path.splitext(found_filename)[1].lower() # 获取文件扩展名 content = "" metadata = {"file_size": os.path.getsize(found_file), "extension": ext} @@ -87,6 +123,7 @@ async def parse_document( content = f"[不支持的文件类型 .{ext}] 文件: {found_filename}" metadata["type"] = "unsupported" + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="document.parse", @@ -113,6 +150,20 @@ async def delete_document( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """删除已上传的文档文件。 + + Args: + file_id: 文件唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 文件不存在时抛出异常。 + """ os.makedirs(settings.UPLOAD_DIR, exist_ok=True) deleted = False @@ -125,6 +176,7 @@ async def delete_document( if not deleted: raise HTTPException(404, "文件不存在") + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="document.delete", @@ -145,11 +197,25 @@ async def format_document( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """对文档内容进行格式修正。 + + 支持 standard、markdown、json 三种格式类型。 + + Args: + payload: 请求体,包含 content 和 format_type 字段。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含格式化后内容的响应数据。 + """ content = payload.get("content", "") format_type = payload.get("format_type", "standard") result = _apply_formatting(content, format_type) + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="document.format", @@ -165,6 +231,15 @@ async def format_document( def _apply_formatting(content: str, format_type: str) -> str: + """应用指定的格式规则对文本内容进行格式化。 + + Args: + content: 待格式化的原始文本内容。 + format_type: 格式类型,支持 standard、markdown、json。 + + Returns: + str: 格式化后的文本内容。 + """ lines = content.splitlines() result = [] @@ -196,4 +271,4 @@ def _apply_formatting(content: str, format_type: str) -> str: except json.JSONDecodeError: return json.dumps({"content": content, "lines": len(lines)}, ensure_ascii=False, indent=2) - return content \ No newline at end of file + return content diff --git a/backend/modules/flow_engine/engine.py b/backend/modules/flow_engine/engine.py index bda58d7..377498a 100644 --- a/backend/modules/flow_engine/engine.py +++ b/backend/modules/flow_engine/engine.py @@ -1,3 +1,10 @@ +"""流引擎核心模块。 + +定义 FlowEngine 流程执行引擎及各类节点 Agent,包括: +- FlowEngine:流程图的解析与遍历执行器 +- LLMNodeAgent / ToolNodeAgent / MCPNodeAgent 等各类节点处理器 +""" + import json import uuid import logging @@ -14,6 +21,7 @@ logger = logging.getLogger(__name__) async def _resolve_model_instance(model_instance_id: str) -> dict | None: + """根据模型实例 ID 从数据库解析模型配置(模型名、base_url、api_key)。""" try: from database import AsyncSessionLocal from sqlalchemy import text @@ -43,24 +51,31 @@ async def _resolve_model_instance(model_instance_id: str) -> dict | None: class FlowSessionMemory: + """流程会话级短期记忆,存储当前对话轮次的消息列表。""" + def __init__(self, session_id: str = "", user_id: str = ""): self.session_id = session_id self.user_id = user_id self._messages: list[dict] = [] def get_history(self, limit: int = 10) -> list[dict]: + """获取最近的消息历史。""" return self._messages[-limit * 2:] def add(self, role: str, content: str): + """添加一条消息到历史。""" self._messages.append({"role": role, "content": content}) def to_list(self) -> list[dict]: + """返回全部消息列表。""" return list(self._messages) class FlowEngine: - MAX_TOTAL_ITERATIONS = 200 - FLOW_TIMEOUT_SECONDS = 300 + """流程执行引擎,解析流程图定义并按拓扑顺序遍历执行各节点。""" + + MAX_TOTAL_ITERATIONS = 200 # 全局最大迭代次数(防止死循环) + FLOW_TIMEOUT_SECONDS = 300 # 单次流程执行超时时间 def __init__(self, flow_definition: dict): self.definition = flow_definition diff --git a/backend/modules/flow_engine/gateway.py b/backend/modules/flow_engine/gateway.py index a895407..b489886 100644 --- a/backend/modules/flow_engine/gateway.py +++ b/backend/modules/flow_engine/gateway.py @@ -1,3 +1,11 @@ +"""流引擎网关路由。 + +提供符合 OpenAI Dify 兼容格式的外部 API 网关,包括: +- /v1/chat-messages: 对话型流程执行(支持 blocking / streaming) +- /v1/workflows/run: 工作流型流程执行 +- /v1/flows/{flow_id}/parameters: 查询流程输入参数 +""" + import uuid import time import json diff --git a/backend/modules/flow_engine/router.py b/backend/modules/flow_engine/router.py index 1ea6867..468d1b1 100644 --- a/backend/modules/flow_engine/router.py +++ b/backend/modules/flow_engine/router.py @@ -1,3 +1,9 @@ +"""流程定义管理路由。 + +提供流程的 CRUD、发布/下架、版本管理、执行、SSE 流式执行、 +API Key 管理、执行历史查询以及市场模板等功能。 +""" + import uuid import time import json diff --git a/backend/modules/mcp_registry/router.py b/backend/modules/mcp_registry/router.py index 75b548f..d49c868 100644 --- a/backend/modules/mcp_registry/router.py +++ b/backend/modules/mcp_registry/router.py @@ -1,3 +1,8 @@ +"""MCP 服务注册模块路由。 + +提供 Model Context Protocol (MCP) 服务的注册、管理、测试和工具发现功能。 +支持 HTTP 传输方式的 MCP 服务接入。 +""" import uuid import httpx from fastapi import APIRouter, Depends, HTTPException, Request @@ -13,6 +18,16 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"]) @router.get("/servers", response_model=list[MCPServiceOut]) async def list_servers(request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + """列出所有已注册的 MCP 服务。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + list[MCPServiceOut]: MCP 服务列表。 + """ result = await db.execute( select(MCPService).order_by(MCPService.updated_at.desc()) ) @@ -21,6 +36,20 @@ async def list_servers(request: Request, db: AsyncSession = Depends(get_db), use @router.get("/servers/{server_id}", response_model=MCPServiceOut) async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + """获取指定 MCP 服务的详细信息。 + + Args: + server_id: MCP 服务唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + MCPServiceOut: MCP 服务详细信息。 + + Raises: + HTTPException: 服务不存在时抛出异常。 + """ result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: @@ -35,6 +64,20 @@ async def register_server( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """注册新的 MCP 服务。 + + Args: + req: MCP 服务创建请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + MCPServiceOut: 注册后的 MCP 服务响应。 + + Raises: + HTTPException: 服务名称已存在时抛出异常。 + """ existing = await db.execute(select(MCPService).where(MCPService.name == req.name)) if existing.scalar_one_or_none(): raise HTTPException(400, "服务名称已存在") @@ -50,6 +93,7 @@ async def register_server( ) db.add(server) + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="mcp.register", @@ -70,6 +114,21 @@ async def update_server( request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """更新 MCP 服务的配置信息。 + + Args: + server_id: MCP 服务唯一标识 ID。 + req: MCP 服务更新请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + MCPServiceOut: 更新后的 MCP 服务响应。 + + Raises: + HTTPException: 服务不存在时抛出异常。 + """ result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: @@ -86,6 +145,7 @@ async def update_server( if req.env is not None: server.env = req.env + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="mcp.update", @@ -104,12 +164,27 @@ async def delete_server( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """注销指定的 MCP 服务。 + + Args: + server_id: MCP 服务唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 服务不存在时抛出异常。 + """ result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: raise HTTPException(404, "MCP服务不存在") await db.delete(server) + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="mcp.delete", @@ -129,6 +204,22 @@ async def test_connection( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """测试 MCP 服务的连接状态并发现可用工具。 + + 对于 HTTP 传输的服务,访问 /.well-known/mcp 端点进行连接测试。 + + Args: + server_id: MCP 服务唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含连接测试结果和工具列表的响应数据。 + + Raises: + HTTPException: 服务不存在时抛出异常。 + """ result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: @@ -141,7 +232,7 @@ async def test_connection( async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get(server.url.rstrip("/") + "/.well-known/mcp") if resp.status_code == 200: - test_results["connectivity"] = True + test_results["connectivity"] = True # 连接成功 data = resp.json() tools = data.get("tools", []) test_results["tools_discovered"] = len(tools) @@ -155,6 +246,7 @@ async def test_connection( test_results["error"] = str(e) server.status = "error" + # 记录审计日志 audit = AuditLog( operator_id=uuid.UUID(user["id"]), action="mcp.test", @@ -174,6 +266,21 @@ async def discover_tools( server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), ): + """从 MCP 服务中发现并注册可用工具。 + + 对于 HTTP 传输的服务,调用 /tools/list 端点获取工具列表。 + + Args: + server_id: MCP 服务唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含发现的工具列表和数量的响应数据。 + + Raises: + HTTPException: 服务不存在或工具发现失败时抛出异常。 + """ result = await db.execute(select(MCPService).where(MCPService.id == server_id)) server = result.scalar_one_or_none() if not server: @@ -193,4 +300,4 @@ async def discover_tools( raise HTTPException(500, f"工具发现失败: {str(e)}") await db.flush() - return {"code": 200, "data": {"tools": server.tools, "count": len(server.tools)}} \ No newline at end of file + return {"code": 200, "data": {"tools": server.tools, "count": len(server.tools)}} diff --git a/backend/modules/memory/manager.py b/backend/modules/memory/manager.py index bb24958..71c7715 100644 --- a/backend/modules/memory/manager.py +++ b/backend/modules/memory/manager.py @@ -1,3 +1,15 @@ +"""三级记忆系统管理器。 + +记忆架构分为三个层次(L1/L2/L3): +- L1(原子层):从对话中提取关键信息原子(用户偏好、事件、指令) +- L2(场景层):对同类原子进行归纳,形成场景摘要 +- L3(画像层):综合所有信息,生成用户画像(Persona) + +数据存储: +- PG(主存储):持久化记忆消息、原子、场景、画像 +- Redis(缓存):近期消息缓存、对话摘要缓存 +""" + import json import asyncio import uuid @@ -16,6 +28,7 @@ _memory_manager: "MemoryManager | None" = None def get_memory_manager() -> "MemoryManager": + """获取全局 MemoryManager 单例。""" global _memory_manager if _memory_manager is None: raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()") @@ -23,6 +36,7 @@ def get_memory_manager() -> "MemoryManager": async def init_memory_manager(db_factory: Callable[[], AsyncSession]): + """初始化记忆管理器,创建 Redis 连接并实例化 MemoryManager。""" global _memory_manager redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) await redis.ping() @@ -30,19 +44,27 @@ async def init_memory_manager(db_factory: Callable[[], AsyncSession]): class MemoryManager: - MAX_HISTORY = 40 - REDIS_CACHE_SIZE = 10 - REDIS_CACHE_TTL = 300 - SUMMARY_CACHE_KEY = "mem:summary" - MSG_CACHE_KEY = "mem:cache:msgs" - ATOM_EXTRACT_EVERY = 10 - SCENE_EXTRACT_EVERY = 50 - PERSONA_UPDATE_EVERY = 30 + """三级记忆管理器,负责记忆的存储、检索、提取与归纳。""" + + MAX_HISTORY = 40 # 单次注入的最大历史消息数 + REDIS_CACHE_SIZE = 10 # Redis 缓存保留的最近消息数 + REDIS_CACHE_TTL = 300 # Redis 缓存 TTL(秒) + SUMMARY_CACHE_KEY = "mem:summary" # 摘要缓存 Redis key 前缀 + MSG_CACHE_KEY = "mem:cache:msgs" # 消息缓存 Redis key 前缀 + ATOM_EXTRACT_EVERY = 10 # 每 N 条消息触发一次 L1 原子提取 + SCENE_EXTRACT_EVERY = 50 # 每 N 条原子触发一次 L2 场景提取 + PERSONA_UPDATE_EVERY = 30 # 每 N 条消息触发一次 L3 画像更新 def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis): + """初始化记忆管理器。 + + Args: + db_factory: 异步数据库会话工厂 + redis: Redis 异步客户端实例 + """ self.db_factory = db_factory self.redis = redis - self._extract_tasks: dict[str, asyncio.Task] = {} + self._extract_tasks: dict[str, asyncio.Task] = {} # 后台提取任务追踪 async def inject_memory( self, @@ -51,6 +73,17 @@ class MemoryManager: session_id: str, context: dict, ): + """向对话上下文中注入三层记忆信息。 + + 从 PG/Redis 中获取近期消息、摘要、原子记忆和画像, + 合并后注入到 context["_memory_context"] 中供 LLM 使用。 + + Args: + user_id: 用户 ID + flow_id: 流程 ID + session_id: 会话 ID + context: 对话上下文字典(会在原地被修改) + """ uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) @@ -84,6 +117,19 @@ class MemoryManager: assistant_msg: str, flow_name: str = "", ): + """记录一次用户-助手对话交换。 + + 将用户消息和助手消息写入 PG,同时更新 Redis 缓存, + 并异步触发 L1/L1/L2/L3 记忆提取任务。 + + Args: + user_id: 用户 ID + flow_id: 流程 ID + session_id: 会话 ID + user_msg: 用户消息内容 + assistant_msg: 助手回复内容 + flow_name: 流程名称(可选,用于会话记录) + """ uid = uuid.UUID(user_id) fid = uuid.UUID(flow_id) sid = uuid.UUID(session_id) @@ -167,12 +213,24 @@ class MemoryManager: async def get_conversation_history( self, user_id: str, flow_id: str, session_id: str, limit: int = 20 ) -> list[dict]: + """获取指定会话的对话历史。 + + Args: + user_id: 用户 ID + flow_id: 流程 ID + session_id: 会话 ID + limit: 返回的最大消息数 + + Returns: + 消息列表,每项含 role/content/ts 字段 + """ uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) fid = uuid.UUID(flow_id) if flow_id else None return await self._pg_get_recent(uid, fid, sid, limit) async def delete_session(self, user_id: str, session_id: str): + """删除指定会话的所有记忆数据(PG + Redis)。""" uid = uuid.UUID(user_id) sid = uuid.UUID(session_id) @@ -199,6 +257,14 @@ class MemoryManager: pass async def list_user_sessions(self, user_id: str) -> list[dict]: + """列出用户的所有记忆会话。 + + Args: + user_id: 用户 ID + + Returns: + 会话列表,每项含 session_id/flow_id/flow_name/last_active_at + """ uid = uuid.UUID(user_id) try: async with self.db_factory() as db: @@ -226,6 +292,7 @@ class MemoryManager: return [] async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]: + """从 PG 查询最近的对话消息。""" try: async with self.db_factory() as db: if fid: @@ -259,6 +326,7 @@ class MemoryManager: return [] async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None: + """从 Redis 读取缓存的消息列表。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" raw = await self.redis.get(cache_key) @@ -269,6 +337,7 @@ class MemoryManager: return None async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]): + """将消息列表写入 Redis 缓存。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" top = messages[-self.REDIS_CACHE_SIZE:] @@ -277,6 +346,7 @@ class MemoryManager: pass async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]): + """追加新消息到 Redis 缓存。""" try: cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" existing = await self._redis_get_recent(uid, sid) or [] @@ -287,6 +357,7 @@ class MemoryManager: pass async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str: + """从 Redis 读取对话摘要。""" try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" val = await self.redis.get(summary_key) @@ -295,6 +366,10 @@ class MemoryManager: return "" async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): + """L1 条件触发:对话摘要生成。 + + 当消息数 >= 30 且尚无摘要时,调用 LLM 生成摘要并缓存到 Redis。 + """ try: summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" existing = await self.redis.get(summary_key) @@ -342,6 +417,7 @@ class MemoryManager: pass async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]: + """从 PG 查询与用户/流程相关的高优先级原子记忆。""" try: async with self.db_factory() as db: result = await db.execute( @@ -363,6 +439,7 @@ class MemoryManager: return [] async def _get_persona(self, uid: uuid.UUID) -> dict: + """从 PG 查询用户画像。""" try: async with self.db_factory() as db: result = await db.execute( @@ -377,6 +454,10 @@ class MemoryManager: return {} async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): + """L1 条件触发:原子记忆提取。 + + 当消息数达到 ATOM_EXTRACT_EVERY 的整数倍时,调用 LLM 从对话中提取信息原子。 + """ try: task_key = f"extract_atoms:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): @@ -413,6 +494,7 @@ class MemoryManager: pass async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str): + """执行 L1 原子记忆提取:调用 LLM 从对话中提取结构化记忆原子。""" try: prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。 @@ -460,6 +542,12 @@ class MemoryManager: logger.warning(f"L1原子提取失败: {e}") async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]): + """对提取的原子进行去重并存储到 PG。 + + 使用文本相似度判断是否与已有原子重复: + - 相似度 > 75% 时更新优先级和元数据 + - 否则插入新原子记录 + """ try: async with self.db_factory() as db: existing = await db.execute( @@ -483,7 +571,7 @@ class MemoryManager: await db.execute( text(""" UPDATE memory_atoms - SET priority = GREATEST(priority, :priority), + SET priority = GREATER(priority, :priority), updated_at = NOW(), metadata = metadata || :meta WHERE id = :id @@ -520,6 +608,7 @@ class MemoryManager: @staticmethod def _text_similarity(a: str, b: str) -> float: + """计算两段文本的 Jaccard 相似度(基于单词集合)。""" a_words = set(a.lower().split()) b_words = set(b.lower().split()) if not a_words or not b_words: @@ -529,6 +618,11 @@ class MemoryManager: return len(intersection) / len(union) async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID): + """L2 条件触发:场景提取。 + + 当原子数达到 SCENE_EXTRACT_EVERY 且距上次提取超过 12 小时时, + 调用 LLM 对已有原子进行场景归纳。 + """ try: task_key = f"extract_scenes:{uid}:{fid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): @@ -564,7 +658,7 @@ class MemoryManager: latest_scene = atoms_result.fetchone() if latest_scene: from datetime import timezone, timedelta - ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc) + ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone=timezone.utc) if ago < timedelta(hours=12): return @@ -587,6 +681,7 @@ class MemoryManager: pass async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]): + """执行 L2 场景提取:调用 LLM 将原子记忆归纳为场景块。""" try: atoms_text = "\n".join( f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms @@ -651,6 +746,11 @@ class MemoryManager: logger.warning(f"L2场景提取失败: {e}") async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID): + """L3 条件触发:用户画像更新。 + + 当消息数达到 PERSONA_UPDATE_EVERY 且距上次更新超过 6 小时时, + 基于已有 persona 类型原子重新生成画像。 + """ try: task_key = f"update_persona:{uid}" if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): @@ -699,6 +799,7 @@ class MemoryManager: pass async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int): + """执行 L3 画像更新:调用 LLM 生成结构化用户画像并持久化。""" try: prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。 @@ -780,6 +881,18 @@ class MemoryManager: top_k: int = 10, embedding: list[float] = None, ) -> list[dict]: + """混合检索记忆:向量相似度 + 全文检索,使用 RRF 算法融合排序。 + + Args: + uid: 用户 ID + query: 搜索查询文本 + fid: 流程 ID(可选,过滤范围) + top_k: 返回结果数 + embedding: 查询向量(可选,启用向量检索) + + Returns: + 按 RRF 分数降序排列的记忆原子列表 + """ results = [] try: async with self.db_factory() as db: diff --git a/backend/modules/memory/router.py b/backend/modules/memory/router.py index dc5311a..1f9ff78 100644 --- a/backend/modules/memory/router.py +++ b/backend/modules/memory/router.py @@ -1,3 +1,5 @@ +"""记忆管理路由:会话列表查询、历史获取、会话清除。""" + from fastapi import APIRouter, Request, Depends, HTTPException from dependencies import get_current_user from modules.memory.manager import get_memory_manager @@ -7,6 +9,7 @@ router = APIRouter(prefix="/api/memory", tags=["记忆管理"]) @router.get("/sessions") async def list_sessions(request: Request, user=Depends(get_current_user)): + """获取当前用户的所有记忆会话列表。""" mm = get_memory_manager() sessions = await mm.list_user_sessions(str(user.id)) return {"code": 200, "data": sessions} @@ -14,6 +17,7 @@ async def list_sessions(request: Request, user=Depends(get_current_user)): @router.get("/sessions/{session_id}") async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)): + """获取指定会话的对话历史。""" mm = get_memory_manager() history = await mm.get_conversation_history( user_id=str(user.id), @@ -25,6 +29,7 @@ async def get_session(session_id: str, request: Request, flow_id: str = "", user @router.delete("/sessions/{session_id}") async def clear_session(session_id: str, request: Request, user=Depends(get_current_user)): + """清除指定会话的所有记忆数据。""" mm = get_memory_manager() await mm.delete_session(str(user.id), session_id) return {"code": 200, "message": "记忆已清除"} \ No newline at end of file diff --git a/backend/modules/memory/schemas.py b/backend/modules/memory/schemas.py index 31e66d4..1e28a6d 100644 --- a/backend/modules/memory/schemas.py +++ b/backend/modules/memory/schemas.py @@ -1,8 +1,11 @@ +"""记忆管理模块的 Pydantic 请求/响应模型。""" + from pydantic import BaseModel, ConfigDict from datetime import datetime class MemorySessionOut(BaseModel): + """记忆会话概要响应体。""" session_id: str flow_id: str flow_name: str @@ -10,10 +13,12 @@ class MemorySessionOut(BaseModel): class ConversationMessage(BaseModel): + """单条对话消息。""" role: str content: str ts: str = "" class ClearSessionRequest(BaseModel): + """清除会话请求体。""" session_id: str \ No newline at end of file diff --git a/backend/modules/model_provider/router.py b/backend/modules/model_provider/router.py index b1c768f..3957c0e 100644 --- a/backend/modules/model_provider/router.py +++ b/backend/modules/model_provider/router.py @@ -1,3 +1,8 @@ +"""模型供应商模块路由。 + +提供模型供应商和模型实例的 CRUD 管理功能。 +支持多供应商接入和模型能力的统一管理。 +""" import uuid import logging from fastapi import APIRouter, Depends, HTTPException @@ -7,12 +12,21 @@ from database import get_db from models import ModelProvider, ModelInstance from dependencies import get_current_user -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"]) @router.get("") async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """列出所有已注册的模型供应商。 + + Args: + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含模型供应商列表的响应数据。 + """ result = await db.execute( select(ModelProvider).order_by(ModelProvider.created_at.desc()) ) @@ -34,6 +48,19 @@ async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_cu @router.post("") async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """注册新的模型供应商。 + + Args: + payload: 请求体,包含 name、provider_type、base_url、api_key、extra_config 字段。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含新供应商 ID 的响应数据。 + + Raises: + HTTPException: 相同 base_url 的供应商已存在时抛出异常。 + """ existing = await db.execute( select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", "")) ) @@ -54,6 +81,20 @@ async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), use @router.put("/{provider_id}") async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """更新模型供应商的配置信息。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + payload: 请求体,包含要更新的供应商字段。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含更新后供应商 ID 的响应数据。 + + Raises: + HTTPException: 供应商不存在时抛出异常。 + """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") @@ -70,6 +111,19 @@ async def update_provider(provider_id: str, payload: dict, db: AsyncSession = De @router.delete("/{provider_id}") async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """删除指定的模型供应商。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 供应商不存在时抛出异常。 + """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") @@ -80,6 +134,15 @@ async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), @router.get("/models/all") async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """列出所有处于活跃状态的模型实例。 + + Args: + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含所有活跃模型实例列表的响应数据。 + """ result = await db.execute( select(ModelInstance) .where(ModelInstance.is_active == True) @@ -101,6 +164,16 @@ async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_c @router.get("/{provider_id}/models") async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """列出指定供应商下的所有模型实例。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含模型实例列表的响应数据。 + """ result = await db.execute( select(ModelInstance) .where(ModelInstance.provider_id == uuid.UUID(provider_id)) @@ -127,6 +200,20 @@ async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user @router.post("/{provider_id}/models") async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """在指定供应商下添加新的模型实例。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + payload: 请求体,包含 model_name、model_type、display_name、capabilities、default_params、is_default 字段。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含新模型 ID 的响应数据。 + + Raises: + HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常。 + """ p = await db.get(ModelProvider, uuid.UUID(provider_id)) if not p: raise HTTPException(404, "供应商不存在") @@ -156,6 +243,21 @@ async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depen @router.put("/{provider_id}/models/{model_id}") async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """更新模型实例的配置信息。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + model_id: 模型实例唯一标识 ID。 + payload: 请求体,包含要更新的模型字段。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 包含更新后模型 ID 的响应数据。 + + Raises: + HTTPException: 模型不存在时抛出异常。 + """ m = await db.get(ModelInstance, uuid.UUID(model_id)) if not m or str(m.provider_id) != provider_id: raise HTTPException(404, "模型不存在") @@ -173,9 +275,23 @@ async def update_model(provider_id: str, model_id: str, payload: dict, db: Async @router.delete("/{provider_id}/models/{model_id}") async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): + """删除指定的模型实例。 + + Args: + provider_id: 模型供应商唯一标识 ID。 + model_id: 模型实例唯一标识 ID。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 模型不存在时抛出异常。 + """ m = await db.get(ModelInstance, uuid.UUID(model_id)) if not m or str(m.provider_id) != provider_id: raise HTTPException(404, "模型不存在") await db.delete(m) await db.commit() - return {"code": 200, "message": "已删除"} \ No newline at end of file + return {"code": 200, "message": "已删除"} diff --git a/backend/modules/monitor/__init__.py b/backend/modules/monitor/__init__.py index e69de29..095af8a 100644 --- a/backend/modules/monitor/__init__.py +++ b/backend/modules/monitor/__init__.py @@ -0,0 +1,5 @@ +"""监控模块。 + +提供员工监控功能,包括员工列表查询、个人数据看板、AI 辅助的员工交互分析等。 +支持基于数据权限范围的访问控制。 +""" diff --git a/backend/modules/monitor/router.py b/backend/modules/monitor/router.py index 441a6b9..1a5672e 100644 --- a/backend/modules/monitor/router.py +++ b/backend/modules/monitor/router.py @@ -1,3 +1,8 @@ +"""监控模块路由。 + +提供员工监控功能,包括员工列表查询、个人数据看板、AI 辅助的员工交互分析等。 +支持基于数据权限范围(all/subordinate_only/self_only)的访问控制。 +""" import uuid import json from datetime import datetime @@ -9,25 +14,42 @@ from models import User, ChatSession, ChatMessage from modules.org.router import _get_subordinate_ids, _user_to_out from schemas import EmployeeAnalysis, UserOut -router = APIRouter(prefix="/api/monitor", tags=["monitor"]) +router = APIRouter(prefix="/api/monitor", tags=["monitor"]) # 监控模块路由前缀 @router.get("/employees", response_model=list[UserOut]) async def get_monitor_employees(request: Request, db: AsyncSession = Depends(get_db)): + """获取可监控的员工列表。 + + 根据当前用户的数据权限范围返回不同的员工列表: + - all:返回所有活跃员工 + - subordinate_only:返回当前用户及其所有下级员工 + - self_only:仅返回当前用户 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + list[UserOut]: 员工信息列表。 + """ user_ctx = request.state.user cur_id = uuid.UUID(user_ctx["id"]) if user_ctx["data_scope"] == "all": + # 数据权限为全部:返回所有活跃员工 result = await db.execute(select(User).where(User.status == "active")) return [await _user_to_out(db, u) for u in result.scalars().all()] elif user_ctx["data_scope"] == "subordinate_only": + # 数据权限为下级:返回当前用户及其所有下级 sub_ids = await _get_subordinate_ids(db, cur_id) sub_ids.add(cur_id) result = await db.execute(select(User).where(User.id.in_(sub_ids))) return [await _user_to_out(db, u) for u in result.scalars().all()] else: + # 数据权限为仅自己:仅返回当前用户 result = await db.execute(select(User).where(User.id == cur_id)) return [await _user_to_out(db, u) for u in result.scalars().all()] @@ -36,9 +58,26 @@ async def get_monitor_employees(request: Request, db: AsyncSession = Depends(get async def get_employee_dashboard( emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db) ): + """获取指定员工的个人数据看板。 + + 包括员工基本信息、消息统计、会话统计、活跃天数、消息分类和最近交互记录。 + 需要进行数据权限校验。 + + Args: + emp_id: 员工唯一标识 ID。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + dict: 包含员工信息和统计数据的看板数据。 + + Raises: + HTTPException: 无权查看或员工不存在时抛出异常。 + """ user_ctx = request.state.user cur_id = uuid.UUID(user_ctx["id"]) + # 数据权限校验 if user_ctx["data_scope"] != "all": if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]: raise HTTPException(403, "无权查看此员工数据") @@ -48,21 +87,25 @@ async def get_employee_dashboard( if emp_id not in sub_ids: raise HTTPException(403, "无权查看此员工数据") + # 查询员工信息 emp_result = await db.execute(select(User).where(User.id == emp_id)) emp = emp_result.scalar_one_or_none() if not emp: raise HTTPException(404, "员工不存在") + # 统计总消息数 total_msgs_result = await db.execute( select(func.count(ChatMessage.id)).where(ChatMessage.user_id == emp_id) ) total_messages = total_msgs_result.scalar() or 0 + # 统计总会话数 session_result = await db.execute( select(func.count(ChatSession.id)).where(ChatSession.user_id == emp_id) ) total_sessions = session_result.scalar() or 0 + # 查询最近 50 条消息 recent_msgs_result = await db.execute( select(ChatMessage) .where(ChatMessage.user_id == emp_id) @@ -71,13 +114,14 @@ async def get_employee_dashboard( ) recent = recent_msgs_result.scalars().all() + # 统计话题分布和活跃天数 topics = {} active_days = set() for msg in recent: if msg.created_at: - active_days.add(msg.created_at.strftime("%Y-%m-%d")) + active_days.add(msg.created_at.strftime("%Y-%m-%d")) # 记录活跃日期 role = msg.role - topics[role] = topics.get(role, 0) + 1 + topics[role] = topics.get(role, 0) + 1 # 按角色统计消息数 return { "code": 200, @@ -89,13 +133,13 @@ async def get_employee_dashboard( "position": emp.position or "", }, "stats": { - "total_messages": total_messages, - "total_sessions": total_sessions, - "active_days": len(active_days), - "message_breakdown": topics, + "total_messages": total_messages, # 总消息数 + "total_sessions": total_sessions, # 总会话数 + "active_days": len(active_days), # 活跃天数 + "message_breakdown": topics, # 消息角色分布 "recent_interactions": [ {"role": m.role, "content": m.content[:200], "created_at": str(m.created_at)} - for m in recent[:10] + for m in recent[:10] # 最近 10 条交互记录 ], }, }, @@ -106,9 +150,26 @@ async def get_employee_dashboard( async def get_employee_analysis( emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db) ): + """获取指定员工的 AI 分析报告。 + + 基于员工与 AI 的交互记录,使用大语言模型生成员工分析报告, + 包括任务完成率、活跃度、主要话题、效率趋势、优势和改进建议等。 + + Args: + emp_id: 员工唯一标识 ID。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + EmployeeAnalysis: 员工分析报告。 + + Raises: + HTTPException: 无权查看或员工不存在时抛出异常。 + """ user_ctx = request.state.user cur_id = uuid.UUID(user_ctx["id"]) + # 数据权限校验 if user_ctx["data_scope"] != "all": if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]: raise HTTPException(403, "无权查看此员工数据") @@ -118,6 +179,7 @@ async def get_employee_analysis( if emp_id not in sub_ids: raise HTTPException(403, "无权查看此员工数据") + # 查询员工信息 emp_result = await db.execute(select(User).where(User.id == emp_id)) emp = emp_result.scalar_one_or_none() if not emp: @@ -128,6 +190,7 @@ async def get_employee_analysis( from agentscope.formatter import OpenAIChatFormatter from agentscope.message import Msg + # 查询最近 100 条消息作为分析数据 msgs_result = await db.execute( select(ChatMessage) .where(ChatMessage.user_id == emp_id) @@ -136,10 +199,12 @@ async def get_employee_analysis( ) messages = msgs_result.scalars().all() + # 构建交互记录文本 interaction_log = "\n".join([ f"[{m.role}] {m.content[:300]}" for m in messages ]) + # 初始化 LLM 模型 model = OpenAIChatModel( config_name="analysis_model", model_name=settings.LLM_MODEL, @@ -148,6 +213,7 @@ async def get_employee_analysis( ) formatter = OpenAIChatFormatter() + # 构建分析提示词 prompt = formatter.format([ Msg("system", f"""你是一个企业管理者分析助手。请根据员工与AI的交互记录,生成一个JSON格式的分析报告。 @@ -166,7 +232,7 @@ async def get_employee_analysis( ]) try: - res = await model(prompt) + res = await model(prompt) # 调用 LLM 生成分析 res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) @@ -174,8 +240,9 @@ async def get_employee_analysis( res_text = res.get_text_content() else: res_text = str(res) - analysis_data = json.loads(res_text) + analysis_data = json.loads(res_text) # 解析 JSON 响应 except Exception: + # LLM 调用失败时使用默认分析数据 analysis_data = { "task_completion_rate": 0.7, "active_days": 0, @@ -193,4 +260,4 @@ async def get_employee_analysis( department=str(emp.department_id) if emp.department_id else "", period=f"最近数据", **analysis_data - ) \ No newline at end of file + ) diff --git a/backend/modules/notification/__init__.py b/backend/modules/notification/__init__.py index e69de29..e713a3c 100644 --- a/backend/modules/notification/__init__.py +++ b/backend/modules/notification/__init__.py @@ -0,0 +1 @@ +"""通知模块。""" \ No newline at end of file diff --git a/backend/modules/notification/router.py b/backend/modules/notification/router.py index cd59b27..e67c921 100644 --- a/backend/modules/notification/router.py +++ b/backend/modules/notification/router.py @@ -1,3 +1,8 @@ +"""通知模块路由。 + +提供实时通知推送功能,支持 WebSocket 连接、消息广播、定向发送。 +支持通知模板管理和企业微信推送集成。 +""" import uuid import json import asyncio @@ -15,22 +20,50 @@ router = APIRouter(prefix="/api/notification", tags=["notification"]) class WebSocketManager: + """WebSocket 连接管理器类,管理所有用户的 WebSocket 连接。 + + 支持按用户 ID 管理多个连接,提供定向发送和广播功能。 + + Attributes: + connections: 用户 ID 到 WebSocket 连接列表的映射字典。 + """ def __init__(self): + """初始化 WebSocket 管理器实例。""" self.connections: dict[str, list[WebSocket]] = {} async def connect(self, user_id: str, ws: WebSocket): + """接受并注册新的 WebSocket 连接。 + + Args: + user_id: 用户唯一标识。 + ws: WebSocket 连接对象。 + """ await ws.accept() if user_id not in self.connections: self.connections[user_id] = [] self.connections[user_id].append(ws) def disconnect(self, user_id: str, ws: WebSocket): + """断开并移除指定的 WebSocket 连接。 + + Args: + user_id: 用户唯一标识。 + ws: 要移除的 WebSocket 连接对象。 + """ if user_id in self.connections: self.connections[user_id].remove(ws) if not self.connections[user_id]: del self.connections[user_id] async def send_to_user(self, user_id: str, message: dict): + """向指定用户的所有连接发送消息。 + + 自动清理失效的连接。 + + Args: + user_id: 目标用户唯一标识。 + message: 要发送的消息字典。 + """ connections = self.connections.get(user_id, []) dead = [] for ws in connections: @@ -42,19 +75,37 @@ class WebSocketManager: self.disconnect(user_id, ws) async def broadcast(self, message: dict): + """向所有在线用户广播消息。 + + Args: + message: 要广播的消息字典。 + """ for user_id in list(self.connections.keys()): await self.send_to_user(user_id, message) @property def active_count(self) -> int: + """获取当前活跃的 WebSocket 连接总数。 + + Returns: + int: 活跃连接数量。 + """ return sum(len(v) for v in self.connections.values()) -ws_manager = WebSocketManager() +ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例实例 @router.websocket("/ws/{user_id}") async def notification_websocket(ws: WebSocket, user_id: str): + """通知 WebSocket 连接处理器。 + + 接受客户端的 WebSocket 连接,处理心跳 ping/pong 消息。 + + Args: + ws: WebSocket 连接对象。 + user_id: 用户唯一标识,从路径参数获取。 + """ await ws_manager.connect(user_id, ws) try: while True: @@ -71,11 +122,24 @@ async def notification_websocket(ws: WebSocket, user_id: str): @router.post("/send") async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): - user_id = payload.get("user_id", "") - target_all = payload.get("target_all", False) - title = payload.get("title", "系统通知") - body = payload.get("message", "") - notify_type = payload.get("type", "info") + """发送实时通知,支持定向发送和广播。 + + 支持推送到企业微信,并记录审计日志。 + + Args: + payload: 请求体,包含 user_id、target_all、title、message、type、push_to_wecom 等字段。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + """ + user_id = payload.get("user_id", "") # 目标用户 ID + target_all = payload.get("target_all", False) # 是否广播给所有用户 + title = payload.get("title", "系统通知") # 通知标题 + body = payload.get("message", "") # 通知内容 + notify_type = payload.get("type", "info") # 通知类型 msg = { "type": notify_type, @@ -85,10 +149,11 @@ async def send_notification(payload: dict, request: Request, db: AsyncSession = } if target_all: - await ws_manager.broadcast(msg) + await ws_manager.broadcast(msg) # 广播给所有在线用户 elif user_id: - await ws_manager.send_to_user(user_id, msg) + await ws_manager.send_to_user(user_id, msg) # 定向发送给指定用户 + # 推送到企业微信 if payload.get("push_to_wecom"): await _push_to_wecom(title, body, user_id) @@ -107,6 +172,15 @@ async def send_notification(payload: dict, request: Request, db: AsyncSession = @router.get("/templates", response_model=list[NotificationTemplateOut]) async def list_templates(request: Request, db: AsyncSession = Depends(get_db)): + """列出所有通知模板。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + list[NotificationTemplateOut]: 通知模板列表。 + """ result = await db.execute( select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc()) ) @@ -120,6 +194,20 @@ async def create_template( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """创建新的通知模板。 + + Args: + req: 通知模板创建请求体。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + NotificationTemplateOut: 创建后的通知模板响应。 + + Raises: + HTTPException: 模板编码已存在时抛出异常。 + """ existing = await db.execute( select(NotificationTemplate).where(NotificationTemplate.code == req.code) ) @@ -141,6 +229,19 @@ async def create_template( @router.get("/templates/{template_id}", response_model=NotificationTemplateOut) async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + """获取指定通知模板的详细信息。 + + Args: + template_id: 通知模板唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + NotificationTemplateOut: 通知模板详细信息。 + + Raises: + HTTPException: 模板不存在时抛出异常。 + """ result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) template = result.scalar_one_or_none() if not template: @@ -154,6 +255,20 @@ async def delete_template( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): + """删除指定的通知模板。 + + Args: + template_id: 通知模板唯一标识 ID。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + user: 当前登录用户信息。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 模板不存在或为系统模板时抛出异常。 + """ result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) template = result.scalar_one_or_none() if not template: @@ -167,10 +282,22 @@ async def delete_template( @router.get("/ws/stats") async def ws_stats(): + """获取 WebSocket 连接统计信息。 + + Returns: + dict: 包含活跃连接数的响应数据。 + """ return {"code": 200, "data": {"active_connections": ws_manager.active_count}} async def _push_to_wecom(title: str, body: str, user_id: str): + """将通知推送到企业微信。 + + Args: + title: 通知标题。 + body: 通知内容。 + user_id: 目标企业微信用户 ID。 + """ if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET: return @@ -195,4 +322,4 @@ async def _push_to_wecom(title: str, body: str, user_id: str): }, ) except Exception: - pass \ No newline at end of file + pass diff --git a/backend/modules/org/__init__.py b/backend/modules/org/__init__.py index e69de29..e9cc7f8 100644 --- a/backend/modules/org/__init__.py +++ b/backend/modules/org/__init__.py @@ -0,0 +1,5 @@ +"""组织管理模块。 + +提供部门和用户的 CRUD 操作、树形部门结构查询、下级用户递归查询等功能。 +支持基于数据权限范围的访问控制。 +""" diff --git a/backend/modules/org/router.py b/backend/modules/org/router.py index d7c4062..56b4541 100644 --- a/backend/modules/org/router.py +++ b/backend/modules/org/router.py @@ -1,3 +1,8 @@ +"""组织管理模块路由。 + +提供部门和用户的 CRUD 操作、树形部门结构查询、下级用户递归查询等功能。 +支持基于数据权限范围(all/subordinate_only/self_only)的访问控制。 +""" import uuid from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select @@ -11,11 +16,22 @@ from schemas import ( ) from modules.auth.router import hash_password, get_user_roles -router = APIRouter(prefix="/api/org", tags=["org"]) +router = APIRouter(prefix="/api/org", tags=["org"]) # 组织管理模块路由前缀 @router.get("/departments", response_model=list[DepartmentOut]) async def get_departments(request: Request, db: AsyncSession = Depends(get_db)): + """获取树形部门结构。 + + 查询所有顶级部门(parent_id 为 NULL),并递归构建完整的部门树。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + list[DepartmentOut]: 树形部门结构列表。 + """ result = await db.execute( select(Department).where(Department.parent_id.is_(None)).order_by(Department.sort_order) ) @@ -24,11 +40,26 @@ async def get_departments(request: Request, db: AsyncSession = Depends(get_db)): async def _build_department_tree(db: AsyncSession, dept: Department, _visited: set[uuid.UUID] = None) -> DepartmentOut: + """递归构建部门树形结构。 + + 查询当前部门的所有子部门,并递归构建子部门的子部门树。 + 使用 _visited 集合防止循环引用导致无限递归。 + + Args: + db: 异步数据库会话。 + dept: 当前部门对象。 + _visited: 已访问的部门 ID 集合,用于防止循环引用。 + + Returns: + DepartmentOut: 包含子部门列表的部门信息。 + """ if _visited is None: _visited = set() if dept.id in _visited: + # 检测到循环引用,返回不包含子部门的部门信息 return DepartmentOut(id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, children=[]) _visited.add(dept.id) + # 查询当前部门的所有子部门 children_result = await db.execute( select(Department).where(Department.parent_id == dept.id).order_by(Department.sort_order) ) @@ -36,7 +67,7 @@ async def _build_department_tree(db: AsyncSession, dept: Department, _visited: s return DepartmentOut( id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, - children=[await _build_department_tree(db, c, _visited) for c in children], + children=[await _build_department_tree(db, c, _visited) for c in children], # 递归构建子部门 ) @@ -44,19 +75,35 @@ async def _build_department_tree(db: AsyncSession, dept: Department, _visited: s async def create_department( req: DepartmentCreate, request: Request, db: AsyncSession = Depends(get_db) ): + """创建新部门。 + + 根据父部门信息计算新部门的层级和路径。 + + Args: + req: 部门创建请求体,包含名称、父部门 ID 和排序权重。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + DepartmentOut: 创建后的部门信息。 + + Raises: + HTTPException: 父部门不存在时抛出异常。 + """ parent_path = "/" level = 0 if req.parent_id: + # 查询父部门信息以计算层级和路径 parent_result = await db.execute(select(Department).where(Department.id == req.parent_id)) parent = parent_result.scalar_one_or_none() if not parent: raise HTTPException(404, "父部门不存在") parent_path = parent.path - level = parent.level + 1 + level = parent.level + 1 # 新部门层级为父部门层级 + 1 dept = Department( name=req.name, parent_id=req.parent_id, - path=f"{parent_path}/{req.name}".replace("//", "/"), + path=f"{parent_path}/{req.name}".replace("//", "/"), # 构建部门路径 level=level, sort_order=req.sort_order, ) db.add(dept) @@ -73,6 +120,20 @@ async def update_department( dept_id: uuid.UUID, req: DepartmentUpdate, request: Request, db: AsyncSession = Depends(get_db), ): + """更新部门信息。 + + Args: + dept_id: 部门唯一标识 ID。 + req: 部门更新请求体,包含可更新的字段。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + DepartmentOut: 更新后的部门信息。 + + Raises: + HTTPException: 部门不存在时抛出异常。 + """ result = await db.execute(select(Department).where(Department.id == dept_id)) dept = result.scalar_one_or_none() if not dept: @@ -92,6 +153,19 @@ async def update_department( @router.delete("/departments/{dept_id}") async def delete_department(dept_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + """删除部门。 + + Args: + dept_id: 部门唯一标识 ID。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + dict: 操作结果响应。 + + Raises: + HTTPException: 部门不存在时抛出异常。 + """ result = await db.execute(select(Department).where(Department.id == dept_id)) dept = result.scalar_one_or_none() if not dept: @@ -102,10 +176,22 @@ async def delete_department(dept_id: uuid.UUID, request: Request, db: AsyncSessi @router.get("/users", response_model=list[UserOut]) async def get_users(request: Request, db: AsyncSession = Depends(get_db)): + """获取用户列表。 + + 根据当前用户的数据权限范围返回不同的用户列表。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + list[UserOut]: 用户信息列表。 + """ user_ctx = request.state.user result = await db.execute(select(User)) users = result.scalars().all() + # 根据数据权限范围过滤用户 if user_ctx["data_scope"] == "self_only": users = [u for u in users if str(u.id) == user_ctx["id"]] elif user_ctx["data_scope"] == "subordinate_only": @@ -118,6 +204,19 @@ async def get_users(request: Request, db: AsyncSession = Depends(get_db)): @router.get("/users/{user_id}", response_model=UserOut) async def get_user(user_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): + """获取指定用户的详细信息。 + + Args: + user_id: 用户唯一标识 ID。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + UserOut: 用户详细信息。 + + Raises: + HTTPException: 用户不存在时抛出异常。 + """ result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: @@ -127,13 +226,28 @@ async def get_user(user_id: uuid.UUID, request: Request, db: AsyncSession = Depe @router.post("/users", response_model=UserOut) async def create_user(req: UserCreate, request: Request, db: AsyncSession = Depends(get_db)): + """创建新用户。 + + 支持设置用户名、密码、部门、职位、上级用户和角色等信息。 + + Args: + req: 用户创建请求体。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + UserOut: 创建后的用户信息。 + + Raises: + HTTPException: 用户名已存在时抛出异常。 + """ existing = await db.execute(select(User).where(User.username == req.username)) if existing.scalar_one_or_none(): raise HTTPException(400, "用户名已存在") user = User( username=req.username, - password_hash=hash_password(req.password), + password_hash=hash_password(req.password), # 密码哈希存储 display_name=req.display_name, email=req.email, phone=req.phone, wecom_user_id=req.wecom_user_id, @@ -144,6 +258,7 @@ async def create_user(req: UserCreate, request: Request, db: AsyncSession = Depe await db.flush() if req.role_ids: + # 为用户分配角色 for role_id in req.role_ids: db.add(UserRole(user_id=user.id, role_id=role_id)) await db.flush() @@ -156,6 +271,22 @@ async def update_user( user_id: uuid.UUID, req: UserUpdate, request: Request, db: AsyncSession = Depends(get_db), ): + """更新用户信息。 + + 支持修改显示名称、邮箱、手机号、部门、职位、上级用户、状态和角色。 + + Args: + user_id: 用户唯一标识 ID。 + req: 用户更新请求体。 + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + UserOut: 更新后的用户信息。 + + Raises: + HTTPException: 用户不存在时抛出异常。 + """ result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: @@ -177,12 +308,13 @@ async def update_user( user.status = req.status if req.role_ids is not None: - await db.execute(select(UserRole).where(UserRole.user_id == user.id)) + # 先删除用户现有角色关联 existing_urs = (await db.execute( select(UserRole).where(UserRole.user_id == user.id) )).scalars().all() for ur in existing_urs: await db.delete(ur) + # 重新分配角色 for role_id in req.role_ids: db.add(UserRole(user_id=user.id, role_id=role_id)) @@ -191,9 +323,20 @@ async def update_user( @router.get("/subordinates", response_model=list[UserOut]) async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db)): + """获取当前用户的所有下级用户(递归)。 + + 递归查询所有直接或间接以当前用户为上级的用户。 + + Args: + request: HTTP 请求对象,包含当前用户上下文。 + db: 异步数据库会话。 + + Returns: + list[UserOut]: 下级用户列表。 + """ user_ctx = request.state.user manager_id = uuid.UUID(user_ctx["id"]) - sub_ids = await _get_subordinate_ids(db, manager_id) + sub_ids = await _get_subordinate_ids(db, manager_id) # 递归获取所有下级 ID result = await db.execute(select(User).where(User.id.in_(sub_ids))) users = result.scalars().all() @@ -201,25 +344,48 @@ async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db)) async def _get_subordinate_ids(db: AsyncSession, manager_id: uuid.UUID, _visited: set[uuid.UUID] = None) -> set[uuid.UUID]: + """递归获取指定管理者的所有下级用户 ID。 + + 递归查询直接或间接以指定用户为上级的所有用户 ID。 + 使用 _visited 集合防止循环引用导致无限递归。 + + Args: + db: 异步数据库会话。 + manager_id: 管理者用户 ID。 + _visited: 已访问的用户 ID 集合,用于防止循环引用。 + + Returns: + set[uuid.UUID]: 所有下级用户 ID 的集合。 + """ if _visited is None: _visited = set() if manager_id in _visited: - return set() + return set() # 检测到循环引用,返回空集合 _visited.add(manager_id) + # 查询直接下级 result = await db.execute(select(User).where(User.manager_id == manager_id)) direct = result.scalars().all() ids = {u.id for u in direct} for sub in direct: - ids.update(await _get_subordinate_ids(db, sub.id, _visited)) + ids.update(await _get_subordinate_ids(db, sub.id, _visited)) # 递归获取子级下级 return ids async def _user_to_out(db: AsyncSession, user: User) -> UserOut: - roles = await get_user_roles(db, user.id) + """将用户数据库对象转换为 UserOut 响应模型。 + + Args: + db: 异步数据库会话,用于查询用户角色信息。 + user: 用户数据库对象。 + + Returns: + UserOut: 用户响应模型,包含角色列表。 + """ + roles = await get_user_roles(db, user.id) # 获取用户角色信息 return UserOut( id=user.id, username=user.username, display_name=user.display_name, email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id, department_id=user.department_id, position=user.position, manager_id=user.manager_id, status=user.status, roles=roles, created_at=user.created_at, - ) \ No newline at end of file + ) diff --git a/backend/modules/rag/knowledge.py b/backend/modules/rag/knowledge.py index 7756714..3499f42 100644 --- a/backend/modules/rag/knowledge.py +++ b/backend/modules/rag/knowledge.py @@ -1,3 +1,9 @@ +"""知识库模块。 + +提供基于 AgentScope 的企业知识库管理功能,包括文档索引、文本索引和语义检索。 +支持多种文档格式(PDF、Word、Excel、纯文本)的自动解析和向量化存储。 +使用 Qdrant 作为向量存储后端,OpenAI Embedding 作为向量化模型。 +""" import os import asyncio import logging @@ -5,15 +11,20 @@ from agentscope.embedding import OpenAITextEmbedding from agentscope.rag import SimpleKnowledge, QdrantStore, TextReader, PDFReader, WordReader, ExcelReader from config import settings -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 -_knowledge_base: SimpleKnowledge | None = None -_STORE_PATH = os.path.join(settings.UPLOAD_DIR, "..", "data", "qdrant") -_COLLECTION_NAME = "enterprise_knowledge" -_VECTOR_DIM = 1536 +_knowledge_base: SimpleKnowledge | None = None # 全局知识库单例实例 +_STORE_PATH = os.path.join(settings.UPLOAD_DIR, "..", "data", "qdrant") # Qdrant 向量存储路径 +_COLLECTION_NAME = "enterprise_knowledge" # Qdrant 集合名称 +_VECTOR_DIM = 1536 # 向量维度(text-embedding-3-small 标准维度) def _get_embedding_model(): + """创建并返回 OpenAI 文本 Embedding 模型实例。 + + Returns: + OpenAITextEmbedding: 配置好的 Embedding 模型。 + """ return OpenAITextEmbedding( api_key=settings.LLM_API_KEY, model_name="text-embedding-3-small", @@ -22,6 +33,14 @@ def _get_embedding_model(): def get_knowledge_base() -> SimpleKnowledge: + """获取或创建全局知识库实例。 + + 采用单例模式,首次调用时初始化 Qdrant 向量存储和 Embedding 模型, + 后续调用直接返回已创建的实例。 + + Returns: + SimpleKnowledge: 初始化好的知识库实例。 + """ global _knowledge_base if _knowledge_base is None: os.makedirs(_STORE_PATH, exist_ok=True) @@ -39,10 +58,23 @@ def get_knowledge_base() -> SimpleKnowledge: async def add_document(file_path: str, file_type: str = "auto") -> str: + """将文档文件添加到知识库中进行索引。 + + 自动根据文件类型选择合适的解析器,将文档切分为多个文本块后 + 进行向量化并存储到知识库中。 + + Args: + file_path: 文档文件的完整路径。 + file_type: 文档类型,auto 表示自动识别。 + + Returns: + str: 索引结果描述或错误信息。 + """ try: ext = os.path.splitext(file_path)[1].lower() kb = get_knowledge_base() + # 根据文件类型选择对应的解析器 if file_type == "auto": if ext == ".pdf": reader = PDFReader(chunk_size=1024, split_by="sentence") @@ -83,6 +115,15 @@ async def add_document(file_path: str, file_type: str = "auto") -> str: async def add_text(text: str, source: str = "manual") -> str: + """将纯文本内容添加到知识库中进行索引。 + + Args: + text: 要索引的文本内容。 + source: 文本来源标识,默认为 manual(手动录入)。 + + Returns: + str: 索引结果描述或错误信息。 + """ try: kb = get_knowledge_base() reader = TextReader(chunk_size=1024, split_by="sentence") @@ -97,6 +138,18 @@ async def add_text(text: str, source: str = "manual") -> str: async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> list[dict]: + """在知识库中执行语义检索。 + + 根据查询文本的向量相似度,从知识库中检索最相关的文档片段。 + + Args: + query: 查询文本。 + limit: 返回结果的最大数量,默认 5 条。 + score_threshold: 相似度分数阈值,低于此值的结果将被过滤,默认 0.3。 + + Returns: + list[dict]: 检索结果列表,每项包含 id、content、score、source 字段。 + """ try: kb = get_knowledge_base() if not kb or not hasattr(kb, 'retrieve'): @@ -124,6 +177,17 @@ async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> li async def retrieve_for_agent(query: str, limit: int = 5) -> str: + """为 AI 智能体执行知识库检索并返回格式化的结果文本。 + + 该函数专为 AgentScope 智能体调用设计,返回人类可读的检索结果。 + + Args: + query: 查询文本。 + limit: 返回结果的最大数量,默认 5 条。 + + Returns: + str: 格式化的检索结果文本,包含相关度分数。 + """ results = await search(query, limit=limit) if not results: return "未找到相关文档。" @@ -131,4 +195,4 @@ async def retrieve_for_agent(query: str, limit: int = 5) -> str: parts = ["根据知识库检索到以下相关内容:"] for i, r in enumerate(results, 1): parts.append(f"\n[{i}] (相关度: {r['score']})\n{r['content']}") - return "\n".join(parts) \ No newline at end of file + return "\n".join(parts) diff --git a/backend/modules/rag/router.py b/backend/modules/rag/router.py index 235916b..36abebb 100644 --- a/backend/modules/rag/router.py +++ b/backend/modules/rag/router.py @@ -1,3 +1,8 @@ +"""知识库(RAG)模块路由。 + +提供知识库的文档上传、文本索引、语义检索和文档管理功能。 +支持文件上传后自动解析和向量化,以及基于相似度的知识检索。 +""" from fastapi import APIRouter, Depends, UploadFile, File, Request from database import get_db from sqlalchemy.ext.asyncio import AsyncSession @@ -18,8 +23,22 @@ async def rag_upload( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """上传文档文件并自动索引到知识库。 + + 将上传的文件保存到服务器后,调用文档解析器将其切分为文本块 + 并进行向量化存储到知识库中。 + + Args: + request: HTTP 请求对象。 + file: 上传的文件对象。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含索引结果和文件信息的响应数据。 + """ os.makedirs(settings.UPLOAD_DIR, exist_ok=True) - filename = f"{uuid.uuid4().hex}_{file.filename}" + filename = f"{uuid.uuid4().hex}_{file.filename}" # 生成唯一文件名,避免冲突 file_path = os.path.join(settings.UPLOAD_DIR, filename) content = await file.read() @@ -37,8 +56,19 @@ async def rag_index_text( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): - text = payload.get("text", "") - source = payload.get("source", "manual") + """将纯文本内容索引到知识库。 + + Args: + request: HTTP 请求对象。 + payload: 请求体,包含 text 和可选的 source 字段。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含索引结果的响应数据。 + """ + text = payload.get("text", "") # 待索引的文本内容 + source = payload.get("source", "manual") # 文本来源标识 if not text: return {"code": 400, "message": "文本内容不能为空"} result = await add_text(text, source) @@ -53,6 +83,18 @@ async def rag_search( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """在知识库中执行语义检索。 + + Args: + request: HTTP 请求对象。 + q: 查询文本。 + limit: 返回结果的最大数量。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含检索结果列表的响应数据。 + """ if not q: return {"code": 400, "message": "查询内容不能为空"} results = await search(q, limit=limit) @@ -67,6 +109,18 @@ async def rag_retrieve( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """为 AI 智能体执行知识库检索,返回格式化的结果文本。 + + Args: + request: HTTP 请求对象。 + q: 查询文本。 + limit: 返回结果的最大数量。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含格式化检索结果的响应数据。 + """ if not q: return {"code": 400, "message": "查询内容不能为空"} result = await retrieve_for_agent(q, limit=limit) @@ -79,6 +133,18 @@ async def list_documents( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """列出知识库中已索引的所有文档来源及其统计信息。 + + 从向量存储中获取所有文档,按来源分组并统计每个来源的文档块数量。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含文档列表和统计信息的响应数据。 + """ try: kb = get_knowledge_base() if not kb or not hasattr(kb, '_embedding_store'): @@ -108,6 +174,7 @@ async def list_documents( }) offset += batch_size + # 按来源分组统计 seen_sources = {} for d in all_docs: src = d["source"] or "unknown" @@ -134,6 +201,17 @@ async def delete_document( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """按来源删除知识库中的文档块。 + + Args: + source: 文档来源标识。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含删除结果的响应数据。 + """ try: kb = get_knowledge_base() store = kb._embedding_store @@ -157,6 +235,16 @@ async def knowledge_stats( db: AsyncSession = Depends(get_db), current_user=Depends(get_current_user), ): + """获取知识库的统计信息,包括文档块数量和来源文件数量。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + current_user: 当前登录用户信息。 + + Returns: + dict: 包含知识库统计信息的响应数据。 + """ try: kb = get_knowledge_base() store = kb._embedding_store @@ -186,4 +274,4 @@ async def knowledge_stats( return {"code": 200, "data": stats_data} except Exception as e: - return {"code": 500, "message": f"获取统计信息失败: {e}"} \ No newline at end of file + return {"code": 500, "message": f"获取统计信息失败: {e}"} diff --git a/backend/modules/rbac/__init__.py b/backend/modules/rbac/__init__.py index e69de29..745ddd1 100644 --- a/backend/modules/rbac/__init__.py +++ b/backend/modules/rbac/__init__.py @@ -0,0 +1 @@ +"""权限管理模块(RBAC)。""" \ No newline at end of file diff --git a/backend/modules/system/router.py b/backend/modules/system/router.py index 8302311..5e6a8f5 100644 --- a/backend/modules/system/router.py +++ b/backend/modules/system/router.py @@ -1,3 +1,8 @@ +"""系统管理模块路由。 + +提供系统健康检查、使用统计、指标收集和缓存管理等功能。 +用于监控系统运行状态和资源使用情况。 +""" import time import uuid import psutil @@ -15,11 +20,23 @@ from middleware.rate_limiter import rate_limiter router = APIRouter(prefix="/api/system", tags=["system"]) -_start_time = time.time() +_start_time = time.time() # 服务启动时间戳 @router.get("/health", response_model=SystemHealthOut) async def health_check(request: Request, db: AsyncSession = Depends(get_db)): + """系统健康检查接口。 + + 检查数据库连接、Redis 连接、内存使用、CPU 使用率等系统健康指标。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + SystemHealthOut: 包含系统健康状态的响应数据。 + """ + # 检查数据库连接 db_ok = False try: await db.execute(select(func.count()).select_from(User)) @@ -27,9 +44,9 @@ async def health_check(request: Request, db: AsyncSession = Depends(get_db)): except Exception: pass - mem = psutil.Process(os.getpid()).memory_info() - cpu = psutil.cpu_percent(interval=0.1) - uptime = time.time() - _start_time + mem = psutil.Process(os.getpid()).memory_info() # 获取当前进程内存信息 + cpu = psutil.cpu_percent(interval=0.1) # 获取 CPU 使用率 + uptime = time.time() - _start_time # 计算服务运行时长 try: user_count = await db.execute(select(func.count(User.id))) @@ -38,22 +55,34 @@ async def health_check(request: Request, db: AsyncSession = Depends(get_db)): active_users = 0 return SystemHealthOut( - status="healthy" if db_ok and cache_manager.available else "degraded", + status="healthy" if db_ok and cache_manager.available else "degraded", # 数据库和 Redis 都正常则为 healthy service="enterprise-ai-platform", uptime_seconds=round(uptime, 1), db_connected=db_ok, redis_connected=cache_manager.available, active_users=active_users, - memory_mb=round(mem.rss / 1024 / 1024, 1), + memory_mb=round(mem.rss / 1024 / 1024, 1), # 转换为 MB cpu_percent=round(cpu, 1), ) @router.get("/stats", response_model=UsageStatsOut) async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)): - today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) + """获取系统使用统计信息。 + + 统计用户数、会话数、消息数、任务数、流程数、API 调用数等关键指标。 + + Args: + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + UsageStatsOut: 包含系统使用统计的响应数据。 + """ + today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) # 今日零点 total_users = (await db.execute(select(func.count(User.id)))).scalar() or 0 + # 今日活跃用户数(有新建会话的用户) active_today = (await db.execute( select(func.count(func.distinct(User.id))) .join(ChatSession, ChatSession.user_id == User.id) @@ -66,6 +95,7 @@ async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)): published = (await db.execute( select(func.count(FlowDefinition.id)).where(FlowDefinition.status == "published") )).scalar() or 0 + # 今日 API 调用数(有新建执行记录的流程) api_calls = (await db.execute( select(func.count(FlowExecution.id)).where(FlowExecution.started_at >= today) )).scalar() or 0 @@ -85,6 +115,16 @@ async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)): @router.post("/metrics") async def collect_metrics(payload: dict, request: Request, db: AsyncSession = Depends(get_db)): + """收集并存储系统指标数据。 + + Args: + payload: 请求体,包含 metric_type、value、source 字段。 + request: HTTP 请求对象。 + db: 异步数据库会话。 + + Returns: + dict: 包含指标 ID 的响应数据。 + """ metric = SystemMetric( metric_type=payload.get("metric_type", "custom"), value={"data": payload.get("value", {}), "source": payload.get("source", "api")}, @@ -101,6 +141,17 @@ async def list_metrics( limit: int = 50, db: AsyncSession = Depends(get_db), ): + """查询系统指标历史数据。 + + Args: + request: HTTP 请求对象。 + metric_type: 可选的指标类型筛选条件。 + limit: 返回结果的最大数量。 + db: 异步数据库会话。 + + Returns: + dict: 包含指标列表的响应数据。 + """ q = select(SystemMetric).order_by(SystemMetric.collected_at.desc()) if metric_type: q = q.where(SystemMetric.metric_type == metric_type) @@ -120,6 +171,14 @@ async def list_metrics( @router.get("/cache/stats") async def cache_stats(request: Request): + """获取缓存系统状态信息。 + + Args: + request: HTTP 请求对象。 + + Returns: + dict: 包含 Redis 可用性状态的响应数据。 + """ return { "code": 200, "data": { @@ -130,6 +189,14 @@ async def cache_stats(request: Request): @router.get("/ratelimit/stats") async def ratelimit_stats(request: Request): + """获取速率限制状态信息。 + + Args: + request: HTTP 请求对象。 + + Returns: + dict: 包含速率限制配置的响应数据。 + """ remaining = await rate_limiter.remaining("global") return { "code": 200, @@ -143,5 +210,14 @@ async def ratelimit_stats(request: Request): @router.post("/cache/clear") async def clear_cache(request: Request, pattern: str = "*"): + """清除缓存数据。 + + Args: + request: HTTP 请求对象。 + pattern: 缓存键匹配模式,默认清除所有缓存。 + + Returns: + dict: 操作结果响应。 + """ await cache_manager.delete_pattern(pattern) - return {"code": 200, "message": "缓存已清除"} \ No newline at end of file + return {"code": 200, "message": "缓存已清除"} diff --git a/backend/modules/task/__init__.py b/backend/modules/task/__init__.py index e69de29..d086150 100644 --- a/backend/modules/task/__init__.py +++ b/backend/modules/task/__init__.py @@ -0,0 +1 @@ +"""任务管理模块。""" \ No newline at end of file diff --git a/backend/modules/wecom/router.py b/backend/modules/wecom/router.py index ddd5c0e..0f411b6 100644 --- a/backend/modules/wecom/router.py +++ b/backend/modules/wecom/router.py @@ -1,3 +1,8 @@ +"""企业微信模块路由。 + +提供企业微信集成相关功能,包括回调消息处理、配置管理和消息发送。 +支持企业微信用户通过企微直接与 AI 助手对话。 +""" import uuid from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy import select @@ -11,18 +16,26 @@ router = APIRouter(prefix="/api/wecom", tags=["wecom"]) @router.post("/callback") async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): - """ - 接收企业微信回调消息,路由到AI助手处理并回复。 - 企微配置的回调URL指向此端点。 + """接收企业微信回调消息,路由到AI助手处理并回复。 + + 企业微信配置的回调 URL 指向此端点。接收企微用户消息后, + 查找对应的系统用户,创建或复用聊天会话,调用 AI 智能体处理消息并返回回复。 + + Args: + request: HTTP 请求对象,包含企业微信回调消息体。 + db: 异步数据库会话。 + + Returns: + dict: 包含消息类型、用户 ID 和 AI 回复的响应数据。 """ try: body = await request.json() except Exception: body = await request.body() - msg_type = "text" - wecom_user_id = "" - content = "" + msg_type = "text" # 消息类型 + wecom_user_id = "" # 企业微信用户 ID + content = "" # 消息内容 if isinstance(body, dict): msg_type = body.get("msg_type", body.get("MsgType", "text")) wecom_user_id = body.get("user_id", body.get("FromUserName", "")) @@ -31,6 +44,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): if not wecom_user_id or not content: return {"code": 200, "message": "received"} + # 根据企业微信用户 ID 查找系统用户 user_result = await db.execute( select(User).where(User.wecom_user_id == wecom_user_id) ) @@ -40,6 +54,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): from agentscope.message import Msg + # 查找或创建聊天会话 session_result = await db.execute( select(ChatSession) .where(ChatSession.user_id == user.id, ChatSession.agent_type == "employee") @@ -56,6 +71,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): db.add(session) await db.flush() + # 保存用户消息 user_msg = ChatMessage( session_id=session.id, user_id=user.id, role="user", content=content, @@ -63,6 +79,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): db.add(user_msg) await db.flush() + # 创建 AI 智能体并处理消息 from agentscope_integration.factory import AgentFactory agent = await AgentFactory.create_agent( agent_type="employee", @@ -76,6 +93,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response) + # 保存 AI 回复消息 ai_msg = ChatMessage( session_id=session.id, user_id=user.id, role="assistant", content=reply_text, @@ -95,6 +113,14 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): @router.get("/config") async def get_wecom_config(request: Request): + """获取企业微信当前配置信息。 + + Args: + request: HTTP 请求对象。 + + Returns: + dict: 包含机器人名称、状态、CorpID、功能列表等配置信息。 + """ return { "code": 200, "data": { @@ -109,6 +135,17 @@ async def get_wecom_config(request: Request): @router.put("/config") async def update_wecom_config(request: Request, payload: dict): + """更新企业微信配置并持久化到 .env 文件。 + + 支持更新 CorpID、AppSecret、AgentID、Token 和 EncodingAESKey 等配置项。 + + Args: + request: HTTP 请求对象。 + payload: 请求体,包含企业微信配置字段。 + + Returns: + dict: 操作结果响应。 + """ import os env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') updates = {} @@ -146,9 +183,18 @@ async def update_wecom_config(request: Request, payload: dict): @router.post("/send") async def send_wecom_message(request: Request, payload: dict): - to_user = payload.get("to_user", "@all") - msg_type = payload.get("msg_type", "text") - content = payload.get("content", "") + """通过企业微信发送消息给指定用户。 + + Args: + request: HTTP 请求对象。 + payload: 请求体,包含 to_user、msg_type、content 字段。 + + Returns: + dict: 操作结果响应。 + """ + to_user = payload.get("to_user", "@all") # 目标用户,默认 @all 表示所有人 + msg_type = payload.get("msg_type", "text") # 消息类型 + content = payload.get("content", "") # 消息内容 if not content: return {"code": 400, "message": "消息内容不能为空"} @@ -158,4 +204,4 @@ async def send_wecom_message(request: Request, payload: dict): result = send_notification(to_user, content, msg_type) return {"code": 200, "message": result} except Exception as e: - return {"code": 500, "message": f"发送失败: {e}"} \ No newline at end of file + return {"code": 500, "message": f"发送失败: {e}"} diff --git a/backend/schemas/__init__.py b/backend/schemas/__init__.py index 15de640..59d8018 100644 --- a/backend/schemas/__init__.py +++ b/backend/schemas/__init__.py @@ -1,3 +1,20 @@ +"""Pydantic 请求/响应模型定义。 + +所有 API 的请求体和响应体均在此定义,涵盖: +- 认证(登录/令牌) +- 用户、部门、角色、权限的增删改查 +- 任务管理 +- 审批流引擎(Flow):节点/边/版本/执行 +- 自定义工具(CustomTool) +- MCP 服务 +- Agent 配置 +- 通知模板 +- 文档上传与解析 +- 审计日志 +- 系统监控指标 +- 通用 API 响应包装 +""" + import uuid from datetime import datetime from pydantic import BaseModel, Field, ConfigDict @@ -5,11 +22,13 @@ from pydantic import BaseModel, Field, ConfigDict # --- Auth --- class LoginRequest(BaseModel): + """登录请求体:用户名 + 密码。""" username: str password: str class TokenResponse(BaseModel): + """令牌响应体:Bearer 令牌 + 用户信息。""" access_token: str token_type: str = "bearer" user: "UserOut" @@ -17,6 +36,7 @@ class TokenResponse(BaseModel): # --- User --- class UserCreate(BaseModel): + """创建用户请求体。""" username: str password: str display_name: str @@ -30,6 +50,7 @@ class UserCreate(BaseModel): class UserUpdate(BaseModel): + """更新用户请求体(所有字段可选)。""" display_name: str | None = None email: str | None = None phone: str | None = None @@ -41,6 +62,7 @@ class UserUpdate(BaseModel): class UserOut(BaseModel): + """用户响应体(ORM 映射)。""" id: uuid.UUID username: str display_name: str @@ -60,18 +82,21 @@ class UserOut(BaseModel): # --- Department --- class DepartmentCreate(BaseModel): + """创建部门请求体。""" name: str parent_id: uuid.UUID | None = None sort_order: int = 0 class DepartmentUpdate(BaseModel): + """更新部门请求体。""" name: str | None = None parent_id: uuid.UUID | None = None sort_order: int | None = None class DepartmentOut(BaseModel): + """部门响应体,含嵌套子部门列表。""" id: uuid.UUID name: str parent_id: uuid.UUID | None = None @@ -86,6 +111,7 @@ class DepartmentOut(BaseModel): # --- Role --- class RoleCreate(BaseModel): + """创建角色请求体。""" name: str code: str = "" description: str | None = None @@ -94,6 +120,7 @@ class RoleCreate(BaseModel): class RoleUpdate(BaseModel): + """更新角色请求体。""" name: str | None = None description: str | None = None data_scope: str | None = None @@ -101,6 +128,7 @@ class RoleUpdate(BaseModel): class RoleOut(BaseModel): + """角色响应体,含权限编码列表。""" id: uuid.UUID name: str code: str = "" @@ -115,6 +143,7 @@ class RoleOut(BaseModel): # --- Permission --- class PermissionOut(BaseModel): + """权限响应体。""" id: uuid.UUID code: str name: str @@ -127,6 +156,7 @@ class PermissionOut(BaseModel): # --- Task --- class TaskCreate(BaseModel): + """创建任务请求体。""" title: str content: str | None = None assignee_id: uuid.UUID @@ -136,6 +166,7 @@ class TaskCreate(BaseModel): class TaskUpdate(BaseModel): + """更新任务请求体。""" title: str | None = None content: str | None = None status: str | None = None @@ -144,6 +175,7 @@ class TaskUpdate(BaseModel): class TaskOut(BaseModel): + """任务响应体。""" id: uuid.UUID title: str content: str | None = None @@ -161,6 +193,7 @@ class TaskOut(BaseModel): # --- Employee Analysis --- class EmployeeAnalysis(BaseModel): + """员工分析报告响应体。""" employee_name: str department: str period: str @@ -177,11 +210,13 @@ class EmployeeAnalysis(BaseModel): # --- Flow --- class TriggerNodeConfig(BaseModel): + """触发节点配置。""" event_type: str = "text_message" channels: list[str] = ["wecom"] callback_url: str = "" class LLMNodeConfig(BaseModel): + """LLM 调用节点配置。""" system_prompt: str = "" model: str = "gpt-4o-mini" temperature: float = 0.7 @@ -193,6 +228,7 @@ class LLMNodeConfig(BaseModel): tool_call: bool = False class ToolNodeConfig(BaseModel): + """内置工具节点配置。""" tool_name: str = "" tool_type: str = "" tool_params: dict = {} @@ -201,6 +237,7 @@ class ToolNodeConfig(BaseModel): error_handling: str = "throw" class MCPNodeConfig(BaseModel): + """MCP 服务节点配置。""" mcp_server: str = "" tool_name: str = "" input_params: dict = {} @@ -209,6 +246,7 @@ class MCPNodeConfig(BaseModel): error_handling: str = "throw" class NotifyNodeConfig(BaseModel): + """通知节点配置。""" channels: dict = {"wecom": True, "web": False} message_template: str = "" web_template: str = "" @@ -218,6 +256,7 @@ class NotifyNodeConfig(BaseModel): error_handling: str = "throw" class ConditionNodeConfig(BaseModel): + """条件分支节点配置。""" condition: str = "" condition_type: str = "expression" true_label: str = "是" @@ -225,6 +264,7 @@ class ConditionNodeConfig(BaseModel): default_branch: str = "false" class RAGNodeConfig(BaseModel): + """RAG 检索节点配置。""" knowledge_base: str = "" top_k: int = 5 search_mode: str = "hybrid" @@ -233,6 +273,7 @@ class RAGNodeConfig(BaseModel): include_metadata: bool = True class OutputNodeConfig(BaseModel): + """输出节点配置。""" format: str = "text" output_template: str = "" indent: int = 2 @@ -241,18 +282,21 @@ class OutputNodeConfig(BaseModel): max_length: int = 2000 class LoopNodeConfig(BaseModel): + """循环节点配置。""" loop_type: str = "fixed" max_iterations: int = 10 count: int = 3 iterator_variable: str = "item" class CodeNodeConfig(BaseModel): + """代码执行节点配置。""" language: str = "python" code: str = "" timeout: int = 30 sandbox: bool = True class FlowNode(BaseModel): + """流程图中单个节点的定义。""" id: str | None = None type: str label: str | None = None @@ -260,6 +304,7 @@ class FlowNode(BaseModel): class FlowEdge(BaseModel): + """流程图中连接边的定义。""" source: str | None = None target: str | None = None from_field: str | None = Field(None, alias="from") @@ -270,6 +315,7 @@ class FlowEdge(BaseModel): class FlowDefinitionCreate(BaseModel): + """创建流程定义请求体。""" name: str description: str | None = None trigger: dict = {} @@ -279,6 +325,7 @@ class FlowDefinitionCreate(BaseModel): class FlowDefinitionUpdate(BaseModel): + """更新流程定义请求体。""" name: str | None = None description: str | None = None nodes: list[FlowNode] | None = None @@ -287,6 +334,7 @@ class FlowDefinitionUpdate(BaseModel): class FlowDefinitionOut(BaseModel): + """流程定义响应体。""" id: uuid.UUID name: str description: str | None = None @@ -305,6 +353,7 @@ class FlowDefinitionOut(BaseModel): class FlowVersionOut(BaseModel): + """流程版本响应体。""" id: uuid.UUID flow_id: uuid.UUID version: int @@ -320,10 +369,12 @@ class FlowVersionOut(BaseModel): class FlowApiKeyCreate(BaseModel): + """创建流程 API 密钥请求体。""" name: str class FlowApiKeyOut(BaseModel): + """流程 API 密钥响应体。""" id: uuid.UUID flow_id: uuid.UUID name: str @@ -336,12 +387,14 @@ class FlowApiKeyOut(BaseModel): class FlowExecuteRequest(BaseModel): + """流程执行请求体。""" input_text: str = "" session_id: str | None = None user_id: str | None = None class FlowChatMessageRequest(BaseModel): + """流程聊天消息请求体。""" query: str inputs: dict = {} response_mode: str = "blocking" @@ -350,11 +403,13 @@ class FlowChatMessageRequest(BaseModel): class OpenAPIImportRequest(BaseModel): + """OpenAPI 导入请求体。""" openapi_url: str base_url_override: str | None = None class CustomToolCreate(BaseModel): + """创建自定义工具请求体。""" model_config = ConfigDict(protected_namespaces=()) name: str description: str | None = None @@ -369,6 +424,7 @@ class CustomToolCreate(BaseModel): class CustomToolUpdate(BaseModel): + """更新自定义工具请求体。""" model_config = ConfigDict(protected_namespaces=()) name: str | None = None description: str | None = None @@ -383,6 +439,7 @@ class CustomToolUpdate(BaseModel): class CustomToolOut(BaseModel): + """自定义工具响应体(ORM 映射)。""" model_config = ConfigDict(from_attributes=True, populate_by_name=True, protected_namespaces=()) id: uuid.UUID name: str @@ -398,6 +455,7 @@ class CustomToolOut(BaseModel): # --- MCP --- class MCPServiceCreate(BaseModel): + """创建 MCP 服务请求体。""" name: str transport: str = "http" url: str | None = None @@ -407,6 +465,7 @@ class MCPServiceCreate(BaseModel): class MCPServiceUpdate(BaseModel): + """更新 MCP 服务请求体。""" transport: str | None = None url: str | None = None command: str | None = None @@ -415,6 +474,7 @@ class MCPServiceUpdate(BaseModel): class MCPServiceOut(BaseModel): + """MCP 服务响应体。""" id: uuid.UUID name: str transport: str @@ -431,6 +491,7 @@ class MCPServiceOut(BaseModel): # --- Agent Config --- class AgentConfigCreate(BaseModel): + """创建 Agent 配置请求体。""" name: str description: str | None = None system_prompt: str = "" @@ -440,6 +501,7 @@ class AgentConfigCreate(BaseModel): class AgentConfigUpdate(BaseModel): + """更新 Agent 配置请求体。""" name: str | None = None description: str | None = None system_prompt: str | None = None @@ -450,6 +512,7 @@ class AgentConfigUpdate(BaseModel): class AgentConfigOut(BaseModel): + """Agent 配置响应体。""" id: uuid.UUID name: str description: str | None = None @@ -468,6 +531,7 @@ class AgentConfigOut(BaseModel): # --- Notification --- class NotificationTemplateCreate(BaseModel): + """创建通知模板请求体。""" name: str code: str channel: str = "wecom" @@ -477,6 +541,7 @@ class NotificationTemplateCreate(BaseModel): class NotificationTemplateOut(BaseModel): + """通知模板响应体。""" id: uuid.UUID name: str code: str @@ -492,6 +557,7 @@ class NotificationTemplateOut(BaseModel): # --- Document --- class DocumentUploadOut(BaseModel): + """文档上传结果响应体。""" file_id: uuid.UUID filename: str file_size: int @@ -500,6 +566,7 @@ class DocumentUploadOut(BaseModel): class DocumentParseResult(BaseModel): + """文档解析结果响应体。""" file_id: uuid.UUID filename: str content: str @@ -508,6 +575,7 @@ class DocumentParseResult(BaseModel): # --- Audit --- class AuditQueryParams(BaseModel): + """审计日志查询参数。""" page: int = 1 page_size: int = 20 action: str | None = None @@ -518,6 +586,7 @@ class AuditQueryParams(BaseModel): class AuditLogOut(BaseModel): + """审计日志条目响应体。""" id: uuid.UUID operator_id: uuid.UUID | None = None action: str @@ -532,6 +601,7 @@ class AuditLogOut(BaseModel): class AuditLogPage(BaseModel): + """审计日志分页响应体。""" items: list[AuditLogOut] total: int page: int @@ -540,6 +610,7 @@ class AuditLogPage(BaseModel): # --- System Metrics --- class SystemMetricOut(BaseModel): + """系统监控指标响应体。""" id: uuid.UUID metric_type: str value: dict @@ -550,6 +621,7 @@ class SystemMetricOut(BaseModel): class SystemHealthOut(BaseModel): + """系统健康状态响应体。""" status: str service: str uptime_seconds: float @@ -561,6 +633,7 @@ class SystemHealthOut(BaseModel): class UsageStatsOut(BaseModel): + """使用统计响应体。""" total_users: int active_users_today: int total_sessions: int @@ -574,6 +647,7 @@ class UsageStatsOut(BaseModel): # --- Generic Response --- class ApiResponse(BaseModel): + """通用 API 响应包装。""" code: int = 200 message: str = "success" data: dict | list | None = None \ No newline at end of file diff --git a/backend/websocket_manager.py b/backend/websocket_manager.py index 40eb776..ef70ebb 100644 --- a/backend/websocket_manager.py +++ b/backend/websocket_manager.py @@ -3,14 +3,22 @@ from typing import Dict, Set import json import logging -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 当前模块的日志记录器 class WebSocketManager: + """WebSocket 连接管理器。 + + 管理用户与多个 WebSocket 连接的生命周期, + 支持单用户多连接、定向推送和全局广播功能。 + """ + def __init__(self): - self.active_connections: Dict[str, Set[WebSocket]] = {} + """初始化 WebSocket 连接池。""" + self.active_connections: Dict[str, Set[WebSocket]] = {} # 活跃连接池:{user_id: {WebSocket, ...}} async def connect(self, websocket: WebSocket, user_id: str): + """接受 WebSocket 连接并注册到指定用户的连接池中。""" await websocket.accept() if user_id not in self.active_connections: self.active_connections[user_id] = set() @@ -18,6 +26,7 @@ class WebSocketManager: logger.info(f"WebSocket 用户 {user_id} 已连接") def disconnect(self, websocket: WebSocket, user_id: str): + """断开 WebSocket 连接并从用户连接池中移除。""" if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: @@ -25,6 +34,7 @@ class WebSocketManager: logger.info(f"WebSocket 用户 {user_id} 已断开") async def send_to_user(self, user_id: str, message: dict): + """向指定用户的所有活跃连接发送消息,自动清理断开连接。""" if user_id not in self.active_connections: return False dead_connections = set() @@ -42,8 +52,9 @@ class WebSocketManager: return sent_count > 0 async def broadcast(self, message: dict): + """向当前所有活跃用户广播消息,遍历所有用户并调用 send_to_user。""" for user_id in list(self.active_connections.keys()): await self.send_to_user(user_id, message) -ws_manager = WebSocketManager() \ No newline at end of file +ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例 \ No newline at end of file