import json import uuid import logging import re import asyncio from agentscope.agent import AgentBase from agentscope.message import Msg from agentscope.tool import Toolkit from agentscope.agent._react_agent import ReActAgent from config import settings logger = logging.getLogger(__name__) class FlowSessionMemory: def __init__(self, session_id: str = "", user_id: str = ""): self.session_id = session_id self.user_id = user_id self._messages: list[dict] = [] def get_history(self, limit: int = 10) -> list[dict]: return self._messages[-limit * 2:] def add(self, role: str, content: str): self._messages.append({"role": role, "content": content}) def to_list(self) -> list[dict]: return list(self._messages) class FlowEngine: MAX_TOTAL_ITERATIONS = 200 FLOW_TIMEOUT_SECONDS = 300 def __init__(self, flow_definition: dict): self.definition = flow_definition self.nodes: dict[str, dict] = {} for node in flow_definition.get("nodes", []): self.nodes[node["id"]] = node self.edges: list[dict] = flow_definition.get("edges", []) self._agent_cache: dict[str, AgentBase] = {} async def execute(self, input_msg: Msg, context: dict) -> Msg: graph = self._build_graph() start_nodes = self._find_start_nodes(graph) if not start_nodes: start_nodes = list(self.nodes.keys())[:1] session_id = context.get("session_id", str(uuid.uuid4())) user_id = context.get("user_id", "") memory = FlowSessionMemory(session_id=session_id, user_id=user_id) context["_memory"] = memory visited: set[str] = set() loop_iterations: dict[str, int] = {} total_iterations = 0 last_result: Msg | None = None async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None: nonlocal last_result, total_iterations total_iterations += 1 if total_iterations > self.MAX_TOTAL_ITERATIONS: logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止") last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system") return node = self.nodes.get(node_id) if not node: return node_type = node.get("type", "") is_loop = node_type == "loop" if not is_loop and node_id in visited: return if not is_loop: visited.add(node_id) if is_loop: loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1 config = node.get("config", {}) max_iter = config.get("max_iterations", 10) if loop_iterations[node_id] > max_iter: logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}") visited.add(node_id) for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_done": await traverse(target_id, incoming_msg) return if is_loop: loop_node_config = dict(node.get("config", {})) loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0) agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config) else: agent = await self._get_or_create_agent(node_id, context) enriched_content = self._resolve_input_mapping(node, incoming_msg, context) current_msg = incoming_msg if enriched_content.strip(): user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg) current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user") try: result = await agent.reply(current_msg, context=context) exec_record = { "node_id": node_id, "node_type": node_type, "label": node.get("label"), "status": "success", "output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500], } context.setdefault("_node_results", {})[node_id] = exec_record last_result = result if is_loop: loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0)) if loop_done: visited.add(node_id) for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_done": await traverse(target_id, result) else: for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_body": body_node = self.nodes.get(target_id) if body_node: visited.discard(target_id) await traverse(target_id, result, {"iteration": loop_iterations[node_id]}) return is_condition = node_type == "condition" cond_result = self._parse_condition_result(result) for target_id, edge_cond in graph.get(node_id, []): if is_condition: if edge_cond and edge_cond == cond_result: await traverse(target_id, result) elif edge_cond == "loop_body" or edge_cond == "loop_done": continue else: await traverse(target_id, result) except Exception as e: logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}") exec_record = { "node_id": node_id, "node_type": node_type, "label": node.get("label"), "status": "error", "error": str(e), } context.setdefault("_node_results", {})[node_id] = exec_record error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system") last_result = error_msg if start_nodes: try: await asyncio.wait_for( traverse(start_nodes[0], input_msg), timeout=self.FLOW_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止") last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system") return last_result or input_msg def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]: graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes} for edge in self.edges: source = edge.get("source") or edge.get("from") target = edge.get("target") or edge.get("to") cond = edge.get("condition") or edge.get("sourceHandle") if cond == "source": cond = None if source and target and source in self.nodes and target in self.nodes: graph[source].append((target, cond)) return graph def _find_start_nodes(self, graph: dict) -> list[str]: target_nodes: set[str] = set() for targets in graph.values(): for target_id, _ in targets: target_nodes.add(target_id) return [nid for nid in self.nodes if nid not in target_nodes] def _parse_condition_result(self, result: Msg) -> str | None: content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result) m = re.search(r'condition:(true|false)', content) if m: return m.group(1) return None def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool: config = node.get("config", {}) loop_type = config.get("loop_type", "fixed") if loop_type == "fixed": count = config.get("count", 3) return iteration >= count content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result) if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE): return True return False async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase: if node_id in self._agent_cache and override_config is None: return self._agent_cache[node_id] node = dict(self.nodes[node_id]) if override_config is not None: node["config"] = override_config agent = await _create_node_agent(node, context) if override_config is None: self._agent_cache[node_id] = agent return agent def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str: config = node.get("config", {}) input_mapping = config.get("input_mapping") if not input_mapping: return "" resolved = {} for key, template in input_mapping.items(): value = template if isinstance(template, str) and "{{" in template: value = _resolve_template(template, context, current_msg) resolved[key] = str(value) return "\n".join([f"{k}: {v}" for k, v in resolved.items()]) async def _create_node_agent(node: dict, context: dict) -> AgentBase: node_type = node.get("type", "") node_id = node.get("id", "") config = node.get("config", {}) if node_type == "trigger": return PassThroughAgent(node_id) elif node_type == "llm": model_config = config.get("model", settings.LLM_MODEL) temperature = config.get("temperature", 0.7) system_prompt = config.get("system_prompt", "你是AI助手。") max_tokens = config.get("max_tokens", 2000) stream = config.get("stream", True) stream_cb = context.get("_stream_callback") agent = LLMNodeAgent( node_id=node_id, system_prompt=system_prompt, model_name=model_config, temperature=temperature, max_tokens=max_tokens, stream=stream, stream_callback=stream_cb, ) memory = context.get("_memory") if memory: agent.set_memory(memory) return agent elif node_type == "tool": tool_name = config.get("tool_name", "") tool_params = config.get("tool_params", {}) timeout = config.get("timeout", 30) retry_count = config.get("retry_count", 0) error_handling = config.get("error_handling", "throw") return ToolNodeAgent( node_id=node_id, tool_name=tool_name, tool_params=tool_params, timeout=timeout, retry_count=retry_count, error_handling=error_handling, ) elif node_type == "mcp": mcp_server = config.get("mcp_server", "") tool_name = config.get("tool_name", "") timeout = config.get("timeout", 30) error_handling = config.get("error_handling", "throw") return MCPNodeAgent( node_id=node_id, server_name=mcp_server, tool_name=tool_name, timeout=timeout, error_handling=error_handling, ) elif node_type in ("wecom_notify", "notify"): return NotifyAgent(node_id=node_id, config=config) elif node_type == "condition": condition_expr = config.get("condition", "") condition_type = config.get("condition_type", "expression") return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type) elif node_type == "rag": return RAGNodeAgent(node_id=node_id, config=config) elif node_type == "output": return OutputNodeAgent(node_id=node_id, config=config) elif node_type == "merge": return ParallelMergeNodeAgent(node_id=node_id, config=config) elif node_type == "loop": return LoopNodeAgent(node_id=node_id, config=config) elif node_type == "code": language = config.get("language", "python") code = config.get("code", "") timeout = config.get("timeout", 30) sandbox = config.get("sandbox", True) return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox) else: return PassThroughAgent(node_id) class PassThroughAgent(AgentBase): def __init__(self, node_id: str): super().__init__() self.name = f"passthrough_{node_id}" async def reply(self, msg, **kwargs) -> Msg: return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") async def observe(self, msg) -> None: pass class LLMNodeAgent(AgentBase): def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None): super().__init__() self.name = f"LLM_{node_id}" self.system_prompt = system_prompt self.model_name = model_name or settings.LLM_MODEL self.temperature = temperature self.max_tokens = max_tokens self.stream = stream self.stream_callback = stream_callback self._memory = None def set_memory(self, memory): self._memory = memory async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) context = kwargs.get("context", {}) memory_ctx = context.get("_memory_context", {}) messages = [{"role": "system", "content": self.system_prompt}] if memory_ctx: summary = memory_ctx.get("summary", "") recent = memory_ctx.get("recent_messages", []) if summary: messages.append({"role": "system", "content": f"[历史对话摘要]\n{summary}"}) for m in recent[-10:]: role = m.get("role", "user") content = m.get("content", "") if len(content) > 2000: content = content[:2000] messages.append({"role": role, "content": content}) elif self._memory: history = self._memory.get_history(limit=5) for h in history: role = h.get("role", "user") content = h.get("content", "") if len(content) > 2000: content = content[:2000] messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": user_text}) try: if self.stream_callback: res_text = await self._stream_llm_call(messages) else: res_text = await self._blocking_llm_call(messages) except Exception as e: logger.warning(f"LLM 调用失败: {e}") res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}" if self._memory and not memory_ctx: self._memory.add("user", user_text) self._memory.add("assistant", res_text) return Msg(self.name, res_text, "assistant") async def _blocking_llm_call(self, messages: list[dict]) -> str: from agentscope_integration.factory import AgentFactory from agentscope.formatter import OpenAIChatFormatter model = AgentFactory._get_model() formatter = OpenAIChatFormatter() scope_msgs = [] for m in messages: scope_msgs.append(Msg(m["role"], m["content"], m["role"])) prompt = formatter.format(scope_msgs) res = await model(prompt) if isinstance(res, list): return res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): return res.get_text_content() return str(res) async def _stream_llm_call(self, messages: list[dict]) -> str: import httpx import json api_base = settings.LLM_API_BASE.rstrip("/") api_key = settings.LLM_API_KEY model_name = self.model_name or settings.LLM_MODEL url = f"{api_base}/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } body = { "model": model_name, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": True, } accumulated = "" try: timeout = httpx.Timeout(60.0, connect=10.0) async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream("POST", url, json=body, headers=headers) as response: async for line in response.aiter_lines(): if not line.startswith("data: "): continue json_str = line[6:].strip() if json_str == "[DONE]": break try: chunk = json.loads(json_str) delta = chunk.get("choices", [{}])[0].get("delta", {}) token = delta.get("content", "") if token: accumulated += token await self.stream_callback("text_chunk", {"content": token}) except json.JSONDecodeError: continue except httpx.TimeoutException: logger.warning("LLM 流式调用超时") if not accumulated: accumulated = "[LLM 超时]" except Exception as e: logger.warning(f"LLM 流式调用失败: {e}") if not accumulated: accumulated = f"[LLM 调用失败: {e}]" return accumulated async def observe(self, msg) -> None: pass class ToolNodeAgent(AgentBase): _TOOL_REGISTRY: dict[str, callable] = {} _TOOL_SCHEMAS: dict[str, dict] = {} _CUSTOM_TOOL_DEFS: dict[str, dict] = {} @classmethod def _init_registry(cls): if cls._TOOL_REGISTRY: return try: from agentscope_integration.tools.document_tools import parse_document, format_correction from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard cls._TOOL_REGISTRY = { "parse_document": parse_document, "format_correction": format_correction, "send_notification": send_notification, "query_wecom_user": query_wecom_user, "list_tasks": list_tasks, "create_task": create_task, "get_task": get_task, "update_task": update_task, "push_task_to_wecom": push_task_to_wecom, "list_subordinates": list_subordinates, "generate_efficiency_report": generate_efficiency_report, "get_task_statistics": get_task_statistics, "get_employee_dashboard": get_employee_dashboard, } for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]: try: mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"]) if hasattr(mod, "SCHEMAS"): cls._TOOL_SCHEMAS.update(mod.SCHEMAS) except Exception: pass except ImportError as e: logger.warning(f"工具注册失败: {e}") @classmethod async def load_custom_tools(cls, db): try: from sqlalchemy import select from models import CustomTool from modules.custom_tool.executor import CustomToolExecutor result = await db.execute( select(CustomTool).where(CustomTool.is_active == True) ) for tool in result.scalars().all(): tool_def = { "endpoint_url": tool.endpoint_url, "method": tool.method, "path": tool.path, "headers_json": tool.headers_json, "auth_type": tool.auth_type, "auth_config": tool.auth_config, "timeout": 30, } cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def cls._TOOL_SCHEMAS[tool.name] = tool.schema_json except ImportError: logger.warning("无法加载自定义工具模块") except Exception as e: logger.warning(f"加载自定义工具失败: {e}") @classmethod def register_custom_tool(cls, name: str, schema: dict, tool_def: dict): cls._CUSTOM_TOOL_DEFS[name] = tool_def cls._TOOL_SCHEMAS[name] = schema @classmethod def get_schemas(cls) -> dict: cls._init_registry() return dict(cls._TOOL_SCHEMAS) def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"): super().__init__() self.name = f"Tool_{node_id}" self.tool_name = tool_name self.tool_params = tool_params or {} self.timeout = timeout self.retry_count = retry_count self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: self._init_registry() user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name) if custom_def: return await self._execute_custom_tool(custom_def, user_text) tool_func = self._TOOL_REGISTRY.get(self.tool_name) if not tool_func: if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant") def _call_sync(): try: result = tool_func(**self.tool_params) if self.tool_params else tool_func() return result except TypeError: try: result = tool_func(user_text, **self.tool_params) return result except Exception as e: raise RuntimeError(f"工具执行失败: {e}") except Exception as e: raise RuntimeError(f"工具执行失败: {e}") for attempt in range(self.retry_count + 1): try: result = await asyncio.wait_for( asyncio.to_thread(_call_sync), timeout=self.timeout, ) return Msg(self.name, str(result), "assistant") except asyncio.TimeoutError: if attempt < self.retry_count: logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") except RuntimeError as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") if self.error_handling == "default": return Msg(self.name, "{}", "assistant") return Msg(self.name, f"[{e}]", "assistant") except Exception as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") return Msg(self.name, f"[工具执行失败: {e}]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant") async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg: from modules.custom_tool.executor import CustomToolExecutor executor = CustomToolExecutor(custom_def) for attempt in range(self.retry_count + 1): try: result = await asyncio.wait_for( executor.execute(self.tool_params), timeout=self.timeout, ) return Msg(self.name, str(result), "assistant") except asyncio.TimeoutError: if attempt < self.retry_count: logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") continue if self.error_handling == "skip": return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant") return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") except Exception as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") if self.error_handling == "default": return Msg(self.name, "{}", "assistant") return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant") return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant") async def observe(self, msg) -> None: pass class ConditionNodeAgent(AgentBase): def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"): super().__init__() self.name = f"Condition_{node_id}" self.condition = condition self.condition_type = condition_type async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.condition: return Msg(self.name, "condition:true|条件为空,默认通过", "assistant") if self.condition_type == "regex": try: matched = bool(re.search(self.condition, user_text)) return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'命' if matched else '未命'}中", "assistant") except re.error as e: return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant") if self.condition_type == "json_path": try: data = json.loads(user_text) if user_text.strip().startswith('{') else {} parts = self.condition.strip('$.').split('.') val = data for p in parts: if isinstance(val, dict): val = val.get(p, None) else: val = None return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant") except: return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant") try: from agentscope.model import OpenAIChatModel from agentscope.formatter import OpenAIChatFormatter model = OpenAIChatModel( config_name=f"condition_{self.name}", model_name=settings.LLM_MODEL, api_key=settings.LLM_API_KEY, api_base=settings.LLM_API_BASE, ) formatter = OpenAIChatFormatter() condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。 条件表达式: {self.condition} 输入内容: {user_text[:2000]} 请严格只输出一行 JSON: {{"result": true/false, "reason": "简要原因"}}""" prompt = formatter.format([ Msg("system", condition_prompt, "system"), Msg("user", user_text[:2000], "user"), ]) res = await model(prompt) res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): res_text = res.get_text_content() else: res_text = str(res) json_match = re.search(r'\{[^}]+\}', res_text) if json_match: parsed = json.loads(json_match.group()) matched = parsed.get("result", False) reason = parsed.get("reason", "") result_flag = "true" if matched else "false" return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant") except Exception as e: logger.warning(f"条件判断LLM调用失败: {e}") return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant") async def observe(self, msg) -> None: pass class MCPNodeAgent(AgentBase): def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"): super().__init__() self.name = f"MCP_{node_id}" self.server_name = server_name self.tool_name = tool_name self.timeout = timeout self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.server_name: return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant") try: from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager client = MCPClientManager.get_http_client(self.server_name) if client and self.tool_name: result = await asyncio.wait_for( client.call_tool(self.tool_name, {"input": user_text}), timeout=self.timeout ) return Msg(self.name, str(result), "assistant") except ImportError: logger.warning("agentscope_runtime MCP 客户端不可用") except asyncio.TimeoutError: if self.error_handling == "skip": return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant") return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant") except Exception as e: logger.warning(f"MCP 调用失败: {e}") if self.error_handling == "skip": return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant") output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入" return Msg(self.name, output, "assistant") async def observe(self, msg) -> None: pass class NotifyAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Notify_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) channels = self.config.get("channels", {"wecom": True, "web": False}) message_type = self.config.get("message_type", "text") results = [] if channels.get("wecom", True): template = self.config.get("message_template", "") target = self.config.get("target", "") message = template or user_text[:500] try: from agentscope_integration.tools.wecom_tools import send_notification result = send_notification(to_user=target or "user", message=message) results.append(f"企微通知: {result[:100]}") except ImportError: pass except Exception as e: logger.warning(f"企微通知发送失败: {e}") results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}") if channels.get("web", False): web_template = self.config.get("web_template", "") web_message = web_template or user_text[:500] try: from websocket_manager import ws_manager target_user = self.config.get("target", "") if target_user: await ws_manager.send_to_user(target_user, { "type": "flow_notification", "message": web_message, "level": "info", }) results.append(f"Web通知: 已推送") except ImportError: pass except Exception as e: logger.warning(f"Web通知推送失败: {e}") results.append(f"Web通知: 推送失败({e})") if not results: results.append("通知已发送") return Msg(self.name, " | ".join(results), "assistant") async def observe(self, msg) -> None: pass class LoopNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Loop_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) loop_type = self.config.get("loop_type", "fixed") count = self.config.get("count", 3) iterator_variable = self.config.get("iterator_variable", "item") iteration = self.config.get("_engine_iteration", 0) + 1 if loop_type == "array": items = self.config.get("items", []) if isinstance(items, str): try: items = json.loads(items) except Exception: items = [line.strip() for line in items.split("\n") if line.strip()] if isinstance(items, list) and items: idx = iteration - 1 if idx >= len(items): return Msg(self.name, "loop:stop|数组迭代完成", "assistant") current_item = items[idx] return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant") return Msg(self.name, "loop:stop|无数组数据", "assistant") if loop_type == "fixed": if iteration >= count: return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant") return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") async def observe(self, msg) -> None: pass class CodeNodeAgent(AgentBase): def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True): super().__init__() self.name = f"Code_{node_id}" self.language = language self.code = code self.timeout = timeout self.sandbox = sandbox async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) import subprocess import tempfile import os import json if not self.code.strip(): return Msg(self.name, user_text, "assistant") safe_input = json.dumps(user_text) code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}" try: with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: f.write(code_with_input) temp_path = f.name proc = await asyncio.create_subprocess_exec( "python", temp_path, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout) except asyncio.TimeoutError: proc.kill() try: os.unlink(temp_path) except: pass return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant") try: os.unlink(temp_path) except: pass if stderr: return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant") output = stdout.decode('utf-8', errors='replace').strip() return Msg(self.name, output or "[代码执行完成,无输出]", "assistant") except Exception as e: return Msg(self.name, f"[代码执行失败: {e}]", "assistant") async def observe(self, msg) -> None: pass class RAGNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"RAG_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) top_k = self.config.get("top_k", 5) similarity_threshold = self.config.get("similarity_threshold", 0.7) search_mode = self.config.get("search_mode", "hybrid") include_metadata = self.config.get("include_metadata", True) try: from modules.rag.knowledge import retrieve_for_agent kb_result = await retrieve_for_agent(user_text, limit=top_k) model = self._get_model() formatter = self._get_formatter() rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。 知识库检索结果: {kb_result} 用户问题: {user_text} 请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。""" prompt = formatter.format([ Msg("system", rag_prompt, "system"), Msg("user", user_text, "user"), ]) res = await model(prompt) res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): res_text = res.get_text_content() else: res_text = str(res) return Msg(self.name, res_text, "assistant") except Exception as e: logger.warning(f"RAG 节点执行失败: {e}") output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}" return Msg(self.name, output, "assistant") def _get_model(self): from agentscope.model import OpenAIChatModel return OpenAIChatModel( config_name=f"rag_{self.name}", model_name=settings.LLM_MODEL, api_key=settings.LLM_API_KEY, api_base=settings.LLM_API_BASE, ) def _get_formatter(self): from agentscope.formatter import OpenAIChatFormatter return OpenAIChatFormatter() async def observe(self, msg) -> None: pass class OutputNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Output_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: output_format = self.config.get("format", "text") output_template = self.config.get("output_template", "") truncate = self.config.get("truncate", False) max_length = self.config.get("max_length", 2000) content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if truncate and len(content) > max_length: content = content[:max_length] + "\n\n[内容已截断]" if output_format == "json": try: parsed = json.loads(content) indent = self.config.get("indent", 2) formatted = json.dumps(parsed, indent=indent, ensure_ascii=False) return Msg(self.name, formatted, "assistant") except (json.JSONDecodeError, ValueError): pass if output_template: try: resolved = output_template.replace("{{output}}", content) return Msg(self.name, resolved, "assistant") except: pass return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant") async def observe(self, msg) -> None: pass def _resolve_template(template: str, context: dict, current_msg: Msg) -> str: result = template import re placeholders = re.findall(r'\{\{(.+?)\}\}', template) for placeholder in placeholders: parts = placeholder.strip().split(".") value = "" if parts[0] == "trigger": value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), "")) elif parts[0] in context.get("_node_results", {}): node_result = context.get("_node_results", {}).get(parts[0], {}) if len(parts) > 1: value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], "")) else: value = str(node_result.get("output", "")) result = result.replace("{{" + placeholder + "}}", value) return result class ParallelMergeNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Merge_{node_id}" self.config = config or {} self._received: dict[str, str] = {} self.merge_type = self.config.get("merge_type", "concat") async def reply(self, msg: Msg, **kwargs) -> Msg: source_id = kwargs.get("source_node_id", "") content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) self._received[source_id] = content expected_count = self.config.get("expected_branches", 0) if expected_count <= 0 or len(self._received) >= expected_count: return self._merge() return Msg(self.name, "", "assistant") def _merge(self) -> Msg: if self.merge_type == "json": merged = json.dumps(self._received, ensure_ascii=False) elif self.merge_type == "first_non_empty": merged = "" for v in self._received.values(): if v.strip(): merged = v break else: parts = [] for k, v in self._received.items(): parts.append(v) merged = "\n\n---\n\n".join(parts) return Msg(self.name, merged, "assistant") async def observe(self, msg) -> None: pass