Browse Source

添加注释

master
MSI-7950X\刘泽明 1 day ago
parent
commit
6bf296eb0a
  1. 1
      backend/__init__.py
  2. 1
      backend/agentscope_integration/__init__.py
  3. 139
      backend/agentscope_integration/factory.py
  4. 34
      backend/agentscope_integration/hooks/rbac_hook.py
  5. 52
      backend/agentscope_integration/memory/user_memory.py
  6. 52
      backend/agentscope_integration/tools/document_tools.py
  7. 59
      backend/agentscope_integration/tools/manager_tools.py
  8. 82
      backend/agentscope_integration/tools/task_tools.py
  9. 60
      backend/agentscope_integration/tools/wecom_tools.py
  10. 38
      backend/config.py
  11. 6
      backend/database.py
  12. 16
      backend/dependencies.py
  13. 49
      backend/main.py
  14. 31
      backend/middleware/apikey_auth.py
  15. 73
      backend/middleware/cache_manager.py
  16. 80
      backend/middleware/rate_limiter.py
  17. 35
      backend/middleware/rbac_middleware.py
  18. 446
      backend/models/__init__.py
  19. 1
      backend/modules/__init__.py
  20. 1
      backend/modules/agent_manager/__init__.py
  21. 85
      backend/modules/agent_manager/router.py
  22. 1
      backend/modules/audit/__init__.py
  23. 62
      backend/modules/audit/router.py
  24. 4
      backend/modules/auth/__init__.py
  25. 123
      backend/modules/auth/router.py
  26. 1
      backend/modules/chat/__init__.py
  27. 52
      backend/modules/chat/router.py
  28. 5
      backend/modules/custom_tool/__init__.py
  29. 73
      backend/modules/custom_tool/executor.py
  30. 81
      backend/modules/custom_tool/parser.py
  31. 114
      backend/modules/custom_tool/router.py
  32. 83
      backend/modules/document/router.py
  33. 19
      backend/modules/flow_engine/engine.py
  34. 8
      backend/modules/flow_engine/gateway.py
  35. 6
      backend/modules/flow_engine/router.py
  36. 109
      backend/modules/mcp_registry/router.py
  37. 135
      backend/modules/memory/manager.py
  38. 5
      backend/modules/memory/router.py
  39. 5
      backend/modules/memory/schemas.py
  40. 118
      backend/modules/model_provider/router.py
  41. 5
      backend/modules/monitor/__init__.py
  42. 87
      backend/modules/monitor/router.py
  43. 1
      backend/modules/notification/__init__.py
  44. 143
      backend/modules/notification/router.py
  45. 5
      backend/modules/org/__init__.py
  46. 186
      backend/modules/org/router.py
  47. 74
      backend/modules/rag/knowledge.py
  48. 94
      backend/modules/rag/router.py
  49. 1
      backend/modules/rbac/__init__.py
  50. 90
      backend/modules/system/router.py
  51. 1
      backend/modules/task/__init__.py
  52. 64
      backend/modules/wecom/router.py
  53. 74
      backend/schemas/__init__.py
  54. 17
      backend/websocket_manager.py

1
backend/__init__.py

@ -0,0 +1 @@
"""后端应用包。"""

1
backend/agentscope_integration/__init__.py

@ -0,0 +1 @@
"""AgentScope 集成模块初始化。"""

139
backend/agentscope_integration/factory.py

