You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1012 lines
41 KiB
1012 lines
41 KiB
import json
|
|
import uuid
|
|
import logging
|
|
import re
|
|
import asyncio
|
|
from agentscope.agent import AgentBase
|
|
from agentscope.message import Msg
|
|
from agentscope.tool import Toolkit
|
|
from agentscope.agent._react_agent import ReActAgent
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FlowSessionMemory:
|
|
def __init__(self, session_id: str = "", user_id: str = ""):
|
|
self.session_id = session_id
|
|
self.user_id = user_id
|
|
self._messages: list[dict] = []
|
|
|
|
def get_history(self, limit: int = 10) -> list[dict]:
|
|
return self._messages[-limit * 2:]
|
|
|
|
def add(self, role: str, content: str):
|
|
self._messages.append({"role": role, "content": content})
|
|
|
|
def to_list(self) -> list[dict]:
|
|
return list(self._messages)
|
|
|
|
|
|
class FlowEngine:
|
|
MAX_TOTAL_ITERATIONS = 200
|
|
FLOW_TIMEOUT_SECONDS = 300
|
|
|
|
def __init__(self, flow_definition: dict):
|
|
self.definition = flow_definition
|
|
self.nodes: dict[str, dict] = {}
|
|
for node in flow_definition.get("nodes", []):
|
|
self.nodes[node["id"]] = node
|
|
self.edges: list[dict] = flow_definition.get("edges", [])
|
|
self._agent_cache: dict[str, AgentBase] = {}
|
|
|
|
async def execute(self, input_msg: Msg, context: dict) -> Msg:
|
|
graph = self._build_graph()
|
|
start_nodes = self._find_start_nodes(graph)
|
|
if not start_nodes:
|
|
start_nodes = list(self.nodes.keys())[:1]
|
|
|
|
session_id = context.get("session_id", str(uuid.uuid4()))
|
|
user_id = context.get("user_id", "")
|
|
memory = FlowSessionMemory(session_id=session_id, user_id=user_id)
|
|
context["_memory"] = memory
|
|
|
|
visited: set[str] = set()
|
|
loop_iterations: dict[str, int] = {}
|
|
total_iterations = 0
|
|
last_result: Msg | None = None
|
|
|
|
async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None:
|
|
nonlocal last_result, total_iterations
|
|
|
|
total_iterations += 1
|
|
if total_iterations > self.MAX_TOTAL_ITERATIONS:
|
|
logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止")
|
|
last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system")
|
|
return
|
|
|
|
node = self.nodes.get(node_id)
|
|
if not node:
|
|
return
|
|
|
|
node_type = node.get("type", "")
|
|
is_loop = node_type == "loop"
|
|
|
|
if not is_loop and node_id in visited:
|
|
return
|
|
|
|
if not is_loop:
|
|
visited.add(node_id)
|
|
|
|
if is_loop:
|
|
loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1
|
|
config = node.get("config", {})
|
|
max_iter = config.get("max_iterations", 10)
|
|
if loop_iterations[node_id] > max_iter:
|
|
logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}")
|
|
visited.add(node_id)
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_done":
|
|
await traverse(target_id, incoming_msg)
|
|
return
|
|
|
|
if is_loop:
|
|
loop_node_config = dict(node.get("config", {}))
|
|
loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0)
|
|
agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config)
|
|
else:
|
|
agent = await self._get_or_create_agent(node_id, context)
|
|
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
|
|
current_msg = incoming_msg
|
|
if enriched_content.strip():
|
|
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
|
|
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
|
|
|
|
try:
|
|
result = await agent.reply(current_msg)
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node_type,
|
|
"label": node.get("label"),
|
|
"status": "success",
|
|
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
last_result = result
|
|
|
|
if is_loop:
|
|
loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0))
|
|
if loop_done:
|
|
visited.add(node_id)
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_done":
|
|
await traverse(target_id, result)
|
|
else:
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_body":
|
|
body_node = self.nodes.get(target_id)
|
|
if body_node:
|
|
visited.discard(target_id)
|
|
await traverse(target_id, result, {"iteration": loop_iterations[node_id]})
|
|
return
|
|
|
|
is_condition = node_type == "condition"
|
|
cond_result = self._parse_condition_result(result)
|
|
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if is_condition:
|
|
if edge_cond and edge_cond == cond_result:
|
|
await traverse(target_id, result)
|
|
elif edge_cond == "loop_body" or edge_cond == "loop_done":
|
|
continue
|
|
else:
|
|
await traverse(target_id, result)
|
|
except Exception as e:
|
|
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node_type,
|
|
"label": node.get("label"),
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
|
|
last_result = error_msg
|
|
|
|
if start_nodes:
|
|
try:
|
|
await asyncio.wait_for(
|
|
traverse(start_nodes[0], input_msg),
|
|
timeout=self.FLOW_TIMEOUT_SECONDS,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止")
|
|
last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system")
|
|
|
|
return last_result or input_msg
|
|
|
|
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
|
|
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
|
|
for edge in self.edges:
|
|
source = edge.get("source") or edge.get("from")
|
|
target = edge.get("target") or edge.get("to")
|
|
cond = edge.get("condition") or edge.get("sourceHandle")
|
|
if cond == "source":
|
|
cond = None
|
|
if source and target and source in self.nodes and target in self.nodes:
|
|
graph[source].append((target, cond))
|
|
return graph
|
|
|
|
def _find_start_nodes(self, graph: dict) -> list[str]:
|
|
target_nodes: set[str] = set()
|
|
for targets in graph.values():
|
|
for target_id, _ in targets:
|
|
target_nodes.add(target_id)
|
|
return [nid for nid in self.nodes if nid not in target_nodes]
|
|
|
|
def _parse_condition_result(self, result: Msg) -> str | None:
|
|
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
|
|
m = re.search(r'condition:(true|false)', content)
|
|
if m:
|
|
return m.group(1)
|
|
return None
|
|
|
|
def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool:
|
|
config = node.get("config", {})
|
|
loop_type = config.get("loop_type", "fixed")
|
|
if loop_type == "fixed":
|
|
count = config.get("count", 3)
|
|
return iteration >= count
|
|
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
|
|
if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE):
|
|
return True
|
|
return False
|
|
|
|
async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase:
|
|
if node_id in self._agent_cache and override_config is None:
|
|
return self._agent_cache[node_id]
|
|
|
|
node = dict(self.nodes[node_id])
|
|
if override_config is not None:
|
|
node["config"] = override_config
|
|
agent = await _create_node_agent(node, context)
|
|
if override_config is None:
|
|
self._agent_cache[node_id] = agent
|
|
return agent
|
|
|
|
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
|
|
config = node.get("config", {})
|
|
input_mapping = config.get("input_mapping")
|
|
if not input_mapping:
|
|
return ""
|
|
|
|
resolved = {}
|
|
for key, template in input_mapping.items():
|
|
value = template
|
|
if isinstance(template, str) and "{{" in template:
|
|
value = _resolve_template(template, context, current_msg)
|
|
resolved[key] = str(value)
|
|
|
|
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
|
|
|
|
|
|
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
|
|
node_type = node.get("type", "")
|
|
node_id = node.get("id", "")
|
|
config = node.get("config", {})
|
|
|
|
if node_type == "trigger":
|
|
return PassThroughAgent(node_id)
|
|
|
|
elif node_type == "llm":
|
|
model_config = config.get("model", settings.LLM_MODEL)
|
|
temperature = config.get("temperature", 0.7)
|
|
system_prompt = config.get("system_prompt", "你是AI助手。")
|
|
max_tokens = config.get("max_tokens", 2000)
|
|
stream = config.get("stream", True)
|
|
agent = LLMNodeAgent(
|
|
node_id=node_id,
|
|
system_prompt=system_prompt,
|
|
model_name=model_config,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=stream,
|
|
)
|
|
memory = context.get("_memory")
|
|
if memory:
|
|
agent.set_memory(memory)
|
|
return agent
|
|
|
|
elif node_type == "tool":
|
|
tool_name = config.get("tool_name", "")
|
|
tool_params = config.get("tool_params", {})
|
|
timeout = config.get("timeout", 30)
|
|
retry_count = config.get("retry_count", 0)
|
|
error_handling = config.get("error_handling", "throw")
|
|
return ToolNodeAgent(
|
|
node_id=node_id,
|
|
tool_name=tool_name,
|
|
tool_params=tool_params,
|
|
timeout=timeout,
|
|
retry_count=retry_count,
|
|
error_handling=error_handling,
|
|
)
|
|
|
|
elif node_type == "mcp":
|
|
mcp_server = config.get("mcp_server", "")
|
|
tool_name = config.get("tool_name", "")
|
|
timeout = config.get("timeout", 30)
|
|
error_handling = config.get("error_handling", "throw")
|
|
return MCPNodeAgent(
|
|
node_id=node_id,
|
|
server_name=mcp_server,
|
|
tool_name=tool_name,
|
|
timeout=timeout,
|
|
error_handling=error_handling,
|
|
)
|
|
|
|
elif node_type in ("wecom_notify", "notify"):
|
|
return NotifyAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "condition":
|
|
condition_expr = config.get("condition", "")
|
|
condition_type = config.get("condition_type", "expression")
|
|
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
|
|
|
|
elif node_type == "rag":
|
|
return RAGNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "output":
|
|
return OutputNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "merge":
|
|
return ParallelMergeNodeAgent(node_id=node_id, config=config)
|
|
elif node_type == "loop":
|
|
return LoopNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "code":
|
|
language = config.get("language", "python")
|
|
code = config.get("code", "")
|
|
timeout = config.get("timeout", 30)
|
|
sandbox = config.get("sandbox", True)
|
|
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
|
|
|
|
else:
|
|
return PassThroughAgent(node_id)
|
|
|
|
|
|
class PassThroughAgent(AgentBase):
|
|
def __init__(self, node_id: str):
|
|
super().__init__()
|
|
self.name = f"passthrough_{node_id}"
|
|
|
|
async def reply(self, msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LLMNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True):
|
|
super().__init__()
|
|
self.name = f"LLM_{node_id}"
|
|
self.system_prompt = system_prompt
|
|
self.model_name = model_name or settings.LLM_MODEL
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.stream = stream
|
|
self._memory = None
|
|
|
|
def set_memory(self, memory):
|
|
self._memory = memory
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
from agentscope_integration.factory import AgentFactory
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
|
|
model = AgentFactory._get_model()
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
formatter = OpenAIChatFormatter()
|
|
messages = [Msg("system", self.system_prompt, "system")]
|
|
|
|
if self._memory:
|
|
history = self._memory.get_history(limit=5)
|
|
for h in history:
|
|
role = h.get("role", "user")
|
|
content = h.get("content", "")
|
|
if len(content) > 2000:
|
|
content = content[:2000]
|
|
messages.append(Msg(role, content, role))
|
|
|
|
messages.append(Msg("user", user_text, "user"))
|
|
prompt = formatter.format(messages)
|
|
|
|
try:
|
|
res = await model(prompt)
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
except Exception as e:
|
|
logger.warning(f"LLM 调用失败: {e}")
|
|
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
|
|
|
|
if self._memory:
|
|
self._memory.add("user", user_text)
|
|
self._memory.add("assistant", res_text)
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ToolNodeAgent(AgentBase):
|
|
_TOOL_REGISTRY: dict[str, callable] = {}
|
|
_TOOL_SCHEMAS: dict[str, dict] = {}
|
|
_CUSTOM_TOOL_DEFS: dict[str, dict] = {}
|
|
|
|
@classmethod
|
|
def _init_registry(cls):
|
|
if cls._TOOL_REGISTRY:
|
|
return
|
|
try:
|
|
from agentscope_integration.tools.document_tools import parse_document, format_correction
|
|
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
|
|
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
|
|
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
|
|
|
|
cls._TOOL_REGISTRY = {
|
|
"parse_document": parse_document,
|
|
"format_correction": format_correction,
|
|
"send_notification": send_notification,
|
|
"query_wecom_user": query_wecom_user,
|
|
"list_tasks": list_tasks,
|
|
"create_task": create_task,
|
|
"get_task": get_task,
|
|
"update_task": update_task,
|
|
"push_task_to_wecom": push_task_to_wecom,
|
|
"list_subordinates": list_subordinates,
|
|
"generate_efficiency_report": generate_efficiency_report,
|
|
"get_task_statistics": get_task_statistics,
|
|
"get_employee_dashboard": get_employee_dashboard,
|
|
}
|
|
|
|
for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]:
|
|
try:
|
|
mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"])
|
|
if hasattr(mod, "SCHEMAS"):
|
|
cls._TOOL_SCHEMAS.update(mod.SCHEMAS)
|
|
except Exception:
|
|
pass
|
|
except ImportError as e:
|
|
logger.warning(f"工具注册失败: {e}")
|
|
|
|
@classmethod
|
|
async def load_custom_tools(cls, db):
|
|
try:
|
|
from sqlalchemy import select
|
|
from models import CustomTool
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True)
|
|
)
|
|
for tool in result.scalars().all():
|
|
tool_def = {
|
|
"endpoint_url": tool.endpoint_url,
|
|
"method": tool.method,
|
|
"path": tool.path,
|
|
"headers_json": tool.headers_json,
|
|
"auth_type": tool.auth_type,
|
|
"auth_config": tool.auth_config,
|
|
"timeout": 30,
|
|
}
|
|
cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def
|
|
cls._TOOL_SCHEMAS[tool.name] = tool.schema_json
|
|
except ImportError:
|
|
logger.warning("无法加载自定义工具模块")
|
|
except Exception as e:
|
|
logger.warning(f"加载自定义工具失败: {e}")
|
|
|
|
@classmethod
|
|
def register_custom_tool(cls, name: str, schema: dict, tool_def: dict):
|
|
cls._CUSTOM_TOOL_DEFS[name] = tool_def
|
|
cls._TOOL_SCHEMAS[name] = schema
|
|
|
|
@classmethod
|
|
def get_schemas(cls) -> dict:
|
|
cls._init_registry()
|
|
return dict(cls._TOOL_SCHEMAS)
|
|
|
|
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"):
|
|
super().__init__()
|
|
self.name = f"Tool_{node_id}"
|
|
self.tool_name = tool_name
|
|
self.tool_params = tool_params or {}
|
|
self.timeout = timeout
|
|
self.retry_count = retry_count
|
|
self.error_handling = error_handling
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
self._init_registry()
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name)
|
|
if custom_def:
|
|
return await self._execute_custom_tool(custom_def, user_text)
|
|
|
|
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
|
|
if not tool_func:
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant")
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
|
|
|
|
def _call_sync():
|
|
try:
|
|
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
|
|
return result
|
|
except TypeError:
|
|
try:
|
|
result = tool_func(user_text, **self.tool_params)
|
|
return result
|
|
except Exception as e:
|
|
raise RuntimeError(f"工具执行失败: {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"工具执行失败: {e}")
|
|
|
|
for attempt in range(self.retry_count + 1):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
asyncio.to_thread(_call_sync),
|
|
timeout=self.timeout,
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except asyncio.TimeoutError:
|
|
if attempt < self.retry_count:
|
|
logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant")
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
|
|
except RuntimeError as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
if self.error_handling == "default":
|
|
return Msg(self.name, "{}", "assistant")
|
|
return Msg(self.name, f"[{e}]", "assistant")
|
|
except Exception as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
|
|
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant")
|
|
|
|
async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg:
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
|
|
executor = CustomToolExecutor(custom_def)
|
|
|
|
for attempt in range(self.retry_count + 1):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
executor.execute(self.tool_params),
|
|
timeout=self.timeout,
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except asyncio.TimeoutError:
|
|
if attempt < self.retry_count:
|
|
logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant")
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
|
|
except Exception as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
if self.error_handling == "default":
|
|
return Msg(self.name, "{}", "assistant")
|
|
return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant")
|
|
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ConditionNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"):
|
|
super().__init__()
|
|
self.name = f"Condition_{node_id}"
|
|
self.condition = condition
|
|
self.condition_type = condition_type
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.condition:
|
|
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
|
|
|
|
if self.condition_type == "regex":
|
|
try:
|
|
matched = bool(re.search(self.condition, user_text))
|
|
return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'命' if matched else '未命'}中", "assistant")
|
|
except re.error as e:
|
|
return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant")
|
|
|
|
if self.condition_type == "json_path":
|
|
try:
|
|
data = json.loads(user_text) if user_text.strip().startswith('{') else {}
|
|
parts = self.condition.strip('$.').split('.')
|
|
val = data
|
|
for p in parts:
|
|
if isinstance(val, dict):
|
|
val = val.get(p, None)
|
|
else:
|
|
val = None
|
|
return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant")
|
|
except:
|
|
return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant")
|
|
|
|
try:
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
|
|
model = OpenAIChatModel(
|
|
config_name=f"condition_{self.name}",
|
|
model_name=settings.LLM_MODEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
formatter = OpenAIChatFormatter()
|
|
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
|
|
|
|
条件表达式: {self.condition}
|
|
|
|
输入内容:
|
|
{user_text[:2000]}
|
|
|
|
请严格只输出一行 JSON:
|
|
{{"result": true/false, "reason": "简要原因"}}"""
|
|
|
|
prompt = formatter.format([
|
|
Msg("system", condition_prompt, "system"),
|
|
Msg("user", user_text[:2000], "user"),
|
|
])
|
|
res = await model(prompt)
|
|
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
|
|
json_match = re.search(r'\{[^}]+\}', res_text)
|
|
if json_match:
|
|
parsed = json.loads(json_match.group())
|
|
matched = parsed.get("result", False)
|
|
reason = parsed.get("reason", "")
|
|
result_flag = "true" if matched else "false"
|
|
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"条件判断LLM调用失败: {e}")
|
|
|
|
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class MCPNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"):
|
|
super().__init__()
|
|
self.name = f"MCP_{node_id}"
|
|
self.server_name = server_name
|
|
self.tool_name = tool_name
|
|
self.timeout = timeout
|
|
self.error_handling = error_handling
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.server_name:
|
|
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
|
|
|
|
try:
|
|
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
|
|
|
|
client = MCPClientManager.get_http_client(self.server_name)
|
|
if client and self.tool_name:
|
|
result = await asyncio.wait_for(
|
|
client.call_tool(self.tool_name, {"input": user_text}),
|
|
timeout=self.timeout
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except ImportError:
|
|
logger.warning("agentscope_runtime MCP 客户端不可用")
|
|
except asyncio.TimeoutError:
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant")
|
|
return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"MCP 调用失败: {e}")
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant")
|
|
|
|
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class NotifyAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Notify_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
channels = self.config.get("channels", {"wecom": True, "web": False})
|
|
message_type = self.config.get("message_type", "text")
|
|
|
|
results = []
|
|
|
|
if channels.get("wecom", True):
|
|
template = self.config.get("message_template", "")
|
|
target = self.config.get("target", "")
|
|
message = template or user_text[:500]
|
|
try:
|
|
from agentscope_integration.tools.wecom_tools import send_notification
|
|
result = send_notification(to_user=target or "user", message=message)
|
|
results.append(f"企微通知: {result[:100]}")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"企微通知发送失败: {e}")
|
|
results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}")
|
|
|
|
if channels.get("web", False):
|
|
web_template = self.config.get("web_template", "")
|
|
web_message = web_template or user_text[:500]
|
|
try:
|
|
from websocket_manager import ws_manager
|
|
target_user = self.config.get("target", "")
|
|
if target_user:
|
|
await ws_manager.send_to_user(target_user, {
|
|
"type": "flow_notification",
|
|
"message": web_message,
|
|
"level": "info",
|
|
})
|
|
results.append(f"Web通知: 已推送")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"Web通知推送失败: {e}")
|
|
results.append(f"Web通知: 推送失败({e})")
|
|
|
|
if not results:
|
|
results.append("通知已发送")
|
|
return Msg(self.name, " | ".join(results), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LoopNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Loop_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
loop_type = self.config.get("loop_type", "fixed")
|
|
count = self.config.get("count", 3)
|
|
iterator_variable = self.config.get("iterator_variable", "item")
|
|
iteration = self.config.get("_engine_iteration", 0) + 1
|
|
|
|
if loop_type == "array":
|
|
items = self.config.get("items", [])
|
|
if isinstance(items, str):
|
|
try:
|
|
items = json.loads(items)
|
|
except Exception:
|
|
items = [line.strip() for line in items.split("\n") if line.strip()]
|
|
if isinstance(items, list) and items:
|
|
idx = iteration - 1
|
|
if idx >= len(items):
|
|
return Msg(self.name, "loop:stop|数组迭代完成", "assistant")
|
|
current_item = items[idx]
|
|
return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant")
|
|
return Msg(self.name, "loop:stop|无数组数据", "assistant")
|
|
|
|
if loop_type == "fixed":
|
|
if iteration >= count:
|
|
return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant")
|
|
return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
|
|
|
|
return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class CodeNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True):
|
|
super().__init__()
|
|
self.name = f"Code_{node_id}"
|
|
self.language = language
|
|
self.code = code
|
|
self.timeout = timeout
|
|
self.sandbox = sandbox
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
import subprocess
|
|
import tempfile
|
|
import os
|
|
import json
|
|
|
|
if not self.code.strip():
|
|
return Msg(self.name, user_text, "assistant")
|
|
|
|
safe_input = json.dumps(user_text)
|
|
code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}"
|
|
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
|
f.write(code_with_input)
|
|
temp_path = f.name
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"python", temp_path,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant")
|
|
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
|
|
if stderr:
|
|
return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant")
|
|
|
|
output = stdout.decode('utf-8', errors='replace').strip()
|
|
return Msg(self.name, output or "[代码执行完成,无输出]", "assistant")
|
|
|
|
except Exception as e:
|
|
return Msg(self.name, f"[代码执行失败: {e}]", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class RAGNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"RAG_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
top_k = self.config.get("top_k", 5)
|
|
similarity_threshold = self.config.get("similarity_threshold", 0.7)
|
|
search_mode = self.config.get("search_mode", "hybrid")
|
|
include_metadata = self.config.get("include_metadata", True)
|
|
|
|
try:
|
|
from modules.rag.knowledge import retrieve_for_agent
|
|
kb_result = await retrieve_for_agent(user_text, limit=top_k)
|
|
|
|
model = self._get_model()
|
|
formatter = self._get_formatter()
|
|
|
|
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
|
|
|
|
知识库检索结果:
|
|
{kb_result}
|
|
|
|
用户问题: {user_text}
|
|
|
|
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
|
|
|
|
prompt = formatter.format([
|
|
Msg("system", rag_prompt, "system"),
|
|
Msg("user", user_text, "user"),
|
|
])
|
|
|
|
res = await model(prompt)
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"RAG 节点执行失败: {e}")
|
|
|
|
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
def _get_model(self):
|
|
from agentscope.model import OpenAIChatModel
|
|
return OpenAIChatModel(
|
|
config_name=f"rag_{self.name}",
|
|
model_name=settings.LLM_MODEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
def _get_formatter(self):
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
return OpenAIChatFormatter()
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class OutputNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Output_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
output_format = self.config.get("format", "text")
|
|
output_template = self.config.get("output_template", "")
|
|
truncate = self.config.get("truncate", False)
|
|
max_length = self.config.get("max_length", 2000)
|
|
|
|
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if truncate and len(content) > max_length:
|
|
content = content[:max_length] + "\n\n[内容已截断]"
|
|
|
|
if output_format == "json":
|
|
try:
|
|
parsed = json.loads(content)
|
|
indent = self.config.get("indent", 2)
|
|
formatted = json.dumps(parsed, indent=indent, ensure_ascii=False)
|
|
return Msg(self.name, formatted, "assistant")
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
if output_template:
|
|
try:
|
|
resolved = output_template.replace("{{output}}", content)
|
|
return Msg(self.name, resolved, "assistant")
|
|
except:
|
|
pass
|
|
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
|
|
result = template
|
|
import re
|
|
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
|
|
for placeholder in placeholders:
|
|
parts = placeholder.strip().split(".")
|
|
value = ""
|
|
if parts[0] == "trigger":
|
|
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
|
|
elif parts[0] in context.get("_node_results", {}):
|
|
node_result = context.get("_node_results", {}).get(parts[0], {})
|
|
if len(parts) > 1:
|
|
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
|
|
else:
|
|
value = str(node_result.get("output", ""))
|
|
result = result.replace("{{" + placeholder + "}}", value)
|
|
return result
|
|
|
|
|
|
class ParallelMergeNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Merge_{node_id}"
|
|
self.config = config or {}
|
|
self._received: dict[str, str] = {}
|
|
self.merge_type = self.config.get("merge_type", "concat")
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
source_id = kwargs.get("source_node_id", "")
|
|
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
self._received[source_id] = content
|
|
|
|
expected_count = self.config.get("expected_branches", 0)
|
|
if expected_count <= 0 or len(self._received) >= expected_count:
|
|
return self._merge()
|
|
return Msg(self.name, "", "assistant")
|
|
|
|
def _merge(self) -> Msg:
|
|
if self.merge_type == "json":
|
|
merged = json.dumps(self._received, ensure_ascii=False)
|
|
elif self.merge_type == "first_non_empty":
|
|
merged = ""
|
|
for v in self._received.values():
|
|
if v.strip():
|
|
merged = v
|
|
break
|
|
else:
|
|
parts = []
|
|
for k, v in self._received.items():
|
|
parts.append(v)
|
|
merged = "\n\n---\n\n".join(parts)
|
|
return Msg(self.name, merged, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|