You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

522 lines
20 KiB

import json
import uuid
import logging
import re
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.agent._react_agent import ReActAgent
from config import settings
logger = logging.getLogger(__name__)
class FlowEngine:
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
graph = self._build_graph()
start_nodes = self._find_start_nodes(graph)
if not start_nodes:
start_nodes = list(self.nodes.keys())[:1]
visited: set[str] = set()
last_result: Msg | None = None
async def traverse(node_id: str, incoming_msg: Msg) -> None:
nonlocal last_result
if node_id in visited:
return
visited.add(node_id)
node = self.nodes.get(node_id)
if not node:
return
agent = await self._get_or_create_agent(node_id, context)
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
current_msg = incoming_msg
if enriched_content.strip():
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
try:
result = await agent.reply(current_msg)
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
last_result = result
is_condition = node.get("type") == "condition"
cond_result = self._parse_condition_result(result)
for target_id, edge_cond in graph.get(node_id, []):
if is_condition:
if edge_cond and edge_cond == cond_result:
await traverse(target_id, result)
else:
await traverse(target_id, result)
except Exception as e:
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
last_result = error_msg
if start_nodes:
await traverse(start_nodes[0], input_msg)
return last_result or input_msg
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("source") or edge.get("from")
target = edge.get("target") or edge.get("to")
cond = edge.get("condition") or edge.get("sourceHandle")
if cond == "source":
cond = None
if source and target and source in self.nodes and target in self.nodes:
graph[source].append((target, cond))
return graph
def _find_start_nodes(self, graph: dict) -> list[str]:
target_nodes: set[str] = set()
for targets in graph.values():
for target_id, _ in targets:
target_nodes.add(target_id)
return [nid for nid in self.nodes if nid not in target_nodes]
def _parse_condition_result(self, result: Msg) -> str | None:
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
m = re.search(r'condition:(true|false)', content)
if m:
return m.group(1)
return None
async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase:
if node_id in self._agent_cache:
return self._agent_cache[node_id]
node = self.nodes[node_id]
agent = await _create_node_agent(node, context)
self._agent_cache[node_id] = agent
return agent
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
return LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
)
elif node_type == "tool":
tool_name = config.get("tool_name", "")
tool_params = config.get("tool_params", {})
return ToolNodeAgent(node_id=node_id, tool_name=tool_name, tool_params=tool_params)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
tool_name = config.get("tool_name", "")
return MCPNodeAgent(node_id=node_id, server_name=mcp_server, tool_name=tool_name)
elif node_type == "wecom_notify":
return WeComNotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition_expr = config.get("condition", "")
return ConditionNodeAgent(node_id=node_id, condition=condition_expr)
elif node_type == "rag":
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
async def reply(self, msg: Msg, **kwargs) -> Msg:
from agentscope_integration.factory import AgentFactory
from agentscope.formatter import OpenAIChatFormatter
model = AgentFactory._get_model()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
formatter = OpenAIChatFormatter()
prompt = await formatter.format([
Msg("system", self.system_prompt, "system"),
Msg("user", user_text, "user"),
])
try:
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
except Exception as e:
logger.warning(f"LLM 调用失败: {e}")
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
return Msg(self.name, res_text, "assistant")
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
_TOOL_REGISTRY: dict[str, callable] = {}
@classmethod
def _init_registry(cls):
if cls._TOOL_REGISTRY:
return
try:
from agentscope_integration.tools.document_tools import parse_document, format_correction
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
cls._TOOL_REGISTRY = {
"parse_document": parse_document,
"format_correction": format_correction,
"send_notification": send_notification,
"query_wecom_user": query_wecom_user,
"list_tasks": list_tasks,
"create_task": create_task,
"get_task": get_task,
"update_task": update_task,
"push_task_to_wecom": push_task_to_wecom,
"list_subordinates": list_subordinates,
"generate_efficiency_report": generate_efficiency_report,
"get_task_statistics": get_task_statistics,
"get_employee_dashboard": get_employee_dashboard,
}
except ImportError as e:
logger.warning(f"工具注册失败: {e}")
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
self.tool_params = tool_params or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
self._init_registry()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
if tool_func:
try:
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
return Msg(self.name, str(result), "assistant")
except TypeError:
try:
result = tool_func(user_text, **self.tool_params)
return Msg(self.name, str(result), "assistant")
except Exception as e:
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
except Exception as e:
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = ""):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.condition:
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
try:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"condition_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
formatter = OpenAIChatFormatter()
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
条件表达式: {self.condition}
输入内容:
{user_text[:2000]}
请严格只输出一行 JSON:
{{"result": true/false, "reason": "简要原因"}}"""
prompt = await formatter.format([
Msg("system", condition_prompt, "system"),
Msg("user", user_text[:2000], "user"),
])
res = await model(prompt)
import json
import re
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
json_match = re.search(r'\{[^}]+\}', res_text)
if json_match:
parsed = json.loads(json_match.group())
matched = parsed.get("result", False)
reason = parsed.get("reason", "")
result_flag = "true" if matched else "false"
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
except Exception as e:
logger.warning(f"条件判断LLM调用失败: {e}")
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = "", tool_name: str = ""):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
self.tool_name = tool_name
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.server_name:
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
try:
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
client = MCPClientManager.get_http_client(self.server_name)
if client and self.tool_name:
result = await client.call_tool(self.tool_name, {"input": user_text})
return Msg(self.name, str(result), "assistant")
except ImportError:
logger.warning("agentscope_runtime MCP 客户端不可用")
except Exception as e:
logger.warning(f"MCP 调用失败: {e}")
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class WeComNotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"WeComNotify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
template = self.config.get("message_template", "")
target = self.config.get("target", "")
message = template or user_text[:500]
try:
from agentscope_integration.tools.wecom_tools import send_notification
result = send_notification(to_user=target or "user", message=message)
return Msg(self.name, result, "assistant")
except ImportError:
pass
except Exception as e:
logger.warning(f"企微通知发送失败: {e}")
result = f"[企微通知] 已向 {target or '用户'} 发送: {message[:100]}"
return Msg(self.name, result, "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
top_k = self.config.get("top_k", 5)
try:
from modules.rag.knowledge import retrieve_for_agent
kb_result = await retrieve_for_agent(user_text, limit=top_k)
model = self._get_model()
formatter = self._get_formatter()
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
知识库检索结果:
{kb_result}
用户问题: {user_text}
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
import asyncio
loop = asyncio.get_event_loop()
messages = await asyncio.to_thread(formatter.format, [
{"role": "system", "content": rag_prompt},
{"role": "user", "content": user_text},
])
res = await model(messages)
res_text = res.get_text_content() if hasattr(res, 'get_text_content') else str(res)
return Msg(self.name, res_text, "assistant")
except Exception as e:
logger.warning(f"RAG 节点执行失败: {e}")
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
return Msg(self.name, output, "assistant")
def _get_model(self):
from agentscope.model import OpenAIChatModel
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
def _get_formatter(self):
from agentscope.formatter import OpenAIChatFormatter
return OpenAIChatFormatter()
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
output_format = self.config.get("format", "text")
if output_format == "json":
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
parsed = json.loads(content)
formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
return Msg(self.name, formatted, "assistant")
except (json.JSONDecodeError, ValueError):
pass
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result