You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1012 lines
41 KiB

import json
import uuid
import logging
import re
import asyncio
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.agent._react_agent import ReActAgent
from config import settings
logger = logging.getLogger(__name__)
class FlowSessionMemory:
def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id
self.user_id = user_id
self._messages: list[dict] = []
def get_history(self, limit: int = 10) -> list[dict]:
return self._messages[-limit * 2:]
def add(self, role: str, content: str):
self._messages.append({"role": role, "content": content})
def to_list(self) -> list[dict]:
return list(self._messages)
class FlowEngine:
MAX_TOTAL_ITERATIONS = 200
FLOW_TIMEOUT_SECONDS = 300
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
graph = self._build_graph()
start_nodes = self._find_start_nodes(graph)
if not start_nodes:
start_nodes = list(self.nodes.keys())[:1]
session_id = context.get("session_id", str(uuid.uuid4()))
user_id = context.get("user_id", "")
memory = FlowSessionMemory(session_id=session_id, user_id=user_id)
context["_memory"] = memory
visited: set[str] = set()
loop_iterations: dict[str, int] = {}
total_iterations = 0
last_result: Msg | None = None
async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None:
nonlocal last_result, total_iterations
total_iterations += 1
if total_iterations > self.MAX_TOTAL_ITERATIONS:
logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止")
last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system")
return
node = self.nodes.get(node_id)
if not node:
return
node_type = node.get("type", "")
is_loop = node_type == "loop"
if not is_loop and node_id in visited:
return
if not is_loop:
visited.add(node_id)
if is_loop:
loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1
config = node.get("config", {})
max_iter = config.get("max_iterations", 10)
if loop_iterations[node_id] > max_iter:
logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}")
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, incoming_msg)
return
if is_loop:
loop_node_config = dict(node.get("config", {}))
loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0)
agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config)
else:
agent = await self._get_or_create_agent(node_id, context)
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
current_msg = incoming_msg
if enriched_content.strip():
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
try:
result = await agent.reply(current_msg)
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
last_result = result
if is_loop:
loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0))
if loop_done:
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, result)
else:
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_body":
body_node = self.nodes.get(target_id)
if body_node:
visited.discard(target_id)
await traverse(target_id, result, {"iteration": loop_iterations[node_id]})
return
is_condition = node_type == "condition"
cond_result = self._parse_condition_result(result)
for target_id, edge_cond in graph.get(node_id, []):
if is_condition:
if edge_cond and edge_cond == cond_result:
await traverse(target_id, result)
elif edge_cond == "loop_body" or edge_cond == "loop_done":
continue
else:
await traverse(target_id, result)
except Exception as e:
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
last_result = error_msg
if start_nodes:
try:
await asyncio.wait_for(
traverse(start_nodes[0], input_msg),
timeout=self.FLOW_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止")
last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system")
return last_result or input_msg
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("source") or edge.get("from")
target = edge.get("target") or edge.get("to")
cond = edge.get("condition") or edge.get("sourceHandle")
if cond == "source":
cond = None
if source and target and source in self.nodes and target in self.nodes:
graph[source].append((target, cond))
return graph
def _find_start_nodes(self, graph: dict) -> list[str]:
target_nodes: set[str] = set()
for targets in graph.values():
for target_id, _ in targets:
target_nodes.add(target_id)
return [nid for nid in self.nodes if nid not in target_nodes]
def _parse_condition_result(self, result: Msg) -> str | None:
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
m = re.search(r'condition:(true|false)', content)
if m:
return m.group(1)
return None
def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool:
config = node.get("config", {})
loop_type = config.get("loop_type", "fixed")
if loop_type == "fixed":
count = config.get("count", 3)
return iteration >= count
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE):
return True
return False
async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase:
if node_id in self._agent_cache and override_config is None:
return self._agent_cache[node_id]
node = dict(self.nodes[node_id])
if override_config is not None:
node["config"] = override_config
agent = await _create_node_agent(node, context)
if override_config is None:
self._agent_cache[node_id] = agent
return agent
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
max_tokens = config.get("max_tokens", 2000)
stream = config.get("stream", True)
agent = LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
)
memory = context.get("_memory")
if memory:
agent.set_memory(memory)
return agent
elif node_type == "tool":
tool_name = config.get("tool_name", "")
tool_params = config.get("tool_params", {})
timeout = config.get("timeout", 30)
retry_count = config.get("retry_count", 0)
error_handling = config.get("error_handling", "throw")
return ToolNodeAgent(
node_id=node_id,
tool_name=tool_name,
tool_params=tool_params,
timeout=timeout,
retry_count=retry_count,
error_handling=error_handling,
)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
tool_name = config.get("tool_name", "")
timeout = config.get("timeout", 30)
error_handling = config.get("error_handling", "throw")
return MCPNodeAgent(
node_id=node_id,
server_name=mcp_server,
tool_name=tool_name,
timeout=timeout,
error_handling=error_handling,
)
elif node_type in ("wecom_notify", "notify"):
return NotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition_expr = config.get("condition", "")
condition_type = config.get("condition_type", "expression")
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
elif node_type == "rag":
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
elif node_type == "merge":
return ParallelMergeNodeAgent(node_id=node_id, config=config)
elif node_type == "loop":
return LoopNodeAgent(node_id=node_id, config=config)
elif node_type == "code":
language = config.get("language", "python")
code = config.get("code", "")
timeout = config.get("timeout", 30)
sandbox = config.get("sandbox", True)
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self._memory = None
def set_memory(self, memory):
self._memory = memory
async def reply(self, msg: Msg, **kwargs) -> Msg:
from agentscope_integration.factory import AgentFactory
from agentscope.formatter import OpenAIChatFormatter
model = AgentFactory._get_model()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
formatter = OpenAIChatFormatter()
messages = [Msg("system", self.system_prompt, "system")]
if self._memory:
history = self._memory.get_history(limit=5)
for h in history:
role = h.get("role", "user")
content = h.get("content", "")
if len(content) > 2000:
content = content[:2000]
messages.append(Msg(role, content, role))
messages.append(Msg("user", user_text, "user"))
prompt = formatter.format(messages)
try:
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
except Exception as e:
logger.warning(f"LLM 调用失败: {e}")
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
if self._memory:
self._memory.add("user", user_text)
self._memory.add("assistant", res_text)
return Msg(self.name, res_text, "assistant")
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
_TOOL_REGISTRY: dict[str, callable] = {}
_TOOL_SCHEMAS: dict[str, dict] = {}
_CUSTOM_TOOL_DEFS: dict[str, dict] = {}
@classmethod
def _init_registry(cls):
if cls._TOOL_REGISTRY:
return
try:
from agentscope_integration.tools.document_tools import parse_document, format_correction
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
cls._TOOL_REGISTRY = {
"parse_document": parse_document,
"format_correction": format_correction,
"send_notification": send_notification,
"query_wecom_user": query_wecom_user,
"list_tasks": list_tasks,
"create_task": create_task,
"get_task": get_task,
"update_task": update_task,
"push_task_to_wecom": push_task_to_wecom,
"list_subordinates": list_subordinates,
"generate_efficiency_report": generate_efficiency_report,
"get_task_statistics": get_task_statistics,
"get_employee_dashboard": get_employee_dashboard,
}
for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]:
try:
mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"])
if hasattr(mod, "SCHEMAS"):
cls._TOOL_SCHEMAS.update(mod.SCHEMAS)
except Exception:
pass
except ImportError as e:
logger.warning(f"工具注册失败: {e}")
@classmethod
async def load_custom_tools(cls, db):
try:
from sqlalchemy import select
from models import CustomTool
from modules.custom_tool.executor import CustomToolExecutor
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True)
)
for tool in result.scalars().all():
tool_def = {
"endpoint_url": tool.endpoint_url,
"method": tool.method,
"path": tool.path,
"headers_json": tool.headers_json,
"auth_type": tool.auth_type,
"auth_config": tool.auth_config,
"timeout": 30,
}
cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def
cls._TOOL_SCHEMAS[tool.name] = tool.schema_json
except ImportError:
logger.warning("无法加载自定义工具模块")
except Exception as e:
logger.warning(f"加载自定义工具失败: {e}")
@classmethod
def register_custom_tool(cls, name: str, schema: dict, tool_def: dict):
cls._CUSTOM_TOOL_DEFS[name] = tool_def
cls._TOOL_SCHEMAS[name] = schema
@classmethod
def get_schemas(cls) -> dict:
cls._init_registry()
return dict(cls._TOOL_SCHEMAS)
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
self.tool_params = tool_params or {}
self.timeout = timeout
self.retry_count = retry_count
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
self._init_registry()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name)
if custom_def:
return await self._execute_custom_tool(custom_def, user_text)
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
if not tool_func:
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
def _call_sync():
try:
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
return result
except TypeError:
try:
result = tool_func(user_text, **self.tool_params)
return result
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
asyncio.to_thread(_call_sync),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except RuntimeError as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[{e}]", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant")
async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg:
from modules.custom_tool.executor import CustomToolExecutor
executor = CustomToolExecutor(custom_def)
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
executor.execute(self.tool_params),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
self.condition_type = condition_type
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.condition:
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
if self.condition_type == "regex":
try:
matched = bool(re.search(self.condition, user_text))
return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'' if matched else '未命'}", "assistant")
except re.error as e:
return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant")
if self.condition_type == "json_path":
try:
data = json.loads(user_text) if user_text.strip().startswith('{') else {}
parts = self.condition.strip('$.').split('.')
val = data
for p in parts:
if isinstance(val, dict):
val = val.get(p, None)
else:
val = None
return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant")
except:
return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant")
try:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"condition_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
formatter = OpenAIChatFormatter()
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
条件表达式: {self.condition}
输入内容:
{user_text[:2000]}
请严格只输出一行 JSON:
{{"result": true/false, "reason": "简要原因"}}"""
prompt = formatter.format([
Msg("system", condition_prompt, "system"),
Msg("user", user_text[:2000], "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
json_match = re.search(r'\{[^}]+\}', res_text)
if json_match:
parsed = json.loads(json_match.group())
matched = parsed.get("result", False)
reason = parsed.get("reason", "")
result_flag = "true" if matched else "false"
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
except Exception as e:
logger.warning(f"条件判断LLM调用失败: {e}")
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
self.tool_name = tool_name
self.timeout = timeout
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.server_name:
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
try:
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
client = MCPClientManager.get_http_client(self.server_name)
if client and self.tool_name:
result = await asyncio.wait_for(
client.call_tool(self.tool_name, {"input": user_text}),
timeout=self.timeout
)
return Msg(self.name, str(result), "assistant")
except ImportError:
logger.warning("agentscope_runtime MCP 客户端不可用")
except asyncio.TimeoutError:
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant")
return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant")
except Exception as e:
logger.warning(f"MCP 调用失败: {e}")
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant")
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class NotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Notify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
channels = self.config.get("channels", {"wecom": True, "web": False})
message_type = self.config.get("message_type", "text")
results = []
if channels.get("wecom", True):
template = self.config.get("message_template", "")
target = self.config.get("target", "")
message = template or user_text[:500]
try:
from agentscope_integration.tools.wecom_tools import send_notification
result = send_notification(to_user=target or "user", message=message)
results.append(f"企微通知: {result[:100]}")
except ImportError:
pass
except Exception as e:
logger.warning(f"企微通知发送失败: {e}")
results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}")
if channels.get("web", False):
web_template = self.config.get("web_template", "")
web_message = web_template or user_text[:500]
try:
from websocket_manager import ws_manager
target_user = self.config.get("target", "")
if target_user:
await ws_manager.send_to_user(target_user, {
"type": "flow_notification",
"message": web_message,
"level": "info",
})
results.append(f"Web通知: 已推送")
except ImportError:
pass
except Exception as e:
logger.warning(f"Web通知推送失败: {e}")
results.append(f"Web通知: 推送失败({e})")
if not results:
results.append("通知已发送")
return Msg(self.name, " | ".join(results), "assistant")
async def observe(self, msg) -> None:
pass
class LoopNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Loop_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
loop_type = self.config.get("loop_type", "fixed")
count = self.config.get("count", 3)
iterator_variable = self.config.get("iterator_variable", "item")
iteration = self.config.get("_engine_iteration", 0) + 1
if loop_type == "array":
items = self.config.get("items", [])
if isinstance(items, str):
try:
items = json.loads(items)
except Exception:
items = [line.strip() for line in items.split("\n") if line.strip()]
if isinstance(items, list) and items:
idx = iteration - 1
if idx >= len(items):
return Msg(self.name, "loop:stop|数组迭代完成", "assistant")
current_item = items[idx]
return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, "loop:stop|无数组数据", "assistant")
if loop_type == "fixed":
if iteration >= count:
return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant")
return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
async def observe(self, msg) -> None:
pass
class CodeNodeAgent(AgentBase):
def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True):
super().__init__()
self.name = f"Code_{node_id}"
self.language = language
self.code = code
self.timeout = timeout
self.sandbox = sandbox
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
import subprocess
import tempfile
import os
import json
if not self.code.strip():
return Msg(self.name, user_text, "assistant")
safe_input = json.dumps(user_text)
code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}"
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
f.write(code_with_input)
temp_path = f.name
proc = await asyncio.create_subprocess_exec(
"python", temp_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
except asyncio.TimeoutError:
proc.kill()
try:
os.unlink(temp_path)
except:
pass
return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant")
try:
os.unlink(temp_path)
except:
pass
if stderr:
return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant")
output = stdout.decode('utf-8', errors='replace').strip()
return Msg(self.name, output or "[代码执行完成,无输出]", "assistant")
except Exception as e:
return Msg(self.name, f"[代码执行失败: {e}]", "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
top_k = self.config.get("top_k", 5)
similarity_threshold = self.config.get("similarity_threshold", 0.7)
search_mode = self.config.get("search_mode", "hybrid")
include_metadata = self.config.get("include_metadata", True)
try:
from modules.rag.knowledge import retrieve_for_agent
kb_result = await retrieve_for_agent(user_text, limit=top_k)
model = self._get_model()
formatter = self._get_formatter()
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
知识库检索结果:
{kb_result}
用户问题: {user_text}
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
prompt = formatter.format([
Msg("system", rag_prompt, "system"),
Msg("user", user_text, "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
return Msg(self.name, res_text, "assistant")
except Exception as e:
logger.warning(f"RAG 节点执行失败: {e}")
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
return Msg(self.name, output, "assistant")
def _get_model(self):
from agentscope.model import OpenAIChatModel
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
def _get_formatter(self):
from agentscope.formatter import OpenAIChatFormatter
return OpenAIChatFormatter()
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
output_format = self.config.get("format", "text")
output_template = self.config.get("output_template", "")
truncate = self.config.get("truncate", False)
max_length = self.config.get("max_length", 2000)
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if truncate and len(content) > max_length:
content = content[:max_length] + "\n\n[内容已截断]"
if output_format == "json":
try:
parsed = json.loads(content)
indent = self.config.get("indent", 2)
formatted = json.dumps(parsed, indent=indent, ensure_ascii=False)
return Msg(self.name, formatted, "assistant")
except (json.JSONDecodeError, ValueError):
pass
if output_template:
try:
resolved = output_template.replace("{{output}}", content)
return Msg(self.name, resolved, "assistant")
except:
pass
return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result
class ParallelMergeNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Merge_{node_id}"
self.config = config or {}
self._received: dict[str, str] = {}
self.merge_type = self.config.get("merge_type", "concat")
async def reply(self, msg: Msg, **kwargs) -> Msg:
source_id = kwargs.get("source_node_id", "")
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
self._received[source_id] = content
expected_count = self.config.get("expected_branches", 0)
if expected_count <= 0 or len(self._received) >= expected_count:
return self._merge()
return Msg(self.name, "", "assistant")
def _merge(self) -> Msg:
if self.merge_type == "json":
merged = json.dumps(self._received, ensure_ascii=False)
elif self.merge_type == "first_non_empty":
merged = ""
for v in self._received.values():
if v.strip():
merged = v
break
else:
parts = []
for k, v in self._received.items():
parts.append(v)
merged = "\n\n---\n\n".join(parts)
return Msg(self.name, merged, "assistant")
async def observe(self, msg) -> None:
pass