@ -1,3 +1,8 @@
"""AgentScope 智能体工厂模块。
提供统一的智能体创建接口根据用户类型员工/管理者/任务/文档创建对应的 AI 智能体实例
支持智能体缓存以减少重复创建的开销
"""
from agentscope.agent import AgentBase from agentscope.agent import AgentBase
from agentscope.agent._react_agent import ReActAgent from agentscope.agent._react_agent import ReActAgent
from agentscope.model import OpenAIChatModel from agentscope.model import OpenAIChatModel
@ -10,13 +15,23 @@ from .hooks.rbac_hook import register_rbac_hooks_for_user
class AgentFactory: class AgentFactory:
_model: OpenAIChatModel | None = None """智能体工厂类,负责创建和管理不同类型的 AI 智能体实例。
_formatter: OpenAIChatFormatter | None = None
_agent_cache: dict[str, AgentBase] = {} 采用类级别的单例模式缓存模型和格式化器实例
_MAX_CACHE_SIZE = 50 同时为每个用户缓存已创建的智能体避免重复初始化
"""
_model: OpenAIChatModel | None = None # 缓存的大语言模型实例
_formatter: OpenAIChatFormatter | None = None # 缓存的消息格式化器实例
_agent_cache: dict[str, AgentBase] = {} # 智能体缓存:{agent_type_user_id: AgentBase}
_MAX_CACHE_SIZE = 50 # 智能体缓存上限
@classmethod @classmethod
def _get_model(cls) -> OpenAIChatModel: def _get_model(cls) -> OpenAIChatModel:
"""获取或创建全局共享的大语言模型实例。
Returns:
OpenAIChatModel: 配置好的大语言模型实例
"""
if cls._model is None: if cls._model is None:
cls._model = OpenAIChatModel( cls._model = OpenAIChatModel(
config_name="enterprise_model", config_name="enterprise_model",
@ -28,6 +43,11 @@ class AgentFactory:
@classmethod @classmethod
def _get_formatter(cls) -> OpenAIChatFormatter: def _get_formatter(cls) -> OpenAIChatFormatter:
"""获取或创建全局共享的消息格式化器实例。
Returns:
OpenAIChatFormatter: OpenAI 聊天格式化器实例
"""
if cls._formatter is None: if cls._formatter is None:
cls._formatter = OpenAIChatFormatter() cls._formatter = OpenAIChatFormatter()
return cls._formatter return cls._formatter
@ -40,7 +60,21 @@ class AgentFactory:
user_name: str, user_name: str,
department_id: str | None = None, department_id: str | None = None,
) -> AgentBase: ) -> AgentBase:
cache_key = f"{agent_type}_{user_id}" """根据智能体类型和用户信息创建对应的 AI 智能体。
优先从缓存中获取已存在的智能体实例如果缓存中不存在则创建新实例
缓存满时会自动淘汰最旧的智能体实例
Args:
agent_type: 智能体类型支持 employee/manager/task/document
user_id: 用户唯一标识
user_name: 用户显示名称
department_id: 所属部门 ID可选
Returns:
AgentBase: 创建或缓存的 AI 智能体实例
"""
cache_key = f"{agent_type}_{user_id}" # 缓存键:智能体类型_用户ID
if cache_key in cls._agent_cache: if cache_key in cls._agent_cache:
return cls._agent_cache[cache_key] return cls._agent_cache[cache_key]
@ -66,18 +100,33 @@ class AgentFactory:
@classmethod @classmethod
async def _create_employee_agent(cls, user_id, user_name, department_id, model, formatter): async def _create_employee_agent(cls, user_id, user_name, department_id, model, formatter):
"""创建员工专属 AI 助手智能体。
该智能体具备文档处理通知发送知识库查询等功能
数据权限范围限定为仅能访问员工自己的数据
Args:
user_id: 用户唯一标识
user_name: 用户显示名称
department_id: 所属部门 ID
model: 大语言模型实例
formatter: 消息格式化器实例
Returns:
ReActAgent: 配置好的员工 AI 智能体
"""
from .tools.wecom_tools import send_notification from .tools.wecom_tools import send_notification
from .tools.document_tools import parse_document, format_correction from .tools.document_tools import parse_document, format_correction
toolkit = Toolkit() toolkit = Toolkit()
toolkit.register_tool_function(send_notification) toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
toolkit.register_tool_function(parse_document) toolkit.register_tool_function(parse_document) # 注册文档解析工具
toolkit.register_tool_function(format_correction) toolkit.register_tool_function(format_correction) # 注册格式修正工具
knowledge = None knowledge = None
try: try:
from modules.rag.knowledge import get_knowledge_base from modules.rag.knowledge import get_knowledge_base
knowledge = get_knowledge_base() knowledge = get_knowledge_base() # 尝试获取知识库
except Exception: except Exception:
pass pass
@ -108,22 +157,36 @@ class AgentFactory:
"user_name": user_name, "user_name": user_name,
"role": "employee", "role": "employee",
"department_id": department_id or "", "department_id": department_id or "",
"data_scope": "self_only", "data_scope": "self_only", # 数据权限:仅限本人
}) })
return agent return agent
@classmethod @classmethod
async def _create_manager_agent(cls, user_id, user_name, model, formatter): async def _create_manager_agent(cls, user_id, user_name, model, formatter):
"""创建管理者专属 AI 分析助手智能体。
该智能体具备下属管理团队效率分析任务统计等管理功能
数据权限范围限定为仅能访问其下属员工的数据
Args:
user_id: 用户唯一标识
user_name: 用户显示名称
model: 大语言模型实例
formatter: 消息格式化器实例
Returns:
ReActAgent: 配置好的管理者 AI 智能体
"""
from .tools.manager_tools import list_subordinates, get_employee_dashboard, generate_efficiency_report, get_task_statistics from .tools.manager_tools import list_subordinates, get_employee_dashboard, generate_efficiency_report, get_task_statistics
from .tools.wecom_tools import send_notification from .tools.wecom_tools import send_notification
toolkit = Toolkit() toolkit = Toolkit()
toolkit.register_tool_function(list_subordinates) toolkit.register_tool_function(list_subordinates) # 注册下属列表查询工具
toolkit.register_tool_function(get_employee_dashboard) toolkit.register_tool_function(get_employee_dashboard) # 注册员工看板查询工具
toolkit.register_tool_function(generate_efficiency_report) toolkit.register_tool_function(generate_efficiency_report) # 注册效率报告生成工具
toolkit.register_tool_function(get_task_statistics) toolkit.register_tool_function(get_task_statistics) # 注册任务统计查询工具
toolkit.register_tool_function(send_notification) toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
agent = ReActAgent( agent = ReActAgent(
name=f"ManagerAI_{user_name}", name=f"ManagerAI_{user_name}",
@ -150,22 +213,36 @@ class AgentFactory:
"user_id": user_id, "user_id": user_id,
"user_name": user_name, "user_name": user_name,
"role": "dept_manager", "role": "dept_manager",
"data_scope": "subordinate_only", "data_scope": "subordinate_only", # 数据权限:仅限下属
}) })
return agent return agent
@classmethod @classmethod
async def _create_task_agent(cls, user_id, user_name, model, formatter): async def _create_task_agent(cls, user_id, user_name, model, formatter):
"""创建任务管理专属 AI 助手智能体。
该智能体专注于任务的创建查询更新和通知推送
帮助用户高效管理日常工作事务
Args:
user_id: 用户唯一标识
user_name: 用户显示名称
model: 大语言模型实例
formatter: 消息格式化器实例
Returns:
ReActAgent: 配置好的任务管理 AI 智能体
"""
from .tools.task_tools import list_tasks, create_task, get_task, update_task from .tools.task_tools import list_tasks, create_task, get_task, update_task
from .tools.wecom_tools import send_notification from .tools.wecom_tools import send_notification
toolkit = Toolkit() toolkit = Toolkit()
toolkit.register_tool_function(list_tasks) toolkit.register_tool_function(list_tasks) # 注册任务列表查询工具
toolkit.register_tool_function(create_task) toolkit.register_tool_function(create_task) # 注册任务创建工具
toolkit.register_tool_function(get_task) toolkit.register_tool_function(get_task) # 注册任务详情查询工具
toolkit.register_tool_function(update_task) toolkit.register_tool_function(update_task) # 注册任务更新工具
toolkit.register_tool_function(send_notification) toolkit.register_tool_function(send_notification) # 注册企业微信通知工具
agent = ReActAgent( agent = ReActAgent(
name=f"TaskAI_{user_name}", name=f"TaskAI_{user_name}",
@ -192,16 +269,30 @@ class AgentFactory:
@classmethod @classmethod
async def _create_document_agent(cls, user_id, user_name, model, formatter): async def _create_document_agent(cls, user_id, user_name, model, formatter):
"""创建文档处理专属 AI 助手智能体。
该智能体专注于各类办公文档的解析格式修正和内容提取
支持 PDFWordExcel 等常见格式
Args:
user_id: 用户唯一标识
user_name: 用户显示名称
model: 大语言模型实例
formatter: 消息格式化器实例
Returns:
ReActAgent: 配置好的文档处理 AI 智能体
"""
from .tools.document_tools import parse_document, format_correction from .tools.document_tools import parse_document, format_correction
toolkit = Toolkit() toolkit = Toolkit()
toolkit.register_tool_function(parse_document) toolkit.register_tool_function(parse_document) # 注册文档解析工具
toolkit.register_tool_function(format_correction) toolkit.register_tool_function(format_correction) # 注册格式修正工具
knowledge = None knowledge = None
try: try:
from modules.rag.knowledge import get_knowledge_base from modules.rag.knowledge import get_knowledge_base
knowledge = get_knowledge_base() knowledge = get_knowledge_base() # 尝试获取知识库
except Exception: except Exception:
pass pass

34
backend/agentscope_integration/hooks/rbac_hook.py

@ -1,9 +1,34 @@
"""RBAC 权限钩子模块。
提供 AgentScope 智能体的 RBAC基于角色的访问控制权限钩子
在智能体回复前自动注入用户上下文信息用户ID角色部门数据权限范围到消息元数据中
"""
from agentscope.agent import AgentBase from agentscope.agent import AgentBase
from agentscope.message import Msg from agentscope.message import Msg
def create_rbac_pre_reply_hook(user_context: dict): def create_rbac_pre_reply_hook(user_context: dict):
"""创建 RBAC 预回复钩子函数。
该钩子会在智能体每次回复前执行将用户的身份信息注入到消息元数据中
以便后续的工具调用和权限校验能够获取正确的用户上下文
Args:
user_context: 用户上下文信息字典包含 user_idroledepartment_iddata_scope 等字段
Returns:
callable: 异步钩子函数用于注册到智能体的 pre_reply 钩子点
"""
async def rbac_pre_reply_hook(self: AgentBase, kwargs: dict) -> dict: async def rbac_pre_reply_hook(self: AgentBase, kwargs: dict) -> dict:
"""RBAC 预回复钩子内部实现。
Args:
self: 智能体实例
kwargs: 传递给智能体 reply 方法的参数字典
Returns:
dict: 修改后的参数字典消息元数据中已注入用户上下文信息
"""
msg = kwargs.get("msg") msg = kwargs.get("msg")
if msg and isinstance(msg, Msg): if msg and isinstance(msg, Msg):
msg.metadata = msg.metadata or {} msg.metadata = msg.metadata or {}
@ -18,6 +43,15 @@ def create_rbac_pre_reply_hook(user_context: dict):
def register_rbac_hooks_for_user(agent: AgentBase, user_context: dict): def register_rbac_hooks_for_user(agent: AgentBase, user_context: dict):
"""为指定智能体注册 RBAC 权限钩子。
将用户上下文信息绑定到智能体的 pre_reply 钩子点
确保智能体在处理每条消息时都能携带正确的用户身份信息
Args:
agent: 目标智能体实例
user_context: 用户上下文信息字典
"""
hook = create_rbac_pre_reply_hook(user_context) hook = create_rbac_pre_reply_hook(user_context)
hook_name = f"rbac_{user_context['user_id']}" hook_name = f"rbac_{user_context['user_id']}"
agent.register_instance_hook("pre_reply", hook_name, hook) agent.register_instance_hook("pre_reply", hook_name, hook)

52
backend/agentscope_integration/memory/user_memory.py

@ -1,13 +1,39 @@
"""用户隔离记忆模块。
提供基于用户 ID 隔离的记忆存储机制确保每个用户只能访问自己的对话历史
通过包装 AgentScope MemoryBase 实现用户级别的记忆隔离
"""
from agentscope.memory import MemoryBase, InMemoryMemory from agentscope.memory import MemoryBase, InMemoryMemory
from agentscope.message import Msg from agentscope.message import Msg
class UserIsolatedMemory(MemoryBase): class UserIsolatedMemory(MemoryBase):
"""用户隔离记忆类,确保每个用户只能访问自己对话历史的记忆管理器。
通过在消息元数据中标记用户 ID在获取记忆时过滤出当前用户的消息
实现多用户环境下的对话历史隔离
Attributes:
user_id: 当前记忆实例绑定的用户唯一标识
_backend: 底层记忆存储实例默认为 InMemoryMemory
"""
def __init__(self, user_id: str, backend_memory: MemoryBase | None = None): def __init__(self, user_id: str, backend_memory: MemoryBase | None = None):
"""初始化用户隔离记忆实例。
Args:
user_id: 用户唯一标识
backend_memory: 可选的底层记忆存储实例不提供则使用内存存储
"""
self.user_id = user_id self.user_id = user_id
self._backend = backend_memory or InMemoryMemory() self._backend = backend_memory or InMemoryMemory()
async def add(self, msg: Msg | list[Msg] | None) -> None: async def add(self, msg: Msg | list[Msg] | None) -> None:
"""添加消息到记忆中,自动标记当前用户 ID。
Args:
msg: 要添加的消息可以是单条消息消息列表或 None
"""
if msg is None: if msg is None:
return return
msgs = msg if isinstance(msg, list) else [msg] msgs = msg if isinstance(msg, list) else [msg]
@ -17,14 +43,40 @@ class UserIsolatedMemory(MemoryBase):
await self._backend.add(msg) await self._backend.add(msg)
async def get_memory(self, **kwargs) -> list[Msg]: async def get_memory(self, **kwargs) -> list[Msg]:
"""获取当前用户的记忆历史。
从底层存储中获取所有消息后过滤出属于当前用户的消息
Args:
**kwargs: 传递给底层存储的额外参数
Returns:
list[Msg]: 属于当前用户的消息列表
"""
all_msgs = await self._backend.get_memory(**kwargs) all_msgs = await self._backend.get_memory(**kwargs)
return [m for m in all_msgs if m.metadata.get("_user_id") == self.user_id] return [m for m in all_msgs if m.metadata.get("_user_id") == self.user_id]
async def delete_by_mark(self, mark: str) -> None: async def delete_by_mark(self, mark: str) -> None:
"""根据标记删除消息。
Args:
mark: 要删除的消息标记
"""
await self._backend.delete_by_mark(mark) await self._backend.delete_by_mark(mark)
async def update_messages_mark(self, msg_ids: list[str], new_mark: str) -> None: async def update_messages_mark(self, msg_ids: list[str], new_mark: str) -> None:
"""更新消息的标记。
Args:
msg_ids: 要更新标记的消息 ID 列表
new_mark: 新的标记字符串
"""
await self._backend.update_messages_mark(msg_ids, new_mark) await self._backend.update_messages_mark(msg_ids, new_mark)
async def update_compressed_summary(self, summary: str) -> None: async def update_compressed_summary(self, summary: str) -> None:
"""更新压缩后的记忆摘要。
Args:
summary: 新的记忆摘要字符串
"""
await self._backend.update_compressed_summary(summary) await self._backend.update_compressed_summary(summary)

52
backend/agentscope_integration/tools/document_tools.py

@ -1,12 +1,22 @@
"""文档处理工具模块。
提供多种办公文档格式的解析和格式修正功能支持 PDFWordExcel 等格式
采用延迟导入策略仅在需要时才尝试加载相应的依赖库
"""
import os import os
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
_IMPORT_ERRORS: dict[str, str] = {} _IMPORT_ERRORS: dict[str, str] = {} # 记录各库的导入错误信息,避免重复尝试
def _try_import_pdf() -> bool: def _try_import_pdf() -> bool:
"""尝试导入 PDF 解析库 PyPDF2。
Returns:
bool: 导入成功返回 True失败返回 False
"""
global _IMPORT_ERRORS global _IMPORT_ERRORS
if "pdf" in _IMPORT_ERRORS: if "pdf" in _IMPORT_ERRORS:
return False return False
@ -19,6 +29,11 @@ def _try_import_pdf() -> bool:
def _try_import_docx() -> bool: def _try_import_docx() -> bool:
"""尝试导入 Word 文档解析库 python-docx。
Returns:
bool: 导入成功返回 True失败返回 False
"""
global _IMPORT_ERRORS global _IMPORT_ERRORS
if "docx" in _IMPORT_ERRORS: if "docx" in _IMPORT_ERRORS:
return False return False
@ -31,6 +46,11 @@ def _try_import_docx() -> bool:
def _try_import_excel() -> bool: def _try_import_excel() -> bool:
"""尝试导入 Excel 表格解析库 openpyxl。
Returns:
bool: 导入成功返回 True失败返回 False
"""
global _IMPORT_ERRORS global _IMPORT_ERRORS
if "excel" in _IMPORT_ERRORS: if "excel" in _IMPORT_ERRORS:
return False return False
@ -43,8 +63,20 @@ def _try_import_excel() -> bool:
def parse_document(file_path: str, file_type: str = "auto") -> str: def parse_document(file_path: str, file_type: str = "auto") -> str:
ext = os.path.splitext(file_path)[1].lower() """解析各类办公文档,提取文本内容。
自动根据文件扩展名识别文档类型支持 PDFWordExcelPPT 和纯文本
Args:
file_path: 文档文件的完整路径
file_type: 文档类型auto 表示自动识别
Returns:
str: 提取的文档文本内容或错误信息
"""
ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名
# 根据扩展名自动识别文件类型
if file_type == "auto": if file_type == "auto":
if ext in (".pdf",): if ext in (".pdf",):
file_type = "pdf" file_type = "pdf"
@ -101,7 +133,7 @@ def parse_document(file_path: str, file_type: str = "auto") -> str:
import openpyxl import openpyxl
try: try:
wb = openpyxl.load_workbook(file_path, data_only=True) wb = openpyxl.load_workbook(file_path, data_only=True) # data_only 获取计算后的值
result_parts = [] result_parts = []
for sheet_name in wb.sheetnames: for sheet_name in wb.sheetnames:
ws = wb[sheet_name] ws = wb[sheet_name]
@ -118,6 +150,7 @@ def parse_document(file_path: str, file_type: str = "auto") -> str:
if file_type in ("ppt", "pptx"): if file_type in ("ppt", "pptx"):
return "PPT 解析暂不支持,请将内容复制到 Word 或 PDF 后重试。" return "PPT 解析暂不支持,请将内容复制到 Word 或 PDF 后重试。"
# 尝试以纯文本方式读取文件
try: try:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
return f.read() return f.read()
@ -135,6 +168,17 @@ def parse_document(file_path: str, file_type: str = "auto") -> str:
def format_correction(content: str, format_rules: str = "standard") -> str: def format_correction(content: str, format_rules: str = "standard") -> str:
"""对文档内容进行格式修正。
根据指定的格式规则对文本进行标准化处理支持标准和企业公文两种模式
Args:
content: 待修正的原始文本内容
format_rules: 格式规则standard 为标准模式enterprise 为企业公文模式
Returns:
str: 格式修正后的文本内容
"""
parts = [] parts = []
parts.append(f"[格式规则: {format_rules}]\n") parts.append(f"[格式规则: {format_rules}]\n")

59
backend/agentscope_integration/tools/manager_tools.py

@ -1,3 +1,8 @@
"""管理者工具模块。
提供管理者专属的工具函数包括下属员工查询员工看板数据获取团队效率报告生成和任务统计等功能
通过内部 HTTP API 与后端组织监控服务通信使用 JWT 服务令牌进行认证
"""
import httpx import httpx
import logging import logging
import os import os
@ -5,13 +10,18 @@ import jwt
import time import time
from config import settings from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") _INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") # 内部 API 基础地址
_client: httpx.Client | None = None _client: httpx.Client | None = None # 全局复用的 HTTP 客户端实例
def _get_client() -> httpx.Client: def _get_client() -> httpx.Client:
"""获取或创建全局复用的 HTTP 客户端实例。
Returns:
httpx.Client: 配置了超时时间的 HTTP 客户端
"""
global _client global _client
if _client is None: if _client is None:
_client = httpx.Client(timeout=30) _client = httpx.Client(timeout=30)
@ -19,6 +29,11 @@ def _get_client() -> httpx.Client:
def _get_service_token() -> str | None: def _get_service_token() -> str | None:
"""生成用于内部服务间调用的 JWT 令牌。
Returns:
str | None: 编码后的 JWT 令牌生成失败返回 None
"""
try: try:
payload = {"sub": "system_tool", "exp": int(time.time()) + 3600, "type": "service"} payload = {"sub": "system_tool", "exp": int(time.time()) + 3600, "type": "service"}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256")
@ -28,10 +43,19 @@ def _get_service_token() -> str | None:
def _headers(token: str | None = None) -> dict: def _headers(token: str | None = None) -> dict:
"""构建 HTTP 请求头,包含认证令牌。
Args:
token: 可选的自定义令牌不提供则自动生成服务令牌
Returns:
dict: 包含 Authorization 头的字典
"""
t = token or _get_service_token() t = token or _get_service_token()
return {"Authorization": f"Bearer {t}"} if t else {} return {"Authorization": f"Bearer {t}"} if t else {}
# 工具函数描述 Schema,用于 AgentScope 工具注册
SCHEMAS = { SCHEMAS = {
"list_subordinates": { "list_subordinates": {
"name": "list_subordinates", "name": "list_subordinates",
@ -67,6 +91,11 @@ SCHEMAS = {
def list_subordinates() -> str: def list_subordinates() -> str:
"""查询当前管理者名下的下属员工列表。
Returns:
str: 格式化的下属员工列表文本或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/org/subordinates", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/org/subordinates", headers=_headers())
users = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) users = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", [])
@ -81,6 +110,14 @@ def list_subordinates() -> str:
def get_employee_dashboard(employee_id: str) -> str: def get_employee_dashboard(employee_id: str) -> str:
"""查询指定员工的工作看板数据,包括任务完成率、响应时间等指标。
Args:
employee_id: 员工唯一标识 ID
Returns:
str: 格式化的员工看板数据或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employee/{employee_id}/dashboard", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employee/{employee_id}/dashboard", headers=_headers())
data = resp.json() data = resp.json()
@ -90,6 +127,14 @@ def get_employee_dashboard(employee_id: str) -> str:
def generate_efficiency_report(department_id: str | None = None) -> str: def generate_efficiency_report(department_id: str | None = None) -> str:
"""生成团队效率分析报告,包含各员工的任务数和完成率统计。
Args:
department_id: 可选的部门 ID用于限定报告范围
Returns:
str: 格式化的团队效率报告或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employees", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/monitor/employees", headers=_headers())
employees = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) employees = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", [])
@ -109,6 +154,14 @@ def generate_efficiency_report(department_id: str | None = None) -> str:
def get_task_statistics(employee_id: str | None = None) -> str: def get_task_statistics(employee_id: str | None = None) -> str:
"""查询任务统计数据,支持按员工筛选。
Args:
employee_id: 可选的员工 ID用于筛选特定员工的任务
Returns:
str: 格式化的任务统计信息或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers())
tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", [])

82
backend/agentscope_integration/tools/task_tools.py

@ -1,3 +1,8 @@
"""任务管理工具模块。
提供任务相关操作的封装包括任务列表查询创建获取详情更新状态等功能
通过内部 HTTP API 与后端任务服务通信使用 JWT 服务令牌进行认证
"""
import httpx import httpx
import logging import logging
import os import os
@ -5,13 +10,18 @@ import jwt
import time import time
from config import settings from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
_INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") _INTERNAL_BASE = os.getenv("INTERNAL_API_BASE", "http://127.0.0.1:8000/api") # 内部 API 基础地址
_client: httpx.Client | None = None _client: httpx.Client | None = None # 全局复用的 HTTP 客户端实例
def _get_client() -> httpx.Client: def _get_client() -> httpx.Client:
"""获取或创建全局复用的 HTTP 客户端实例。
Returns:
httpx.Client: 配置了超时时间的 HTTP 客户端
"""
global _client global _client
if _client is None: if _client is None:
_client = httpx.Client(timeout=30) _client = httpx.Client(timeout=30)
@ -19,11 +29,16 @@ def _get_client() -> httpx.Client:
def _get_service_token() -> str | None: def _get_service_token() -> str | None:
"""生成用于内部服务间调用的 JWT 令牌。
Returns:
str | None: 编码后的 JWT 令牌生成失败返回 None
"""
try: try:
payload = { payload = {
"sub": "system_tool", "sub": "system_tool", # 令牌主体标识为系统工具
"exp": int(time.time()) + 3600, "exp": int(time.time()) + 3600, # 1 小时后过期
"type": "service", "type": "service", # 令牌类型为服务令牌
} }
token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256") token = jwt.encode(payload, settings.JWT_SECRET, algorithm="HS256")
return token return token
@ -32,10 +47,19 @@ def _get_service_token() -> str | None:
def _headers(token: str | None = None) -> dict: def _headers(token: str | None = None) -> dict:
"""构建 HTTP 请求头,包含认证令牌。
Args:
token: 可选的自定义令牌不提供则自动生成服务令牌
Returns:
dict: 包含 Authorization 头的字典
"""
t = token or _get_service_token() t = token or _get_service_token()
return {"Authorization": f"Bearer {t}"} if t else {} return {"Authorization": f"Bearer {t}"} if t else {}
# 工具函数描述 Schema,用于 AgentScope 工具注册
SCHEMAS = { SCHEMAS = {
"list_tasks": { "list_tasks": {
"name": "list_tasks", "name": "list_tasks",
@ -97,6 +121,14 @@ SCHEMAS = {
def list_tasks(status: str | None = None) -> str: def list_tasks(status: str | None = None) -> str:
"""查询任务列表,支持按状态筛选。
Args:
status: 可选的任务状态筛选条件todo/in_progress/done
Returns:
str: 格式化的任务列表文本或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/tasks", headers=_headers())
tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", []) tasks = resp.json() if isinstance(resp.json(), list) else resp.json().get("data", [])
@ -118,6 +150,18 @@ def list_tasks(status: str | None = None) -> str:
def create_task(title: str, description: str = "", assignee_id: str = "", priority: str = "medium", deadline: str | None = None) -> str: def create_task(title: str, description: str = "", assignee_id: str = "", priority: str = "medium", deadline: str | None = None) -> str:
"""创建新任务。
Args:
title: 任务标题必填
description: 任务描述
assignee_id: 负责人用户 ID
priority: 任务优先级默认 medium
deadline: 截止日期
Returns:
str: 创建结果描述或错误信息
"""
try: try:
body = {"title": title, "description": description, "assignee_id": assignee_id, "priority": priority, "deadline": deadline} body = {"title": title, "description": description, "assignee_id": assignee_id, "priority": priority, "deadline": deadline}
resp = _get_client().post(f"{_INTERNAL_BASE}/tasks", json=body, headers=_headers()) resp = _get_client().post(f"{_INTERNAL_BASE}/tasks", json=body, headers=_headers())
@ -128,6 +172,14 @@ def create_task(title: str, description: str = "", assignee_id: str = "", priori
def get_task(task_id: str) -> str: def get_task(task_id: str) -> str:
"""查询指定任务的详细信息。
Args:
task_id: 任务唯一标识 ID
Returns:
str: 格式化的任务详情文本或错误信息
"""
try: try:
resp = _get_client().get(f"{_INTERNAL_BASE}/tasks/{task_id}", headers=_headers()) resp = _get_client().get(f"{_INTERNAL_BASE}/tasks/{task_id}", headers=_headers())
t = resp.json() t = resp.json()
@ -137,6 +189,16 @@ def get_task(task_id: str) -> str:
def update_task(task_id: str, status: str | None = None, description: str | None = None) -> str: def update_task(task_id: str, status: str | None = None, description: str | None = None) -> str:
"""更新任务状态或描述。
Args:
task_id: 任务唯一标识 ID
status: 新的任务状态todo/in_progress/done
description: 新的任务描述
Returns:
str: 更新结果描述或错误信息
"""
try: try:
body = {} body = {}
if status: if status:
@ -150,6 +212,14 @@ def update_task(task_id: str, status: str | None = None, description: str | None
def push_task_to_wecom(task_id: str) -> str: def push_task_to_wecom(task_id: str) -> str:
"""将任务通知推送到企业微信。
Args:
task_id: 任务唯一标识 ID
Returns:
str: 推送结果描述或错误信息
"""
try: try:
resp = _get_client().post(f"{_INTERNAL_BASE}/tasks/{task_id}/push", headers=_headers()) resp = _get_client().post(f"{_INTERNAL_BASE}/tasks/{task_id}/push", headers=_headers())
return f"任务 {task_id[:8]} 已推送至企业微信" return f"任务 {task_id[:8]} 已推送至企业微信"

60
backend/agentscope_integration/tools/wecom_tools.py

@ -1,12 +1,29 @@
"""企业微信工具模块。
提供企业微信 API 的封装支持发送消息查询用户信息群消息发送等功能
包含 access_token 的自动获取和缓存机制
"""
import httpx import httpx
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
_WECOM_ACCESS_TOKEN: dict = {"token": None, "expires_at": 0} _WECOM_ACCESS_TOKEN: dict = {"token": None, "expires_at": 0} # 企业微信 access_token 缓存
def _get_access_token(corp_id: str, app_secret: str) -> str | None: def _get_access_token(corp_id: str, app_secret: str) -> str | None:
"""获取或刷新企业微信 access_token。
优先使用缓存中的 token如果已过期或不存在则重新请求
Token 过期前 5 分钟会自动刷新
Args:
corp_id: 企业微信 CorpID
app_secret: 企业微信应用 Secret
Returns:
str | None: 有效的 access_token获取失败返回 None
"""
if not corp_id or not app_secret: if not corp_id or not app_secret:
logger.warning("WECOM_CORP_ID 或 WECOM_APP_SECRET 未配置,无法发送企微通知") logger.warning("WECOM_CORP_ID 或 WECOM_APP_SECRET 未配置,无法发送企微通知")
return None return None
@ -22,7 +39,7 @@ def _get_access_token(corp_id: str, app_secret: str) -> str | None:
data = resp.json() data = resp.json()
if data.get("errcode") == 0: if data.get("errcode") == 0:
_WECOM_ACCESS_TOKEN["token"] = data["access_token"] _WECOM_ACCESS_TOKEN["token"] = data["access_token"]
_WECOM_ACCESS_TOKEN["expires_at"] = now + data.get("expires_in", 7200) - 300 _WECOM_ACCESS_TOKEN["expires_at"] = now + data.get("expires_in", 7200) - 300 # 提前 5 分钟过期
return _WECOM_ACCESS_TOKEN["token"] return _WECOM_ACCESS_TOKEN["token"]
else: else:
logger.error(f"获取企微 token 失败: {data}") logger.error(f"获取企微 token 失败: {data}")
@ -33,11 +50,28 @@ def _get_access_token(corp_id: str, app_secret: str) -> str | None:
def _get_config(): def _get_config():
"""从全局配置中获取企业微信 CorpID 和 AppSecret。
Returns:
tuple: (corp_id, app_secret) 元组
"""
from config import settings from config import settings
return settings.WECOM_CORP_ID, settings.WECOM_APP_SECRET return settings.WECOM_CORP_ID, settings.WECOM_APP_SECRET
def send_notification(to_user: str, message: str, msg_type: str = "text") -> str: def send_notification(to_user: str, message: str, msg_type: str = "text") -> str:
"""向指定企业微信用户发送通知消息。
支持文本和文本卡片两种消息类型
Args:
to_user: 接收消息的企业微信用户 ID
message: 消息内容
msg_type: 消息类型支持 text/textcard
Returns:
str: 发送结果描述信息
"""
corp_id, app_secret = _get_config() corp_id, app_secret = _get_config()
token = _get_access_token(corp_id, app_secret) token = _get_access_token(corp_id, app_secret)
if not token: if not token:
@ -78,6 +112,14 @@ def send_notification(to_user: str, message: str, msg_type: str = "text") -> str
def query_wecom_user(user_id: str) -> str: def query_wecom_user(user_id: str) -> str:
"""查询企业微信用户的详细信息。
Args:
user_id: 企业微信用户 ID
Returns:
str: 用户信息描述或错误信息
"""
corp_id, app_secret = _get_config() corp_id, app_secret = _get_config()
token = _get_access_token(corp_id, app_secret) token = _get_access_token(corp_id, app_secret)
if not token: if not token:
@ -97,6 +139,18 @@ def query_wecom_user(user_id: str) -> str:
def send_wecom_group_message(message: str, group_id: str | None = None, msg_type: str = "text") -> str: def send_wecom_group_message(message: str, group_id: str | None = None, msg_type: str = "text") -> str:
"""向企业微信群发送消息。
支持文本和 Markdown 两种消息格式
Args:
message: 消息内容
group_id: 企业微信群聊 ID
msg_type: 消息类型支持 text/markdown
Returns:
str: 发送结果描述信息
"""
corp_id, app_secret = _get_config() corp_id, app_secret = _get_config()
token = _get_access_token(corp_id, app_secret) token = _get_access_token(corp_id, app_secret)
if not token: if not token:

38
backend/config.py

@ -3,27 +3,29 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
"""全局配置类,从环境变量加载所有应用配置项,支持通过 .env 文件覆盖。"""
DATABASE_URL: str = os.getenv( DATABASE_URL: str = os.getenv(
"DATABASE_URL", "DATABASE_URL",
"postgresql+asyncpg://enterprise:enterprise123@localhost:5432/enterprise_ai", "postgresql+asyncpg://enterprise:enterprise123@localhost:5432/enterprise_ai",
) ) # PostgreSQL 数据库连接 URL(asyncpg 异步驱动)
REDIS_URL: str = os.getenv("REDIS_URL", "redis://:redis123@localhost:6379/0") REDIS_URL: str = os.getenv("REDIS_URL", "redis://:redis123@localhost:6379/0") # Redis 连接 URL
JWT_SECRET: str = os.getenv("JWT_SECRET", "dev-secret-change-me") JWT_SECRET: str = os.getenv("JWT_SECRET", "dev-secret-change-me") # JWT 令牌签名密钥
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256" # JWT 签名算法
JWT_EXPIRE_MINUTES: int = 1440 JWT_EXPIRE_MINUTES: int = 1440 # JWT 令牌过期时间(分钟),默认 24 小时
LLM_API_KEY: str = os.getenv("LLM_API_KEY", "sk-placeholder") LLM_API_KEY: str = os.getenv("LLM_API_KEY", "sk-placeholder") # 大语言模型 API 密钥
LLM_API_BASE: str = os.getenv("LLM_API_BASE", "https://api.openai.com/v1") LLM_API_BASE: str = os.getenv("LLM_API_BASE", "https://api.openai.com/v1") # 大语言模型 API 基础地址
LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o-mini") LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o-mini") # 默认使用的大语言模型名称
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) # 每分钟请求速率限制
RATE_LIMIT_BURST: int = int(os.getenv("RATE_LIMIT_BURST", "10")) RATE_LIMIT_BURST: int = int(os.getenv("RATE_LIMIT_BURST", "10")) # 速率限制突发上限
UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads") UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads") # 文件上传存储目录
MAX_UPLOAD_SIZE_MB: int = int(os.getenv("MAX_UPLOAD_SIZE_MB", "50")) MAX_UPLOAD_SIZE_MB: int = int(os.getenv("MAX_UPLOAD_SIZE_MB", "50")) # 最大上传文件大小(MB)
WECOM_CORP_ID: str = os.getenv("WECOM_CORP_ID", "") WECOM_CORP_ID: str = os.getenv("WECOM_CORP_ID", "") # 企业微信 CorpID
WECOM_APP_SECRET: str = os.getenv("WECOM_APP_SECRET", "") WECOM_APP_SECRET: str = os.getenv("WECOM_APP_SECRET", "") # 企业微信应用 Secret
WECOM_TOKEN: str = os.getenv("WECOM_TOKEN", "") WECOM_TOKEN: str = os.getenv("WECOM_TOKEN", "") # 企业微信 Token(用于回调验证)
WECOM_AES_KEY: str = os.getenv("WECOM_AES_KEY", "") WECOM_AES_KEY: str = os.getenv("WECOM_AES_KEY", "") # 企业微信 AES 密钥(用于回调消息解密)
METRICS_COLLECTION_INTERVAL: int = 60 METRICS_COLLECTION_INTERVAL: int = 60 # 系统指标采集间隔(秒)
settings = Settings() settings = Settings() # 全局配置单例实例

6
backend/database.py

@ -5,9 +5,11 @@ from config import settings
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""SQLAlchemy ORM 基类,所有数据库模型均继承此类。"""
pass pass
# 异步数据库引擎,连接池大小 20,最大溢出 40,启用连接健康检查
async_engine = create_async_engine( async_engine = create_async_engine(
settings.DATABASE_URL, settings.DATABASE_URL,
pool_size=20, pool_size=20,
@ -17,6 +19,7 @@ async_engine = create_async_engine(
echo=False, echo=False,
) )
# 异步数据库会话工厂,用于创建数据库会话实例
AsyncSessionLocal = async_sessionmaker( AsyncSessionLocal = async_sessionmaker(
async_engine, async_engine,
class_=AsyncSession, class_=AsyncSession,
@ -25,6 +28,7 @@ AsyncSessionLocal = async_sessionmaker(
async def init_db(): async def init_db():
"""初始化数据库:创建所有表并执行增量迁移。"""
async with async_engine.begin() as conn: async with async_engine.begin() as conn:
from models import Base as MBase from models import Base as MBase
await conn.run_sync(MBase.metadata.create_all) await conn.run_sync(MBase.metadata.create_all)
@ -33,6 +37,7 @@ async def init_db():
async def _run_migrations(): async def _run_migrations():
"""执行数据库增量迁移,在已有表上安全添加新字段。"""
async with async_engine.begin() as conn: async with async_engine.begin() as conn:
await conn.execute(text( await conn.execute(text(
"ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_version_id UUID REFERENCES flow_versions(id)" "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_version_id UUID REFERENCES flow_versions(id)"
@ -43,6 +48,7 @@ async def _run_migrations():
async def get_db(): async def get_db():
"""FastAPI 依赖注入函数,提供数据库会话,自动提交或回滚事务。"""
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
try: try:
yield session yield session

16
backend/dependencies.py

@ -6,20 +6,26 @@ from database import AsyncSessionLocal
from models import User, UserRole, Role, RolePermission, Permission from models import User, UserRole, Role, RolePermission, Permission
from config import settings from config import settings
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False) # HTTP Bearer 令牌认证方案
async def get_current_user( async def get_current_user(
request: Request, request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(security), credentials: HTTPAuthorizationCredentials | None = Depends(security),
) -> dict: ) -> dict:
"""获取当前登录用户及其角色、权限信息。
优先从 request.state.user 读取 RBAC 中间件预填充
否则通过 JWT 令牌解析用户身份并从数据库查询角色权限
返回包含 idusernamedisplay_namerolepermissionsdata_scope 的字典
"""
if hasattr(request.state, "user") and request.state.user: if hasattr(request.state, "user") and request.state.user:
return request.state.user return request.state.user
if credentials: if credentials:
try: try:
payload = jwt.decode( payload = jwt.decode(
credentials.credentials, credentials=credentials.credentials,
settings.JWT_SECRET, settings.JWT_SECRET,
algorithms=[settings.JWT_ALGORITHM], algorithms=[settings.JWT_ALGORITHM],
) )
@ -33,11 +39,13 @@ async def get_current_user(
if not user: if not user:
raise HTTPException(401, "用户不存在") raise HTTPException(401, "用户不存在")
# 查询用户关联的所有角色
ur_result = await db.execute( ur_result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user.id) select(Role).join(UserRole).where(UserRole.user_id == user.id)
) )
roles = ur_result.scalars().all() roles = ur_result.scalars().all()
# 收集所有权限编码和数据权限范围
permissions = [] permissions = []
data_scopes = [] data_scopes = []
for role in roles: for role in roles:
@ -45,7 +53,7 @@ async def get_current_user(
rp_result = await db.execute( rp_result = await db.execute(
select(Permission.code) select(Permission.code)
.join(RolePermission) .join(RolePermission)
.where(RolePermission.role_id == role.id) .where(RolePermission.role_id == role.id)
) )
perms = rp_result.scalars().all() perms = rp_result.scalars().all()
permissions.extend(perms) permissions.extend(perms)
@ -68,6 +76,7 @@ async def get_current_user(
def require_permission(perm_code: str): def require_permission(perm_code: str):
"""权限检查依赖注入工厂,根据权限编码校验当前用户是否拥有该权限。"""
async def checker(user: dict = Depends(get_current_user)) -> dict: async def checker(user: dict = Depends(get_current_user)) -> dict:
if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []): if perm_code not in user.get("permissions", []) and "*:*" not in user.get("permissions", []):
raise HTTPException(403, f"缺少权限: {perm_code}") raise HTTPException(403, f"缺少权限: {perm_code}")
@ -76,6 +85,7 @@ def require_permission(perm_code: str):
async def get_db(): async def get_db():
"""FastAPI 依赖注入函数,提供异步数据库会话,自动提交或回滚。"""
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
try: try:
yield session yield session

49
backend/main.py

@ -30,6 +30,7 @@ from database import AsyncSessionLocal
@asynccontextmanager @asynccontextmanager
async def lifespan(app: AgentApp): async def lifespan(app: AgentApp):
"""应用生命周期管理器,启动时初始化数据库和缓存,关闭时清理资源。"""
await init_db() await init_db()
await cache_manager.connect() await cache_manager.connect()
await init_memory_manager(AsyncSessionLocal) await init_memory_manager(AsyncSessionLocal)
@ -44,32 +45,34 @@ async def lifespan(app: AgentApp):
app = AgentApp( app = AgentApp(
app_name="Enterprise AI Platform", app_name="Enterprise AI Platform", # 应用名称
app_description="企业级 AI Agent 平台 - 双RBAC/企微集成/无代码流编排", app_description="企业级 AI Agent 平台 - 双RBAC/企微集成/无代码流编排", # 应用描述
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url=None, redoc_url=None,
) )
app.middleware("http")(rate_limit_middleware) # 注册全局 HTTP 中间件
app.middleware("http")(rbac_middleware) app.middleware("http")(rate_limit_middleware) # 速率限制中间件
app.middleware("http")(rbac_middleware) # RBAC 权限中间件
app.include_router(auth_router) # 注册所有业务模块的路由
app.include_router(org_router) app.include_router(auth_router) # 认证模块
app.include_router(rbac_router) app.include_router(org_router) # 组织架构模块
app.include_router(wecom_router) app.include_router(rbac_router) # 权限管理模块
app.include_router(agent_manager_router) app.include_router(wecom_router) # 企业微信模块
app.include_router(task_router) app.include_router(agent_manager_router) # 智能体管理模块
app.include_router(monitor_router) app.include_router(task_router) # 任务管理模块
app.include_router(mcp_router) app.include_router(monitor_router) # 监控模块
app.include_router(flow_router) app.include_router(mcp_router) # MCP 服务注册模块
app.include_router(gateway_router) app.include_router(flow_router) # 流程定义管理模块
app.include_router(audit_router) app.include_router(gateway_router) # 流程 API 网关模块
app.include_router(document_router) app.include_router(audit_router) # 审计日志模块
app.include_router(notification_router) app.include_router(document_router) # 文档管理模块
app.include_router(system_router) app.include_router(notification_router) # 通知模块
app.include_router(rag_router) app.include_router(system_router) # 系统设置模块
app.include_router(chat_router) app.include_router(rag_router) # 知识库模块
app.include_router(custom_tool_router) app.include_router(chat_router) # 对话模块
app.include_router(memory_router) app.include_router(custom_tool_router) # 自定义工具模块
app.include_router(model_provider_router) app.include_router(memory_router) # 记忆管理模块
app.include_router(model_provider_router) # 模型供应商管理模块

31
backend/middleware/apikey_auth.py

@ -1,3 +1,8 @@
"""API 密钥认证中间件模块。
提供基于 API 密钥的流程访问认证功能
主要用于外部系统通过 API Key 调用已发布的 AI 流程
"""
import hashlib import hashlib
from datetime import datetime from datetime import datetime
from fastapi import Request, HTTPException from fastapi import Request, HTTPException
@ -8,15 +13,29 @@ from database import get_db
async def authenticate_api_key(request: Request) -> dict: async def authenticate_api_key(request: Request) -> dict:
"""验证请求中的 API 密钥并返回关联的流程信息。
Authorization 请求头中提取 API Key验证其有效性并更新最后使用时间
API Key 必须以 "flow-" 开头验证时对其 SHA-256 哈希值进行数据库匹配
Args:
request: 当前 HTTP 请求对象
Returns:
dict: 包含 flow_idapi_key_id auth_type 的认证信息字典
Raises:
HTTPException: 当缺少认证信息API Key 格式无效或密钥不存在时抛出 401 异常
"""
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "): if not auth_header.startswith("Bearer "):
raise HTTPException(401, "缺少认证信息") raise HTTPException(401, "缺少认证信息")
raw_key = auth_header[7:] raw_key = auth_header[7:] # 提取 Bearer 后的密钥部分
if not raw_key.startswith("flow-"): if not raw_key.startswith("flow-"):
raise HTTPException(401, "无效的API Key格式") raise HTTPException(401, "无效的API Key格式")
key_hash = hashlib.sha256(raw_key.encode()).hexdigest() key_hash = hashlib.sha256(raw_key.encode()).hexdigest() # 计算 SHA-256 哈希值
db_gen = get_db() db_gen = get_db()
db: AsyncSession = await db_gen.__anext__() db: AsyncSession = await db_gen.__anext__()
@ -28,13 +47,13 @@ async def authenticate_api_key(request: Request) -> dict:
if not api_key: if not api_key:
raise HTTPException(401, "API Key无效或已删除") raise HTTPException(401, "API Key无效或已删除")
api_key.last_used_at = datetime.utcnow() api_key.last_used_at = datetime.utcnow() # 更新最后使用时间
await db.flush() await db.flush()
return { return {
"flow_id": str(api_key.flow_id), "flow_id": str(api_key.flow_id), # 关联的流程 ID
"api_key_id": str(api_key.id), "api_key_id": str(api_key.id), # API Key 记录 ID
"auth_type": "api_key", "auth_type": "api_key", # 认证类型标识
} }
finally: finally:
try: try:

73
backend/middleware/cache_manager.py

@ -1,3 +1,8 @@
"""缓存管理器模块。
提供二级缓存机制Redis + 内存用于缓存 API 响应和计算结果
Redis 不可用时自动降级为纯内存缓存保证系统的高可用性
"""
import json import json
import time import time
import asyncio import asyncio
@ -7,13 +12,30 @@ from config import settings
class CacheManager: class CacheManager:
"""二级缓存管理器类,优先使用 Redis 缓存,降级时使用内存缓存。
提供 get/set/delete/delete_pattern 四种基本操作
支持 TTL 过期时间和模式匹配批量删除
Attributes:
_local: 内存缓存存储结构为 {key: (expire_timestamp, value)}
_redis: Redis 异步客户端实例
_redis_available: Redis 是否可用标志
_lock: 异步锁保证内存缓存操作的并发安全
"""
def __init__(self): def __init__(self):
self._local: dict[str, tuple[float, Any]] = {} """初始化缓存管理器实例。"""
self._redis: Redis | None = None self._local: dict[str, tuple[float, Any]] = {} # 内存缓存:{key: (过期时间戳, 值)}
self._redis_available = False self._redis: Redis | None = None # Redis 异步客户端
self._lock = asyncio.Lock() self._redis_available = False # Redis 可用性标志
self._lock = asyncio.Lock() # 异步锁
async def connect(self): async def connect(self):
"""连接到 Redis 服务器。
尝试从配置中的 REDIS_URL 建立连接如果连接失败则标记 Redis 不可用
后续操作将自动降级为纯内存缓存
"""
try: try:
self._redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) self._redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await self._redis.ping() await self._redis.ping()
@ -22,14 +44,31 @@ class CacheManager:
self._redis_available = False self._redis_available = False
async def disconnect(self): async def disconnect(self):
"""断开 Redis 连接,释放资源。"""
if self._redis: if self._redis:
await self._redis.close() await self._redis.close()
@property @property
def available(self) -> bool: def available(self) -> bool:
"""检查缓存是否可用(Redis 或内存至少一个可用)。
Returns:
bool: Redis 可用时返回 True否则返回 False
"""
return self._redis_available return self._redis_available
async def get(self, key: str) -> Any | None: async def get(self, key: str) -> Any | None:
"""从缓存中获取指定键的值。
优先从 Redis 获取如果 Redis 不可用或未找到则从内存缓存获取
内存缓存中的过期条目会被自动清理
Args:
key: 缓存键
Returns:
Any | None: 缓存值未找到或已过期返回 None
"""
if self._redis_available and self._redis: if self._redis_available and self._redis:
try: try:
val = await self._redis.get(key) val = await self._redis.get(key)
@ -48,6 +87,15 @@ class CacheManager:
return None return None
async def set(self, key: str, value: Any, ttl: int = 300): async def set(self, key: str, value: Any, ttl: int = 300):
"""将值设置到缓存中,指定过期时间。
同时写入 Redis 和内存缓存Redis 写入失败不影响内存缓存
Args:
key: 缓存键
value: 要缓存的值支持任意可 JSON 序列化的类型
ttl: 过期时间默认 300 5 分钟
"""
if self._redis_available and self._redis: if self._redis_available and self._redis:
try: try:
await self._redis.setex(key, ttl, json.dumps(value, default=str)) await self._redis.setex(key, ttl, json.dumps(value, default=str))
@ -56,6 +104,7 @@ class CacheManager:
async with self._lock: async with self._lock:
self._local[key] = (time.time() + ttl, value) self._local[key] = (time.time() + ttl, value)
# 当内存缓存条目超过上限时清理过期条目
if len(self._local) > 10000: if len(self._local) > 10000:
now = time.time() now = time.time()
expired = [k for k, (t, v) in self._local.items() if now >= t] expired = [k for k, (t, v) in self._local.items() if now >= t]
@ -63,6 +112,13 @@ class CacheManager:
del self._local[k] del self._local[k]
async def delete(self, key: str): async def delete(self, key: str):
"""从缓存中删除指定键。
同时从 Redis 和内存缓存中删除任一删除失败不影响另一个
Args:
key: 要删除的缓存键
"""
if self._redis_available and self._redis: if self._redis_available and self._redis:
try: try:
await self._redis.delete(key) await self._redis.delete(key)
@ -72,6 +128,13 @@ class CacheManager:
self._local.pop(key, None) self._local.pop(key, None)
async def delete_pattern(self, pattern: str): async def delete_pattern(self, pattern: str):
"""按模式匹配批量删除缓存键。
Redis 中使用 keys 命令匹配内存缓存中使用字符串包含匹配
Args:
pattern: 匹配模式支持通配符 *
"""
if self._redis_available and self._redis: if self._redis_available and self._redis:
try: try:
keys = await self._redis.keys(pattern) keys = await self._redis.keys(pattern)
@ -85,4 +148,4 @@ class CacheManager:
del self._local[k] del self._local[k]
cache_manager = CacheManager() cache_manager = CacheManager() # 全局缓存管理器单例实例

80
backend/middleware/rate_limiter.py

@ -1,3 +1,8 @@
"""速率限制中间件模块。
提供基于令牌桶算法的 HTTP 请求速率限制功能
采用内存中的滑动窗口机制限制每个 IP 地址在指定时间窗口内的请求数量
"""
import time import time
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
@ -6,14 +11,31 @@ from config import settings
class RateLimiter: class RateLimiter:
"""内存速率限制器类,使用滑动窗口算法限制请求频率。
为每个唯一键通常是 IP 地址维护一个时间戳列表
在每次请求时清理过期时间戳并检查是否超过限制
Attributes:
MAX_KEYS: 最大缓存的键数量防止内存无限增长
_buckets: 滑动窗口桶存储每个键的请求时间戳列表
_lock: 异步锁保证并发安全
_last_cleanup: 上次清理缓存的时间戳
"""
MAX_KEYS = 10000 MAX_KEYS = 10000
def __init__(self): def __init__(self):
self._buckets: dict[str, list[float]] = defaultdict(list) """初始化速率限制器实例。"""
self._lock = asyncio.Lock() self._buckets: dict[str, list[float]] = defaultdict(list) # 滑动窗口桶:{key: [timestamp, ...]}
self._last_cleanup = time.time() self._lock = asyncio.Lock() # 异步锁,保证并发安全
self._last_cleanup = time.time() # 上次清理缓存的时间戳
async def _cleanup(self): async def _cleanup(self):
"""清理过期和空闲的键,释放内存空间。
仅在距上次清理超过 60 秒时执行实际清理操作
删除空桶或最后一个请求超过 120 秒的桶
"""
now = time.time() now = time.time()
if now - self._last_cleanup < 60: if now - self._last_cleanup < 60:
return return
@ -23,21 +45,33 @@ class RateLimiter:
del self._buckets[k] del self._buckets[k]
async def check(self, key: str) -> bool: async def check(self, key: str) -> bool:
"""检查指定键是否允许新的请求。
使用滑动窗口算法清理窗口外的时间戳后检查是否超过限制
如果超过限制则拒绝请求否则记录当前时间戳并允许通过
Args:
key: 速率限制键通常为 IP 地址
Returns:
bool: 允许请求返回 True拒绝请求返回 False
"""
now = time.time() now = time.time()
limit = settings.RATE_LIMIT_PER_MINUTE limit = settings.RATE_LIMIT_PER_MINUTE # 每分钟请求限制数
window = 60.0 window = 60.0 # 时间窗口(秒)
async with self._lock: async with self._lock:
await self._cleanup() await self._cleanup()
bucket = self._buckets[key] bucket = self._buckets[key]
bucket = [t for t in bucket if now - t < window] bucket = [t for t in bucket if now - t < window] # 过滤窗口内的时间戳
self._buckets[key] = bucket self._buckets[key] = bucket
if len(bucket) >= limit: if len(bucket) >= limit:
return False return False # 超过限制,拒绝请求
bucket.append(now) bucket.append(now) # 记录当前请求时间戳
# 当缓存键数量超过上限时,淘汰最旧的键
if len(self._buckets) > self.MAX_KEYS: if len(self._buckets) > self.MAX_KEYS:
oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2] oldest_keys = sorted(self._buckets, key=lambda k: self._buckets[k][0] if self._buckets[k] else 0)[:len(self._buckets) - self.MAX_KEYS // 2]
for k in oldest_keys: for k in oldest_keys:
@ -46,27 +80,49 @@ class RateLimiter:
return True return True
async def remaining(self, key: str) -> int: async def remaining(self, key: str) -> int:
"""获取指定键剩余的请求次数。
Args:
key: 速率限制键
Returns:
int: 当前时间窗口内剩余的请求次数
"""
now = time.time() now = time.time()
async with self._lock: async with self._lock:
bucket = [t for t in self._buckets.get(key, []) if now - t < 60] bucket = [t for t in self._buckets.get(key, []) if now - t < 60]
return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket)) return max(0, settings.RATE_LIMIT_PER_MINUTE - len(bucket))
rate_limiter = RateLimiter() rate_limiter = RateLimiter() # 全局速率限制器单例实例
async def rate_limit_middleware(request: Request, call_next): async def rate_limit_middleware(request: Request, call_next):
"""速率限制中间件。
对每个 HTTP 请求进行速率限制检查
1. 跳过公开路径健康检查登录等
2. 基于客户端 IP 地址进行速率限制
3. 在响应头中添加剩余请求次数信息
Args:
request: 当前 HTTP 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
Response: 如果未超限则返回后续处理结果否则返回 429 错误响应
"""
path = request.url.path path = request.url.path
if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]: if path in ["/health", "/api/auth/login", "/docs", "/openapi.json"]:
return await call_next(request) return await call_next(request)
client_ip = request.client.host if request.client else "unknown" client_ip = request.client.host if request.client else "unknown" # 客户端 IP 地址
key = f"ratelimit:{client_ip}" key = f"ratelimit:{client_ip}" # 速率限制键
if not await rate_limiter.check(key): if not await rate_limiter.check(key):
raise HTTPException(429, "请求过于频繁,请稍后再试") raise HTTPException(429, "请求过于频繁,请稍后再试")
response = await call_next(request) response = await call_next(request)
remaining = await rate_limiter.remaining(key) remaining = await rate_limiter.remaining(key)
response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Remaining"] = str(remaining) # 响应头中添加剩余请求次数
return response return response

35
backend/middleware/rbac_middleware.py

@ -1,3 +1,9 @@
"""RBAC 权限中间件模块。
提供全局 HTTP 请求的 RBAC基于角色的访问控制权限校验中间件
每个请求都会经过此中间件解析 JWT 令牌并查询用户的角色和权限信息
将用户上下文存储到 request.state.user 中供后续路由使用
"""
import jwt import jwt
from fastapi import Request, HTTPException from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -8,34 +14,55 @@ from sqlalchemy import select
async def rbac_middleware(request: Request, call_next): async def rbac_middleware(request: Request, call_next):
"""RBAC 权限校验中间件。
对每个 HTTP 请求进行权限校验
1. 跳过公开路径登录健康检查等
2. 解析 JWT 令牌获取用户身份
3. 从数据库查询用户的角色权限和数据权限范围
4. 将用户上下文存储到 request.state.user
Args:
request: 当前 HTTP 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
Response: 如果权限校验通过则返回后续处理结果否则返回 401 错误响应
"""
# 公开路径列表,无需认证即可访问
public_paths = ["/api/auth/login", "/api/auth/wecom", "/health", "/docs", "/openapi.json", "/wecom/callback"] public_paths = ["/api/auth/login", "/api/auth/wecom", "/health", "/docs", "/openapi.json", "/wecom/callback"]
if any(request.url.path.startswith(p) for p in public_paths): if any(request.url.path.startswith(p) for p in public_paths):
return await call_next(request) return await call_next(request)
# 从 Authorization 头中提取 JWT 令牌
token = request.headers.get("Authorization", "").replace("Bearer ", "") token = request.headers.get("Authorization", "").replace("Bearer ", "")
if not token: if not token:
return JSONResponse({"code": 401, "message": "未提供认证令牌"}, 401) return JSONResponse({"code": 401, "message": "未提供认证令牌"}, 401)
# 解析 JWT 令牌获取用户 ID
try: try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
user_id = payload.get("sub") user_id = payload.get("sub")
except jwt.PyJWTError: except jwt.PyJWTError:
return JSONResponse({"code": 401, "message": "令牌无效或已过期"}, 401) return JSONResponse({"code": 401, "message": "令牌无效或已过期"}, 401)
# 从数据库查询用户信息和权限
async with AsyncSessionLocal() as db: async with AsyncSessionLocal() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or user.status != "active": if not user or user.status != "active":
return JSONResponse({"code": 401, "message": "用户不存在或已禁用"}, 401) return JSONResponse({"code": 401, "message": "用户不存在或已禁用"}, 401)
# 查询用户关联的所有角色
ur_result = await db.execute( ur_result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user.id) select(Role).join(UserRole).where(UserRole.user_id == user.id)
) )
roles = ur_result.scalars().all() roles = ur_result.scalars().all()
role_codes = [r.code for r in roles] role_codes = [r.code for r in roles] # 角色编码列表
is_root = "root" in role_codes is_root = "root" in role_codes # 是否为超级管理员
# 收集所有权限编码和数据权限范围
permissions = [] permissions = []
data_scopes = [] data_scopes = []
for role in roles: for role in roles:
@ -46,11 +73,13 @@ async def rbac_middleware(request: Request, call_next):
perms = rp_result.scalars().all() perms = rp_result.scalars().all()
permissions.extend([p.code for p in perms]) permissions.extend([p.code for p in perms])
unique_perms = list(set(permissions)) unique_perms = list(set(permissions)) # 去重后的权限列表
# 超级管理员自动拥有所有权限
if is_root and "*:*" not in unique_perms: if is_root and "*:*" not in unique_perms:
unique_perms.insert(0, "*:*") unique_perms.insert(0, "*:*")
# 将用户上下文存储到 request.state 中
request.state.user = { request.state.user = {
"id": str(user.id), "id": str(user.id),
"username": user.username, "username": user.username,

446
backend/models/__init__.py

@ -1,3 +1,9 @@
"""
数据库 ORM 模型模块
本模块定义了所有数据库表对应的 SQLAlchemy ORM 模型
每个类映射到一张数据库表类属性映射到表字段
"""
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Boolean, JSON, Text, Float from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Boolean, JSON, Text, Float
@ -7,389 +13,201 @@ from database import Base
class Department(Base): class Department(Base):
"""部门表 (departments),存储企业部门层级结构,支持多级树形组织架构。"""
__tablename__ = "departments" __tablename__ = "departments"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 部门唯一标识 UUID
name = Column(String(100), nullable=False) name = Column(String(100), nullable=False) # 部门名称
parent_id = Column(UUID(as_uuid=True), ForeignKey("departments.id"), nullable=True) parent_id = Column(UUID(as_uuid=True), ForeignKey("departments.id"), nullable=True) # 上级部门 ID,用于构建树形结构
path = Column(String(500), default="/") path = Column(String(500), default="/") # 部门路径,从根节点到当前节点的路径字符串
level = Column(Integer, default=0) level = Column(Integer, default=0) # 部门层级深度(根部门为 0)
sort_order = Column(Integer, default=0) sort_order = Column(Integer, default=0) # 排序序号,同级部门按此字段排列顺序
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
children = relationship("Department", backref="parent", remote_side=[id]) children = relationship("Department", backref="parent", remote_side=[id]) # 子部门列表(一对多自引用)
users = relationship("User", back_populates="department") users = relationship("User", back_populates="department") # 部门下的用户列表
class User(Base): class User(Base):
"""用户表 (users),存储系统用户信息,包括账号、身份、组织归属。"""
__tablename__ = "users" __tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 用户唯一标识 UUID
username = Column(String(50), unique=True, nullable=False) username = Column(String(50), unique=True, nullable=False) # 登录用户名(唯一)
password_hash = Column(String(255), nullable=False) password_hash = Column(String(255), nullable=False) # 密码哈希值(bcrypt 加密)
display_name = Column(String(100), nullable=False) display_name = Column(String(100), nullable=False) # 用户显示名称
email = Column(String(100)) email = Column(String(100)) # 电子邮箱
phone = Column(String(20)) phone = Column(String(20)) # 手机号码
wecom_user_id = Column(String(100), unique=True) wecom_user_id = Column(String(100), unique=True) # 企业微信用户 ID(唯一)
department_id = Column(UUID(as_uuid=True), ForeignKey("departments.id")) department_id = Column(UUID(as_uuid=True), ForeignKey("departments.id")) # 所属部门 ID
position = Column(String(100)) position = Column(String(100)) # 职位/岗位名称
manager_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) manager_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 直接上级用户 ID
status = Column(String(20), default="active") status = Column(String(20), default="active") # 用户状态:active/inactive/disabled
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
department = relationship("Department", back_populates="users") department = relationship("Department", back_populates="users") # 所属部门(多对一)
roles = relationship("UserRole", back_populates="user") roles = relationship("UserRole", back_populates="user") # 用户角色列表(通过中间表关联)
manager = relationship("User", remote_side=[id], backref="subordinates") manager = relationship("User", remote_side=[id], backref="subordinates") # 直接上级(自引用)
class Role(Base): class Role(Base):
"""角色表 (roles),存储系统角色定义,用于 RBAC 权限管理。"""
__tablename__ = "roles" __tablename__ = "roles"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 角色唯一标识 UUID
name = Column(String(50), unique=True, nullable=False) name = Column(String(50), unique=True, nullable=False) # 角色名称(唯一)
code = Column(String(50), unique=True, nullable=False, default="") code = Column(String(50), unique=True, nullable=False, default="") # 角色编码(唯一,如 admin/user)
description = Column(String(200)) description = Column(String(200)) # 角色描述
is_system = Column(Boolean, default=False) is_system = Column(Boolean, default=False) # 是否为系统内置角色(不可删除)
data_scope = Column(String(50), default="self_only") data_scope = Column(String(50), default="self_only") # 数据权限范围:self_only/department/all
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
permissions = relationship("RolePermission", back_populates="role") permissions = relationship("RolePermission", back_populates="role") # 角色权限关联列表
class Permission(Base): class Permission(Base):
"""权限表 (permissions),存储系统中每个可操作的权限点。"""
__tablename__ = "permissions" __tablename__ = "permissions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 权限唯一标识 UUID
code = Column(String(100), unique=True, nullable=False) code = Column(String(100), unique=True, nullable=False) # 权限编码(唯一)
name = Column(String(100), nullable=False) name = Column(String(100), nullable=False) # 权限名称
resource = Column(String(100), nullable=False) resource = Column(String(100), nullable=False) # 所属资源名称(如 user/role)
action = Column(String(50), nullable=False) action = Column(String(50), nullable=False) # 操作类型(create/read/update/delete)
description = Column(String(200)) description = Column(String(200)) # 权限描述
class RolePermission(Base): class RolePermission(Base):
"""角色-权限关联表 (role_permissions),多对多关联中间表。"""
__tablename__ = "role_permissions" __tablename__ = "role_permissions"
role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) # 角色 ID(级联删除)
permission_id = Column(UUID(as_uuid=True), ForeignKey("permissions.id", ondelete="CASCADE"), primary_key=True) permission_id = Column(UUID(as_uuid=True), ForeignKey("permissions.id", ondelete="CASCADE"), primary_key=True) # 权限 ID(级联删除)
role = relationship("Role", back_populates="permissions") role = relationship("Role", back_populates="permissions")
permission = relationship("Permission") permission = relationship("Permission")
class UserRole(Base): class UserRole(Base):
"""用户-角色关联表 (user_roles),多对多关联中间表。"""
__tablename__ = "user_roles" __tablename__ = "user_roles"
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) # 用户 ID(级联删除)
role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) role_id = Column(UUID(as_uuid=True), ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True) # 角色 ID(级联删除)
user = relationship("User", back_populates="roles") user = relationship("User", back_populates="roles")
role = relationship("Role") role = relationship("Role")
class ChatSession(Base): class ChatSession(Base):
"""聊天会话表 (chat_sessions),存储用户与 AI 智能体的对话会话记录。"""
__tablename__ = "chat_sessions" __tablename__ = "chat_sessions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 会话唯一标识 UUID
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) # 所属用户 ID
agent_type = Column(String(50), nullable=False) agent_type = Column(String(50), nullable=False) # 智能体类型(chat/flow/rag)
session_id = Column(String(100), unique=True, nullable=False) session_id = Column(String(100), unique=True, nullable=False) # 外部会话 ID(对客户端暴露)
status = Column(String(20), default="active") status = Column(String(20), default="active") # 会话状态:active/closed
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
class ChatMessage(Base): class ChatMessage(Base):
"""聊天消息表 (chat_messages),存储聊天会话中的每条消息内容。"""
__tablename__ = "chat_messages" __tablename__ = "chat_messages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 消息唯一标识 UUID
session_id = Column(UUID(as_uuid=True), ForeignKey("chat_sessions.id", ondelete="CASCADE")) session_id = Column(UUID(as_uuid=True), ForeignKey("chat_sessions.id", ondelete="CASCADE")) # 所属会话 ID
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")) # 发送者用户 ID
role = Column(String(20), nullable=False) role = Column(String(20), nullable=False) # 消息角色:user/assistant/system
content = Column(Text, nullable=False) content = Column(Text, nullable=False) # 消息内容文本
metadata_ = Column("metadata", JSON, default=dict) metadata_ = Column("metadata", JSON, default=dict) # 元数据(额外信息 JSON)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
class Task(Base): class Task(Base):
"""任务表 (tasks),存储分配给用户的待办任务信息。"""
__tablename__ = "tasks" __tablename__ = "tasks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 任务唯一标识 UUID
title = Column(String(200), nullable=False) title = Column(String(200), nullable=False) # 任务标题
content = Column(Text) content = Column(Text) # 任务内容描述
assigner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) assigner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 任务分配者(发起人)ID
assignee_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) assignee_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) # 任务执行人 ID
status = Column(String(20), default="pending") status = Column(String(20), default="pending") # 任务状态:pending/in_progress/completed/cancelled
priority = Column(String(20), default="normal") priority = Column(String(20), default="normal") # 优先级:low/normal/high/urgent
deadline = Column(DateTime) deadline = Column(DateTime) # 截止日期时间
wecom_message_id = Column(String(100)) wecom_message_id = Column(String(100)) # 企业微信消息 ID
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
class FlowDefinition(Base): class FlowDefinition(Base):
"""流程定义表 (flow_definitions),存储可执行 AI 工作流的节点和连线配置。"""
__tablename__ = "flow_definitions" __tablename__ = "flow_definitions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 流程唯一标识 UUID
name = Column(String(200), nullable=False) name = Column(String(200), nullable=False) # 流程名称
description = Column(Text) description = Column(Text) # 流程描述
version = Column(Integer, default=1) version = Column(Integer, default=1) # 当前版本号
status = Column(String(20), default="draft") status = Column(String(20), default="draft") # 流程状态:draft/published/archived
definition_json = Column(JSON, nullable=False, default=dict) definition_json = Column(JSON, nullable=False, default=dict) # 已发布的节点和连线配置 JSON
published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) published_version_id = Column(UUID(as_uuid=True), ForeignKey("flow_versions.id"), nullable=True) # 已发布版本 ID
draft_definition_json = Column(JSON, nullable=True, default=None) draft_definition_json = Column(JSON, nullable=True, default=None) # 草稿编辑中的配置 JSON
creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID
flow_mode = Column(String(20), default="chatflow") flow_mode = Column(String(20), default="chatflow") # 流程模式:chatflow/workflow
published_to_wecom = Column(Boolean, default=False) published_to_wecom = Column(Boolean, default=False) # 是否已发布到企业微信
published_to_web = Column(Boolean, default=False) published_to_web = Column(Boolean, default=False) # 是否已发布到 Web 端
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
published_version = relationship("FlowVersion", foreign_keys=[published_version_id], post_update=True) published_version = relationship("FlowVersion", foreign_keys=[published_version_id], post_update=True)
class FlowVersion(Base): class FlowVersion(Base):
"""流程版本表 (flow_versions),存储流程定义的历史版本快照。"""
__tablename__ = "flow_versions" __tablename__ = "flow_versions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 版本唯一标识 UUID
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) # 所属流程定义 ID
version = Column(Integer, nullable=False) version = Column(Integer, nullable=False) # 版本号(同一流程内递增)
definition_json = Column(JSON, nullable=False, default=dict) definition_json = Column(JSON, nullable=False, default=dict) # 该版本的流程定义 JSON 快照
changelog = Column(Text, default="") changelog = Column(Text, default="") # 版本变更日志
published_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) published_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 发布者用户 ID
published_to_wecom = Column(Boolean, default=False) published_to_wecom = Column(Boolean, default=False) # 是否发布到企业微信
published_to_web = Column(Boolean, default=False) published_to_web = Column(Boolean, default=False) # 是否发布到 Web 端
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
class FlowApiKey(Base): class FlowApiKey(Base):
"""流程 API 密钥表 (flow_api_keys),存储用于外部调用流程的 API 密钥。"""
__tablename__ = "flow_api_keys" __tablename__ = "flow_api_keys"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 密钥唯一标识 UUID
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False) # 所属流程 ID
name = Column(String(100), nullable=False) name = Column(String(100), nullable=False) # 密钥名称
key_hash = Column(String(64), nullable=False) key_hash = Column(String(64), nullable=False) # 密钥哈希值(SHA-256 加密存储)
key_prefix = Column(String(10), nullable=False) key_prefix = Column(String(10), nullable=False) # 密钥前缀(用于显示识别)
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID
last_used_at = Column(DateTime, nullable=True) last_used_at = Column(DateTime, nullable=True) # 最后使用时间
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
class FlowTemplate(Base): class FlowTemplate(Base):
"""流程模板表 (flow_templates),存储预定义的流程模板,可供用户快速创建流程。"""
__tablename__ = "flow_templates" __tablename__ = "flow_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # 模板唯一标识 UUID
name = Column(String(200), nullable=False) name = Column(String(200), nullable=False) # 模板名称
description = Column(Text, default="") description = Column(Text, default="") # 模板描述
category = Column(String(50), default="") category = Column(String(50), default="") # 模板分类
definition_json = Column(JSON, nullable=False, default=dict) definition_json = Column(JSON, nullable=False, default=dict) # 模板的流程定义 JSON
icon = Column(String(50), default="") icon = Column(String(50), default="") # 模板图标名称
sort_order = Column(Integer, default=0) sort_order = Column(Integer, default=0) # 排序序号
is_builtin = Column(Boolean, default=False) is_builtin = Column(Boolean, default=False) # 是否为系统内置模板
usage_count = Column(Integer, default=0) usage_count = Column(Integer, default=0) # 使用次数统计
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) created_by = Column(UUID(as_uuid=True), ForeignKey("users.id")) # 创建者用户 ID
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow) # 记录创建时间
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # 记录更新时间
class CustomTool(Base):
__tablename__ = "custom_tools"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
description = Column(Text)
schema_json = Column(JSON, nullable=False, default=dict)
endpoint_url = Column(String(500), nullable=False)
method = Column(String(10), default="GET")
path = Column(String(500), default="")
headers_json = Column(JSON, default=dict)
auth_type = Column(String(20), default="none")
auth_config = Column(JSON, default=dict)
created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"))
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class FlowExecution(Base):
__tablename__ = "flow_executions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"))
version = Column(Integer, nullable=True)
trigger_type = Column(String(50))
trigger_user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
input_data = Column(JSON)
output_data = Column(JSON)
status = Column(String(20), default="running")
token_usage = Column(JSON, default=dict)
latency_ms = Column(Integer, nullable=True)
error_message = Column(Text, nullable=True)
started_at = Column(DateTime, default=datetime.utcnow)
finished_at = Column(DateTime)
class MCPService(Base):
__tablename__ = "mcp_services"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), unique=True, nullable=False)
transport = Column(String(20), default="http")
url = Column(String(500))
command = Column(String(500))
args = Column(JSON, default=list)
env = Column(JSON, default=dict)
status = Column(String(20), default="disconnected")
tools = Column(JSON, default=list)
creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class NotificationTemplate(Base):
__tablename__ = "notification_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
code = Column(String(100), unique=True, nullable=False)
channel = Column(String(20), default="wecom")
title_template = Column(String(500))
body_template = Column(Text, nullable=False)
variables = Column(JSON, default=list)
is_system = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
class SystemMetric(Base):
__tablename__ = "system_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
metric_type = Column(String(50), nullable=False)
value = Column(JSON, nullable=False)
collected_at = Column(DateTime, default=datetime.utcnow)
class AgentConfig(Base):
__tablename__ = "agent_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
description = Column(String(500))
system_prompt = Column(Text, default="")
model = Column(String(50), default="gpt-4o-mini")
model_instance_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
embedding_model_id = Column(UUID(as_uuid=True), ForeignKey("model_instances.id"), nullable=True)
temperature = Column(Float, default=0.7)
tools = Column(JSON, default=list)
status = Column(String(20), default="active")
creator_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class AuditLog(Base):
__tablename__ = "audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
operator_id = Column(UUID(as_uuid=True), ForeignKey("users.id"))
action = Column(String(100), nullable=False)
resource = Column(String(100))
resource_id = Column(String(100))
detail = Column(JSON, default=dict)
ip_address = Column(String(50))
created_at = Column(DateTime, default=datetime.utcnow)
class MemoryMessage(Base):
__tablename__ = "memory_messages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
session_id = Column(UUID(as_uuid=True), nullable=False)
role = Column(String(20), nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
class MemoryAtom(Base):
__tablename__ = "memory_atoms"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
atom_type = Column(String(20), nullable=False)
content = Column(Text, nullable=False)
priority = Column(Integer, default=50)
source_session_id = Column(UUID(as_uuid=True), nullable=True)
metadata_ = Column("metadata", JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MemoryScene(Base):
__tablename__ = "memory_scenes"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="SET NULL"), nullable=True)
scene_name = Column(String(200), nullable=False)
summary = Column(Text, nullable=False)
heat = Column(Integer, default=0)
content = Column(JSON, default=dict)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MemoryPersona(Base):
__tablename__ = "memory_personas"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True)
content = Column(JSON, default=dict, nullable=False)
raw_text = Column(Text, default="")
version = Column(Integer, default=1)
updated_at = Column(DateTime, default=datetime.utcnow)
class MemorySession(Base):
__tablename__ = "memory_sessions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
flow_id = Column(UUID(as_uuid=True), ForeignKey("flow_definitions.id", ondelete="CASCADE"), nullable=False)
session_id = Column(UUID(as_uuid=True), nullable=False)
flow_name = Column(String(200), default="")
message_count = Column(Integer, default=0)
last_active_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime, default=datetime.utcnow)
class ModelProvider(Base):
__tablename__ = "model_providers"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
provider_type = Column(String(50), nullable=False)
base_url = Column(String(500))
api_key = Column(Text)
extra_config = Column(JSON, default=dict)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ModelInstance(Base):
__tablename__ = "model_instances"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
provider_id = Column(UUID(as_uuid=True), ForeignKey("model_providers.id", ondelete="CASCADE"))
model_name = Column(String(100), nullable=False)
model_type = Column(String(30), nullable=False)
display_name = Column(String(200))
capabilities = Column(JSON, default=dict)
default_params = Column(JSON, default=dict)
is_default = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)

