import json import uuid import logging from collections import deque from agentscope.agent import AgentBase from agentscope.message import Msg from agentscope.tool import Toolkit from agentscope.agent._react_agent import ReActAgent from config import settings logger = logging.getLogger(__name__) class FlowEngine: def __init__(self, flow_definition: dict): self.definition = flow_definition self.nodes: dict[str, dict] = {} for node in flow_definition.get("nodes", []): self.nodes[node["id"]] = node self.edges: list[dict] = flow_definition.get("edges", []) self._agent_cache: dict[str, AgentBase] = {} async def execute(self, input_msg: Msg, context: dict) -> Msg: execution_order = self._topological_sort() current_msg = input_msg for i, node_id in enumerate(execution_order): agent = await self._get_or_create_agent(node_id, context) node = self.nodes[node_id] enriched_content = self._resolve_input_mapping(node, current_msg, context) if enriched_content.strip(): user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg) current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user") try: result = await agent(current_msg) exec_record = { "node_id": node_id, "node_type": node.get("type"), "label": node.get("label"), "status": "success", "output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500], } context.setdefault("_node_results", {})[node_id] = exec_record current_msg = result except Exception as e: logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}") exec_record = { "node_id": node_id, "node_type": node.get("type"), "label": node.get("label"), "status": "error", "error": str(e), } context.setdefault("_node_results", {})[node_id] = exec_record current_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system") return current_msg async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase: if node_id in self._agent_cache: return self._agent_cache[node_id] node = self.nodes[node_id] agent = await _create_node_agent(node, context) self._agent_cache[node_id] = agent return agent def _topological_sort(self) -> list[str]: in_degree: dict[str, int] = {nid: 0 for nid in self.nodes} adj: dict[str, list[str]] = {nid: [] for nid in self.nodes} for edge in self.edges: source = edge.get("from") or edge.get("source") target = edge.get("to") or edge.get("target") if source and target and source in self.nodes and target in self.nodes: adj[source].append(target) in_degree[target] = in_degree.get(target, 0) + 1 queue = deque([nid for nid, deg in in_degree.items() if deg == 0]) order = [] while queue: node_id = queue.popleft() order.append(node_id) for neighbor in adj.get(node_id, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) remaining = [nid for nid in self.nodes if nid not in order] order.extend(remaining) return order def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str: config = node.get("config", {}) input_mapping = config.get("input_mapping") if not input_mapping: return "" resolved = {} for key, template in input_mapping.items(): value = template if isinstance(template, str) and "{{" in template: value = _resolve_template(template, context, current_msg) resolved[key] = str(value) return "\n".join([f"{k}: {v}" for k, v in resolved.items()]) async def _create_node_agent(node: dict, context: dict) -> AgentBase: node_type = node.get("type", "") node_id = node.get("id", "") config = node.get("config", {}) if node_type == "trigger": return PassThroughAgent(node_id) elif node_type == "llm": model_config = config.get("model", settings.LLM_MODEL) temperature = config.get("temperature", 0.7) system_prompt = config.get("system_prompt", "你是AI助手。") return LLMNodeAgent( node_id=node_id, system_prompt=system_prompt, model_name=model_config, temperature=temperature, ) elif node_type == "tool": tool_name = config.get("tool_name", "") tool_params = config.get("tool_params", {}) return ToolNodeAgent(node_id=node_id, tool_name=tool_name, tool_params=tool_params) elif node_type == "mcp": mcp_server = config.get("mcp_server", "") tool_name = config.get("tool_name", "") return MCPNodeAgent(node_id=node_id, server_name=mcp_server, tool_name=tool_name) elif node_type == "wecom_notify": return WeComNotifyAgent(node_id=node_id, config=config) elif node_type == "condition": condition_expr = config.get("condition", "") return ConditionNodeAgent(node_id=node_id, condition=condition_expr) elif node_type == "rag": return RAGNodeAgent(node_id=node_id, config=config) elif node_type == "output": return OutputNodeAgent(node_id=node_id, config=config) else: return PassThroughAgent(node_id) class PassThroughAgent(AgentBase): def __init__(self, node_id: str): super().__init__() self.name = f"passthrough_{node_id}" async def reply(self, msg, **kwargs) -> Msg: return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") async def observe(self, msg) -> None: pass class LLMNodeAgent(AgentBase): def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7): super().__init__() self.name = f"LLM_{node_id}" self.system_prompt = system_prompt self.model_name = model_name or settings.LLM_MODEL self.temperature = temperature async def reply(self, msg: Msg, **kwargs) -> Msg: from agentscope.model import OpenAIChatModel from agentscope.formatter import OpenAIChatFormatter model = OpenAIChatModel( config_name=f"flow_llm_{self.name}", model_name=self.model_name, api_key=settings.LLM_API_KEY, api_base=settings.LLM_API_BASE, ) user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) formatter = OpenAIChatFormatter() prompt = await formatter.format([ Msg("system", self.system_prompt, "system"), Msg("user", user_text, "user"), ]) try: res = await model(prompt) res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): res_text = res.get_text_content() else: res_text = str(res) except Exception as e: logger.warning(f"LLM 调用失败: {e}") res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}" return Msg(self.name, res_text, "assistant") async def observe(self, msg) -> None: pass class ToolNodeAgent(AgentBase): _TOOL_REGISTRY: dict[str, callable] = {} @classmethod def _init_registry(cls): if cls._TOOL_REGISTRY: return try: from agentscope_integration.tools.document_tools import parse_document, format_correction from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard cls._TOOL_REGISTRY = { "parse_document": parse_document, "format_correction": format_correction, "send_notification": send_notification, "query_wecom_user": query_wecom_user, "list_tasks": list_tasks, "create_task": create_task, "get_task": get_task, "update_task": update_task, "push_task_to_wecom": push_task_to_wecom, "list_subordinates": list_subordinates, "generate_efficiency_report": generate_efficiency_report, "get_task_statistics": get_task_statistics, "get_employee_dashboard": get_employee_dashboard, } except ImportError as e: logger.warning(f"工具注册失败: {e}") def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None): super().__init__() self.name = f"Tool_{node_id}" self.tool_name = tool_name self.tool_params = tool_params or {} async def reply(self, msg: Msg, **kwargs) -> Msg: self._init_registry() user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) tool_func = self._TOOL_REGISTRY.get(self.tool_name) if tool_func: try: result = tool_func(**self.tool_params) if self.tool_params else tool_func() return Msg(self.name, str(result), "assistant") except TypeError: try: result = tool_func(user_text, **self.tool_params) return Msg(self.name, str(result), "assistant") except Exception as e: return Msg(self.name, f"[工具执行失败: {e}]", "assistant") except Exception as e: return Msg(self.name, f"[工具执行失败: {e}]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant") async def observe(self, msg) -> None: pass class ConditionNodeAgent(AgentBase): def __init__(self, node_id: str, condition: str = ""): super().__init__() self.name = f"Condition_{node_id}" self.condition = condition async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.condition: return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") result_text = f"[条件判断: {self.condition[:80]}]\n输入: {user_text[:300]}\n结果: 条件满足,继续执行。" return Msg(self.name, result_text, "assistant") async def observe(self, msg) -> None: pass class MCPNodeAgent(AgentBase): def __init__(self, node_id: str, server_name: str = "", tool_name: str = ""): super().__init__() self.name = f"MCP_{node_id}" self.server_name = server_name self.tool_name = tool_name async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.server_name: return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant") try: from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager client = MCPClientManager.get_http_client(self.server_name) if client and self.tool_name: result = await client.call_tool(self.tool_name, {"input": user_text}) return Msg(self.name, str(result), "assistant") except ImportError: logger.warning("agentscope_runtime MCP 客户端不可用") except Exception as e: logger.warning(f"MCP 调用失败: {e}") output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入" return Msg(self.name, output, "assistant") async def observe(self, msg) -> None: pass class WeComNotifyAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"WeComNotify_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) template = self.config.get("message_template", "") target = self.config.get("target", "") message = template or user_text[:500] try: from agentscope_integration.tools.wecom_tools import send_notification result = send_notification(to_user=target or "user", message=message) return Msg(self.name, result, "assistant") except ImportError: pass except Exception as e: logger.warning(f"企微通知发送失败: {e}") result = f"[企微通知] 已向 {target or '用户'} 发送: {message[:100]}" return Msg(self.name, result, "assistant") async def observe(self, msg) -> None: pass class RAGNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"RAG_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) top_k = self.config.get("top_k", 5) try: from modules.rag.knowledge import retrieve_for_agent kb_result = await retrieve_for_agent(user_text, limit=top_k) model = self._get_model() formatter = self._get_formatter() rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。 知识库检索结果: {kb_result} 用户问题: {user_text} 请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。""" import asyncio loop = asyncio.get_event_loop() messages = await asyncio.to_thread(formatter.format, [ {"role": "system", "content": rag_prompt}, {"role": "user", "content": user_text}, ]) res = await model(messages) res_text = res.get_text_content() if hasattr(res, 'get_text_content') else str(res) return Msg(self.name, res_text, "assistant") except Exception as e: logger.warning(f"RAG 节点执行失败: {e}") output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}" return Msg(self.name, output, "assistant") def _get_model(self): from agentscope.model import OpenAIChatModel return OpenAIChatModel( config_name=f"rag_{self.name}", model_name=settings.LLM_MODEL, api_key=settings.LLM_API_KEY, api_base=settings.LLM_API_BASE, ) def _get_formatter(self): from agentscope.formatter import OpenAIChatFormatter return OpenAIChatFormatter() async def observe(self, msg) -> None: pass class OutputNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Output_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: output_format = self.config.get("format", "text") if output_format == "json": content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) try: parsed = json.loads(content) formatted = json.dumps(parsed, indent=2, ensure_ascii=False) return Msg(self.name, formatted, "assistant") except (json.JSONDecodeError, ValueError): pass return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") async def observe(self, msg) -> None: pass def _resolve_template(template: str, context: dict, current_msg: Msg) -> str: result = template import re placeholders = re.findall(r'\{\{(.+?)\}\}', template) for placeholder in placeholders: parts = placeholder.strip().split(".") value = "" if parts[0] == "trigger": value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), "")) elif parts[0] in context.get("_node_results", {}): node_result = context.get("_node_results", {}).get(parts[0], {}) if len(parts) > 1: value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], "")) else: value = str(node_result.get("output", "")) result = result.replace("{{" + placeholder + "}}", value) return result