You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

318 lines
12 KiB

import json
import uuid
from collections import deque
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from config import settings
class FlowEngine:
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
execution_order = self._topological_sort()
current_msg = input_msg
for node_id in execution_order:
agent = await self._get_or_create_agent(node_id, context)
node = self.nodes[node_id]
enriched_content = self._resolve_input_mapping(node, current_msg, context)
if enriched_content:
if hasattr(current_msg, 'get_text_content'):
enriched_msg = Msg(
name=current_msg.name if hasattr(current_msg, 'name') else "user",
content=enriched_content + "\n\n---\n" + (current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)),
role="user",
)
else:
enriched_msg = Msg(name="user", content=enriched_content, role="user")
current_msg = enriched_msg
try:
result = await agent.reply(current_msg)
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
current_msg = result
except Exception as e:
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
current_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
return current_msg
async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase:
if node_id in self._agent_cache:
return self._agent_cache[node_id]
node = self.nodes[node_id]
agent = await _create_node_agent(node, context)
self._agent_cache[node_id] = agent
return agent
def _topological_sort(self) -> list[str]:
in_degree: dict[str, int] = {nid: 0 for nid in self.nodes}
adj: dict[str, list[str]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("from") or edge.get("source")
target = edge.get("to") or edge.get("target")
if source and target and source in self.nodes and target in self.nodes:
adj[source].append(target)
in_degree[target] = in_degree.get(target, 0) + 1
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
order = []
while queue:
node_id = queue.popleft()
order.append(node_id)
for neighbor in adj.get(node_id, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
remaining = [nid for nid in self.nodes if nid not in order]
order.extend(remaining)
return order
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
return LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
)
elif node_type == "tool":
tool_name = config.get("tool_name", "")
return ToolNodeAgent(node_id=node_id, tool_name=tool_name)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
return MCPNodeAgent(node_id=node_id, server_name=mcp_server)
elif node_type == "wecom_notify":
return WeComNotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition = config.get("condition", "")
return ConditionNodeAgent(node_id=node_id, condition=condition)
elif node_type == "rag":
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
async def reply(self, msg: Msg, **kwargs) -> Msg:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"flow_llm_{self.name}",
model_name=self.model_name,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
formatter = OpenAIChatFormatter()
prompt = await formatter.format([
Msg("system", self.system_prompt, "system"),
Msg("user", user_text, "user"),
])
try:
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
except Exception:
res_text = f"[LLM 调用失败,使用模拟输出] 已处理: {user_text[:200]}"
return Msg(self.name, res_text, "assistant")
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
def __init__(self, node_id: str, tool_name: str = ""):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
output = f"[工具 {self.tool_name}] 已处理输入,返回结果。"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = ""):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
output = f"[MCP {self.server_name}] 调用完成,返回数据。"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class WeComNotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"WeComNotify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
template = self.config.get("message_template", "通知: 任务处理完成")
target = self.config.get("target", "")
result = f"[企微通知] 已向 {target or '用户'} 推送消息: {template[:100]}"
return Msg(self.name, result, "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = ""):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
async def reply(self, msg: Msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
output = f"[RAG检索] 已从知识库检索相关内容。"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result