1
backend/modules/__init__.py

@ -0,0 +1 @@
"""业务模块包。"""

1
backend/modules/agent_manager/__init__.py

@ -0,0 +1 @@
"""Agent 管理模块。"""

85
backend/modules/agent_manager/router.py

@ -1,3 +1,8 @@
"""智能体管理模块路由。
提供 AI 智能体的配置管理对话交互和聊天历史记录查询功能
支持多种智能体类型员工/管理者/任务/文档的动态创建和对话
"""
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select from sqlalchemy import select
@ -17,16 +22,32 @@ async def agent_chat(
payload: dict, payload: dict,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""与指定类型的 AI 智能体进行对话。
根据用户身份创建或获取对应的智能体实例处理用户消息并返回 AI 回复
同时会自动创建或复用聊天会话并保存对话消息到数据库
Args:
agent_type: 智能体类型employee/manager/task/document
request: HTTP 请求对象用于获取当前用户信息
payload: 请求体包含 message 和可选的 session_id
db: 异步数据库会话
Returns:
dict: 包含 session_idreply role 的响应数据
"""
user_ctx = request.state.user user_ctx = request.state.user
user_id = uuid.UUID(user_ctx["id"]) user_id = uuid.UUID(user_ctx["id"])
msg_content = payload.get("message", "") msg_content = payload.get("message", "") # 用户消息内容
session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}") session_id = payload.get("session_id", f"session_{uuid.uuid4().hex[:12]}") # 会话 ID,未提供则自动生成
# 查询当前用户信息
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException(404, "用户不存在") raise HTTPException(404, "用户不存在")
# 查询或创建聊天会话
session_result = await db.execute( session_result = await db.execute(
select(ChatSession).where(ChatSession.session_id == session_id) select(ChatSession).where(ChatSession.session_id == session_id)
) )
@ -39,6 +60,7 @@ async def agent_chat(
db.add(session) db.add(session)
await db.flush() await db.flush()
# 保存用户消息到数据库
user_msg = ChatMessage( user_msg = ChatMessage(
session_id=session.id, user_id=user.id, session_id=session.id, user_id=user.id,
role="user", content=msg_content, role="user", content=msg_content,
@ -46,6 +68,7 @@ async def agent_chat(
db.add(user_msg) db.add(user_msg)
await db.flush() await db.flush()
# 创建对应类型的 AI 智能体
agent = await AgentFactory.create_agent( agent = await AgentFactory.create_agent(
agent_type=agent_type, agent_type=agent_type,
user_id=str(user.id), user_id=str(user.id),
@ -53,12 +76,14 @@ async def agent_chat(
department_id=str(user.department_id) if user.department_id else None, department_id=str(user.department_id) if user.department_id else None,
) )
# 构造消息并调用智能体回复
from agentscope.message import Msg from agentscope.message import Msg
input_msg = Msg(name="user", content=msg_content, role="user") input_msg = Msg(name="user", content=msg_content, role="user")
response = await agent.reply(input_msg) response = await agent.reply(input_msg)
reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response) reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response)
# 保存 AI 回复消息到数据库
ai_msg = ChatMessage( ai_msg = ChatMessage(
session_id=session.id, user_id=user.id, session_id=session.id, user_id=user.id,
role="assistant", content=reply_text, role="assistant", content=reply_text,
@ -78,6 +103,11 @@ async def agent_chat(
@router.get("/list") @router.get("/list")
async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)): async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)):
"""获取所有处于活跃状态的智能体配置列表。
Returns:
dict: 包含智能体配置列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(AgentConfig).where(AgentConfig.status == "active").order_by(AgentConfig.updated_at.desc()) select(AgentConfig).where(AgentConfig.status == "active").order_by(AgentConfig.updated_at.desc())
) )
@ -99,6 +129,16 @@ async def get_agent_list(request: Request, db: AsyncSession = Depends(get_db)):
@router.post("/", response_model=AgentConfigOut) @router.post("/", response_model=AgentConfigOut)
async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSession = Depends(get_db)): async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSession = Depends(get_db)):
"""创建新的智能体配置。
Args:
req: 智能体配置创建请求体
request: HTTP 请求对象
db: 异步数据库会话
Returns:
AgentConfigOut: 创建后的智能体配置响应
"""
user_ctx = request.state.user user_ctx = request.state.user
agent = AgentConfig( agent = AgentConfig(
name=req.name, name=req.name,
@ -123,6 +163,16 @@ async def create_agent(req: AgentConfigCreate, request: Request, db: AsyncSessio
@router.get("/{agent_id}", response_model=AgentConfigOut) @router.get("/{agent_id}", response_model=AgentConfigOut)
async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""获取指定智能体的配置详情。
Args:
agent_id: 智能体唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
Returns:
AgentConfigOut: 智能体配置响应
"""
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none() agent = result.scalar_one_or_none()
if not agent: if not agent:
@ -139,6 +189,17 @@ async def get_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = De
@router.put("/{agent_id}", response_model=AgentConfigOut) @router.put("/{agent_id}", response_model=AgentConfigOut)
async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Request, db: AsyncSession = Depends(get_db)): async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Request, db: AsyncSession = Depends(get_db)):
"""更新指定智能体的配置。
Args:
agent_id: 智能体唯一标识 ID
req: 智能体配置更新请求体
request: HTTP 请求对象
db: 异步数据库会话
Returns:
AgentConfigOut: 更新后的智能体配置响应
"""
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none() agent = result.scalar_one_or_none()
if not agent: if not agent:
@ -170,6 +231,16 @@ async def update_agent(agent_id: uuid.UUID, req: AgentConfigUpdate, request: Req
@router.delete("/{agent_id}") @router.delete("/{agent_id}")
async def delete_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): async def delete_agent(agent_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""删除指定的智能体配置。
Args:
agent_id: 智能体唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 操作结果响应
"""
result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id)) result = await db.execute(select(AgentConfig).where(AgentConfig.id == agent_id))
agent = result.scalar_one_or_none() agent = result.scalar_one_or_none()
if not agent: if not agent:
@ -184,6 +255,16 @@ async def get_chat_history(
request: Request, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""获取指定会话的完整聊天历史记录。
Args:
session_id: 会话唯一标识
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含消息列表的响应数据按时间顺序排列
"""
session_result = await db.execute( session_result = await db.execute(
select(ChatSession).where(ChatSession.session_id == session_id) select(ChatSession).where(ChatSession.session_id == session_id)
) )

1
backend/modules/audit/__init__.py

@ -0,0 +1 @@
"""审计日志模块。"""

62
backend/modules/audit/router.py

@ -1,3 +1,8 @@
"""审计日志模块路由。
提供审计日志的查询统计和导出功能
记录系统中所有重要操作的详细信息支持按操作类型资源操作人和时间范围筛选
"""
import uuid import uuid
import csv import csv
import io import io
@ -26,6 +31,22 @@ async def list_logs(
date_to: datetime | None = Query(None), date_to: datetime | None = Query(None),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""分页查询审计日志列表,支持多条件筛选。
Args:
request: HTTP 请求对象
page: 页码 1 开始
page_size: 每页数量最大 100
action: 可选的操作类型筛选条件
resource: 可选的资源类型筛选条件
operator_id: 可选的操作人 ID 筛选条件
date_from: 可选的起始时间筛选条件
date_to: 可选的结束时间筛选条件
db: 异步数据库会话
Returns:
AuditLogPage: 分页的审计日志响应数据
"""
conditions = [] conditions = []
if action: if action:
conditions.append(AuditLog.action == action) conditions.append(AuditLog.action == action)
@ -38,14 +59,16 @@ async def list_logs(
if date_to: if date_to:
conditions.append(AuditLog.created_at <= date_to) conditions.append(AuditLog.created_at <= date_to)
where = and_(*conditions) if conditions else None where = and_(*conditions) if conditions else None # 组合所有筛选条件
# 查询总数
count_q = select(func.count(AuditLog.id)) count_q = select(func.count(AuditLog.id))
if where is not None: if where is not None:
count_q = count_q.where(where) count_q = count_q.where(where)
total_result = await db.execute(count_q) total_result = await db.execute(count_q)
total = total_result.scalar() or 0 total = total_result.scalar() or 0
# 分页查询
q = select(AuditLog).order_by(AuditLog.created_at.desc()) q = select(AuditLog).order_by(AuditLog.created_at.desc())
if where is not None: if where is not None:
q = q.where(where) q = q.where(where)
@ -63,6 +86,15 @@ async def list_logs(
@router.get("/actions") @router.get("/actions")
async def list_action_types(request: Request, db: AsyncSession = Depends(get_db)): async def list_action_types(request: Request, db: AsyncSession = Depends(get_db)):
"""获取所有操作类型及其出现次数统计。
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含操作类型统计列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action) select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
) )
@ -74,15 +106,25 @@ async def list_action_types(request: Request, db: AsyncSession = Depends(get_db)
@router.get("/stats") @router.get("/stats")
async def audit_stats(request: Request, db: AsyncSession = Depends(get_db)): async def audit_stats(request: Request, db: AsyncSession = Depends(get_db)):
"""获取审计日志的统计摘要,包括总数、今日数量和 TOP 排行。
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含审计统计摘要的响应数据
"""
total_result = await db.execute(select(func.count(AuditLog.id))) total_result = await db.execute(select(func.count(AuditLog.id)))
total = total_result.scalar() or 0 total = total_result.scalar() or 0
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) # 今日零点
today_result = await db.execute( today_result = await db.execute(
select(func.count(AuditLog.id)).where(AuditLog.created_at >= today_start) select(func.count(AuditLog.id)).where(AuditLog.created_at >= today_start)
) )
today = today_result.scalar() or 0 today = today_result.scalar() or 0
# 最常见的操作类型 TOP 10
top_result = await db.execute( top_result = await db.execute(
select(AuditLog.action, func.count(AuditLog.id)) select(AuditLog.action, func.count(AuditLog.id))
.group_by(AuditLog.action) .group_by(AuditLog.action)
@ -91,6 +133,7 @@ async def audit_stats(request: Request, db: AsyncSession = Depends(get_db)):
) )
top_actions = [{"action": r[0], "count": r[1]} for r in top_result.all()] top_actions = [{"action": r[0], "count": r[1]} for r in top_result.all()]
# 最常见的资源类型 TOP 10
top_resources = await db.execute( top_resources = await db.execute(
select(AuditLog.resource, func.count(AuditLog.id)) select(AuditLog.resource, func.count(AuditLog.id))
.group_by(AuditLog.resource) .group_by(AuditLog.resource)
@ -117,6 +160,17 @@ async def export_logs(
date_to: datetime | None = Query(None), date_to: datetime | None = Query(None),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""导出审计日志为 CSV 文件。
Args:
request: HTTP 请求对象
date_from: 可选的起始时间筛选条件
date_to: 可选的结束时间筛选条件
db: 异步数据库会话
Returns:
StreamingResponse: CSV 格式的文件流响应
"""
conditions = [] conditions = []
if date_from: if date_from:
conditions.append(AuditLog.created_at >= date_from) conditions.append(AuditLog.created_at >= date_from)
@ -126,13 +180,13 @@ async def export_logs(
q = select(AuditLog).order_by(AuditLog.created_at.desc()) q = select(AuditLog).order_by(AuditLog.created_at.desc())
if conditions: if conditions:
q = q.where(and_(*conditions)) q = q.where(and_(*conditions))
q = q.limit(10000) q = q.limit(10000) # 最多导出 10000 条
result = await db.execute(q) result = await db.execute(q)
logs = result.scalars().all() logs = result.scalars().all()
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow(["ID", "操作时间", "操作人ID", "操作", "资源", "资源ID", "详情", "IP地址"]) writer.writerow(["ID", "操作时间", "操作人ID", "操作", "资源", "资源ID", "详情", "IP地址"]) # CSV 表头
for log in logs: for log in logs:
writer.writerow([ writer.writerow([
str(log.id), str(log.id),

4
backend/modules/auth/__init__.py

@ -0,0 +1,4 @@
"""认证模块。
提供用户登录JWT 令牌生成企业微信 OAuth 授权个人信息修改和密码修改等认证功能
"""

123
backend/modules/auth/router.py

@ -1,3 +1,8 @@
"""认证模块路由。
提供用户登录JWT 令牌生成企业微信 OAuth 授权个人信息查询/修改和密码修改等功能
支持基于用户名密码和企业微信 OAuth 两种认证方式
"""
import uuid import uuid
import secrets import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -12,17 +17,36 @@ from models import User, UserRole, Role, RolePermission, Permission
from schemas import LoginRequest, TokenResponse, UserOut, RoleOut from schemas import LoginRequest, TokenResponse, UserOut, RoleOut
from config import settings from config import settings
# OAuth 状态存储,用于防止 CSRF 攻击
_oauth_states: dict[str, float] = {} _oauth_states: dict[str, float] = {}
_OAUTH_STATE_TTL = 600 _OAUTH_STATE_TTL = 600 # OAuth 状态有效期(秒)
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
"""对密码进行 bcrypt 哈希加密。
Args:
password: 明文密码字符串
Returns:
str: bcrypt 加密后的哈希字符串
"""
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"]) # 认证路由前缀
async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> list[str]: async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> list[str]:
"""根据角色 ID 列表获取所有关联的权限代码。
Args:
db: 异步数据库会话
role_ids: 角色 ID 列表
Returns:
list[str]: 去重后的权限代码列表
"""
result = await db.execute( result = await db.execute(
select(Permission.code) select(Permission.code)
.join(RolePermission) .join(RolePermission)
@ -32,12 +56,24 @@ async def get_permission_codes(db: AsyncSession, role_ids: list[uuid.UUID]) -> l
async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]: async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]:
"""获取用户的所有角色及其权限信息。
查询用户关联的所有角色并为每个角色查询其关联的权限代码
Args:
db: 异步数据库会话
user_id: 用户唯一标识 ID
Returns:
list[RoleOut]: 角色信息列表每个角色包含名称代码描述和权限代码列表
"""
result = await db.execute( result = await db.execute(
select(Role).join(UserRole).where(UserRole.user_id == user_id) select(Role).join(UserRole).where(UserRole.user_id == user_id)
) )
roles = result.scalars().all() roles = result.scalars().all()
out = [] out = []
for role in roles: for role in roles:
# 查询该角色关联的所有权限代码
rp_result = await db.execute( rp_result = await db.execute(
select(Permission.code) select(Permission.code)
.join(RolePermission) .join(RolePermission)
@ -58,6 +94,20 @@ async def get_user_roles(db: AsyncSession, user_id: uuid.UUID) -> list[RoleOut]:
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
"""用户登录接口。
验证用户名和密码验证通过后生成 JWT 令牌并返回用户信息
Args:
req: 登录请求体包含用户名和密码
db: 异步数据库会话
Returns:
TokenResponse: 包含访问令牌和用户信息的响应
Raises:
HTTPException: 用户名或密码错误账户被禁用时抛出异常
"""
result = await db.execute(select(User).where(User.username == req.username)) result = await db.execute(select(User).where(User.username == req.username))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or not bcrypt.checkpw(req.password.encode('utf-8'), user.password_hash.encode('utf-8')): if not user or not bcrypt.checkpw(req.password.encode('utf-8'), user.password_hash.encode('utf-8')):
@ -65,8 +115,9 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
if user.status != "active": if user.status != "active":
raise HTTPException(403, "账户已被禁用") raise HTTPException(403, "账户已被禁用")
roles = await get_user_roles(db, user.id) roles = await get_user_roles(db, user.id) # 获取用户角色信息
# 生成 JWT 令牌
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES) expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
token = jwt.encode( token = jwt.encode(
{"sub": str(user.id), "username": user.username, "exp": expire}, {"sub": str(user.id), "username": user.username, "exp": expire},
@ -88,12 +139,24 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
@router.get("/me", response_model=UserOut) @router.get("/me", response_model=UserOut)
async def get_me(request: Request, db: AsyncSession = Depends(get_db)): async def get_me(request: Request, db: AsyncSession = Depends(get_db)):
user_ctx = request.state.user """获取当前登录用户的详细信息。
Args:
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
UserOut: 当前用户信息包含角色列表
Raises:
HTTPException: 用户不存在时抛出异常
"""
user_ctx = request.state.user # 从请求状态中获取当前用户上下文
result = await db.execute(select(User).where(User.id == user_ctx["id"])) result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException(404, "用户不存在") raise HTTPException(404, "用户不存在")
roles = await get_user_roles(db, user.id) roles = await get_user_roles(db, user.id) # 获取用户角色信息
return UserOut( return UserOut(
id=user.id, username=user.username, display_name=user.display_name, id=user.id, username=user.username, display_name=user.display_name,
email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id, email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id,
@ -105,17 +168,29 @@ async def get_me(request: Request, db: AsyncSession = Depends(get_db)):
@router.get("/wecom/oauth-url") @router.get("/wecom/oauth-url")
async def get_wecom_oauth_url(request: Request): async def get_wecom_oauth_url(request: Request):
"""获取企业微信 OAuth 授权 URL。
生成用于企业微信网页授权登录的 URL包含随机 state 参数用于防 CSRF 攻击
Args:
request: HTTP 请求对象用于获取基础 URL 构建回调地址
Returns:
dict: 包含 OAuth 授权 URL state 参数的响应数据
"""
corp_id = settings.WECOM_CORP_ID or "" corp_id = settings.WECOM_CORP_ID or ""
if not corp_id: if not corp_id:
return {"code": 400, "message": "请先配置 WECOM_CORP_ID"} return {"code": 400, "message": "请先配置 WECOM_CORP_ID"}
base_url = str(request.base_url).rstrip("/") base_url = str(request.base_url).rstrip("/")
redirect_uri = f"{base_url}/api/auth/wecom/callback" redirect_uri = f"{base_url}/api/auth/wecom/callback" # OAuth 回调地址
state = secrets.token_urlsafe(32) state = secrets.token_urlsafe(32) # 生成随机 state 用于防 CSRF
import time import time
_oauth_states[state] = time.time() _oauth_states[state] = time.time() # 存储 state 及其创建时间
# 清理过期的 state
expired = [k for k, v in _oauth_states.items() if time.time() - v > _OAUTH_STATE_TTL] expired = [k for k, v in _oauth_states.items() if time.time() - v > _OAUTH_STATE_TTL]
for k in expired: for k in expired:
del _oauth_states[k] del _oauth_states[k]
# 拼接企业微信 OAuth 授权 URL
url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state={state}#wechat_redirect" url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={corp_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_base&state={state}#wechat_redirect"
return {"code": 200, "data": {"url": url, "state": state}} return {"code": 200, "data": {"url": url, "state": state}}
@ -126,6 +201,21 @@ async def update_me(
payload: dict, payload: dict,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""更新当前用户的个人信息。
支持修改显示名称邮箱和手机号
Args:
request: HTTP 请求对象包含当前用户上下文
payload: 更新字段字典可包含 display_nameemailphone
db: 异步数据库会话
Returns:
UserOut: 更新后的用户信息
Raises:
HTTPException: 用户不存在时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
result = await db.execute(select(User).where(User.id == user_ctx["id"])) result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -156,6 +246,21 @@ async def change_password(
payload: dict, payload: dict,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""修改当前用户的登录密码。
需要验证旧密码正确性新密码至少 6
Args:
request: HTTP 请求对象包含当前用户上下文
payload: 包含 old_password new_password 的字典
db: 异步数据库会话
Returns:
dict: 修改成功的响应数据
Raises:
HTTPException: 用户不存在旧密码错误或新密码长度不足时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
result = await db.execute(select(User).where(User.id == user_ctx["id"])) result = await db.execute(select(User).where(User.id == user_ctx["id"]))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -169,6 +274,6 @@ async def change_password(
if len(new_pw) < 6: if len(new_pw) < 6:
raise HTTPException(400, "新密码至少6位") raise HTTPException(400, "新密码至少6位")
user.password_hash = hash_password(new_pw) user.password_hash = hash_password(new_pw) # 更新为新密码哈希
await db.commit() await db.commit()
return {"code": 200, "message": "密码已修改"} return {"code": 200, "message": "密码已修改"}

1
backend/modules/chat/__init__.py

@ -0,0 +1 @@
"""聊天模块。"""

52
backend/modules/chat/router.py

@ -1,3 +1,8 @@
"""对话模块路由。
提供基于流程的聊天功能支持 WebSocket 实时通信和 HTTP 消息发送
可以执行已发布的 AI 流程并将结果返回给客户端
"""
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, Request
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -12,7 +17,15 @@ router = APIRouter(prefix="/api/chat", tags=["chat"])
@router.websocket("/ws") @router.websocket("/ws")
async def chat_websocket(websocket: WebSocket): async def chat_websocket(websocket: WebSocket):
user_id = websocket.query_params.get("user_id", "anonymous") """WebSocket 聊天连接处理器。
接受客户端的 WebSocket 连接并将其注册到 WebSocket 管理器中
持续接收消息并回显给发送者断开时自动清理连接
Args:
websocket: WebSocket 连接对象
"""
user_id = websocket.query_params.get("user_id", "anonymous") # 从查询参数获取用户 ID
await ws_manager.connect(websocket, user_id) await ws_manager.connect(websocket, user_id)
try: try:
while True: while True:
@ -29,6 +42,23 @@ async def chat_message(
payload: dict, payload: dict,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""向指定的已发布流程发送消息并获取 AI 回复。
加载流程定义后使用 FlowEngine 执行将用户消息作为输入
返回流程执行结果
Args:
flow_id: 流程定义的唯一标识 ID
request: HTTP 请求对象用于获取当前用户信息
payload: 请求体包含 message query 字段作为输入文本
db: 异步数据库会话
Returns:
dict: 包含 AI 回复和节点执行结果的响应数据
Raises:
HTTPException: 流程不存在未发布或执行失败时抛出异常
"""
try: try:
import uuid as _uuid import uuid as _uuid
fid = _uuid.UUID(flow_id) fid = _uuid.UUID(flow_id)
@ -40,7 +70,8 @@ async def chat_message(
if not flow or flow.status != "published": if not flow or flow.status != "published":
raise HTTPException(404, "流不存在或未发布") raise HTTPException(404, "流不存在或未发布")
definition = flow.definition_json definition = flow.definition_json # 流程定义 JSON
# 如果有已发布版本,优先使用版本的定义
published_version_id = getattr(flow, 'published_version_id', None) published_version_id = getattr(flow, 'published_version_id', None)
if published_version_id: if published_version_id:
ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id)) ver_result = await db.execute(select(FlowVersion).where(FlowVersion.id == published_version_id))
@ -50,7 +81,7 @@ async def chat_message(
definition = json.loads(json.dumps(published.definition_json)) definition = json.loads(json.dumps(published.definition_json))
user_ctx = request.state.user user_ctx = request.state.user
input_text = payload.get("message", payload.get("query", "")) input_text = payload.get("message", payload.get("query", "")) # 用户输入文本
if not input_text: if not input_text:
raise HTTPException(400, "请输入消息内容") raise HTTPException(400, "请输入消息内容")
@ -59,9 +90,9 @@ async def chat_message(
context = { context = {
"user_id": user_ctx.get("id", "web_user"), "user_id": user_ctx.get("id", "web_user"),
"username": user_ctx.get("username", "网页访客"), "username": user_ctx.get("username", "网页访客"),
"trigger_data": {"channel": "web_chat"}, "trigger_data": {"channel": "web_chat"}, # 触发渠道为网页聊天
"_node_results": {}, "_node_results": {}, # 存储各节点的执行结果
"session_id": payload.get("session_id", str(uuid.uuid4())), "session_id": payload.get("session_id", str(_uuid.uuid4())),
} }
try: try:
@ -81,6 +112,15 @@ async def chat_message(
@router.get("/flows") @router.get("/flows")
async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)): async def list_chat_flows(request: Request, db: AsyncSession = Depends(get_db)):
"""列出所有已发布的流程,供聊天界面选择使用。
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含已发布流程列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(FlowDefinition).where(FlowDefinition.status == "published") select(FlowDefinition).where(FlowDefinition.status == "published")
) )

5
backend/modules/custom_tool/__init__.py

@ -1,3 +1,8 @@
"""自定义工具模块。
提供自定义工具的创建管理导入和执行功能
支持从 OpenAPI 规范自动导入工具定义以及手动创建自定义 HTTP 工具
"""
from .router import router from .router import router
__all__ = ["router"] __all__ = ["router"]

73
backend/modules/custom_tool/executor.py

@ -1,33 +1,69 @@
"""自定义工具执行器。
提供执行自定义 HTTP 工具的功能支持多种认证方式API KeyBearer Token HTTP 方法
"""
import httpx import httpx
import json import json
class CustomToolExecutor: class CustomToolExecutor:
"""自定义工具执行器类。
根据工具定义端点 URLHTTP 方法认证配置等执行 HTTP 请求
Attributes:
endpoint_url: API 端点的基础 URL
method: HTTP 请求方法
path: API 路径
headers: 请求头字典
auth_type: 认证类型none/api_key/bearer
auth_config: 认证配置信息
timeout: 请求超时时间
"""
def __init__(self, tool_def: dict): def __init__(self, tool_def: dict):
self.endpoint_url = tool_def.get("endpoint_url", "") """初始化工具执行器。
self.method = tool_def.get("method", "GET")
self.path = tool_def.get("path", "") Args:
self.headers = dict(tool_def.get("headers_json", {})) tool_def: 工具定义字典包含 endpoint_urlmethodpathheaders_jsonauth_typeauth_configtimeout 等字段
self.auth_type = tool_def.get("auth_type", "none") """
self.auth_config = dict(tool_def.get("auth_config", {})) self.endpoint_url = tool_def.get("endpoint_url", "") # API 基础 URL
self.timeout = int(tool_def.get("timeout", 30)) self.method = tool_def.get("method", "GET") # HTTP 请求方法,默认为 GET
self.path = tool_def.get("path", "") # API 路径
self.headers = dict(tool_def.get("headers_json", {})) # 请求头字典
self.auth_type = tool_def.get("auth_type", "none") # 认证类型,默认为无认证
self.auth_config = dict(tool_def.get("auth_config", {})) # 认证配置信息
self.timeout = int(tool_def.get("timeout", 30)) # 请求超时时间(秒),默认 30 秒
async def execute(self, params: dict) -> str: async def execute(self, params: dict) -> str:
"""执行自定义工具请求。
根据工具定义构造完整的 URL应用认证信息发送 HTTP 请求并返回响应结果
Args:
params: 请求参数GET 请求作为查询参数其他方法作为 JSON 请求体
Returns:
str: 响应内容的字符串表示最大长度 4000 字符优先返回 JSON 格式否则返回纯文本
"""
# 构造完整 URL(确保基础 URL 和路径之间只有一个斜杠)
url = f"{self.endpoint_url.rstrip('/')}/{self.path.lstrip('/')}" url = f"{self.endpoint_url.rstrip('/')}/{self.path.lstrip('/')}"
headers = dict(self.headers) headers = dict(self.headers) # 复制请求头
req_params = dict(params) req_params = dict(params) # 复制请求参数
# 根据认证类型添加认证信息
if self.auth_type == "api_key": if self.auth_type == "api_key":
key = self.auth_config.get("key", "") key = self.auth_config.get("key", "") # API Key
loc = self.auth_config.get("location", "header") loc = self.auth_config.get("location", "header") # 认证位置(header/query)
name = self.auth_config.get("name", "X-API-Key") name = self.auth_config.get("name", "X-API-Key") # 认证参数名
if loc == "header": if loc == "header":
headers[name] = key headers[name] = key # 添加到请求头
else: else:
req_params[name] = key req_params[name] = key # 添加到查询参数
elif self.auth_type == "bearer": elif self.auth_type == "bearer":
headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" headers["Authorization"] = f"Bearer {self.auth_config.get('token', '')}" # Bearer Token 认证
timeout = httpx.Timeout(self.timeout) timeout = httpx.Timeout(self.timeout) # 创建超时配置
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
if self.method == "GET": if self.method == "GET":
resp = await client.get(url, params=req_params, headers=headers) resp = await client.get(url, params=req_params, headers=headers)
@ -36,8 +72,9 @@ class CustomToolExecutor:
self.method, url, json=req_params, headers=headers self.method, url, json=req_params, headers=headers
) )
# 尝试解析 JSON 响应,否则返回纯文本
try: try:
data = resp.json() data = resp.json()
return json.dumps(data, ensure_ascii=False, indent=2)[:4000] return json.dumps(data, ensure_ascii=False, indent=2)[:4000] # 格式化 JSON,限制最大长度
except Exception: except Exception:
return resp.text[:4000] return resp.text[:4000] # 返回纯文本响应,限制最大长度

81
backend/modules/custom_tool/parser.py

@ -1,21 +1,48 @@
"""OpenAPI 规范解析器。
提供从 OpenAPI/Swagger 规范文档中自动解析 API 端点并转换为自定义工具定义的功能
"""
import json import json
from typing import Any from typing import Any
class OpenAPIParser: class OpenAPIParser:
"""OpenAPI 规范解析器类。
解析 OpenAPI 3.0 规范文档提取其中的 API 端点信息并转换为自定义工具定义
Attributes:
spec: OpenAPI 规范文档的字典表示
base_url: API 服务的基础 URL
"""
def __init__(self, spec: dict): def __init__(self, spec: dict):
self.spec = spec """初始化 OpenAPI 解析器。
self.base_url = ""
Args:
spec: OpenAPI 规范文档的字典表示包含 serverspaths 等字段
"""
self.spec = spec # OpenAPI 规范文档内容
self.base_url = "" # API 基础 URL
servers = spec.get("servers", [{}]) servers = spec.get("servers", [{}])
if servers and isinstance(servers, list): if servers and isinstance(servers, list):
self.base_url = servers[0].get("url", "") self.base_url = servers[0].get("url", "") # 获取第一个服务器 URL 作为基础地址
def parse_tools(self) -> list[dict]: def parse_tools(self) -> list[dict]:
"""解析 OpenAPI 规范中的所有 API 端点。
遍历 paths 中的所有 HTTP 方法将每个端点转换为工具定义
Returns:
list[dict]: 工具定义列表每个工具包含 namedescriptionparameterspathmethod 等信息
"""
tools = [] tools = []
paths = self.spec.get("paths", {}) paths = self.spec.get("paths", {}) # 获取所有 API 路径
for path, methods in paths.items(): for path, methods in paths.items():
if not isinstance(methods, dict): if not isinstance(methods, dict):
continue continue
for method, operation in methods.items(): for method, operation in methods.items():
# 只处理标准的 HTTP 方法
if method in ("get", "post", "put", "delete", "patch") and isinstance(operation, dict): if method in ("get", "post", "put", "delete", "patch") and isinstance(operation, dict):
tool = self._parse_endpoint(path, method, operation) tool = self._parse_endpoint(path, method, operation)
if tool: if tool:
@ -23,47 +50,71 @@ class OpenAPIParser:
return tools return tools
def _parse_endpoint(self, path: str, method: str, operation: dict) -> dict | None: def _parse_endpoint(self, path: str, method: str, operation: dict) -> dict | None:
"""解析单个 API 端点的详细信息。
Args:
path: API 路径 "/users/{id}"
method: HTTP 方法 "get""post"
operation: 端点的操作定义包含 operationIdsummaryparameters
Returns:
dict | None: 工具定义字典包含名称描述参数等信息如果解析失败返回 None
"""
# 生成工具名称:优先使用 operationId,否则从路径生成
op_id = operation.get("operationId", "") op_id = operation.get("operationId", "")
if not op_id: if not op_id:
op_id = f"{method}_{path.replace('/', '_').strip('_')}" op_id = f"{method}_{path.replace('/', '_').strip('_')}"
# 生成工具描述:优先使用 summary,其次 description,最后使用方法和路径
description = operation.get("summary") or operation.get("description") or f"{method.upper()} {path}" description = operation.get("summary") or operation.get("description") or f"{method.upper()} {path}"
properties = self._parse_parameters(operation) properties = self._parse_parameters(operation) # 解析参数
required = [] required = []
for param in operation.get("parameters", []): for param in operation.get("parameters", []):
if isinstance(param, dict) and param.get("required"): if isinstance(param, dict) and param.get("required"):
required.append(param["name"]) required.append(param["name"]) # 收集必填参数名
return { return {
"name": op_id, "name": op_id, # 工具名称
"description": description, "description": description, # 工具描述
"parameters": { "parameters": { # 参数 Schema
"type": "object", "type": "object",
"properties": properties, "properties": properties,
"required": required, "required": required,
}, },
"path": path, "path": path, # API 路径
"method": method.upper(), "method": method.upper(), # HTTP 方法(大写)
} }
def _parse_parameters(self, operation: dict) -> dict[str, Any]: def _parse_parameters(self, operation: dict) -> dict[str, Any]:
"""解析 API 端点的参数定义。
包括查询参数路径参数请求头参数和请求体参数
Args:
operation: 端点的操作定义
Returns:
dict[str, Any]: 参数属性字典键为参数名值为参数类型和描述
"""
props = {} props = {}
# 解析 query/path/header 参数
for param in operation.get("parameters", []): for param in operation.get("parameters", []):
if not isinstance(param, dict): if not isinstance(param, dict):
continue continue
pname = param.get("name", "") pname = param.get("name", "")
if not pname: if not pname:
continue continue
schema = param.get("schema", {}) schema = param.get("schema", {}) # 参数的 Schema 定义
if not isinstance(schema, dict): if not isinstance(schema, dict):
schema = {} schema = {}
props[pname] = { props[pname] = {
"type": schema.get("type", "string"), "type": schema.get("type", "string"), # 参数类型,默认为 string
"description": param.get("description", ""), "description": param.get("description", ""), # 参数描述
} }
if "enum" in schema: if "enum" in schema: # 如果有限定值列表
props[pname]["enum"] = schema["enum"] props[pname]["enum"] = schema["enum"]
# 解析请求体(requestBody)中的 JSON Schema 属性
body = ( body = (
operation.get("requestBody", {}) operation.get("requestBody", {})
.get("content", {}) .get("content", {})

114
backend/modules/custom_tool/router.py

@ -1,3 +1,8 @@
"""自定义工具模块路由。
提供自定义工具的创建管理导入和执行功能
支持从 OpenAPI 规范自动导入工具定义以及手动创建自定义 HTTP 工具
"""
import uuid import uuid
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
@ -12,19 +17,34 @@ from modules.flow_engine.engine import ToolNodeAgent
from dependencies import get_current_user from dependencies import get_current_user
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"]) router = APIRouter(prefix="/api/custom-tools", tags=["custom_tools"])
@router.post("/import-openapi") @router.post("/import-openapi")
async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)): async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncSession = Depends(get_db)):
"""从 OpenAPI 规范 URL 导入工具定义。
自动下载 OpenAPI 文档解析其中的 API 端点并创建对应的自定义工具
Args:
req: OpenAPI 导入请求体包含 openapi_url 和可选的 base_url_override
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含导入成功工具列表的响应数据
Raises:
HTTPException: 获取 OpenAPI 文档失败或解析不到工具时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
try: try:
async with httpx.AsyncClient(timeout=30) as client: async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(req.openapi_url) resp = await client.get(req.openapi_url)
resp.raise_for_status() resp.raise_for_status()
spec = resp.json() spec = resp.json() # 解析 OpenAPI 规范 JSON
except httpx.HTTPError as e: except httpx.HTTPError as e:
raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}") raise HTTPException(400, f"获取 OpenAPI 文档失败: {e}")
except ValueError: except ValueError:
@ -35,7 +55,7 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS
if not tools: if not tools:
raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具") raise HTTPException(400, "未能从 OpenAPI 文档中解析出任何工具")
base_url = req.base_url_override or parser.base_url base_url = req.base_url_override or parser.base_url # 优先使用用户指定的基础 URL
if not base_url: if not base_url:
raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override") raise HTTPException(400, "未能确定 API 基础 URL,请提供 base_url_override")
@ -45,7 +65,7 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS
select(CustomTool).where(CustomTool.name == t["name"]) select(CustomTool).where(CustomTool.name == t["name"])
) )
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
continue continue # 跳过已存在的同名工具
tool = CustomTool( tool = CustomTool(
name=t["name"], name=t["name"],
@ -79,6 +99,19 @@ async def import_openapi(req: OpenAPIImportRequest, request: Request, db: AsyncS
@router.post("/", response_model=CustomToolOut) @router.post("/", response_model=CustomToolOut)
async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)): async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncSession = Depends(get_db)):
"""创建新的自定义工具,支持手动创建或从 OpenAPI 导入。
Args:
req: 自定义工具创建请求体
request: HTTP 请求对象
db: 异步数据库会话
Returns:
CustomToolOut: 创建后的自定义工具响应
Raises:
HTTPException: 获取 OpenAPI 文档失败或创建失败时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
user_id = uuid.UUID(user_ctx["id"]) user_id = uuid.UUID(user_ctx["id"])
@ -121,6 +154,7 @@ async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncS
await db.flush() await db.flush()
return created_tool return created_tool
# 手动创建模式
schema_json = req.tool_schema or {} schema_json = req.tool_schema or {}
if not schema_json and req.endpoint_url: if not schema_json and req.endpoint_url:
schema_json = { schema_json = {
@ -161,6 +195,14 @@ async def create_custom_tool(req: CustomToolCreate, request: Request, db: AsyncS
@router.get("/", response_model=list[CustomToolOut]) @router.get("/", response_model=list[CustomToolOut])
async def list_custom_tools(db: AsyncSession = Depends(get_db)): async def list_custom_tools(db: AsyncSession = Depends(get_db)):
"""列出所有处于活跃状态的自定义工具。
Args:
db: 异步数据库会话
Returns:
list[CustomToolOut]: 活跃自定义工具列表
"""
result = await db.execute( result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc()) select(CustomTool).where(CustomTool.is_active == True).order_by(CustomTool.updated_at.desc())
) )
@ -169,6 +211,18 @@ async def list_custom_tools(db: AsyncSession = Depends(get_db)):
@router.get("/{tool_id}", response_model=CustomToolOut) @router.get("/{tool_id}", response_model=CustomToolOut)
async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
"""获取指定自定义工具的详细信息。
Args:
tool_id: 自定义工具唯一标识 ID
db: 异步数据库会话
Returns:
CustomToolOut: 自定义工具详细信息
Raises:
HTTPException: 工具不存在时抛出异常
"""
tool = await db.get(CustomTool, tool_id) tool = await db.get(CustomTool, tool_id)
if not tool: if not tool:
raise HTTPException(404, "工具不存在") raise HTTPException(404, "工具不存在")
@ -177,6 +231,19 @@ async def get_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)
@router.put("/{tool_id}", response_model=CustomToolOut) @router.put("/{tool_id}", response_model=CustomToolOut)
async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)): async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: AsyncSession = Depends(get_db)):
"""更新自定义工具的配置信息。
Args:
tool_id: 自定义工具唯一标识 ID
req: 自定义工具更新请求体
db: 异步数据库会话
Returns:
CustomToolOut: 更新后的自定义工具响应
Raises:
HTTPException: 工具不存在时抛出异常
"""
tool = await db.get(CustomTool, tool_id) tool = await db.get(CustomTool, tool_id)
if not tool: if not tool:
raise HTTPException(404, "工具不存在") raise HTTPException(404, "工具不存在")
@ -206,16 +273,45 @@ async def update_custom_tool(tool_id: uuid.UUID, req: CustomToolUpdate, db: Asyn
@router.delete("/{tool_id}") @router.delete("/{tool_id}")
async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)): async def delete_custom_tool(tool_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
"""删除(停用)自定义工具。
采用软删除方式将工具标记为非活跃状态而非真正删除
Args:
tool_id: 自定义工具唯一标识 ID
db: 异步数据库会话
Returns:
dict: 操作结果响应
Raises:
HTTPException: 工具不存在时抛出异常
"""
tool = await db.get(CustomTool, tool_id) tool = await db.get(CustomTool, tool_id)
if not tool: if not tool:
raise HTTPException(404, "工具不存在") raise HTTPException(404, "工具不存在")
tool.is_active = False tool.is_active = False # 软删除:标记为非活跃
await db.flush() await db.flush()
return {"code": 200, "message": "工具已停用"} return {"code": 200, "message": "工具已停用"}
@router.post("/{tool_id}/test") @router.post("/{tool_id}/test")
async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)): async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSession = Depends(get_db)):
"""测试执行自定义工具。
使用自定义工具的配置信息创建执行器并执行返回执行结果
Args:
tool_id: 自定义工具唯一标识 ID
params: 测试参数传递给工具执行的参数体
db: 异步数据库会话
Returns:
dict: 包含执行结果的响应数据
Raises:
HTTPException: 工具不存在或执行失败时抛出异常
"""
tool = await db.get(CustomTool, tool_id) tool = await db.get(CustomTool, tool_id)
if not tool: if not tool:
raise HTTPException(404, "工具不存在") raise HTTPException(404, "工具不存在")
@ -239,6 +335,14 @@ async def test_custom_tool(tool_id: uuid.UUID, params: dict = None, db: AsyncSes
@router.get("/schemas/all") @router.get("/schemas/all")
async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)): async def get_all_tool_schemas(db: AsyncSession = Depends(get_db)):
"""获取所有活跃自定义工具的参数 Schema。
Args:
db: 异步数据库会话
Returns:
dict: 包含所有工具 Schema 的响应数据格式为 {工具名: schema}
"""
result = await db.execute( result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True) select(CustomTool).where(CustomTool.is_active == True)
) )

83
backend/modules/document/router.py

@ -1,3 +1,8 @@
"""文档处理模块路由。
提供文档上传解析格式修正和删除功能
支持多种文档格式文本PDFWordExcel 的处理
"""
import os import os
import uuid import uuid
import shutil import shutil
@ -20,7 +25,22 @@ async def upload_document(
request: Request = None, request: Request = None,
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024 """上传文档文件到服务器。
检查文件大小限制后保存到上传目录
Args:
file: 上传的文件对象
request: HTTP 请求对象
user: 当前登录用户信息
Returns:
DocumentUploadOut: 包含文件 ID文件名大小等信息的响应
Raises:
HTTPException: 文件大小超过限制时抛出异常
"""
max_size = settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024 # 最大允许上传大小(字节)
content = await file.read() content = await file.read()
if len(content) > max_size: if len(content) > max_size:
raise HTTPException(400, f"文件大小超过限制 ({settings.MAX_UPLOAD_SIZE_MB}MB)") raise HTTPException(400, f"文件大小超过限制 ({settings.MAX_UPLOAD_SIZE_MB}MB)")
@ -28,8 +48,8 @@ async def upload_document(
file_id = uuid.uuid4() file_id = uuid.uuid4()
os.makedirs(settings.UPLOAD_DIR, exist_ok=True) os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
ext = os.path.splitext(file.filename or "unknown")[1] ext = os.path.splitext(file.filename or "unknown")[1] # 获取文件扩展名
stored_name = f"{file_id}{ext}" stored_name = f"{file_id}{ext}" # 使用 UUID 作为存储文件名
file_path = os.path.join(settings.UPLOAD_DIR, stored_name) file_path = os.path.join(settings.UPLOAD_DIR, stored_name)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@ -51,6 +71,22 @@ async def parse_document(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""解析已上传的文档文件,提取文本内容。
根据文件扩展名选择合适的解析方式支持纯文本PDFWordExcel 等格式
Args:
file_id: 文件唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
DocumentParseResult: 包含文件内容和元数据的解析结果
Raises:
HTTPException: 文件不存在时抛出异常
"""
ext_map = {".txt", ".md", ".py", ".js", ".ts", ".json", ".xml", ".yaml", ".yml", ".csv", ".html", ".css", ".java", ".go", ".rs"} ext_map = {".txt", ".md", ".py", ".js", ".ts", ".json", ".xml", ".yaml", ".yml", ".csv", ".html", ".css", ".java", ".go", ".rs"}
os.makedirs(settings.UPLOAD_DIR, exist_ok=True) os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
@ -65,7 +101,7 @@ async def parse_document(
if not found_file: if not found_file:
raise HTTPException(404, "文件不存在") raise HTTPException(404, "文件不存在")
ext = os.path.splitext(found_filename)[1].lower() ext = os.path.splitext(found_filename)[1].lower() # 获取文件扩展名
content = "" content = ""
metadata = {"file_size": os.path.getsize(found_file), "extension": ext} metadata = {"file_size": os.path.getsize(found_file), "extension": ext}
@ -87,6 +123,7 @@ async def parse_document(
content = f"[不支持的文件类型 .{ext}] 文件: {found_filename}" content = f"[不支持的文件类型 .{ext}] 文件: {found_filename}"
metadata["type"] = "unsupported" metadata["type"] = "unsupported"
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="document.parse", action="document.parse",
@ -113,6 +150,20 @@ async def delete_document(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""删除已上传的文档文件。
Args:
file_id: 文件唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
Raises:
HTTPException: 文件不存在时抛出异常
"""
os.makedirs(settings.UPLOAD_DIR, exist_ok=True) os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
deleted = False deleted = False
@ -125,6 +176,7 @@ async def delete_document(
if not deleted: if not deleted:
raise HTTPException(404, "文件不存在") raise HTTPException(404, "文件不存在")
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="document.delete", action="document.delete",
@ -145,11 +197,25 @@ async def format_document(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""对文档内容进行格式修正。
支持 standardmarkdownjson 三种格式类型
Args:
payload: 请求体包含 content format_type 字段
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含格式化后内容的响应数据
"""
content = payload.get("content", "") content = payload.get("content", "")
format_type = payload.get("format_type", "standard") format_type = payload.get("format_type", "standard")
result = _apply_formatting(content, format_type) result = _apply_formatting(content, format_type)
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="document.format", action="document.format",
@ -165,6 +231,15 @@ async def format_document(
def _apply_formatting(content: str, format_type: str) -> str: def _apply_formatting(content: str, format_type: str) -> str:
"""应用指定的格式规则对文本内容进行格式化。
Args:
content: 待格式化的原始文本内容
format_type: 格式类型支持 standardmarkdownjson
Returns:
str: 格式化后的文本内容
"""
lines = content.splitlines() lines = content.splitlines()
result = [] result = []

19
backend/modules/flow_engine/engine.py

@ -1,3 +1,10 @@
"""流引擎核心模块。
定义 FlowEngine 流程执行引擎及各类节点 Agent包括
- FlowEngine流程图的解析与遍历执行器
- LLMNodeAgent / ToolNodeAgent / MCPNodeAgent 等各类节点处理器
"""
import json import json
import uuid import uuid
import logging import logging
@ -14,6 +21,7 @@ logger = logging.getLogger(__name__)
async def _resolve_model_instance(model_instance_id: str) -> dict | None: async def _resolve_model_instance(model_instance_id: str) -> dict | None:
"""根据模型实例 ID 从数据库解析模型配置(模型名、base_url、api_key)。"""
try: try:
from database import AsyncSessionLocal from database import AsyncSessionLocal
from sqlalchemy import text from sqlalchemy import text
@ -43,24 +51,31 @@ async def _resolve_model_instance(model_instance_id: str) -> dict | None:
class FlowSessionMemory: class FlowSessionMemory:
"""流程会话级短期记忆,存储当前对话轮次的消息列表。"""
def __init__(self, session_id: str = "", user_id: str = ""): def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id self.session_id = session_id
self.user_id = user_id self.user_id = user_id
self._messages: list[dict] = [] self._messages: list[dict] = []
def get_history(self, limit: int = 10) -> list[dict]: def get_history(self, limit: int = 10) -> list[dict]:
"""获取最近的消息历史。"""
return self._messages[-limit * 2:] return self._messages[-limit * 2:]
def add(self, role: str, content: str): def add(self, role: str, content: str):
"""添加一条消息到历史。"""
self._messages.append({"role": role, "content": content}) self._messages.append({"role": role, "content": content})
def to_list(self) -> list[dict]: def to_list(self) -> list[dict]:
"""返回全部消息列表。"""
return list(self._messages) return list(self._messages)
class FlowEngine: class FlowEngine:
MAX_TOTAL_ITERATIONS = 200 """流程执行引擎,解析流程图定义并按拓扑顺序遍历执行各节点。"""
FLOW_TIMEOUT_SECONDS = 300
MAX_TOTAL_ITERATIONS = 200 # 全局最大迭代次数(防止死循环)
FLOW_TIMEOUT_SECONDS = 300 # 单次流程执行超时时间
def __init__(self, flow_definition: dict): def __init__(self, flow_definition: dict):
self.definition = flow_definition self.definition = flow_definition

8
backend/modules/flow_engine/gateway.py

@ -1,3 +1,11 @@
"""流引擎网关路由。
提供符合 OpenAI Dify 兼容格式的外部 API 网关包括
- /v1/chat-messages: 对话型流程执行支持 blocking / streaming
- /v1/workflows/run: 工作流型流程执行
- /v1/flows/{flow_id}/parameters: 查询流程输入参数
"""
import uuid import uuid
import time import time
import json import json

6
backend/modules/flow_engine/router.py

@ -1,3 +1,9 @@
"""流程定义管理路由。
提供流程的 CRUD发布/下架版本管理执行SSE 流式执行
API Key 管理执行历史查询以及市场模板等功能
"""
import uuid import uuid
import time import time
import json import json

109
backend/modules/mcp_registry/router.py

@ -1,3 +1,8 @@
"""MCP 服务注册模块路由。
提供 Model Context Protocol (MCP) 服务的注册管理测试和工具发现功能
支持 HTTP 传输方式的 MCP 服务接入
"""
import uuid import uuid
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
@ -13,6 +18,16 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
@router.get("/servers", response_model=list[MCPServiceOut]) @router.get("/servers", response_model=list[MCPServiceOut])
async def list_servers(request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): async def list_servers(request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
"""列出所有已注册的 MCP 服务。
Args:
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
list[MCPServiceOut]: MCP 服务列表
"""
result = await db.execute( result = await db.execute(
select(MCPService).order_by(MCPService.updated_at.desc()) select(MCPService).order_by(MCPService.updated_at.desc())
) )
@ -21,6 +36,20 @@ async def list_servers(request: Request, db: AsyncSession = Depends(get_db), use
@router.get("/servers/{server_id}", response_model=MCPServiceOut) @router.get("/servers/{server_id}", response_model=MCPServiceOut)
async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): async def get_server(server_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
"""获取指定 MCP 服务的详细信息。
Args:
server_id: MCP 服务唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
MCPServiceOut: MCP 服务详细信息
Raises:
HTTPException: 服务不存在时抛出异常
"""
result = await db.execute(select(MCPService).where(MCPService.id == server_id)) result = await db.execute(select(MCPService).where(MCPService.id == server_id))
server = result.scalar_one_or_none() server = result.scalar_one_or_none()
if not server: if not server:
@ -35,6 +64,20 @@ async def register_server(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""注册新的 MCP 服务。
Args:
req: MCP 服务创建请求体
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
MCPServiceOut: 注册后的 MCP 服务响应
Raises:
HTTPException: 服务名称已存在时抛出异常
"""
existing = await db.execute(select(MCPService).where(MCPService.name == req.name)) existing = await db.execute(select(MCPService).where(MCPService.name == req.name))
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException(400, "服务名称已存在") raise HTTPException(400, "服务名称已存在")
@ -50,6 +93,7 @@ async def register_server(
) )
db.add(server) db.add(server)
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="mcp.register", action="mcp.register",
@ -70,6 +114,21 @@ async def update_server(
request: Request, db: AsyncSession = Depends(get_db), request: Request, db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""更新 MCP 服务的配置信息。
Args:
server_id: MCP 服务唯一标识 ID
req: MCP 服务更新请求体
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
MCPServiceOut: 更新后的 MCP 服务响应
Raises:
HTTPException: 服务不存在时抛出异常
"""
result = await db.execute(select(MCPService).where(MCPService.id == server_id)) result = await db.execute(select(MCPService).where(MCPService.id == server_id))
server = result.scalar_one_or_none() server = result.scalar_one_or_none()
if not server: if not server:
@ -86,6 +145,7 @@ async def update_server(
if req.env is not None: if req.env is not None:
server.env = req.env server.env = req.env
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="mcp.update", action="mcp.update",
@ -104,12 +164,27 @@ async def delete_server(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""注销指定的 MCP 服务。
Args:
server_id: MCP 服务唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
Raises:
HTTPException: 服务不存在时抛出异常
"""
result = await db.execute(select(MCPService).where(MCPService.id == server_id)) result = await db.execute(select(MCPService).where(MCPService.id == server_id))
server = result.scalar_one_or_none() server = result.scalar_one_or_none()
if not server: if not server:
raise HTTPException(404, "MCP服务不存在") raise HTTPException(404, "MCP服务不存在")
await db.delete(server) await db.delete(server)
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="mcp.delete", action="mcp.delete",
@ -129,6 +204,22 @@ async def test_connection(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""测试 MCP 服务的连接状态并发现可用工具。
对于 HTTP 传输的服务访问 /.well-known/mcp 端点进行连接测试
Args:
server_id: MCP 服务唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含连接测试结果和工具列表的响应数据
Raises:
HTTPException: 服务不存在时抛出异常
"""
result = await db.execute(select(MCPService).where(MCPService.id == server_id)) result = await db.execute(select(MCPService).where(MCPService.id == server_id))
server = result.scalar_one_or_none() server = result.scalar_one_or_none()
if not server: if not server:
@ -141,7 +232,7 @@ async def test_connection(
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(server.url.rstrip("/") + "/.well-known/mcp") resp = await client.get(server.url.rstrip("/") + "/.well-known/mcp")
if resp.status_code == 200: if resp.status_code == 200:
test_results["connectivity"] = True test_results["connectivity"] = True # 连接成功
data = resp.json() data = resp.json()
tools = data.get("tools", []) tools = data.get("tools", [])
test_results["tools_discovered"] = len(tools) test_results["tools_discovered"] = len(tools)
@ -155,6 +246,7 @@ async def test_connection(
test_results["error"] = str(e) test_results["error"] = str(e)
server.status = "error" server.status = "error"
# 记录审计日志
audit = AuditLog( audit = AuditLog(
operator_id=uuid.UUID(user["id"]), operator_id=uuid.UUID(user["id"]),
action="mcp.test", action="mcp.test",
@ -174,6 +266,21 @@ async def discover_tools(
server_id: uuid.UUID, request: Request, server_id: uuid.UUID, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""从 MCP 服务中发现并注册可用工具。
对于 HTTP 传输的服务调用 /tools/list 端点获取工具列表
Args:
server_id: MCP 服务唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含发现的工具列表和数量的响应数据
Raises:
HTTPException: 服务不存在或工具发现失败时抛出异常
"""
result = await db.execute(select(MCPService).where(MCPService.id == server_id)) result = await db.execute(select(MCPService).where(MCPService.id == server_id))
server = result.scalar_one_or_none() server = result.scalar_one_or_none()
if not server: if not server:

135
backend/modules/memory/manager.py

@ -1,3 +1,15 @@
"""三级记忆系统管理器。
记忆架构分为三个层次L1/L2/L3
- L1原子层从对话中提取关键信息原子用户偏好事件指令
- L2场景层对同类原子进行归纳形成场景摘要
- L3画像层综合所有信息生成用户画像Persona
数据存储
- PG主存储持久化记忆消息原子场景画像
- Redis缓存近期消息缓存对话摘要缓存
"""
import json import json
import asyncio import asyncio
import uuid import uuid
@ -16,6 +28,7 @@ _memory_manager: "MemoryManager | None" = None
def get_memory_manager() -> "MemoryManager": def get_memory_manager() -> "MemoryManager":
"""获取全局 MemoryManager 单例。"""
global _memory_manager global _memory_manager
if _memory_manager is None: if _memory_manager is None:
raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()") raise RuntimeError("MemoryManager 未初始化,请先调用 init_memory_manager()")
@ -23,6 +36,7 @@ def get_memory_manager() -> "MemoryManager":
async def init_memory_manager(db_factory: Callable[[], AsyncSession]): async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
"""初始化记忆管理器,创建 Redis 连接并实例化 MemoryManager。"""
global _memory_manager global _memory_manager
redis = Redis.from_url(settings.REDIS_URL, decode_responses=True) redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
await redis.ping() await redis.ping()
@ -30,19 +44,27 @@ async def init_memory_manager(db_factory: Callable[[], AsyncSession]):
class MemoryManager: class MemoryManager:
MAX_HISTORY = 40 """三级记忆管理器,负责记忆的存储、检索、提取与归纳。"""
REDIS_CACHE_SIZE = 10
REDIS_CACHE_TTL = 300 MAX_HISTORY = 40 # 单次注入的最大历史消息数
SUMMARY_CACHE_KEY = "mem:summary" REDIS_CACHE_SIZE = 10 # Redis 缓存保留的最近消息数
MSG_CACHE_KEY = "mem:cache:msgs" REDIS_CACHE_TTL = 300 # Redis 缓存 TTL(秒)
ATOM_EXTRACT_EVERY = 10 SUMMARY_CACHE_KEY = "mem:summary" # 摘要缓存 Redis key 前缀
SCENE_EXTRACT_EVERY = 50 MSG_CACHE_KEY = "mem:cache:msgs" # 消息缓存 Redis key 前缀
PERSONA_UPDATE_EVERY = 30 ATOM_EXTRACT_EVERY = 10 # 每 N 条消息触发一次 L1 原子提取
SCENE_EXTRACT_EVERY = 50 # 每 N 条原子触发一次 L2 场景提取
PERSONA_UPDATE_EVERY = 30 # 每 N 条消息触发一次 L3 画像更新
def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis): def __init__(self, db_factory: Callable[[], AsyncSession], redis: Redis):
"""初始化记忆管理器。
Args:
db_factory: 异步数据库会话工厂
redis: Redis 异步客户端实例
"""
self.db_factory = db_factory self.db_factory = db_factory
self.redis = redis self.redis = redis
self._extract_tasks: dict[str, asyncio.Task] = {} self._extract_tasks: dict[str, asyncio.Task] = {} # 后台提取任务追踪
async def inject_memory( async def inject_memory(
self, self,
@ -51,6 +73,17 @@ class MemoryManager:
session_id: str, session_id: str,
context: dict, context: dict,
): ):
"""向对话上下文中注入三层记忆信息。
PG/Redis 中获取近期消息摘要原子记忆和画像
合并后注入到 context["_memory_context"] 中供 LLM 使用
Args:
user_id: 用户 ID
flow_id: 流程 ID
session_id: 会话 ID
context: 对话上下文字典会在原地被修改
"""
uid = uuid.UUID(user_id) uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id) fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id) sid = uuid.UUID(session_id)
@ -84,6 +117,19 @@ class MemoryManager:
assistant_msg: str, assistant_msg: str,
flow_name: str = "", flow_name: str = "",
): ):
"""记录一次用户-助手对话交换。
将用户消息和助手消息写入 PG同时更新 Redis 缓存
并异步触发 L1/L1/L2/L3 记忆提取任务
Args:
user_id: 用户 ID
flow_id: 流程 ID
session_id: 会话 ID
user_msg: 用户消息内容
assistant_msg: 助手回复内容
flow_name: 流程名称可选用于会话记录
"""
uid = uuid.UUID(user_id) uid = uuid.UUID(user_id)
fid = uuid.UUID(flow_id) fid = uuid.UUID(flow_id)
sid = uuid.UUID(session_id) sid = uuid.UUID(session_id)
@ -167,12 +213,24 @@ class MemoryManager:
async def get_conversation_history( async def get_conversation_history(
self, user_id: str, flow_id: str, session_id: str, limit: int = 20 self, user_id: str, flow_id: str, session_id: str, limit: int = 20
) -> list[dict]: ) -> list[dict]:
"""获取指定会话的对话历史。
Args:
user_id: 用户 ID
flow_id: 流程 ID
session_id: 会话 ID
limit: 返回的最大消息数
Returns:
消息列表每项含 role/content/ts 字段
"""
uid = uuid.UUID(user_id) uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id) sid = uuid.UUID(session_id)
fid = uuid.UUID(flow_id) if flow_id else None fid = uuid.UUID(flow_id) if flow_id else None
return await self._pg_get_recent(uid, fid, sid, limit) return await self._pg_get_recent(uid, fid, sid, limit)
async def delete_session(self, user_id: str, session_id: str): async def delete_session(self, user_id: str, session_id: str):
"""删除指定会话的所有记忆数据(PG + Redis)。"""
uid = uuid.UUID(user_id) uid = uuid.UUID(user_id)
sid = uuid.UUID(session_id) sid = uuid.UUID(session_id)
@ -199,6 +257,14 @@ class MemoryManager:
pass pass
async def list_user_sessions(self, user_id: str) -> list[dict]: async def list_user_sessions(self, user_id: str) -> list[dict]:
"""列出用户的所有记忆会话。
Args:
user_id: 用户 ID
Returns:
会话列表每项含 session_id/flow_id/flow_name/last_active_at
"""
uid = uuid.UUID(user_id) uid = uuid.UUID(user_id)
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:
@ -226,6 +292,7 @@ class MemoryManager:
return [] return []
async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]: async def _pg_get_recent(self, uid: uuid.UUID, fid: uuid.UUID | None, sid: uuid.UUID, limit: int) -> list[dict]:
"""从 PG 查询最近的对话消息。"""
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:
if fid: if fid:
@ -259,6 +326,7 @@ class MemoryManager:
return [] return []
async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None: async def _redis_get_recent(self, uid: uuid.UUID, sid: uuid.UUID) -> list[dict] | None:
"""从 Redis 读取缓存的消息列表。"""
try: try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
raw = await self.redis.get(cache_key) raw = await self.redis.get(cache_key)
@ -269,6 +337,7 @@ class MemoryManager:
return None return None
async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]): async def _redis_set_recent(self, uid: uuid.UUID, sid: uuid.UUID, messages: list[dict]):
"""将消息列表写入 Redis 缓存。"""
try: try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
top = messages[-self.REDIS_CACHE_SIZE:] top = messages[-self.REDIS_CACHE_SIZE:]
@ -277,6 +346,7 @@ class MemoryManager:
pass pass
async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]): async def _redis_append_recent(self, uid: uuid.UUID, sid: uuid.UUID, new_msgs: list[dict]):
"""追加新消息到 Redis 缓存。"""
try: try:
cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}" cache_key = f"{self.MSG_CACHE_KEY}:{uid}:{sid}"
existing = await self._redis_get_recent(uid, sid) or [] existing = await self._redis_get_recent(uid, sid) or []
@ -287,6 +357,7 @@ class MemoryManager:
pass pass
async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str: async def _get_summary(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID) -> str:
"""从 Redis 读取对话摘要。"""
try: try:
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
val = await self.redis.get(summary_key) val = await self.redis.get(summary_key)
@ -295,6 +366,10 @@ class MemoryManager:
return "" return ""
async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): async def _maybe_summarize(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
"""L1 条件触发:对话摘要生成。
当消息数 >= 30 且尚无摘要时调用 LLM 生成摘要并缓存到 Redis
"""
try: try:
summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}" summary_key = f"{self.SUMMARY_CACHE_KEY}:{uid}:{sid}"
existing = await self.redis.get(summary_key) existing = await self.redis.get(summary_key)
@ -342,6 +417,7 @@ class MemoryManager:
pass pass
async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]: async def _get_relevant_atoms(self, uid: uuid.UUID, fid: uuid.UUID) -> list[dict]:
"""从 PG 查询与用户/流程相关的高优先级原子记忆。"""
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:
result = await db.execute( result = await db.execute(
@ -363,6 +439,7 @@ class MemoryManager:
return [] return []
async def _get_persona(self, uid: uuid.UUID) -> dict: async def _get_persona(self, uid: uuid.UUID) -> dict:
"""从 PG 查询用户画像。"""
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:
result = await db.execute( result = await db.execute(
@ -377,6 +454,10 @@ class MemoryManager:
return {} return {}
async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID): async def _maybe_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID):
"""L1 条件触发:原子记忆提取。
当消息数达到 ATOM_EXTRACT_EVERY 的整数倍时调用 LLM 从对话中提取信息原子
"""
try: try:
task_key = f"extract_atoms:{uid}:{fid}" task_key = f"extract_atoms:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
@ -413,6 +494,7 @@ class MemoryManager:
pass pass
async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str): async def _do_extract_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, dialogue: str):
"""执行 L1 原子记忆提取:调用 LLM 从对话中提取结构化记忆原子。"""
try: try:
prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。 prompt = f"""请从以下对话中提取关键的结构化记忆原子。每个原子是一个独立的、可检索的事实或信息片段。
@ -460,6 +542,12 @@ class MemoryManager:
logger.warning(f"L1原子提取失败: {e}") logger.warning(f"L1原子提取失败: {e}")
async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]): async def _dedup_and_store_atoms(self, uid: uuid.UUID, fid: uuid.UUID, sid: uuid.UUID, atoms: list[dict]):
"""对提取的原子进行去重并存储到 PG。
使用文本相似度判断是否与已有原子重复
- 相似度 > 75% 时更新优先级和元数据
- 否则插入新原子记录
"""
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:
existing = await db.execute( existing = await db.execute(
@ -483,7 +571,7 @@ class MemoryManager:
await db.execute( await db.execute(
text(""" text("""
UPDATE memory_atoms UPDATE memory_atoms
SET priority = GREATEST(priority, :priority), SET priority = GREATER(priority, :priority),
updated_at = NOW(), updated_at = NOW(),
metadata = metadata || :meta metadata = metadata || :meta
WHERE id = :id WHERE id = :id
@ -520,6 +608,7 @@ class MemoryManager:
@staticmethod @staticmethod
def _text_similarity(a: str, b: str) -> float: def _text_similarity(a: str, b: str) -> float:
"""计算两段文本的 Jaccard 相似度(基于单词集合)。"""
a_words = set(a.lower().split()) a_words = set(a.lower().split())
b_words = set(b.lower().split()) b_words = set(b.lower().split())
if not a_words or not b_words: if not a_words or not b_words:
@ -529,6 +618,11 @@ class MemoryManager:
return len(intersection) / len(union) return len(intersection) / len(union)
async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID): async def _maybe_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID):
"""L2 条件触发:场景提取。
当原子数达到 SCENE_EXTRACT_EVERY 且距上次提取超过 12 小时时
调用 LLM 对已有原子进行场景归纳
"""
try: try:
task_key = f"extract_scenes:{uid}:{fid}" task_key = f"extract_scenes:{uid}:{fid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
@ -564,7 +658,7 @@ class MemoryManager:
latest_scene = atoms_result.fetchone() latest_scene = atoms_result.fetchone()
if latest_scene: if latest_scene:
from datetime import timezone, timedelta from datetime import timezone, timedelta
ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone.utc) ago = datetime.now(timezone.utc) - latest_scene[0].replace(tzinfo=timezone=timezone.utc)
if ago < timedelta(hours=12): if ago < timedelta(hours=12):
return return
@ -587,6 +681,7 @@ class MemoryManager:
pass pass
async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]): async def _do_extract_scenes(self, uid: uuid.UUID, fid: uuid.UUID, atoms: list[dict]):
"""执行 L2 场景提取:调用 LLM 将原子记忆归纳为场景块。"""
try: try:
atoms_text = "\n".join( atoms_text = "\n".join(
f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms f"[{a['type']}/{a['priority']}] {a['content']}" for a in atoms
@ -651,6 +746,11 @@ class MemoryManager:
logger.warning(f"L2场景提取失败: {e}") logger.warning(f"L2场景提取失败: {e}")
async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID): async def _maybe_update_persona(self, uid: uuid.UUID, fid: uuid.UUID):
"""L3 条件触发:用户画像更新。
当消息数达到 PERSONA_UPDATE_EVERY 且距上次更新超过 6 小时时
基于已有 persona 类型原子重新生成画像
"""
try: try:
task_key = f"update_persona:{uid}" task_key = f"update_persona:{uid}"
if task_key in self._extract_tasks and not self._extract_tasks[task_key].done(): if task_key in self._extract_tasks and not self._extract_tasks[task_key].done():
@ -699,6 +799,7 @@ class MemoryManager:
pass pass
async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int): async def _do_update_persona(self, uid: uuid.UUID, persona_text: str, version: int):
"""执行 L3 画像更新:调用 LLM 生成结构化用户画像并持久化。"""
try: try:
prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。 prompt = f"""请根据以下用户信息片段,生成一份结构化的用户画像。
@ -780,6 +881,18 @@ class MemoryManager:
top_k: int = 10, top_k: int = 10,
embedding: list[float] = None, embedding: list[float] = None,
) -> list[dict]: ) -> list[dict]:
"""混合检索记忆:向量相似度 + 全文检索,使用 RRF 算法融合排序。
Args:
uid: 用户 ID
query: 搜索查询文本
fid: 流程 ID可选过滤范围
top_k: 返回结果数
embedding: 查询向量可选启用向量检索
Returns:
RRF 分数降序排列的记忆原子列表
"""
results = [] results = []
try: try:
async with self.db_factory() as db: async with self.db_factory() as db:

