import json import uuid import logging import re import asyncio import httpx from agentscope.agent import AgentBase from agentscope.message import Msg from agentscope.tool import Toolkit from agentscope.agent._react_agent import ReActAgent from config import settings logger = logging.getLogger(__name__) async def _resolve_model_instance(model_instance_id: str) -> dict | None: try: from database import AsyncSessionLocal from sqlalchemy import text import uuid as _uuid uid = _uuid.UUID(model_instance_id) async with AsyncSessionLocal() as db: result = await db.execute( text(""" SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key FROM model_instances mi JOIN model_providers mp ON mi.provider_id = mp.id WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true """), {"id": uid}, ) row = result.fetchone() if row: return { "model": row[0], "base_url": row[2], "api_key": row[3], "params": row[1] or {}, } except Exception: pass return None class FlowSessionMemory: def __init__(self, session_id: str = "", user_id: str = ""): self.session_id = session_id self.user_id = user_id self._messages: list[dict] = [] def get_history(self, limit: int = 10) -> list[dict]: return self._messages[-limit * 2:] def add(self, role: str, content: str): self._messages.append({"role": role, "content": content}) def to_list(self) -> list[dict]: return list(self._messages) class FlowEngine: MAX_TOTAL_ITERATIONS = 200 FLOW_TIMEOUT_SECONDS = 300 def __init__(self, flow_definition: dict): self.definition = flow_definition self.nodes: dict[str, dict] = {} for node in flow_definition.get("nodes", []): self.nodes[node["id"]] = node self.edges: list[dict] = flow_definition.get("edges", []) self._agent_cache: dict[str, AgentBase] = {} async def execute(self, input_msg: Msg, context: dict) -> Msg: graph = self._build_graph() start_nodes = self._find_start_nodes(graph) if not start_nodes: start_nodes = list(self.nodes.keys())[:1] session_id = context.get("session_id", str(uuid.uuid4())) user_id = context.get("user_id", "") memory = FlowSessionMemory(session_id=session_id, user_id=user_id) context["_memory"] = memory visited: set[str] = set() loop_iterations: dict[str, int] = {} total_iterations = 0 last_result: Msg | None = None async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None: nonlocal last_result, total_iterations total_iterations += 1 if total_iterations > self.MAX_TOTAL_ITERATIONS: logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止") last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system") return node = self.nodes.get(node_id) if not node: return node_type = node.get("type", "") is_loop = node_type == "loop" if not is_loop and node_id in visited: return if not is_loop: visited.add(node_id) if is_loop: loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1 config = node.get("config", {}) max_iter = config.get("max_iterations", 10) if loop_iterations[node_id] > max_iter: logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}") visited.add(node_id) for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_done": await traverse(target_id, incoming_msg) return if is_loop: loop_node_config = dict(node.get("config", {})) loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0) agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config) else: agent = await self._get_or_create_agent(node_id, context) enriched_content = self._resolve_input_mapping(node, incoming_msg, context) current_msg = incoming_msg if enriched_content.strip(): user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg) current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user") try: result = await agent.reply(current_msg, context=context) exec_record = { "node_id": node_id, "node_type": node_type, "label": node.get("label"), "status": "success", "output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500], } context.setdefault("_node_results", {})[node_id] = exec_record last_result = result if is_loop: loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0)) if loop_done: visited.add(node_id) for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_done": await traverse(target_id, result) else: for target_id, edge_cond in graph.get(node_id, []): if edge_cond == "loop_body": body_node = self.nodes.get(target_id) if body_node: visited.discard(target_id) await traverse(target_id, result, {"iteration": loop_iterations[node_id]}) return is_condition = node_type == "condition" cond_result = self._parse_condition_result(result) for target_id, edge_cond in graph.get(node_id, []): if is_condition: if edge_cond and edge_cond == cond_result: await traverse(target_id, result) elif edge_cond == "loop_body" or edge_cond == "loop_done": continue else: await traverse(target_id, result) except Exception as e: logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}") exec_record = { "node_id": node_id, "node_type": node_type, "label": node.get("label"), "status": "error", "error": str(e), } context.setdefault("_node_results", {})[node_id] = exec_record error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system") last_result = error_msg if start_nodes: try: await asyncio.wait_for( traverse(start_nodes[0], input_msg), timeout=self.FLOW_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止") last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system") return last_result or input_msg def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]: graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes} for edge in self.edges: source = edge.get("source") or edge.get("from") target = edge.get("target") or edge.get("to") cond = edge.get("condition") or edge.get("sourceHandle") if cond == "source": cond = None if source and target and source in self.nodes and target in self.nodes: graph[source].append((target, cond)) return graph def _find_start_nodes(self, graph: dict) -> list[str]: target_nodes: set[str] = set() for targets in graph.values(): for target_id, _ in targets: target_nodes.add(target_id) return [nid for nid in self.nodes if nid not in target_nodes] def _parse_condition_result(self, result: Msg) -> str | None: content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result) m = re.search(r'condition:(true|false)', content) if m: return m.group(1) return None def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool: config = node.get("config", {}) loop_type = config.get("loop_type", "fixed") if loop_type == "fixed": count = config.get("count", 3) return iteration >= count content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result) if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE): return True return False async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase: if node_id in self._agent_cache and override_config is None: return self._agent_cache[node_id] node = dict(self.nodes[node_id]) if override_config is not None: node["config"] = override_config agent = await _create_node_agent(node, context) if override_config is None: self._agent_cache[node_id] = agent return agent def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str: config = node.get("config", {}) input_mapping = config.get("input_mapping") if not input_mapping: return "" resolved = {} for key, template in input_mapping.items(): value = template if isinstance(template, str) and "{{" in template: value = _resolve_template(template, context, current_msg) resolved[key] = str(value) return "\n".join([f"{k}: {v}" for k, v in resolved.items()]) async def _create_node_agent(node: dict, context: dict) -> AgentBase: node_type = node.get("type", "") node_id = node.get("id", "") config = node.get("config", {}) if node_type == "trigger": return PassThroughAgent(node_id) elif node_type == "llm": model_config = config.get("model", settings.LLM_MODEL) model_instance_id = config.get("model_instance_id") base_url = settings.LLM_API_BASE api_key = settings.LLM_API_KEY if model_instance_id: resolved = await _resolve_model_instance(model_instance_id) if resolved: model_config = resolved.get("model", model_config) base_url = resolved.get("base_url", base_url) api_key = resolved.get("api_key", api_key) temperature = config.get("temperature", 0.7) system_prompt = config.get("system_prompt", "你是AI助手。") max_tokens = config.get("max_tokens", 2000) stream = config.get("stream", True) stream_cb = context.get("_stream_callback") agent = LLMNodeAgent( node_id=node_id, system_prompt=system_prompt, model_name=model_config, temperature=temperature, max_tokens=max_tokens, stream=stream, stream_callback=stream_cb, base_url=base_url, api_key=api_key, ) memory = context.get("_memory") if memory: agent.set_memory(memory) return agent elif node_type == "tool": tool_name = config.get("tool_name", "") tool_params = config.get("tool_params", {}) timeout = config.get("timeout", 30) retry_count = config.get("retry_count", 0) error_handling = config.get("error_handling", "throw") return ToolNodeAgent( node_id=node_id, tool_name=tool_name, tool_params=tool_params, timeout=timeout, retry_count=retry_count, error_handling=error_handling, ) elif node_type == "mcp": mcp_server = config.get("mcp_server", "") tool_name = config.get("tool_name", "") timeout = config.get("timeout", 30) error_handling = config.get("error_handling", "throw") return MCPNodeAgent( node_id=node_id, server_name=mcp_server, tool_name=tool_name, timeout=timeout, error_handling=error_handling, ) elif node_type in ("wecom_notify", "notify"): return NotifyAgent(node_id=node_id, config=config) elif node_type == "condition": condition_expr = config.get("condition", "") condition_type = config.get("condition_type", "expression") return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type) elif node_type == "rag": model_instance_id = config.get("model_instance_id") if model_instance_id: resolved = await _resolve_model_instance(model_instance_id) if resolved: config = dict(config) config["_resolved_model"] = resolved return RAGNodeAgent(node_id=node_id, config=config) elif node_type == "output": return OutputNodeAgent(node_id=node_id, config=config) elif node_type == "merge": return ParallelMergeNodeAgent(node_id=node_id, config=config) elif node_type == "loop": return LoopNodeAgent(node_id=node_id, config=config) elif node_type == "code": language = config.get("language", "python") code = config.get("code", "") timeout = config.get("timeout", 30) sandbox = config.get("sandbox", True) return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox) elif node_type == "http_request": return HttpRequestNodeAgent(node_id=node_id, config=config) elif node_type == "question_classifier": return QuestionClassifierNodeAgent(node_id=node_id, config=config) elif node_type == "variable_assigner": return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context) elif node_type == "template_transform": return TemplateTransformNodeAgent(node_id=node_id, config=config) elif node_type == "iteration": return IterationNodeAgent(node_id=node_id, config=config) elif node_type == "question_optimiser": return QuestionOptimiserNodeAgent(node_id=node_id, config=config) else: return PassThroughAgent(node_id) class PassThroughAgent(AgentBase): def __init__(self, node_id: str): super().__init__() self.name = f"passthrough_{node_id}" async def reply(self, msg, **kwargs) -> Msg: return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant") async def observe(self, msg) -> None: pass class HttpRequestNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"HttpRequest_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: method = self.config.get("method", "GET").upper() url = self.config.get("url", "") headers = self.config.get("headers", {}) body = self.config.get("body", "") auth_type = self.config.get("auth_type", "none") auth_config = self.config.get("auth_config", {}) timeout = self.config.get("timeout", 30) retry_count = self.config.get("retry_count", 0) if isinstance(headers, str): try: headers = json.loads(headers) except (json.JSONDecodeError, ValueError): headers = {} request_headers = dict(headers) if auth_type == "bearer" and auth_config.get("token"): request_headers["Authorization"] = f"Bearer {auth_config['token']}" elif auth_type == "api_key" and auth_config.get("api_key"): key_name = auth_config.get("key_name", "X-API-Key") request_headers[key_name] = auth_config["api_key"] elif auth_type == "basic" and auth_config.get("username"): import base64 credentials = f"{auth_config['username']}:{auth_config.get('password', '')}" request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}" last_error = None for attempt in range(retry_count + 1): try: async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client: request_kwargs = {"headers": request_headers} if method in ("POST", "PUT", "PATCH"): if isinstance(body, dict): request_kwargs["json"] = body elif body and isinstance(body, str): try: request_kwargs["json"] = json.loads(body) except (json.JSONDecodeError, ValueError): request_kwargs["content"] = body response = await client.request(method, url, **request_kwargs) resp_status = response.status_code resp_text = response.text try: resp_body = json.loads(resp_text) except (json.JSONDecodeError, ValueError): resp_body = resp_text result = json.dumps({ "status_code": resp_status, "body": resp_body, }, ensure_ascii=False) return Msg(self.name, result, "assistant") except Exception as e: last_error = str(e) if attempt < retry_count: await asyncio.sleep(1) error_result = json.dumps({ "status_code": 0, "error": last_error or "unknown error", "body": None, }, ensure_ascii=False) return Msg(self.name, error_result, "assistant") async def observe(self, msg) -> None: pass class QuestionClassifierNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Classifier_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: categories = self.config.get("categories", []) instruction = self.config.get("instruction", "") model_name = self.config.get("model", settings.LLM_MODEL) temperature = self.config.get("temperature", 0.3) user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not categories: return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant") category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories]) prompt = f"""请对以下用户输入进行意图分类。 分类选项: {category_desc} {instruction} 用户输入:{user_text} 请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}""" try: import openai client = openai.AsyncOpenAI( base_url=settings.LLM_API_BASE, api_key=settings.LLM_API_KEY, ) response = await client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}], temperature=temperature, max_tokens=200, ) result_text = response.choices[0].message.content.strip() try: parsed = json.loads(result_text) return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant") except (json.JSONDecodeError, ValueError): return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant") except Exception as e: logger.error(f"QuestionClassifier error: {e}") return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant") async def observe(self, msg) -> None: pass class VariableAssignerNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None, context: dict = None): super().__init__() self.name = f"VarAssign_{node_id}" self.config = config or {} self._context = context or {} async def reply(self, msg: Msg, **kwargs) -> Msg: assignments = self.config.get("assignments", []) results = {} for assignment in assignments: target_var = assignment.get("target_var", "") source_type = assignment.get("source_type", "constant") source_value = assignment.get("source_value", "") if source_type == "constant": value = source_value elif source_type == "upstream_output": value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) elif source_type == "template": value = _resolve_template(source_value, self._context, msg) elif source_type == "expression": try: safe_locals = {"msg": msg, "context": self._context} value = eval(source_value, {"__builtins__": {}}, safe_locals) value = str(value) except Exception as e: value = f"[expression error: {e}]" else: value = source_value results[target_var] = value output = json.dumps(results, ensure_ascii=False) return Msg(self.name, output, "assistant") async def observe(self, msg) -> None: pass class LLMNodeAgent(AgentBase): def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""): super().__init__() self.name = f"LLM_{node_id}" self.system_prompt = system_prompt self.model_name = model_name or settings.LLM_MODEL self.temperature = temperature self.max_tokens = max_tokens self.stream = stream self.stream_callback = stream_callback self.base_url = base_url or settings.LLM_API_BASE self.api_key = api_key or settings.LLM_API_KEY self._memory = None def set_memory(self, memory): self._memory = memory async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) context = kwargs.get("context", {}) memory_ctx = context.get("_memory_context", {}) messages = [{"role": "system", "content": self.system_prompt}] if memory_ctx: summary = memory_ctx.get("summary", "") recent = memory_ctx.get("recent_messages", []) if summary: messages.append({"role": "system", "content": f"[历史对话摘要]\n{summary}"}) for m in recent[-10:]: role = m.get("role", "user") content = m.get("content", "") if len(content) > 2000: content = content[:2000] messages.append({"role": role, "content": content}) elif self._memory: history = self._memory.get_history(limit=5) for h in history: role = h.get("role", "user") content = h.get("content", "") if len(content) > 2000: content = content[:2000] messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": user_text}) try: if self.stream_callback: res_text = await self._stream_llm_call(messages) else: res_text = await self._blocking_llm_call(messages) except Exception as e: logger.warning(f"LLM 调用失败: {e}") res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}" if self._memory and not memory_ctx: self._memory.add("user", user_text) self._memory.add("assistant", res_text) return Msg(self.name, res_text, "assistant") async def _blocking_llm_call(self, messages: list[dict]) -> str: import httpx api_base = self.base_url.rstrip("/") api_key = self.api_key model_name = self.model_name or settings.LLM_MODEL url = f"{api_base}/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } body = { "model": model_name, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": False, } try: async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client: resp = await client.post(url, json=body, headers=headers) data = resp.json() return data.get("choices", [{}])[0].get("message", {}).get("content", "") except Exception as e: logger.warning(f"LLM 阻塞调用失败: {e}") return f"[LLM 调用失败: {e}]" async def _stream_llm_call(self, messages: list[dict]) -> str: import httpx import json api_base = self.base_url.rstrip("/") api_key = self.api_key model_name = self.model_name or settings.LLM_MODEL url = f"{api_base}/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } body = { "model": model_name, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": True, } accumulated = "" try: timeout = httpx.Timeout(60.0, connect=10.0) async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream("POST", url, json=body, headers=headers) as response: async for line in response.aiter_lines(): if not line.startswith("data: "): continue json_str = line[6:].strip() if json_str == "[DONE]": break try: chunk = json.loads(json_str) delta = chunk.get("choices", [{}])[0].get("delta", {}) token = delta.get("content", "") if token: accumulated += token await self.stream_callback("text_chunk", {"content": token}) except json.JSONDecodeError: continue except httpx.TimeoutException: logger.warning("LLM 流式调用超时") if not accumulated: accumulated = "[LLM 超时]" except Exception as e: logger.warning(f"LLM 流式调用失败: {e}") if not accumulated: accumulated = f"[LLM 调用失败: {e}]" return accumulated async def observe(self, msg) -> None: pass class ToolNodeAgent(AgentBase): _TOOL_REGISTRY: dict[str, callable] = {} _TOOL_SCHEMAS: dict[str, dict] = {} _CUSTOM_TOOL_DEFS: dict[str, dict] = {} @classmethod def _init_registry(cls): if cls._TOOL_REGISTRY: return try: from agentscope_integration.tools.document_tools import parse_document, format_correction from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard cls._TOOL_REGISTRY = { "parse_document": parse_document, "format_correction": format_correction, "send_notification": send_notification, "query_wecom_user": query_wecom_user, "list_tasks": list_tasks, "create_task": create_task, "get_task": get_task, "update_task": update_task, "push_task_to_wecom": push_task_to_wecom, "list_subordinates": list_subordinates, "generate_efficiency_report": generate_efficiency_report, "get_task_statistics": get_task_statistics, "get_employee_dashboard": get_employee_dashboard, } for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]: try: mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"]) if hasattr(mod, "SCHEMAS"): cls._TOOL_SCHEMAS.update(mod.SCHEMAS) except Exception: pass except ImportError as e: logger.warning(f"工具注册失败: {e}") @classmethod async def load_custom_tools(cls, db): try: from sqlalchemy import select from models import CustomTool from modules.custom_tool.executor import CustomToolExecutor result = await db.execute( select(CustomTool).where(CustomTool.is_active == True) ) for tool in result.scalars().all(): tool_def = { "endpoint_url": tool.endpoint_url, "method": tool.method, "path": tool.path, "headers_json": tool.headers_json, "auth_type": tool.auth_type, "auth_config": tool.auth_config, "timeout": 30, } cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def cls._TOOL_SCHEMAS[tool.name] = tool.schema_json except ImportError: logger.warning("无法加载自定义工具模块") except Exception as e: logger.warning(f"加载自定义工具失败: {e}") @classmethod def register_custom_tool(cls, name: str, schema: dict, tool_def: dict): cls._CUSTOM_TOOL_DEFS[name] = tool_def cls._TOOL_SCHEMAS[name] = schema @classmethod def get_schemas(cls) -> dict: cls._init_registry() return dict(cls._TOOL_SCHEMAS) def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"): super().__init__() self.name = f"Tool_{node_id}" self.tool_name = tool_name self.tool_params = tool_params or {} self.timeout = timeout self.retry_count = retry_count self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: self._init_registry() user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name) if custom_def: return await self._execute_custom_tool(custom_def, user_text) tool_func = self._TOOL_REGISTRY.get(self.tool_name) if not tool_func: if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant") def _call_sync(): try: result = tool_func(**self.tool_params) if self.tool_params else tool_func() return result except TypeError: try: result = tool_func(user_text, **self.tool_params) return result except Exception as e: raise RuntimeError(f"工具执行失败: {e}") except Exception as e: raise RuntimeError(f"工具执行失败: {e}") for attempt in range(self.retry_count + 1): try: result = await asyncio.wait_for( asyncio.to_thread(_call_sync), timeout=self.timeout, ) return Msg(self.name, str(result), "assistant") except asyncio.TimeoutError: if attempt < self.retry_count: logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") except RuntimeError as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") if self.error_handling == "default": return Msg(self.name, "{}", "assistant") return Msg(self.name, f"[{e}]", "assistant") except Exception as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") return Msg(self.name, f"[工具执行失败: {e}]", "assistant") return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant") async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg: from modules.custom_tool.executor import CustomToolExecutor executor = CustomToolExecutor(custom_def) for attempt in range(self.retry_count + 1): try: result = await asyncio.wait_for( executor.execute(self.tool_params), timeout=self.timeout, ) return Msg(self.name, str(result), "assistant") except asyncio.TimeoutError: if attempt < self.retry_count: logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}") continue if self.error_handling == "skip": return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant") return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant") except Exception as e: if attempt < self.retry_count: continue if self.error_handling == "skip": return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant") if self.error_handling == "default": return Msg(self.name, "{}", "assistant") return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant") return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant") async def observe(self, msg) -> None: pass class ConditionNodeAgent(AgentBase): def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"): super().__init__() self.name = f"Condition_{node_id}" self.condition = condition self.condition_type = condition_type async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.condition: return Msg(self.name, "condition:true|条件为空,默认通过", "assistant") if self.condition_type == "regex": try: matched = bool(re.search(self.condition, user_text)) return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'命' if matched else '未命'}中", "assistant") except re.error as e: return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant") if self.condition_type == "json_path": try: data = json.loads(user_text) if user_text.strip().startswith('{') else {} parts = self.condition.strip('$.').split('.') val = data for p in parts: if isinstance(val, dict): val = val.get(p, None) else: val = None return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant") except: return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant") try: from agentscope.model import OpenAIChatModel from agentscope.formatter import OpenAIChatFormatter model = OpenAIChatModel( config_name=f"condition_{self.name}", model_name=settings.LLM_MODEL, api_key=settings.LLM_API_KEY, api_base=settings.LLM_API_BASE, ) formatter = OpenAIChatFormatter() condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。 条件表达式: {self.condition} 输入内容: {user_text[:2000]} 请严格只输出一行 JSON: {{"result": true/false, "reason": "简要原因"}}""" prompt = formatter.format([ Msg("system", condition_prompt, "system"), Msg("user", user_text[:2000], "user"), ]) res = await model(prompt) res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): res_text = res.get_text_content() else: res_text = str(res) json_match = re.search(r'\{[^}]+\}', res_text) if json_match: parsed = json.loads(json_match.group()) matched = parsed.get("result", False) reason = parsed.get("reason", "") result_flag = "true" if matched else "false" return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant") except Exception as e: logger.warning(f"条件判断LLM调用失败: {e}") return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant") async def observe(self, msg) -> None: pass class MCPNodeAgent(AgentBase): def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"): super().__init__() self.name = f"MCP_{node_id}" self.server_name = server_name self.tool_name = tool_name self.timeout = timeout self.error_handling = error_handling async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if not self.server_name: return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant") try: from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager client = MCPClientManager.get_http_client(self.server_name) if client and self.tool_name: result = await asyncio.wait_for( client.call_tool(self.tool_name, {"input": user_text}), timeout=self.timeout ) return Msg(self.name, str(result), "assistant") except ImportError: logger.warning("agentscope_runtime MCP 客户端不可用") except asyncio.TimeoutError: if self.error_handling == "skip": return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant") return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant") except Exception as e: logger.warning(f"MCP 调用失败: {e}") if self.error_handling == "skip": return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant") output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入" return Msg(self.name, output, "assistant") async def observe(self, msg) -> None: pass class NotifyAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Notify_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) channels = self.config.get("channels", {"wecom": True, "web": False}) message_type = self.config.get("message_type", "text") results = [] if channels.get("wecom", True): template = self.config.get("message_template", "") target = self.config.get("target", "") message = template or user_text[:500] try: from agentscope_integration.tools.wecom_tools import send_notification result = send_notification(to_user=target or "user", message=message) results.append(f"企微通知: {result[:100]}") except ImportError: pass except Exception as e: logger.warning(f"企微通知发送失败: {e}") results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}") if channels.get("web", False): web_template = self.config.get("web_template", "") web_message = web_template or user_text[:500] try: from websocket_manager import ws_manager target_user = self.config.get("target", "") if target_user: await ws_manager.send_to_user(target_user, { "type": "flow_notification", "message": web_message, "level": "info", }) results.append(f"Web通知: 已推送") except ImportError: pass except Exception as e: logger.warning(f"Web通知推送失败: {e}") results.append(f"Web通知: 推送失败({e})") if not results: results.append("通知已发送") return Msg(self.name, " | ".join(results), "assistant") async def observe(self, msg) -> None: pass class LoopNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Loop_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) loop_type = self.config.get("loop_type", "fixed") count = self.config.get("count", 3) iterator_variable = self.config.get("iterator_variable", "item") iteration = self.config.get("_engine_iteration", 0) + 1 if loop_type == "array": items = self.config.get("items", []) if isinstance(items, str): try: items = json.loads(items) except Exception: items = [line.strip() for line in items.split("\n") if line.strip()] if isinstance(items, list) and items: idx = iteration - 1 if idx >= len(items): return Msg(self.name, "loop:stop|数组迭代完成", "assistant") current_item = items[idx] return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant") return Msg(self.name, "loop:stop|无数组数据", "assistant") if loop_type == "fixed": if iteration >= count: return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant") return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant") async def observe(self, msg) -> None: pass class CodeNodeAgent(AgentBase): def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True): super().__init__() self.name = f"Code_{node_id}" self.language = language self.code = code self.timeout = timeout self.sandbox = sandbox async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) import subprocess import tempfile import os import json if not self.code.strip(): return Msg(self.name, user_text, "assistant") safe_input = json.dumps(user_text) code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}" try: with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: f.write(code_with_input) temp_path = f.name proc = await asyncio.create_subprocess_exec( "python", temp_path, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout) except asyncio.TimeoutError: proc.kill() try: os.unlink(temp_path) except: pass return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant") try: os.unlink(temp_path) except: pass if stderr: return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant") output = stdout.decode('utf-8', errors='replace').strip() return Msg(self.name, output or "[代码执行完成,无输出]", "assistant") except Exception as e: return Msg(self.name, f"[代码执行失败: {e}]", "assistant") async def observe(self, msg) -> None: pass class RAGNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"RAG_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) top_k = self.config.get("top_k", 5) similarity_threshold = self.config.get("similarity_threshold", 0.7) search_mode = self.config.get("search_mode", "hybrid") include_metadata = self.config.get("include_metadata", True) try: from modules.rag.knowledge import retrieve_for_agent kb_result = await retrieve_for_agent(user_text, limit=top_k) model = self._get_model() formatter = self._get_formatter() rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。 知识库检索结果: {kb_result} 用户问题: {user_text} 请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。""" prompt = formatter.format([ Msg("system", rag_prompt, "system"), Msg("user", user_text, "user"), ]) res = await model(prompt) res_text = "" if isinstance(res, list): res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0]) elif hasattr(res, 'get_text_content'): res_text = res.get_text_content() else: res_text = str(res) return Msg(self.name, res_text, "assistant") except Exception as e: logger.warning(f"RAG 节点执行失败: {e}") output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}" return Msg(self.name, output, "assistant") def _get_model(self): from agentscope.model import OpenAIChatModel resolved = self.config.get("_resolved_model", {}) return OpenAIChatModel( config_name=f"rag_{self.name}", model_name=resolved.get("model", settings.LLM_MODEL), api_key=resolved.get("api_key", settings.LLM_API_KEY), api_base=resolved.get("base_url", settings.LLM_API_BASE), ) def _get_formatter(self): from agentscope.formatter import OpenAIChatFormatter return OpenAIChatFormatter() async def observe(self, msg) -> None: pass class OutputNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Output_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: output_format = self.config.get("format", "text") output_template = self.config.get("output_template", "") truncate = self.config.get("truncate", False) max_length = self.config.get("max_length", 2000) content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) if truncate and len(content) > max_length: content = content[:max_length] + "\n\n[内容已截断]" if output_format == "json": try: parsed = json.loads(content) indent = self.config.get("indent", 2) formatted = json.dumps(parsed, indent=indent, ensure_ascii=False) return Msg(self.name, formatted, "assistant") except (json.JSONDecodeError, ValueError): pass if output_template: try: resolved = output_template.replace("{{output}}", content) return Msg(self.name, resolved, "assistant") except: pass return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant") async def observe(self, msg) -> None: pass def _resolve_template(template: str, context: dict, current_msg: Msg) -> str: result = template import re placeholders = re.findall(r'\{\{(.+?)\}\}', template) for placeholder in placeholders: parts = placeholder.strip().split(".") value = "" if parts[0] == "trigger": value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), "")) elif parts[0] in context.get("_node_results", {}): node_result = context.get("_node_results", {}).get(parts[0], {}) if len(parts) > 1: value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], "")) else: value = str(node_result.get("output", "")) result = result.replace("{{" + placeholder + "}}", value) return result class TemplateTransformNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Template_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: template = self.config.get("template", "") output_type = self.config.get("output_type", "string") context = kwargs.get("context", {}) user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) try: rendered = _resolve_template(template, context, msg) if not rendered: rendered = template rendered = rendered.replace("{{input}}", user_text) except Exception: rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text if output_type == "json": try: parsed = json.loads(rendered) return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant") except (json.JSONDecodeError, ValueError): pass return Msg(self.name, rendered, "assistant") async def observe(self, msg) -> None: pass class IterationNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Iteration_{node_id}" self.config = config or {} self._results: list[str] = [] async def reply(self, msg: Msg, **kwargs) -> Msg: input_array_source = self.config.get("input_array_source", "") max_iterations = self.config.get("max_iterations", 20) context = kwargs.get("context", {}) items = [] if input_array_source: resolved = _resolve_template(input_array_source, context, msg) try: items = json.loads(resolved) if not isinstance(items, list): items = [resolved] except (json.JSONDecodeError, ValueError): items = [resolved] if resolved else [] else: user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) try: items = json.loads(user_text) if not isinstance(items, list): items = [user_text] except (json.JSONDecodeError, ValueError): items = [item.strip() for item in user_text.split("\n") if item.strip()] if not items: items = [user_text] items = items[:max_iterations] results = [] for i, item in enumerate(items): results.append({ "index": i, "item": item, }) output = json.dumps(results, ensure_ascii=False) return Msg(self.name, output, "assistant") def get_iteration_items(self) -> list: return self._results async def observe(self, msg) -> None: pass class QuestionOptimiserNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"QOpt_{node_id}" self.config = config or {} async def reply(self, msg: Msg, **kwargs) -> Msg: optimization_type = self.config.get("optimization_type", "rewrite") model_name = self.config.get("model", settings.LLM_MODEL) context = kwargs.get("context", {}) user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) history = context.get("_memory", {}).get("recent_messages", []) persona = context.get("_memory", {}).get("persona", {}) atoms = context.get("_memory", {}).get("atoms", []) persona_text = persona.get("raw_text", "") atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else "" history_text = "\n".join( f"{m['role']}: {m['content'][:200]}" for m in history[-6:] ) if history else "" context_parts = [] if persona_text: context_parts.append(f"用户画像: {persona_text}") if atoms_text: context_parts.append(f"已知信息: {atoms_text}") if history_text: context_parts.append(f"近期对话: {history_text}") context_block = "\n".join(context_parts) if context_parts else "" if optimization_type == "rewrite": prompt = f"""{context_block} 原始问题: {user_text} 请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。 只返回优化后的问题,不要其他内容。""" elif optimization_type == "expand": prompt = f"""{context_block} 简短问题: {user_text} 请将以上问题扩展为更详细的版本,添加必要的背景和细节。 只返回扩展后的问题,不要其他内容。""" else: return Msg(self.name, user_text, "assistant") try: import httpx api_base = settings.LLM_API_BASE.rstrip("/") async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{api_base}/chat/completions", json={ "model": model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": 300, "temperature": 0.3, }, headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, ) data = resp.json() result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text) return Msg(self.name, result.strip(), "assistant") except Exception as e: logger.warning(f"问题优化失败: {e}") return Msg(self.name, user_text, "assistant") async def observe(self, msg) -> None: pass class ParallelMergeNodeAgent(AgentBase): def __init__(self, node_id: str, config: dict = None): super().__init__() self.name = f"Merge_{node_id}" self.config = config or {} self._received: dict[str, str] = {} self.merge_type = self.config.get("merge_type", "concat") async def reply(self, msg: Msg, **kwargs) -> Msg: source_id = kwargs.get("source_node_id", "") content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg) self._received[source_id] = content expected_count = self.config.get("expected_branches", 0) if expected_count <= 0 or len(self._received) >= expected_count: return self._merge() return Msg(self.name, "", "assistant") def _merge(self) -> Msg: if self.merge_type == "json": merged = json.dumps(self._received, ensure_ascii=False) elif self.merge_type == "first_non_empty": merged = "" for v in self._received.values(): if v.strip(): merged = v break else: parts = [] for k, v in self._received.items(): parts.append(v) merged = "\n\n---\n\n".join(parts) return Msg(self.name, merged, "assistant") async def observe(self, msg) -> None: pass