5
backend/modules/memory/router.py

@ -1,3 +1,5 @@
"""记忆管理路由:会话列表查询、历史获取、会话清除。"""
from fastapi import APIRouter, Request, Depends, HTTPException from fastapi import APIRouter, Request, Depends, HTTPException
from dependencies import get_current_user from dependencies import get_current_user
from modules.memory.manager import get_memory_manager from modules.memory.manager import get_memory_manager
@ -7,6 +9,7 @@ router = APIRouter(prefix="/api/memory", tags=["记忆管理"])
@router.get("/sessions") @router.get("/sessions")
async def list_sessions(request: Request, user=Depends(get_current_user)): async def list_sessions(request: Request, user=Depends(get_current_user)):
"""获取当前用户的所有记忆会话列表。"""
mm = get_memory_manager() mm = get_memory_manager()
sessions = await mm.list_user_sessions(str(user.id)) sessions = await mm.list_user_sessions(str(user.id))
return {"code": 200, "data": sessions} return {"code": 200, "data": sessions}
@ -14,6 +17,7 @@ async def list_sessions(request: Request, user=Depends(get_current_user)):
@router.get("/sessions/{session_id}") @router.get("/sessions/{session_id}")
async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)): async def get_session(session_id: str, request: Request, flow_id: str = "", user=Depends(get_current_user)):
"""获取指定会话的对话历史。"""
mm = get_memory_manager() mm = get_memory_manager()
history = await mm.get_conversation_history( history = await mm.get_conversation_history(
user_id=str(user.id), user_id=str(user.id),
@ -25,6 +29,7 @@ async def get_session(session_id: str, request: Request, flow_id: str = "", user
@router.delete("/sessions/{session_id}") @router.delete("/sessions/{session_id}")
async def clear_session(session_id: str, request: Request, user=Depends(get_current_user)): async def clear_session(session_id: str, request: Request, user=Depends(get_current_user)):
"""清除指定会话的所有记忆数据。"""
mm = get_memory_manager() mm = get_memory_manager()
await mm.delete_session(str(user.id), session_id) await mm.delete_session(str(user.id), session_id)
return {"code": 200, "message": "记忆已清除"} return {"code": 200, "message": "记忆已清除"}

5
backend/modules/memory/schemas.py

@ -1,8 +1,11 @@
"""记忆管理模块的 Pydantic 请求/响应模型。"""
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from datetime import datetime from datetime import datetime
class MemorySessionOut(BaseModel): class MemorySessionOut(BaseModel):
"""记忆会话概要响应体。"""
session_id: str session_id: str
flow_id: str flow_id: str
flow_name: str flow_name: str
@ -10,10 +13,12 @@ class MemorySessionOut(BaseModel):
class ConversationMessage(BaseModel): class ConversationMessage(BaseModel):
"""单条对话消息。"""
role: str role: str
content: str content: str
ts: str = "" ts: str = ""
class ClearSessionRequest(BaseModel): class ClearSessionRequest(BaseModel):
"""清除会话请求体。"""
session_id: str session_id: str

118
backend/modules/model_provider/router.py

@ -1,3 +1,8 @@
"""模型供应商模块路由。
提供模型供应商和模型实例的 CRUD 管理功能
支持多供应商接入和模型能力的统一管理
"""
import uuid import uuid
import logging import logging
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@ -7,12 +12,21 @@ from database import get_db
from models import ModelProvider, ModelInstance from models import ModelProvider, ModelInstance
from dependencies import get_current_user from dependencies import get_current_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"]) router = APIRouter(prefix="/api/model-providers", tags=["模型供应商"])
@router.get("") @router.get("")
async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有已注册的模型供应商。
Args:
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含模型供应商列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(ModelProvider).order_by(ModelProvider.created_at.desc()) select(ModelProvider).order_by(ModelProvider.created_at.desc())
) )
@ -34,6 +48,19 @@ async def list_providers(db: AsyncSession = Depends(get_db), user=Depends(get_cu
@router.post("") @router.post("")
async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""注册新的模型供应商。
Args:
payload: 请求体包含 nameprovider_typebase_urlapi_keyextra_config 字段
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含新供应商 ID 的响应数据
Raises:
HTTPException: 相同 base_url 的供应商已存在时抛出异常
"""
existing = await db.execute( existing = await db.execute(
select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", "")) select(ModelProvider).where(ModelProvider.base_url == payload.get("base_url", ""))
) )
@ -54,6 +81,20 @@ async def create_provider(payload: dict, db: AsyncSession = Depends(get_db), use
@router.put("/{provider_id}") @router.put("/{provider_id}")
async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def update_provider(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型供应商的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID
payload: 请求体包含要更新的供应商字段
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含更新后供应商 ID 的响应数据
Raises:
HTTPException: 供应商不存在时抛出异常
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id)) p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p: if not p:
raise HTTPException(404, "供应商不存在") raise HTTPException(404, "供应商不存在")
@ -70,6 +111,19 @@ async def update_provider(provider_id: str, payload: dict, db: AsyncSession = De
@router.delete("/{provider_id}") @router.delete("/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型供应商。
Args:
provider_id: 模型供应商唯一标识 ID
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
Raises:
HTTPException: 供应商不存在时抛出异常
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id)) p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p: if not p:
raise HTTPException(404, "供应商不存在") raise HTTPException(404, "供应商不存在")
@ -80,6 +134,15 @@ async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db),
@router.get("/models/all") @router.get("/models/all")
async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出所有处于活跃状态的模型实例。
Args:
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含所有活跃模型实例列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(ModelInstance) select(ModelInstance)
.where(ModelInstance.is_active == True) .where(ModelInstance.is_active == True)
@ -101,6 +164,16 @@ async def list_all_models(db: AsyncSession = Depends(get_db), user=Depends(get_c
@router.get("/{provider_id}/models") @router.get("/{provider_id}/models")
async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""列出指定供应商下的所有模型实例。
Args:
provider_id: 模型供应商唯一标识 ID
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含模型实例列表的响应数据
"""
result = await db.execute( result = await db.execute(
select(ModelInstance) select(ModelInstance)
.where(ModelInstance.provider_id == uuid.UUID(provider_id)) .where(ModelInstance.provider_id == uuid.UUID(provider_id))
@ -127,6 +200,20 @@ async def list_models(provider_id: str, db: AsyncSession = Depends(get_db), user
@router.post("/{provider_id}/models") @router.post("/{provider_id}/models")
async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""在指定供应商下添加新的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID
payload: 请求体包含 model_namemodel_typedisplay_namecapabilitiesdefault_paramsis_default 字段
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含新模型 ID 的响应数据
Raises:
HTTPException: 供应商不存在或相同名称的模型已存在时抛出异常
"""
p = await db.get(ModelProvider, uuid.UUID(provider_id)) p = await db.get(ModelProvider, uuid.UUID(provider_id))
if not p: if not p:
raise HTTPException(404, "供应商不存在") raise HTTPException(404, "供应商不存在")
@ -156,6 +243,21 @@ async def create_model(provider_id: str, payload: dict, db: AsyncSession = Depen
@router.put("/{provider_id}/models/{model_id}") @router.put("/{provider_id}/models/{model_id}")
async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def update_model(provider_id: str, model_id: str, payload: dict, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""更新模型实例的配置信息。
Args:
provider_id: 模型供应商唯一标识 ID
model_id: 模型实例唯一标识 ID
payload: 请求体包含要更新的模型字段
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 包含更新后模型 ID 的响应数据
Raises:
HTTPException: 模型不存在时抛出异常
"""
m = await db.get(ModelInstance, uuid.UUID(model_id)) m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id: if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在") raise HTTPException(404, "模型不存在")
@ -173,6 +275,20 @@ async def update_model(provider_id: str, model_id: str, payload: dict, db: Async
@router.delete("/{provider_id}/models/{model_id}") @router.delete("/{provider_id}/models/{model_id}")
async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)): async def delete_model(provider_id: str, model_id: str, db: AsyncSession = Depends(get_db), user=Depends(get_current_user)):
"""删除指定的模型实例。
Args:
provider_id: 模型供应商唯一标识 ID
model_id: 模型实例唯一标识 ID
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
Raises:
HTTPException: 模型不存在时抛出异常
"""
m = await db.get(ModelInstance, uuid.UUID(model_id)) m = await db.get(ModelInstance, uuid.UUID(model_id))
if not m or str(m.provider_id) != provider_id: if not m or str(m.provider_id) != provider_id:
raise HTTPException(404, "模型不存在") raise HTTPException(404, "模型不存在")

5
backend/modules/monitor/__init__.py

@ -0,0 +1,5 @@
"""监控模块。
提供员工监控功能包括员工列表查询个人数据看板AI 辅助的员工交互分析等
支持基于数据权限范围的访问控制
"""

87
backend/modules/monitor/router.py

@ -1,3 +1,8 @@
"""监控模块路由。
提供员工监控功能包括员工列表查询个人数据看板AI 辅助的员工交互分析等
支持基于数据权限范围all/subordinate_only/self_only的访问控制
"""
import uuid import uuid
import json import json
from datetime import datetime from datetime import datetime
@ -9,25 +14,42 @@ from models import User, ChatSession, ChatMessage
from modules.org.router import _get_subordinate_ids, _user_to_out from modules.org.router import _get_subordinate_ids, _user_to_out
from schemas import EmployeeAnalysis, UserOut from schemas import EmployeeAnalysis, UserOut
router = APIRouter(prefix="/api/monitor", tags=["monitor"]) router = APIRouter(prefix="/api/monitor", tags=["monitor"]) # 监控模块路由前缀
@router.get("/employees", response_model=list[UserOut]) @router.get("/employees", response_model=list[UserOut])
async def get_monitor_employees(request: Request, db: AsyncSession = Depends(get_db)): async def get_monitor_employees(request: Request, db: AsyncSession = Depends(get_db)):
"""获取可监控的员工列表。
根据当前用户的数据权限范围返回不同的员工列表
- all返回所有活跃员工
- subordinate_only返回当前用户及其所有下级员工
- self_only仅返回当前用户
Args:
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
list[UserOut]: 员工信息列表
"""
user_ctx = request.state.user user_ctx = request.state.user
cur_id = uuid.UUID(user_ctx["id"]) cur_id = uuid.UUID(user_ctx["id"])
if user_ctx["data_scope"] == "all": if user_ctx["data_scope"] == "all":
# 数据权限为全部:返回所有活跃员工
result = await db.execute(select(User).where(User.status == "active")) result = await db.execute(select(User).where(User.status == "active"))
return [await _user_to_out(db, u) for u in result.scalars().all()] return [await _user_to_out(db, u) for u in result.scalars().all()]
elif user_ctx["data_scope"] == "subordinate_only": elif user_ctx["data_scope"] == "subordinate_only":
# 数据权限为下级:返回当前用户及其所有下级
sub_ids = await _get_subordinate_ids(db, cur_id) sub_ids = await _get_subordinate_ids(db, cur_id)
sub_ids.add(cur_id) sub_ids.add(cur_id)
result = await db.execute(select(User).where(User.id.in_(sub_ids))) result = await db.execute(select(User).where(User.id.in_(sub_ids)))
return [await _user_to_out(db, u) for u in result.scalars().all()] return [await _user_to_out(db, u) for u in result.scalars().all()]
else: else:
# 数据权限为仅自己:仅返回当前用户
result = await db.execute(select(User).where(User.id == cur_id)) result = await db.execute(select(User).where(User.id == cur_id))
return [await _user_to_out(db, u) for u in result.scalars().all()] return [await _user_to_out(db, u) for u in result.scalars().all()]
@ -36,9 +58,26 @@ async def get_monitor_employees(request: Request, db: AsyncSession = Depends(get
async def get_employee_dashboard( async def get_employee_dashboard(
emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db) emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)
): ):
"""获取指定员工的个人数据看板。
包括员工基本信息消息统计会话统计活跃天数消息分类和最近交互记录
需要进行数据权限校验
Args:
emp_id: 员工唯一标识 ID
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
dict: 包含员工信息和统计数据的看板数据
Raises:
HTTPException: 无权查看或员工不存在时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
cur_id = uuid.UUID(user_ctx["id"]) cur_id = uuid.UUID(user_ctx["id"])
# 数据权限校验
if user_ctx["data_scope"] != "all": if user_ctx["data_scope"] != "all":
if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]: if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]:
raise HTTPException(403, "无权查看此员工数据") raise HTTPException(403, "无权查看此员工数据")
@ -48,21 +87,25 @@ async def get_employee_dashboard(
if emp_id not in sub_ids: if emp_id not in sub_ids:
raise HTTPException(403, "无权查看此员工数据") raise HTTPException(403, "无权查看此员工数据")
# 查询员工信息
emp_result = await db.execute(select(User).where(User.id == emp_id)) emp_result = await db.execute(select(User).where(User.id == emp_id))
emp = emp_result.scalar_one_or_none() emp = emp_result.scalar_one_or_none()
if not emp: if not emp:
raise HTTPException(404, "员工不存在") raise HTTPException(404, "员工不存在")
# 统计总消息数
total_msgs_result = await db.execute( total_msgs_result = await db.execute(
select(func.count(ChatMessage.id)).where(ChatMessage.user_id == emp_id) select(func.count(ChatMessage.id)).where(ChatMessage.user_id == emp_id)
) )
total_messages = total_msgs_result.scalar() or 0 total_messages = total_msgs_result.scalar() or 0
# 统计总会话数
session_result = await db.execute( session_result = await db.execute(
select(func.count(ChatSession.id)).where(ChatSession.user_id == emp_id) select(func.count(ChatSession.id)).where(ChatSession.user_id == emp_id)
) )
total_sessions = session_result.scalar() or 0 total_sessions = session_result.scalar() or 0
# 查询最近 50 条消息
recent_msgs_result = await db.execute( recent_msgs_result = await db.execute(
select(ChatMessage) select(ChatMessage)
.where(ChatMessage.user_id == emp_id) .where(ChatMessage.user_id == emp_id)
@ -71,13 +114,14 @@ async def get_employee_dashboard(
) )
recent = recent_msgs_result.scalars().all() recent = recent_msgs_result.scalars().all()
# 统计话题分布和活跃天数
topics = {} topics = {}
active_days = set() active_days = set()
for msg in recent: for msg in recent:
if msg.created_at: if msg.created_at:
active_days.add(msg.created_at.strftime("%Y-%m-%d")) active_days.add(msg.created_at.strftime("%Y-%m-%d")) # 记录活跃日期
role = msg.role role = msg.role
topics[role] = topics.get(role, 0) + 1 topics[role] = topics.get(role, 0) + 1 # 按角色统计消息数
return { return {
"code": 200, "code": 200,
@ -89,13 +133,13 @@ async def get_employee_dashboard(
"position": emp.position or "", "position": emp.position or "",
}, },
"stats": { "stats": {
"total_messages": total_messages, "total_messages": total_messages, # 总消息数
"total_sessions": total_sessions, "total_sessions": total_sessions, # 总会话数
"active_days": len(active_days), "active_days": len(active_days), # 活跃天数
"message_breakdown": topics, "message_breakdown": topics, # 消息角色分布
"recent_interactions": [ "recent_interactions": [
{"role": m.role, "content": m.content[:200], "created_at": str(m.created_at)} {"role": m.role, "content": m.content[:200], "created_at": str(m.created_at)}
for m in recent[:10] for m in recent[:10] # 最近 10 条交互记录
], ],
}, },
}, },
@ -106,9 +150,26 @@ async def get_employee_dashboard(
async def get_employee_analysis( async def get_employee_analysis(
emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db) emp_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)
): ):
"""获取指定员工的 AI 分析报告。
基于员工与 AI 的交互记录使用大语言模型生成员工分析报告
包括任务完成率活跃度主要话题效率趋势优势和改进建议等
Args:
emp_id: 员工唯一标识 ID
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
EmployeeAnalysis: 员工分析报告
Raises:
HTTPException: 无权查看或员工不存在时抛出异常
"""
user_ctx = request.state.user user_ctx = request.state.user
cur_id = uuid.UUID(user_ctx["id"]) cur_id = uuid.UUID(user_ctx["id"])
# 数据权限校验
if user_ctx["data_scope"] != "all": if user_ctx["data_scope"] != "all":
if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]: if user_ctx["data_scope"] == "self_only" and str(emp_id) != user_ctx["id"]:
raise HTTPException(403, "无权查看此员工数据") raise HTTPException(403, "无权查看此员工数据")
@ -118,6 +179,7 @@ async def get_employee_analysis(
if emp_id not in sub_ids: if emp_id not in sub_ids:
raise HTTPException(403, "无权查看此员工数据") raise HTTPException(403, "无权查看此员工数据")
# 查询员工信息
emp_result = await db.execute(select(User).where(User.id == emp_id)) emp_result = await db.execute(select(User).where(User.id == emp_id))
emp = emp_result.scalar_one_or_none() emp = emp_result.scalar_one_or_none()
if not emp: if not emp:
@ -128,6 +190,7 @@ async def get_employee_analysis(
from agentscope.formatter import OpenAIChatFormatter from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg from agentscope.message import Msg
# 查询最近 100 条消息作为分析数据
msgs_result = await db.execute( msgs_result = await db.execute(
select(ChatMessage) select(ChatMessage)
.where(ChatMessage.user_id == emp_id) .where(ChatMessage.user_id == emp_id)
@ -136,10 +199,12 @@ async def get_employee_analysis(
) )
messages = msgs_result.scalars().all() messages = msgs_result.scalars().all()
# 构建交互记录文本
interaction_log = "\n".join([ interaction_log = "\n".join([
f"[{m.role}] {m.content[:300]}" for m in messages f"[{m.role}] {m.content[:300]}" for m in messages
]) ])
# 初始化 LLM 模型
model = OpenAIChatModel( model = OpenAIChatModel(
config_name="analysis_model", config_name="analysis_model",
model_name=settings.LLM_MODEL, model_name=settings.LLM_MODEL,
@ -148,6 +213,7 @@ async def get_employee_analysis(
) )
formatter = OpenAIChatFormatter() formatter = OpenAIChatFormatter()
# 构建分析提示词
prompt = formatter.format([ prompt = formatter.format([
Msg("system", f"""你是一个企业管理者分析助手。请根据员工与AI的交互记录,生成一个JSON格式的分析报告。 Msg("system", f"""你是一个企业管理者分析助手。请根据员工与AI的交互记录,生成一个JSON格式的分析报告。
@ -166,7 +232,7 @@ async def get_employee_analysis(
]) ])
try: try:
res = await model(prompt) res = await model(prompt) # 调用 LLM 生成分析
res_text = "" res_text = ""
if isinstance(res, list): if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
@ -174,8 +240,9 @@ async def get_employee_analysis(
res_text = res.get_text_content() res_text = res.get_text_content()
else: else:
res_text = str(res) res_text = str(res)
analysis_data = json.loads(res_text) analysis_data = json.loads(res_text) # 解析 JSON 响应
except Exception: except Exception:
# LLM 调用失败时使用默认分析数据
analysis_data = { analysis_data = {
"task_completion_rate": 0.7, "task_completion_rate": 0.7,
"active_days": 0, "active_days": 0,

1
backend/modules/notification/__init__.py

@ -0,0 +1 @@
"""通知模块。"""

143
backend/modules/notification/router.py

@ -1,3 +1,8 @@
"""通知模块路由。
提供实时通知推送功能支持 WebSocket 连接消息广播定向发送
支持通知模板管理和企业微信推送集成
"""
import uuid import uuid
import json import json
import asyncio import asyncio
@ -15,22 +20,50 @@ router = APIRouter(prefix="/api/notification", tags=["notification"])
class WebSocketManager: class WebSocketManager:
"""WebSocket 连接管理器类,管理所有用户的 WebSocket 连接。
支持按用户 ID 管理多个连接提供定向发送和广播功能
Attributes:
connections: 用户 ID WebSocket 连接列表的映射字典
"""
def __init__(self): def __init__(self):
"""初始化 WebSocket 管理器实例。"""
self.connections: dict[str, list[WebSocket]] = {} self.connections: dict[str, list[WebSocket]] = {}
async def connect(self, user_id: str, ws: WebSocket): async def connect(self, user_id: str, ws: WebSocket):
"""接受并注册新的 WebSocket 连接。
Args:
user_id: 用户唯一标识
ws: WebSocket 连接对象
"""
await ws.accept() await ws.accept()
if user_id not in self.connections: if user_id not in self.connections:
self.connections[user_id] = [] self.connections[user_id] = []
self.connections[user_id].append(ws) self.connections[user_id].append(ws)
def disconnect(self, user_id: str, ws: WebSocket): def disconnect(self, user_id: str, ws: WebSocket):
"""断开并移除指定的 WebSocket 连接。
Args:
user_id: 用户唯一标识
ws: 要移除的 WebSocket 连接对象
"""
if user_id in self.connections: if user_id in self.connections:
self.connections[user_id].remove(ws) self.connections[user_id].remove(ws)
if not self.connections[user_id]: if not self.connections[user_id]:
del self.connections[user_id] del self.connections[user_id]
async def send_to_user(self, user_id: str, message: dict): async def send_to_user(self, user_id: str, message: dict):
"""向指定用户的所有连接发送消息。
自动清理失效的连接
Args:
user_id: 目标用户唯一标识
message: 要发送的消息字典
"""
connections = self.connections.get(user_id, []) connections = self.connections.get(user_id, [])
dead = [] dead = []
for ws in connections: for ws in connections:
@ -42,19 +75,37 @@ class WebSocketManager:
self.disconnect(user_id, ws) self.disconnect(user_id, ws)
async def broadcast(self, message: dict): async def broadcast(self, message: dict):
"""向所有在线用户广播消息。
Args:
message: 要广播的消息字典
"""
for user_id in list(self.connections.keys()): for user_id in list(self.connections.keys()):
await self.send_to_user(user_id, message) await self.send_to_user(user_id, message)
@property @property
def active_count(self) -> int: def active_count(self) -> int:
"""获取当前活跃的 WebSocket 连接总数。
Returns:
int: 活跃连接数量
"""
return sum(len(v) for v in self.connections.values()) return sum(len(v) for v in self.connections.values())
ws_manager = WebSocketManager() ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例实例
@router.websocket("/ws/{user_id}") @router.websocket("/ws/{user_id}")
async def notification_websocket(ws: WebSocket, user_id: str): async def notification_websocket(ws: WebSocket, user_id: str):
"""通知 WebSocket 连接处理器。
接受客户端的 WebSocket 连接处理心跳 ping/pong 消息
Args:
ws: WebSocket 连接对象
user_id: 用户唯一标识从路径参数获取
"""
await ws_manager.connect(user_id, ws) await ws_manager.connect(user_id, ws)
try: try:
while True: while True:
@ -71,11 +122,24 @@ async def notification_websocket(ws: WebSocket, user_id: str):
@router.post("/send") @router.post("/send")
async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): async def send_notification(payload: dict, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
user_id = payload.get("user_id", "") """发送实时通知,支持定向发送和广播。
target_all = payload.get("target_all", False)
title = payload.get("title", "系统通知") 支持推送到企业微信并记录审计日志
body = payload.get("message", "")
notify_type = payload.get("type", "info") Args:
payload: 请求体包含 user_idtarget_alltitlemessagetypepush_to_wecom 等字段
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
"""
user_id = payload.get("user_id", "") # 目标用户 ID
target_all = payload.get("target_all", False) # 是否广播给所有用户
title = payload.get("title", "系统通知") # 通知标题
body = payload.get("message", "") # 通知内容
notify_type = payload.get("type", "info") # 通知类型
msg = { msg = {
"type": notify_type, "type": notify_type,
@ -85,10 +149,11 @@ async def send_notification(payload: dict, request: Request, db: AsyncSession =
} }
if target_all: if target_all:
await ws_manager.broadcast(msg) await ws_manager.broadcast(msg) # 广播给所有在线用户
elif user_id: elif user_id:
await ws_manager.send_to_user(user_id, msg) await ws_manager.send_to_user(user_id, msg) # 定向发送给指定用户
# 推送到企业微信
if payload.get("push_to_wecom"): if payload.get("push_to_wecom"):
await _push_to_wecom(title, body, user_id) await _push_to_wecom(title, body, user_id)
@ -107,6 +172,15 @@ async def send_notification(payload: dict, request: Request, db: AsyncSession =
@router.get("/templates", response_model=list[NotificationTemplateOut]) @router.get("/templates", response_model=list[NotificationTemplateOut])
async def list_templates(request: Request, db: AsyncSession = Depends(get_db)): async def list_templates(request: Request, db: AsyncSession = Depends(get_db)):
"""列出所有通知模板。
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
list[NotificationTemplateOut]: 通知模板列表
"""
result = await db.execute( result = await db.execute(
select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc()) select(NotificationTemplate).order_by(NotificationTemplate.created_at.desc())
) )
@ -120,6 +194,20 @@ async def create_template(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""创建新的通知模板。
Args:
req: 通知模板创建请求体
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
NotificationTemplateOut: 创建后的通知模板响应
Raises:
HTTPException: 模板编码已存在时抛出异常
"""
existing = await db.execute( existing = await db.execute(
select(NotificationTemplate).where(NotificationTemplate.code == req.code) select(NotificationTemplate).where(NotificationTemplate.code == req.code)
) )
@ -141,6 +229,19 @@ async def create_template(
@router.get("/templates/{template_id}", response_model=NotificationTemplateOut) @router.get("/templates/{template_id}", response_model=NotificationTemplateOut)
async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): async def get_template(template_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""获取指定通知模板的详细信息。
Args:
template_id: 通知模板唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
Returns:
NotificationTemplateOut: 通知模板详细信息
Raises:
HTTPException: 模板不存在时抛出异常
"""
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none() template = result.scalar_one_or_none()
if not template: if not template:
@ -154,6 +255,20 @@ async def delete_template(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user: dict = Depends(get_current_user), user: dict = Depends(get_current_user),
): ):
"""删除指定的通知模板。
Args:
template_id: 通知模板唯一标识 ID
request: HTTP 请求对象
db: 异步数据库会话
user: 当前登录用户信息
Returns:
dict: 操作结果响应
Raises:
HTTPException: 模板不存在或为系统模板时抛出异常
"""
result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id)) result = await db.execute(select(NotificationTemplate).where(NotificationTemplate.id == template_id))
template = result.scalar_one_or_none() template = result.scalar_one_or_none()
if not template: if not template:
@ -167,10 +282,22 @@ async def delete_template(
@router.get("/ws/stats") @router.get("/ws/stats")
async def ws_stats(): async def ws_stats():
"""获取 WebSocket 连接统计信息。
Returns:
dict: 包含活跃连接数的响应数据
"""
return {"code": 200, "data": {"active_connections": ws_manager.active_count}} return {"code": 200, "data": {"active_connections": ws_manager.active_count}}
async def _push_to_wecom(title: str, body: str, user_id: str): async def _push_to_wecom(title: str, body: str, user_id: str):
"""将通知推送到企业微信。
Args:
title: 通知标题
body: 通知内容
user_id: 目标企业微信用户 ID
"""
if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET: if not settings.WECOM_CORP_ID or not settings.WECOM_APP_SECRET:
return return

5
backend/modules/org/__init__.py

@ -0,0 +1,5 @@
"""组织管理模块。
提供部门和用户的 CRUD 操作树形部门结构查询下级用户递归查询等功能
支持基于数据权限范围的访问控制
"""

186
backend/modules/org/router.py

@ -1,3 +1,8 @@
"""组织管理模块路由。
提供部门和用户的 CRUD 操作树形部门结构查询下级用户递归查询等功能
支持基于数据权限范围all/subordinate_only/self_only的访问控制
"""
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select from sqlalchemy import select
@ -11,11 +16,22 @@ from schemas import (
) )
from modules.auth.router import hash_password, get_user_roles from modules.auth.router import hash_password, get_user_roles
router = APIRouter(prefix="/api/org", tags=["org"]) router = APIRouter(prefix="/api/org", tags=["org"]) # 组织管理模块路由前缀
@router.get("/departments", response_model=list[DepartmentOut]) @router.get("/departments", response_model=list[DepartmentOut])
async def get_departments(request: Request, db: AsyncSession = Depends(get_db)): async def get_departments(request: Request, db: AsyncSession = Depends(get_db)):
"""获取树形部门结构。
查询所有顶级部门parent_id NULL并递归构建完整的部门树
Args:
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
list[DepartmentOut]: 树形部门结构列表
"""
result = await db.execute( result = await db.execute(
select(Department).where(Department.parent_id.is_(None)).order_by(Department.sort_order) select(Department).where(Department.parent_id.is_(None)).order_by(Department.sort_order)
) )
@ -24,11 +40,26 @@ async def get_departments(request: Request, db: AsyncSession = Depends(get_db)):
async def _build_department_tree(db: AsyncSession, dept: Department, _visited: set[uuid.UUID] = None) -> DepartmentOut: async def _build_department_tree(db: AsyncSession, dept: Department, _visited: set[uuid.UUID] = None) -> DepartmentOut:
"""递归构建部门树形结构。
查询当前部门的所有子部门并递归构建子部门的子部门树
使用 _visited 集合防止循环引用导致无限递归
Args:
db: 异步数据库会话
dept: 当前部门对象
_visited: 已访问的部门 ID 集合用于防止循环引用
Returns:
DepartmentOut: 包含子部门列表的部门信息
"""
if _visited is None: if _visited is None:
_visited = set() _visited = set()
if dept.id in _visited: if dept.id in _visited:
# 检测到循环引用,返回不包含子部门的部门信息
return DepartmentOut(id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, children=[]) return DepartmentOut(id=dept.id, name=dept.name, parent_id=dept.parent_id, path=dept.path, level=dept.level, sort_order=dept.sort_order, children=[])
_visited.add(dept.id) _visited.add(dept.id)
# 查询当前部门的所有子部门
children_result = await db.execute( children_result = await db.execute(
select(Department).where(Department.parent_id == dept.id).order_by(Department.sort_order) select(Department).where(Department.parent_id == dept.id).order_by(Department.sort_order)
) )
@ -36,7 +67,7 @@ async def _build_department_tree(db: AsyncSession, dept: Department, _visited: s
return DepartmentOut( return DepartmentOut(
id=dept.id, name=dept.name, parent_id=dept.parent_id, id=dept.id, name=dept.name, parent_id=dept.parent_id,
path=dept.path, level=dept.level, sort_order=dept.sort_order, path=dept.path, level=dept.level, sort_order=dept.sort_order,
children=[await _build_department_tree(db, c, _visited) for c in children], children=[await _build_department_tree(db, c, _visited) for c in children], # 递归构建子部门
) )
@ -44,19 +75,35 @@ async def _build_department_tree(db: AsyncSession, dept: Department, _visited: s
async def create_department( async def create_department(
req: DepartmentCreate, request: Request, db: AsyncSession = Depends(get_db) req: DepartmentCreate, request: Request, db: AsyncSession = Depends(get_db)
): ):
"""创建新部门。
根据父部门信息计算新部门的层级和路径
Args:
req: 部门创建请求体包含名称父部门 ID 和排序权重
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
DepartmentOut: 创建后的部门信息
Raises:
HTTPException: 父部门不存在时抛出异常
"""
parent_path = "/" parent_path = "/"
level = 0 level = 0
if req.parent_id: if req.parent_id:
# 查询父部门信息以计算层级和路径
parent_result = await db.execute(select(Department).where(Department.id == req.parent_id)) parent_result = await db.execute(select(Department).where(Department.id == req.parent_id))
parent = parent_result.scalar_one_or_none() parent = parent_result.scalar_one_or_none()
if not parent: if not parent:
raise HTTPException(404, "父部门不存在") raise HTTPException(404, "父部门不存在")
parent_path = parent.path parent_path = parent.path
level = parent.level + 1 level = parent.level + 1 # 新部门层级为父部门层级 + 1
dept = Department( dept = Department(
name=req.name, parent_id=req.parent_id, name=req.name, parent_id=req.parent_id,
path=f"{parent_path}/{req.name}".replace("//", "/"), path=f"{parent_path}/{req.name}".replace("//", "/"), # 构建部门路径
level=level, sort_order=req.sort_order, level=level, sort_order=req.sort_order,
) )
db.add(dept) db.add(dept)
@ -73,6 +120,20 @@ async def update_department(
dept_id: uuid.UUID, req: DepartmentUpdate, dept_id: uuid.UUID, req: DepartmentUpdate,
request: Request, db: AsyncSession = Depends(get_db), request: Request, db: AsyncSession = Depends(get_db),
): ):
"""更新部门信息。
Args:
dept_id: 部门唯一标识 ID
req: 部门更新请求体包含可更新的字段
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
DepartmentOut: 更新后的部门信息
Raises:
HTTPException: 部门不存在时抛出异常
"""
result = await db.execute(select(Department).where(Department.id == dept_id)) result = await db.execute(select(Department).where(Department.id == dept_id))
dept = result.scalar_one_or_none() dept = result.scalar_one_or_none()
if not dept: if not dept:
@ -92,6 +153,19 @@ async def update_department(
@router.delete("/departments/{dept_id}") @router.delete("/departments/{dept_id}")
async def delete_department(dept_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): async def delete_department(dept_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""删除部门。
Args:
dept_id: 部门唯一标识 ID
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
dict: 操作结果响应
Raises:
HTTPException: 部门不存在时抛出异常
"""
result = await db.execute(select(Department).where(Department.id == dept_id)) result = await db.execute(select(Department).where(Department.id == dept_id))
dept = result.scalar_one_or_none() dept = result.scalar_one_or_none()
if not dept: if not dept:
@ -102,10 +176,22 @@ async def delete_department(dept_id: uuid.UUID, request: Request, db: AsyncSessi
@router.get("/users", response_model=list[UserOut]) @router.get("/users", response_model=list[UserOut])
async def get_users(request: Request, db: AsyncSession = Depends(get_db)): async def get_users(request: Request, db: AsyncSession = Depends(get_db)):
"""获取用户列表。
根据当前用户的数据权限范围返回不同的用户列表
Args:
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
list[UserOut]: 用户信息列表
"""
user_ctx = request.state.user user_ctx = request.state.user
result = await db.execute(select(User)) result = await db.execute(select(User))
users = result.scalars().all() users = result.scalars().all()
# 根据数据权限范围过滤用户
if user_ctx["data_scope"] == "self_only": if user_ctx["data_scope"] == "self_only":
users = [u for u in users if str(u.id) == user_ctx["id"]] users = [u for u in users if str(u.id) == user_ctx["id"]]
elif user_ctx["data_scope"] == "subordinate_only": elif user_ctx["data_scope"] == "subordinate_only":
@ -118,6 +204,19 @@ async def get_users(request: Request, db: AsyncSession = Depends(get_db)):
@router.get("/users/{user_id}", response_model=UserOut) @router.get("/users/{user_id}", response_model=UserOut)
async def get_user(user_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)): async def get_user(user_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
"""获取指定用户的详细信息。
Args:
user_id: 用户唯一标识 ID
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
UserOut: 用户详细信息
Raises:
HTTPException: 用户不存在时抛出异常
"""
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@ -127,13 +226,28 @@ async def get_user(user_id: uuid.UUID, request: Request, db: AsyncSession = Depe
@router.post("/users", response_model=UserOut) @router.post("/users", response_model=UserOut)
async def create_user(req: UserCreate, request: Request, db: AsyncSession = Depends(get_db)): async def create_user(req: UserCreate, request: Request, db: AsyncSession = Depends(get_db)):
"""创建新用户。
支持设置用户名密码部门职位上级用户和角色等信息
Args:
req: 用户创建请求体
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
UserOut: 创建后的用户信息
Raises:
HTTPException: 用户名已存在时抛出异常
"""
existing = await db.execute(select(User).where(User.username == req.username)) existing = await db.execute(select(User).where(User.username == req.username))
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException(400, "用户名已存在") raise HTTPException(400, "用户名已存在")
user = User( user = User(
username=req.username, username=req.username,
password_hash=hash_password(req.password), password_hash=hash_password(req.password), # 密码哈希存储
display_name=req.display_name, display_name=req.display_name,
email=req.email, phone=req.phone, email=req.email, phone=req.phone,
wecom_user_id=req.wecom_user_id, wecom_user_id=req.wecom_user_id,
@ -144,6 +258,7 @@ async def create_user(req: UserCreate, request: Request, db: AsyncSession = Depe
await db.flush() await db.flush()
if req.role_ids: if req.role_ids:
# 为用户分配角色
for role_id in req.role_ids: for role_id in req.role_ids:
db.add(UserRole(user_id=user.id, role_id=role_id)) db.add(UserRole(user_id=user.id, role_id=role_id))
await db.flush() await db.flush()
@ -156,6 +271,22 @@ async def update_user(
user_id: uuid.UUID, req: UserUpdate, user_id: uuid.UUID, req: UserUpdate,
request: Request, db: AsyncSession = Depends(get_db), request: Request, db: AsyncSession = Depends(get_db),
): ):
"""更新用户信息。
支持修改显示名称邮箱手机号部门职位上级用户状态和角色
Args:
user_id: 用户唯一标识 ID
req: 用户更新请求体
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
UserOut: 更新后的用户信息
Raises:
HTTPException: 用户不存在时抛出异常
"""
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@ -177,12 +308,13 @@ async def update_user(
user.status = req.status user.status = req.status
if req.role_ids is not None: if req.role_ids is not None:
await db.execute(select(UserRole).where(UserRole.user_id == user.id)) # 先删除用户现有角色关联
existing_urs = (await db.execute( existing_urs = (await db.execute(
select(UserRole).where(UserRole.user_id == user.id) select(UserRole).where(UserRole.user_id == user.id)
)).scalars().all() )).scalars().all()
for ur in existing_urs: for ur in existing_urs:
await db.delete(ur) await db.delete(ur)
# 重新分配角色
for role_id in req.role_ids: for role_id in req.role_ids:
db.add(UserRole(user_id=user.id, role_id=role_id)) db.add(UserRole(user_id=user.id, role_id=role_id))
@ -191,9 +323,20 @@ async def update_user(
@router.get("/subordinates", response_model=list[UserOut]) @router.get("/subordinates", response_model=list[UserOut])
async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db)): async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db)):
"""获取当前用户的所有下级用户(递归)。
递归查询所有直接或间接以当前用户为上级的用户
Args:
request: HTTP 请求对象包含当前用户上下文
db: 异步数据库会话
Returns:
list[UserOut]: 下级用户列表
"""
user_ctx = request.state.user user_ctx = request.state.user
manager_id = uuid.UUID(user_ctx["id"]) manager_id = uuid.UUID(user_ctx["id"])
sub_ids = await _get_subordinate_ids(db, manager_id) sub_ids = await _get_subordinate_ids(db, manager_id) # 递归获取所有下级 ID
result = await db.execute(select(User).where(User.id.in_(sub_ids))) result = await db.execute(select(User).where(User.id.in_(sub_ids)))
users = result.scalars().all() users = result.scalars().all()
@ -201,21 +344,44 @@ async def get_subordinates(request: Request, db: AsyncSession = Depends(get_db))
async def _get_subordinate_ids(db: AsyncSession, manager_id: uuid.UUID, _visited: set[uuid.UUID] = None) -> set[uuid.UUID]: async def _get_subordinate_ids(db: AsyncSession, manager_id: uuid.UUID, _visited: set[uuid.UUID] = None) -> set[uuid.UUID]:
"""递归获取指定管理者的所有下级用户 ID。
递归查询直接或间接以指定用户为上级的所有用户 ID
使用 _visited 集合防止循环引用导致无限递归
Args:
db: 异步数据库会话
manager_id: 管理者用户 ID
_visited: 已访问的用户 ID 集合用于防止循环引用
Returns:
set[uuid.UUID]: 所有下级用户 ID 的集合
"""
if _visited is None: if _visited is None:
_visited = set() _visited = set()
if manager_id in _visited: if manager_id in _visited:
return set() return set() # 检测到循环引用,返回空集合
_visited.add(manager_id) _visited.add(manager_id)
# 查询直接下级
result = await db.execute(select(User).where(User.manager_id == manager_id)) result = await db.execute(select(User).where(User.manager_id == manager_id))
direct = result.scalars().all() direct = result.scalars().all()
ids = {u.id for u in direct} ids = {u.id for u in direct}
for sub in direct: for sub in direct:
ids.update(await _get_subordinate_ids(db, sub.id, _visited)) ids.update(await _get_subordinate_ids(db, sub.id, _visited)) # 递归获取子级下级
return ids return ids
async def _user_to_out(db: AsyncSession, user: User) -> UserOut: async def _user_to_out(db: AsyncSession, user: User) -> UserOut:
roles = await get_user_roles(db, user.id) """将用户数据库对象转换为 UserOut 响应模型。
Args:
db: 异步数据库会话用于查询用户角色信息
user: 用户数据库对象
Returns:
UserOut: 用户响应模型包含角色列表
"""
roles = await get_user_roles(db, user.id) # 获取用户角色信息
return UserOut( return UserOut(
id=user.id, username=user.username, display_name=user.display_name, id=user.id, username=user.username, display_name=user.display_name,
email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id, email=user.email, phone=user.phone, wecom_user_id=user.wecom_user_id,

74
backend/modules/rag/knowledge.py

@ -1,3 +1,9 @@
"""知识库模块。
提供基于 AgentScope 的企业知识库管理功能包括文档索引文本索引和语义检索
支持多种文档格式PDFWordExcel纯文本的自动解析和向量化存储
使用 Qdrant 作为向量存储后端OpenAI Embedding 作为向量化模型
"""
import os import os
import asyncio import asyncio
import logging import logging
@ -5,15 +11,20 @@ from agentscope.embedding import OpenAITextEmbedding
from agentscope.rag import SimpleKnowledge, QdrantStore, TextReader, PDFReader, WordReader, ExcelReader from agentscope.rag import SimpleKnowledge, QdrantStore, TextReader, PDFReader, WordReader, ExcelReader
from config import settings from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
_knowledge_base: SimpleKnowledge | None = None _knowledge_base: SimpleKnowledge | None = None # 全局知识库单例实例
_STORE_PATH = os.path.join(settings.UPLOAD_DIR, "..", "data", "qdrant") _STORE_PATH = os.path.join(settings.UPLOAD_DIR, "..", "data", "qdrant") # Qdrant 向量存储路径
_COLLECTION_NAME = "enterprise_knowledge" _COLLECTION_NAME = "enterprise_knowledge" # Qdrant 集合名称
_VECTOR_DIM = 1536 _VECTOR_DIM = 1536 # 向量维度(text-embedding-3-small 标准维度)
def _get_embedding_model(): def _get_embedding_model():
"""创建并返回 OpenAI 文本 Embedding 模型实例。
Returns:
OpenAITextEmbedding: 配置好的 Embedding 模型
"""
return OpenAITextEmbedding( return OpenAITextEmbedding(
api_key=settings.LLM_API_KEY, api_key=settings.LLM_API_KEY,
model_name="text-embedding-3-small", model_name="text-embedding-3-small",
@ -22,6 +33,14 @@ def _get_embedding_model():
def get_knowledge_base() -> SimpleKnowledge: def get_knowledge_base() -> SimpleKnowledge:
"""获取或创建全局知识库实例。
采用单例模式首次调用时初始化 Qdrant 向量存储和 Embedding 模型
后续调用直接返回已创建的实例
Returns:
SimpleKnowledge: 初始化好的知识库实例
"""
global _knowledge_base global _knowledge_base
if _knowledge_base is None: if _knowledge_base is None:
os.makedirs(_STORE_PATH, exist_ok=True) os.makedirs(_STORE_PATH, exist_ok=True)
@ -39,10 +58,23 @@ def get_knowledge_base() -> SimpleKnowledge:
async def add_document(file_path: str, file_type: str = "auto") -> str: async def add_document(file_path: str, file_type: str = "auto") -> str:
"""将文档文件添加到知识库中进行索引。
自动根据文件类型选择合适的解析器将文档切分为多个文本块后
进行向量化并存储到知识库中
Args:
file_path: 文档文件的完整路径
file_type: 文档类型auto 表示自动识别
Returns:
str: 索引结果描述或错误信息
"""
try: try:
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
kb = get_knowledge_base() kb = get_knowledge_base()
# 根据文件类型选择对应的解析器
if file_type == "auto": if file_type == "auto":
if ext == ".pdf": if ext == ".pdf":
reader = PDFReader(chunk_size=1024, split_by="sentence") reader = PDFReader(chunk_size=1024, split_by="sentence")
@ -83,6 +115,15 @@ async def add_document(file_path: str, file_type: str = "auto") -> str:
async def add_text(text: str, source: str = "manual") -> str: async def add_text(text: str, source: str = "manual") -> str:
"""将纯文本内容添加到知识库中进行索引。
Args:
text: 要索引的文本内容
source: 文本来源标识默认为 manual手动录入
Returns:
str: 索引结果描述或错误信息
"""
try: try:
kb = get_knowledge_base() kb = get_knowledge_base()
reader = TextReader(chunk_size=1024, split_by="sentence") reader = TextReader(chunk_size=1024, split_by="sentence")
@ -97,6 +138,18 @@ async def add_text(text: str, source: str = "manual") -> str:
async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> list[dict]: async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> list[dict]:
"""在知识库中执行语义检索。
根据查询文本的向量相似度从知识库中检索最相关的文档片段
Args:
query: 查询文本
limit: 返回结果的最大数量默认 5
score_threshold: 相似度分数阈值低于此值的结果将被过滤默认 0.3
Returns:
list[dict]: 检索结果列表每项包含 idcontentscoresource 字段
"""
try: try:
kb = get_knowledge_base() kb = get_knowledge_base()
if not kb or not hasattr(kb, 'retrieve'): if not kb or not hasattr(kb, 'retrieve'):
@ -124,6 +177,17 @@ async def search(query: str, limit: int = 5, score_threshold: float = 0.3) -> li
async def retrieve_for_agent(query: str, limit: int = 5) -> str: async def retrieve_for_agent(query: str, limit: int = 5) -> str:
"""为 AI 智能体执行知识库检索并返回格式化的结果文本。
该函数专为 AgentScope 智能体调用设计返回人类可读的检索结果
Args:
query: 查询文本
limit: 返回结果的最大数量默认 5
Returns:
str: 格式化的检索结果文本包含相关度分数
"""
results = await search(query, limit=limit) results = await search(query, limit=limit)
if not results: if not results:
return "未找到相关文档。" return "未找到相关文档。"

94
backend/modules/rag/router.py

@ -1,3 +1,8 @@
"""知识库(RAG)模块路由。
提供知识库的文档上传文本索引语义检索和文档管理功能
支持文件上传后自动解析和向量化以及基于相似度的知识检索
"""
from fastapi import APIRouter, Depends, UploadFile, File, Request from fastapi import APIRouter, Depends, UploadFile, File, Request
from database import get_db from database import get_db
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -18,8 +23,22 @@ async def rag_upload(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""上传文档文件并自动索引到知识库。
将上传的文件保存到服务器后调用文档解析器将其切分为文本块
并进行向量化存储到知识库中
Args:
request: HTTP 请求对象
file: 上传的文件对象
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含索引结果和文件信息的响应数据
"""
os.makedirs(settings.UPLOAD_DIR, exist_ok=True) os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
filename = f"{uuid.uuid4().hex}_{file.filename}" filename = f"{uuid.uuid4().hex}_{file.filename}" # 生成唯一文件名,避免冲突
file_path = os.path.join(settings.UPLOAD_DIR, filename) file_path = os.path.join(settings.UPLOAD_DIR, filename)
content = await file.read() content = await file.read()
@ -37,8 +56,19 @@ async def rag_index_text(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
text = payload.get("text", "") """将纯文本内容索引到知识库。
source = payload.get("source", "manual")
Args:
request: HTTP 请求对象
payload: 请求体包含 text 和可选的 source 字段
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含索引结果的响应数据
"""
text = payload.get("text", "") # 待索引的文本内容
source = payload.get("source", "manual") # 文本来源标识
if not text: if not text:
return {"code": 400, "message": "文本内容不能为空"} return {"code": 400, "message": "文本内容不能为空"}
result = await add_text(text, source) result = await add_text(text, source)
@ -53,6 +83,18 @@ async def rag_search(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""在知识库中执行语义检索。
Args:
request: HTTP 请求对象
q: 查询文本
limit: 返回结果的最大数量
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含检索结果列表的响应数据
"""
if not q: if not q:
return {"code": 400, "message": "查询内容不能为空"} return {"code": 400, "message": "查询内容不能为空"}
results = await search(q, limit=limit) results = await search(q, limit=limit)
@ -67,6 +109,18 @@ async def rag_retrieve(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""为 AI 智能体执行知识库检索,返回格式化的结果文本。
Args:
request: HTTP 请求对象
q: 查询文本
limit: 返回结果的最大数量
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含格式化检索结果的响应数据
"""
if not q: if not q:
return {"code": 400, "message": "查询内容不能为空"} return {"code": 400, "message": "查询内容不能为空"}
result = await retrieve_for_agent(q, limit=limit) result = await retrieve_for_agent(q, limit=limit)
@ -79,6 +133,18 @@ async def list_documents(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""列出知识库中已索引的所有文档来源及其统计信息。
从向量存储中获取所有文档按来源分组并统计每个来源的文档块数量
Args:
request: HTTP 请求对象
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含文档列表和统计信息的响应数据
"""
try: try:
kb = get_knowledge_base() kb = get_knowledge_base()
if not kb or not hasattr(kb, '_embedding_store'): if not kb or not hasattr(kb, '_embedding_store'):
@ -108,6 +174,7 @@ async def list_documents(
}) })
offset += batch_size offset += batch_size
# 按来源分组统计
seen_sources = {} seen_sources = {}
for d in all_docs: for d in all_docs:
src = d["source"] or "unknown" src = d["source"] or "unknown"
@ -134,6 +201,17 @@ async def delete_document(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""按来源删除知识库中的文档块。
Args:
source: 文档来源标识
request: HTTP 请求对象
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含删除结果的响应数据
"""
try: try:
kb = get_knowledge_base() kb = get_knowledge_base()
store = kb._embedding_store store = kb._embedding_store
@ -157,6 +235,16 @@ async def knowledge_stats(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""获取知识库的统计信息,包括文档块数量和来源文件数量。
Args:
request: HTTP 请求对象
db: 异步数据库会话
current_user: 当前登录用户信息
Returns:
dict: 包含知识库统计信息的响应数据
"""
try: try:
kb = get_knowledge_base() kb = get_knowledge_base()
store = kb._embedding_store store = kb._embedding_store

1
backend/modules/rbac/__init__.py

@ -0,0 +1 @@
"""权限管理模块(RBAC)。"""

90
backend/modules/system/router.py

@ -1,3 +1,8 @@
"""系统管理模块路由。
提供系统健康检查使用统计指标收集和缓存管理等功能
用于监控系统运行状态和资源使用情况
"""
import time import time
import uuid import uuid
import psutil import psutil
@ -15,11 +20,23 @@ from middleware.rate_limiter import rate_limiter
router = APIRouter(prefix="/api/system", tags=["system"]) router = APIRouter(prefix="/api/system", tags=["system"])
_start_time = time.time() _start_time = time.time() # 服务启动时间戳
@router.get("/health", response_model=SystemHealthOut) @router.get("/health", response_model=SystemHealthOut)
async def health_check(request: Request, db: AsyncSession = Depends(get_db)): async def health_check(request: Request, db: AsyncSession = Depends(get_db)):
"""系统健康检查接口。
检查数据库连接Redis 连接内存使用CPU 使用率等系统健康指标
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
SystemHealthOut: 包含系统健康状态的响应数据
"""
# 检查数据库连接
db_ok = False db_ok = False
try: try:
await db.execute(select(func.count()).select_from(User)) await db.execute(select(func.count()).select_from(User))
@ -27,9 +44,9 @@ async def health_check(request: Request, db: AsyncSession = Depends(get_db)):
except Exception: except Exception:
pass pass
mem = psutil.Process(os.getpid()).memory_info() mem = psutil.Process(os.getpid()).memory_info() # 获取当前进程内存信息
cpu = psutil.cpu_percent(interval=0.1) cpu = psutil.cpu_percent(interval=0.1) # 获取 CPU 使用率
uptime = time.time() - _start_time uptime = time.time() - _start_time # 计算服务运行时长
try: try:
user_count = await db.execute(select(func.count(User.id))) user_count = await db.execute(select(func.count(User.id)))
@ -38,22 +55,34 @@ async def health_check(request: Request, db: AsyncSession = Depends(get_db)):
active_users = 0 active_users = 0
return SystemHealthOut( return SystemHealthOut(
status="healthy" if db_ok and cache_manager.available else "degraded", status="healthy" if db_ok and cache_manager.available else "degraded", # 数据库和 Redis 都正常则为 healthy
service="enterprise-ai-platform", service="enterprise-ai-platform",
uptime_seconds=round(uptime, 1), uptime_seconds=round(uptime, 1),
db_connected=db_ok, db_connected=db_ok,
redis_connected=cache_manager.available, redis_connected=cache_manager.available,
active_users=active_users, active_users=active_users,
memory_mb=round(mem.rss / 1024 / 1024, 1), memory_mb=round(mem.rss / 1024 / 1024, 1), # 转换为 MB
cpu_percent=round(cpu, 1), cpu_percent=round(cpu, 1),
) )
@router.get("/stats", response_model=UsageStatsOut) @router.get("/stats", response_model=UsageStatsOut)
async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)): async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)):
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) """获取系统使用统计信息。
统计用户数会话数消息数任务数流程数API 调用数等关键指标
Args:
request: HTTP 请求对象
db: 异步数据库会话
Returns:
UsageStatsOut: 包含系统使用统计的响应数据
"""
today = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) # 今日零点
total_users = (await db.execute(select(func.count(User.id)))).scalar() or 0 total_users = (await db.execute(select(func.count(User.id)))).scalar() or 0
# 今日活跃用户数(有新建会话的用户)
active_today = (await db.execute( active_today = (await db.execute(
select(func.count(func.distinct(User.id))) select(func.count(func.distinct(User.id)))
.join(ChatSession, ChatSession.user_id == User.id) .join(ChatSession, ChatSession.user_id == User.id)
@ -66,6 +95,7 @@ async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)):
published = (await db.execute( published = (await db.execute(
select(func.count(FlowDefinition.id)).where(FlowDefinition.status == "published") select(func.count(FlowDefinition.id)).where(FlowDefinition.status == "published")
)).scalar() or 0 )).scalar() or 0
# 今日 API 调用数(有新建执行记录的流程)
api_calls = (await db.execute( api_calls = (await db.execute(
select(func.count(FlowExecution.id)).where(FlowExecution.started_at >= today) select(func.count(FlowExecution.id)).where(FlowExecution.started_at >= today)
)).scalar() or 0 )).scalar() or 0
@ -85,6 +115,16 @@ async def usage_stats(request: Request, db: AsyncSession = Depends(get_db)):
@router.post("/metrics") @router.post("/metrics")
async def collect_metrics(payload: dict, request: Request, db: AsyncSession = Depends(get_db)): async def collect_metrics(payload: dict, request: Request, db: AsyncSession = Depends(get_db)):
"""收集并存储系统指标数据。
Args:
payload: 请求体包含 metric_typevaluesource 字段
request: HTTP 请求对象
db: 异步数据库会话
Returns:
dict: 包含指标 ID 的响应数据
"""
metric = SystemMetric( metric = SystemMetric(
metric_type=payload.get("metric_type", "custom"), metric_type=payload.get("metric_type", "custom"),
value={"data": payload.get("value", {}), "source": payload.get("source", "api")}, value={"data": payload.get("value", {}), "source": payload.get("source", "api")},
@ -101,6 +141,17 @@ async def list_metrics(
limit: int = 50, limit: int = 50,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""查询系统指标历史数据。
Args:
request: HTTP 请求对象
metric_type: 可选的指标类型筛选条件
limit: 返回结果的最大数量
db: 异步数据库会话
Returns:
dict: 包含指标列表的响应数据
"""
q = select(SystemMetric).order_by(SystemMetric.collected_at.desc()) q = select(SystemMetric).order_by(SystemMetric.collected_at.desc())
if metric_type: if metric_type:
q = q.where(SystemMetric.metric_type == metric_type) q = q.where(SystemMetric.metric_type == metric_type)
@ -120,6 +171,14 @@ async def list_metrics(
@router.get("/cache/stats") @router.get("/cache/stats")
async def cache_stats(request: Request): async def cache_stats(request: Request):
"""获取缓存系统状态信息。
Args:
request: HTTP 请求对象
Returns:
dict: 包含 Redis 可用性状态的响应数据
"""
return { return {
"code": 200, "code": 200,
"data": { "data": {
@ -130,6 +189,14 @@ async def cache_stats(request: Request):
@router.get("/ratelimit/stats") @router.get("/ratelimit/stats")
async def ratelimit_stats(request: Request): async def ratelimit_stats(request: Request):
"""获取速率限制状态信息。
Args:
request: HTTP 请求对象
Returns:
dict: 包含速率限制配置的响应数据
"""
remaining = await rate_limiter.remaining("global") remaining = await rate_limiter.remaining("global")
return { return {
"code": 200, "code": 200,
@ -143,5 +210,14 @@ async def ratelimit_stats(request: Request):
@router.post("/cache/clear") @router.post("/cache/clear")
async def clear_cache(request: Request, pattern: str = "*"): async def clear_cache(request: Request, pattern: str = "*"):
"""清除缓存数据。
Args:
request: HTTP 请求对象
pattern: 缓存键匹配模式默认清除所有缓存
Returns:
dict: 操作结果响应
"""
await cache_manager.delete_pattern(pattern) await cache_manager.delete_pattern(pattern)
return {"code": 200, "message": "缓存已清除"} return {"code": 200, "message": "缓存已清除"}

1
backend/modules/task/__init__.py

@ -0,0 +1 @@
"""任务管理模块。"""

64
backend/modules/wecom/router.py

@ -1,3 +1,8 @@
"""企业微信模块路由。
提供企业微信集成相关功能包括回调消息处理配置管理和消息发送
支持企业微信用户通过企微直接与 AI 助手对话
"""
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select from sqlalchemy import select
@ -11,18 +16,26 @@ router = APIRouter(prefix="/api/wecom", tags=["wecom"])
@router.post("/callback") @router.post("/callback")
async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)): async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
""" """接收企业微信回调消息,路由到AI助手处理并回复。
接收企业微信回调消息路由到AI助手处理并回复
企微配置的回调URL指向此端点 企业微信配置的回调 URL 指向此端点接收企微用户消息后
查找对应的系统用户创建或复用聊天会话调用 AI 智能体处理消息并返回回复
Args:
request: HTTP 请求对象包含企业微信回调消息体
db: 异步数据库会话
Returns:
dict: 包含消息类型用户 ID AI 回复的响应数据
""" """
try: try:
body = await request.json() body = await request.json()
except Exception: except Exception:
body = await request.body() body = await request.body()
msg_type = "text" msg_type = "text" # 消息类型
wecom_user_id = "" wecom_user_id = "" # 企业微信用户 ID
content = "" content = "" # 消息内容
if isinstance(body, dict): if isinstance(body, dict):
msg_type = body.get("msg_type", body.get("MsgType", "text")) msg_type = body.get("msg_type", body.get("MsgType", "text"))
wecom_user_id = body.get("user_id", body.get("FromUserName", "")) wecom_user_id = body.get("user_id", body.get("FromUserName", ""))
@ -31,6 +44,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
if not wecom_user_id or not content: if not wecom_user_id or not content:
return {"code": 200, "message": "received"} return {"code": 200, "message": "received"}
# 根据企业微信用户 ID 查找系统用户
user_result = await db.execute( user_result = await db.execute(
select(User).where(User.wecom_user_id == wecom_user_id) select(User).where(User.wecom_user_id == wecom_user_id)
) )
@ -40,6 +54,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
from agentscope.message import Msg from agentscope.message import Msg
# 查找或创建聊天会话
session_result = await db.execute( session_result = await db.execute(
select(ChatSession) select(ChatSession)
.where(ChatSession.user_id == user.id, ChatSession.agent_type == "employee") .where(ChatSession.user_id == user.id, ChatSession.agent_type == "employee")
@ -56,6 +71,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
db.add(session) db.add(session)
await db.flush() await db.flush()
# 保存用户消息
user_msg = ChatMessage( user_msg = ChatMessage(
session_id=session.id, user_id=user.id, session_id=session.id, user_id=user.id,
role="user", content=content, role="user", content=content,
@ -63,6 +79,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
db.add(user_msg) db.add(user_msg)
await db.flush() await db.flush()
# 创建 AI 智能体并处理消息
from agentscope_integration.factory import AgentFactory from agentscope_integration.factory import AgentFactory
agent = await AgentFactory.create_agent( agent = await AgentFactory.create_agent(
agent_type="employee", agent_type="employee",
@ -76,6 +93,7 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response) reply_text = response.get_text_content() if hasattr(response, 'get_text_content') else str(response)
# 保存 AI 回复消息
ai_msg = ChatMessage( ai_msg = ChatMessage(
session_id=session.id, user_id=user.id, session_id=session.id, user_id=user.id,
role="assistant", content=reply_text, role="assistant", content=reply_text,
@ -95,6 +113,14 @@ async def wecom_callback(request: Request, db: AsyncSession = Depends(get_db)):
@router.get("/config") @router.get("/config")
async def get_wecom_config(request: Request): async def get_wecom_config(request: Request):
"""获取企业微信当前配置信息。
Args:
request: HTTP 请求对象
Returns:
dict: 包含机器人名称状态CorpID功能列表等配置信息
"""
return { return {
"code": 200, "code": 200,
"data": { "data": {
@ -109,6 +135,17 @@ async def get_wecom_config(request: Request):
@router.put("/config") @router.put("/config")
async def update_wecom_config(request: Request, payload: dict): async def update_wecom_config(request: Request, payload: dict):
"""更新企业微信配置并持久化到 .env 文件。
支持更新 CorpIDAppSecretAgentIDToken EncodingAESKey 等配置项
Args:
request: HTTP 请求对象
payload: 请求体包含企业微信配置字段
Returns:
dict: 操作结果响应
"""
import os import os
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
updates = {} updates = {}
@ -146,9 +183,18 @@ async def update_wecom_config(request: Request, payload: dict):
@router.post("/send") @router.post("/send")
async def send_wecom_message(request: Request, payload: dict): async def send_wecom_message(request: Request, payload: dict):
to_user = payload.get("to_user", "@all") """通过企业微信发送消息给指定用户。
msg_type = payload.get("msg_type", "text")
content = payload.get("content", "") Args:
request: HTTP 请求对象
payload: 请求体包含 to_usermsg_typecontent 字段
Returns:
dict: 操作结果响应
"""
to_user = payload.get("to_user", "@all") # 目标用户,默认 @all 表示所有人
msg_type = payload.get("msg_type", "text") # 消息类型
content = payload.get("content", "") # 消息内容
if not content: if not content:
return {"code": 400, "message": "消息内容不能为空"} return {"code": 400, "message": "消息内容不能为空"}

74
backend/schemas/__init__.py

@ -1,3 +1,20 @@
"""Pydantic 请求/响应模型定义。
所有 API 的请求体和响应体均在此定义涵盖
- 认证登录/令牌
- 用户部门角色权限的增删改查
- 任务管理
- 审批流引擎Flow节点//版本/执行
- 自定义工具CustomTool
- MCP 服务
- Agent 配置
- 通知模板
- 文档上传与解析
- 审计日志
- 系统监控指标
- 通用 API 响应包装
"""
import uuid import uuid
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, Field, ConfigDict
@ -5,11 +22,13 @@ from pydantic import BaseModel, Field, ConfigDict
# --- Auth --- # --- Auth ---
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
"""登录请求体:用户名 + 密码。"""
username: str username: str
password: str password: str
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
"""令牌响应体:Bearer 令牌 + 用户信息。"""
access_token: str access_token: str
token_type: str = "bearer" token_type: str = "bearer"
user: "UserOut" user: "UserOut"
@ -17,6 +36,7 @@ class TokenResponse(BaseModel):
# --- User --- # --- User ---
class UserCreate(BaseModel): class UserCreate(BaseModel):
"""创建用户请求体。"""
username: str username: str
password: str password: str
display_name: str display_name: str
@ -30,6 +50,7 @@ class UserCreate(BaseModel):
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
"""更新用户请求体(所有字段可选)。"""
display_name: str | None = None display_name: str | None = None
email: str | None = None email: str | None = None
phone: str | None = None phone: str | None = None
@ -41,6 +62,7 @@ class UserUpdate(BaseModel):
class UserOut(BaseModel): class UserOut(BaseModel):
"""用户响应体(ORM 映射)。"""
id: uuid.UUID id: uuid.UUID
username: str username: str
display_name: str display_name: str
@ -60,18 +82,21 @@ class UserOut(BaseModel):
# --- Department --- # --- Department ---
class DepartmentCreate(BaseModel): class DepartmentCreate(BaseModel):
"""创建部门请求体。"""
name: str name: str
parent_id: uuid.UUID | None = None parent_id: uuid.UUID | None = None
sort_order: int = 0 sort_order: int = 0
class DepartmentUpdate(BaseModel): class DepartmentUpdate(BaseModel):
"""更新部门请求体。"""
name: str | None = None name: str | None = None
parent_id: uuid.UUID | None = None parent_id: uuid.UUID | None = None
sort_order: int | None = None sort_order: int | None = None
class DepartmentOut(BaseModel): class DepartmentOut(BaseModel):
"""部门响应体,含嵌套子部门列表。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
parent_id: uuid.UUID | None = None parent_id: uuid.UUID | None = None
@ -86,6 +111,7 @@ class DepartmentOut(BaseModel):
# --- Role --- # --- Role ---
class RoleCreate(BaseModel): class RoleCreate(BaseModel):
"""创建角色请求体。"""
name: str name: str
code: str = "" code: str = ""
description: str | None = None description: str | None = None
@ -94,6 +120,7 @@ class RoleCreate(BaseModel):
class RoleUpdate(BaseModel): class RoleUpdate(BaseModel):
"""更新角色请求体。"""
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
data_scope: str | None = None data_scope: str | None = None
@ -101,6 +128,7 @@ class RoleUpdate(BaseModel):
class RoleOut(BaseModel): class RoleOut(BaseModel):
"""角色响应体,含权限编码列表。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
code: str = "" code: str = ""
@ -115,6 +143,7 @@ class RoleOut(BaseModel):
# --- Permission --- # --- Permission ---
class PermissionOut(BaseModel): class PermissionOut(BaseModel):
"""权限响应体。"""
id: uuid.UUID id: uuid.UUID
code: str code: str
name: str name: str
@ -127,6 +156,7 @@ class PermissionOut(BaseModel):
# --- Task --- # --- Task ---
class TaskCreate(BaseModel): class TaskCreate(BaseModel):
"""创建任务请求体。"""
title: str title: str
content: str | None = None content: str | None = None
assignee_id: uuid.UUID assignee_id: uuid.UUID
@ -136,6 +166,7 @@ class TaskCreate(BaseModel):
class TaskUpdate(BaseModel): class TaskUpdate(BaseModel):
"""更新任务请求体。"""
title: str | None = None title: str | None = None
content: str | None = None content: str | None = None
status: str | None = None status: str | None = None
@ -144,6 +175,7 @@ class TaskUpdate(BaseModel):
class TaskOut(BaseModel): class TaskOut(BaseModel):
"""任务响应体。"""
id: uuid.UUID id: uuid.UUID
title: str title: str
content: str | None = None content: str | None = None
@ -161,6 +193,7 @@ class TaskOut(BaseModel):
# --- Employee Analysis --- # --- Employee Analysis ---
class EmployeeAnalysis(BaseModel): class EmployeeAnalysis(BaseModel):
"""员工分析报告响应体。"""
employee_name: str employee_name: str
department: str department: str
period: str period: str
@ -177,11 +210,13 @@ class EmployeeAnalysis(BaseModel):
# --- Flow --- # --- Flow ---
class TriggerNodeConfig(BaseModel): class TriggerNodeConfig(BaseModel):
"""触发节点配置。"""
event_type: str = "text_message" event_type: str = "text_message"
channels: list[str] = ["wecom"] channels: list[str] = ["wecom"]
callback_url: str = "" callback_url: str = ""
class LLMNodeConfig(BaseModel): class LLMNodeConfig(BaseModel):
"""LLM 调用节点配置。"""
system_prompt: str = "" system_prompt: str = ""
model: str = "gpt-4o-mini" model: str = "gpt-4o-mini"
temperature: float = 0.7 temperature: float = 0.7
@ -193,6 +228,7 @@ class LLMNodeConfig(BaseModel):
tool_call: bool = False tool_call: bool = False
class ToolNodeConfig(BaseModel): class ToolNodeConfig(BaseModel):
"""内置工具节点配置。"""
tool_name: str = "" tool_name: str = ""
tool_type: str = "" tool_type: str = ""
tool_params: dict = {} tool_params: dict = {}
@ -201,6 +237,7 @@ class ToolNodeConfig(BaseModel):
error_handling: str = "throw" error_handling: str = "throw"
class MCPNodeConfig(BaseModel): class MCPNodeConfig(BaseModel):
"""MCP 服务节点配置。"""
mcp_server: str = "" mcp_server: str = ""
tool_name: str = "" tool_name: str = ""
input_params: dict = {} input_params: dict = {}
@ -209,6 +246,7 @@ class MCPNodeConfig(BaseModel):
error_handling: str = "throw" error_handling: str = "throw"
class NotifyNodeConfig(BaseModel): class NotifyNodeConfig(BaseModel):
"""通知节点配置。"""
channels: dict = {"wecom": True, "web": False} channels: dict = {"wecom": True, "web": False}
message_template: str = "" message_template: str = ""
web_template: str = "" web_template: str = ""
@ -218,6 +256,7 @@ class NotifyNodeConfig(BaseModel):
error_handling: str = "throw" error_handling: str = "throw"
class ConditionNodeConfig(BaseModel): class ConditionNodeConfig(BaseModel):
"""条件分支节点配置。"""
condition: str = "" condition: str = ""
condition_type: str = "expression" condition_type: str = "expression"
true_label: str = "" true_label: str = ""
@ -225,6 +264,7 @@ class ConditionNodeConfig(BaseModel):
default_branch: str = "false" default_branch: str = "false"
class RAGNodeConfig(BaseModel): class RAGNodeConfig(BaseModel):
"""RAG 检索节点配置。"""
knowledge_base: str = "" knowledge_base: str = ""
top_k: int = 5 top_k: int = 5
search_mode: str = "hybrid" search_mode: str = "hybrid"
@ -233,6 +273,7 @@ class RAGNodeConfig(BaseModel):
include_metadata: bool = True include_metadata: bool = True
class OutputNodeConfig(BaseModel): class OutputNodeConfig(BaseModel):
"""输出节点配置。"""
format: str = "text" format: str = "text"
output_template: str = "" output_template: str = ""
indent: int = 2 indent: int = 2
@ -241,18 +282,21 @@ class OutputNodeConfig(BaseModel):
max_length: int = 2000 max_length: int = 2000
class LoopNodeConfig(BaseModel): class LoopNodeConfig(BaseModel):
"""循环节点配置。"""
loop_type: str = "fixed" loop_type: str = "fixed"
max_iterations: int = 10 max_iterations: int = 10
count: int = 3 count: int = 3
iterator_variable: str = "item" iterator_variable: str = "item"
class CodeNodeConfig(BaseModel): class CodeNodeConfig(BaseModel):
"""代码执行节点配置。"""
language: str = "python" language: str = "python"
code: str = "" code: str = ""
timeout: int = 30 timeout: int = 30
sandbox: bool = True sandbox: bool = True
class FlowNode(BaseModel): class FlowNode(BaseModel):
"""流程图中单个节点的定义。"""
id: str | None = None id: str | None = None
type: str type: str
label: str | None = None label: str | None = None
@ -260,6 +304,7 @@ class FlowNode(BaseModel):
class FlowEdge(BaseModel): class FlowEdge(BaseModel):
"""流程图中连接边的定义。"""
source: str | None = None source: str | None = None
target: str | None = None target: str | None = None
from_field: str | None = Field(None, alias="from") from_field: str | None = Field(None, alias="from")
@ -270,6 +315,7 @@ class FlowEdge(BaseModel):
class FlowDefinitionCreate(BaseModel): class FlowDefinitionCreate(BaseModel):
"""创建流程定义请求体。"""
name: str name: str
description: str | None = None description: str | None = None
trigger: dict = {} trigger: dict = {}
@ -279,6 +325,7 @@ class FlowDefinitionCreate(BaseModel):
class FlowDefinitionUpdate(BaseModel): class FlowDefinitionUpdate(BaseModel):
"""更新流程定义请求体。"""
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
nodes: list[FlowNode] | None = None nodes: list[FlowNode] | None = None
@ -287,6 +334,7 @@ class FlowDefinitionUpdate(BaseModel):
class FlowDefinitionOut(BaseModel): class FlowDefinitionOut(BaseModel):
"""流程定义响应体。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
description: str | None = None description: str | None = None
@ -305,6 +353,7 @@ class FlowDefinitionOut(BaseModel):
class FlowVersionOut(BaseModel): class FlowVersionOut(BaseModel):
"""流程版本响应体。"""
id: uuid.UUID id: uuid.UUID
flow_id: uuid.UUID flow_id: uuid.UUID
version: int version: int
@ -320,10 +369,12 @@ class FlowVersionOut(BaseModel):
class FlowApiKeyCreate(BaseModel): class FlowApiKeyCreate(BaseModel):
"""创建流程 API 密钥请求体。"""
name: str name: str
class FlowApiKeyOut(BaseModel): class FlowApiKeyOut(BaseModel):
"""流程 API 密钥响应体。"""
id: uuid.UUID id: uuid.UUID
flow_id: uuid.UUID flow_id: uuid.UUID
name: str name: str
@ -336,12 +387,14 @@ class FlowApiKeyOut(BaseModel):
class FlowExecuteRequest(BaseModel): class FlowExecuteRequest(BaseModel):
"""流程执行请求体。"""
input_text: str = "" input_text: str = ""
session_id: str | None = None session_id: str | None = None
user_id: str | None = None user_id: str | None = None
class FlowChatMessageRequest(BaseModel): class FlowChatMessageRequest(BaseModel):
"""流程聊天消息请求体。"""
query: str query: str
inputs: dict = {} inputs: dict = {}
response_mode: str = "blocking" response_mode: str = "blocking"
@ -350,11 +403,13 @@ class FlowChatMessageRequest(BaseModel):
class OpenAPIImportRequest(BaseModel): class OpenAPIImportRequest(BaseModel):
"""OpenAPI 导入请求体。"""
openapi_url: str openapi_url: str
base_url_override: str | None = None base_url_override: str | None = None
class CustomToolCreate(BaseModel): class CustomToolCreate(BaseModel):
"""创建自定义工具请求体。"""
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
name: str name: str
description: str | None = None description: str | None = None
@ -369,6 +424,7 @@ class CustomToolCreate(BaseModel):
class CustomToolUpdate(BaseModel): class CustomToolUpdate(BaseModel):
"""更新自定义工具请求体。"""
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
@ -383,6 +439,7 @@ class CustomToolUpdate(BaseModel):
class CustomToolOut(BaseModel): class CustomToolOut(BaseModel):
"""自定义工具响应体(ORM 映射)。"""
model_config = ConfigDict(from_attributes=True, populate_by_name=True, protected_namespaces=()) model_config = ConfigDict(from_attributes=True, populate_by_name=True, protected_namespaces=())
id: uuid.UUID id: uuid.UUID
name: str name: str
@ -398,6 +455,7 @@ class CustomToolOut(BaseModel):
# --- MCP --- # --- MCP ---
class MCPServiceCreate(BaseModel): class MCPServiceCreate(BaseModel):
"""创建 MCP 服务请求体。"""
name: str name: str
transport: str = "http" transport: str = "http"
url: str | None = None url: str | None = None
@ -407,6 +465,7 @@ class MCPServiceCreate(BaseModel):
class MCPServiceUpdate(BaseModel): class MCPServiceUpdate(BaseModel):
"""更新 MCP 服务请求体。"""
transport: str | None = None transport: str | None = None
url: str | None = None url: str | None = None
command: str | None = None command: str | None = None
@ -415,6 +474,7 @@ class MCPServiceUpdate(BaseModel):
class MCPServiceOut(BaseModel): class MCPServiceOut(BaseModel):
"""MCP 服务响应体。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
transport: str transport: str
@ -431,6 +491,7 @@ class MCPServiceOut(BaseModel):
# --- Agent Config --- # --- Agent Config ---
class AgentConfigCreate(BaseModel): class AgentConfigCreate(BaseModel):
"""创建 Agent 配置请求体。"""
name: str name: str
description: str | None = None description: str | None = None
system_prompt: str = "" system_prompt: str = ""
@ -440,6 +501,7 @@ class AgentConfigCreate(BaseModel):
class AgentConfigUpdate(BaseModel): class AgentConfigUpdate(BaseModel):
"""更新 Agent 配置请求体。"""
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
system_prompt: str | None = None system_prompt: str | None = None
@ -450,6 +512,7 @@ class AgentConfigUpdate(BaseModel):
class AgentConfigOut(BaseModel): class AgentConfigOut(BaseModel):
"""Agent 配置响应体。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
description: str | None = None description: str | None = None
@ -468,6 +531,7 @@ class AgentConfigOut(BaseModel):
# --- Notification --- # --- Notification ---
class NotificationTemplateCreate(BaseModel): class NotificationTemplateCreate(BaseModel):
"""创建通知模板请求体。"""
name: str name: str
code: str code: str
channel: str = "wecom" channel: str = "wecom"
@ -477,6 +541,7 @@ class NotificationTemplateCreate(BaseModel):
class NotificationTemplateOut(BaseModel): class NotificationTemplateOut(BaseModel):
"""通知模板响应体。"""
id: uuid.UUID id: uuid.UUID
name: str name: str
code: str code: str
@ -492,6 +557,7 @@ class NotificationTemplateOut(BaseModel):
# --- Document --- # --- Document ---
class DocumentUploadOut(BaseModel): class DocumentUploadOut(BaseModel):
"""文档上传结果响应体。"""
file_id: uuid.UUID file_id: uuid.UUID
filename: str filename: str
file_size: int file_size: int
@ -500,6 +566,7 @@ class DocumentUploadOut(BaseModel):
class DocumentParseResult(BaseModel): class DocumentParseResult(BaseModel):
"""文档解析结果响应体。"""
file_id: uuid.UUID file_id: uuid.UUID
filename: str filename: str
content: str content: str
@ -508,6 +575,7 @@ class DocumentParseResult(BaseModel):
# --- Audit --- # --- Audit ---
class AuditQueryParams(BaseModel): class AuditQueryParams(BaseModel):
"""审计日志查询参数。"""
page: int = 1 page: int = 1
page_size: int = 20 page_size: int = 20
action: str | None = None action: str | None = None
@ -518,6 +586,7 @@ class AuditQueryParams(BaseModel):
class AuditLogOut(BaseModel): class AuditLogOut(BaseModel):
"""审计日志条目响应体。"""
id: uuid.UUID id: uuid.UUID
operator_id: uuid.UUID | None = None operator_id: uuid.UUID | None = None
action: str action: str
@ -532,6 +601,7 @@ class AuditLogOut(BaseModel):
class AuditLogPage(BaseModel): class AuditLogPage(BaseModel):
"""审计日志分页响应体。"""
items: list[AuditLogOut] items: list[AuditLogOut]
total: int total: int
page: int page: int
@ -540,6 +610,7 @@ class AuditLogPage(BaseModel):
# --- System Metrics --- # --- System Metrics ---
class SystemMetricOut(BaseModel): class SystemMetricOut(BaseModel):
"""系统监控指标响应体。"""
id: uuid.UUID id: uuid.UUID
metric_type: str metric_type: str
value: dict value: dict
@ -550,6 +621,7 @@ class SystemMetricOut(BaseModel):
class SystemHealthOut(BaseModel): class SystemHealthOut(BaseModel):
"""系统健康状态响应体。"""
status: str status: str
service: str service: str
uptime_seconds: float uptime_seconds: float
@ -561,6 +633,7 @@ class SystemHealthOut(BaseModel):
class UsageStatsOut(BaseModel): class UsageStatsOut(BaseModel):
"""使用统计响应体。"""
total_users: int total_users: int
active_users_today: int active_users_today: int
total_sessions: int total_sessions: int
@ -574,6 +647,7 @@ class UsageStatsOut(BaseModel):
# --- Generic Response --- # --- Generic Response ---
class ApiResponse(BaseModel): class ApiResponse(BaseModel):
"""通用 API 响应包装。"""
code: int = 200 code: int = 200
message: str = "success" message: str = "success"
data: dict | list | None = None data: dict | list | None = None

17
backend/websocket_manager.py

@ -3,14 +3,22 @@ from typing import Dict, Set
import json import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) # 当前模块的日志记录器
class WebSocketManager: class WebSocketManager:
"""WebSocket 连接管理器。
管理用户与多个 WebSocket 连接的生命周期
支持单用户多连接定向推送和全局广播功能
"""
def __init__(self): def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {} """初始化 WebSocket 连接池。"""
self.active_connections: Dict[str, Set[WebSocket]] = {} # 活跃连接池:{user_id: {WebSocket, ...}}
async def connect(self, websocket: WebSocket, user_id: str): async def connect(self, websocket: WebSocket, user_id: str):
"""接受 WebSocket 连接并注册到指定用户的连接池中。"""
await websocket.accept() await websocket.accept()
if user_id not in self.active_connections: if user_id not in self.active_connections:
self.active_connections[user_id] = set() self.active_connections[user_id] = set()
@ -18,6 +26,7 @@ class WebSocketManager:
logger.info(f"WebSocket 用户 {user_id} 已连接") logger.info(f"WebSocket 用户 {user_id} 已连接")
def disconnect(self, websocket: WebSocket, user_id: str): def disconnect(self, websocket: WebSocket, user_id: str):
"""断开 WebSocket 连接并从用户连接池中移除。"""
if user_id in self.active_connections: if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket) self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]: if not self.active_connections[user_id]:
@ -25,6 +34,7 @@ class WebSocketManager:
logger.info(f"WebSocket 用户 {user_id} 已断开") logger.info(f"WebSocket 用户 {user_id} 已断开")
async def send_to_user(self, user_id: str, message: dict): async def send_to_user(self, user_id: str, message: dict):
"""向指定用户的所有活跃连接发送消息,自动清理断开连接。"""
if user_id not in self.active_connections: if user_id not in self.active_connections:
return False return False
dead_connections = set() dead_connections = set()
@ -42,8 +52,9 @@ class WebSocketManager:
return sent_count > 0 return sent_count > 0
async def broadcast(self, message: dict): async def broadcast(self, message: dict):
"""向当前所有活跃用户广播消息,遍历所有用户并调用 send_to_user。"""
for user_id in list(self.active_connections.keys()): for user_id in list(self.active_connections.keys()):
await self.send_to_user(user_id, message) await self.send_to_user(user_id, message)
ws_manager = WebSocketManager() ws_manager = WebSocketManager() # 全局 WebSocket 管理器单例
Loading…
Cancel
Save