You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
318 lines
12 KiB
318 lines
12 KiB
import json
|
|
import uuid
|
|
from collections import deque
|
|
from agentscope.agent import AgentBase
|
|
from agentscope.message import Msg
|
|
from agentscope.tool import Toolkit
|
|
from config import settings
|
|
|
|
|
|
class FlowEngine:
|
|
def __init__(self, flow_definition: dict):
|
|
self.definition = flow_definition
|
|
self.nodes: dict[str, dict] = {}
|
|
for node in flow_definition.get("nodes", []):
|
|
self.nodes[node["id"]] = node
|
|
self.edges: list[dict] = flow_definition.get("edges", [])
|
|
self._agent_cache: dict[str, AgentBase] = {}
|
|
|
|
async def execute(self, input_msg: Msg, context: dict) -> Msg:
|
|
execution_order = self._topological_sort()
|
|
current_msg = input_msg
|
|
|
|
for node_id in execution_order:
|
|
agent = await self._get_or_create_agent(node_id, context)
|
|
node = self.nodes[node_id]
|
|
|
|
enriched_content = self._resolve_input_mapping(node, current_msg, context)
|
|
if enriched_content:
|
|
if hasattr(current_msg, 'get_text_content'):
|
|
enriched_msg = Msg(
|
|
name=current_msg.name if hasattr(current_msg, 'name') else "user",
|
|
content=enriched_content + "\n\n---\n" + (current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)),
|
|
role="user",
|
|
)
|
|
else:
|
|
enriched_msg = Msg(name="user", content=enriched_content, role="user")
|
|
current_msg = enriched_msg
|
|
|
|
try:
|
|
result = await agent.reply(current_msg)
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node.get("type"),
|
|
"label": node.get("label"),
|
|
"status": "success",
|
|
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
current_msg = result
|
|
except Exception as e:
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node.get("type"),
|
|
"label": node.get("label"),
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
current_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
|
|
|
|
return current_msg
|
|
|
|
async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase:
|
|
if node_id in self._agent_cache:
|
|
return self._agent_cache[node_id]
|
|
|
|
node = self.nodes[node_id]
|
|
agent = await _create_node_agent(node, context)
|
|
self._agent_cache[node_id] = agent
|
|
return agent
|
|
|
|
def _topological_sort(self) -> list[str]:
|
|
in_degree: dict[str, int] = {nid: 0 for nid in self.nodes}
|
|
adj: dict[str, list[str]] = {nid: [] for nid in self.nodes}
|
|
|
|
for edge in self.edges:
|
|
source = edge.get("from") or edge.get("source")
|
|
target = edge.get("to") or edge.get("target")
|
|
if source and target and source in self.nodes and target in self.nodes:
|
|
adj[source].append(target)
|
|
in_degree[target] = in_degree.get(target, 0) + 1
|
|
|
|
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
|
|
order = []
|
|
|
|
while queue:
|
|
node_id = queue.popleft()
|
|
order.append(node_id)
|
|
for neighbor in adj.get(node_id, []):
|
|
in_degree[neighbor] -= 1
|
|
if in_degree[neighbor] == 0:
|
|
queue.append(neighbor)
|
|
|
|
remaining = [nid for nid in self.nodes if nid not in order]
|
|
order.extend(remaining)
|
|
return order
|
|
|
|
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
|
|
config = node.get("config", {})
|
|
input_mapping = config.get("input_mapping")
|
|
if not input_mapping:
|
|
return ""
|
|
|
|
resolved = {}
|
|
for key, template in input_mapping.items():
|
|
value = template
|
|
if isinstance(template, str) and "{{" in template:
|
|
value = _resolve_template(template, context, current_msg)
|
|
resolved[key] = str(value)
|
|
|
|
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
|
|
|
|
|
|
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
|
|
node_type = node.get("type", "")
|
|
node_id = node.get("id", "")
|
|
config = node.get("config", {})
|
|
|
|
if node_type == "trigger":
|
|
return PassThroughAgent(node_id)
|
|
|
|
elif node_type == "llm":
|
|
model_config = config.get("model", settings.LLM_MODEL)
|
|
temperature = config.get("temperature", 0.7)
|
|
system_prompt = config.get("system_prompt", "你是AI助手。")
|
|
return LLMNodeAgent(
|
|
node_id=node_id,
|
|
system_prompt=system_prompt,
|
|
model_name=model_config,
|
|
temperature=temperature,
|
|
)
|
|
|
|
elif node_type == "tool":
|
|
tool_name = config.get("tool_name", "")
|
|
return ToolNodeAgent(node_id=node_id, tool_name=tool_name)
|
|
|
|
elif node_type == "mcp":
|
|
mcp_server = config.get("mcp_server", "")
|
|
return MCPNodeAgent(node_id=node_id, server_name=mcp_server)
|
|
|
|
elif node_type == "wecom_notify":
|
|
return WeComNotifyAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "condition":
|
|
condition = config.get("condition", "")
|
|
return ConditionNodeAgent(node_id=node_id, condition=condition)
|
|
|
|
elif node_type == "rag":
|
|
return RAGNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "output":
|
|
return OutputNodeAgent(node_id=node_id, config=config)
|
|
|
|
else:
|
|
return PassThroughAgent(node_id)
|
|
|
|
|
|
class PassThroughAgent(AgentBase):
|
|
def __init__(self, node_id: str):
|
|
super().__init__()
|
|
self.name = f"passthrough_{node_id}"
|
|
|
|
async def reply(self, msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LLMNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7):
|
|
super().__init__()
|
|
self.name = f"LLM_{node_id}"
|
|
self.system_prompt = system_prompt
|
|
self.model_name = model_name or settings.LLM_MODEL
|
|
self.temperature = temperature
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
|
|
model = OpenAIChatModel(
|
|
config_name=f"flow_llm_{self.name}",
|
|
model_name=self.model_name,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
formatter = OpenAIChatFormatter()
|
|
prompt = await formatter.format([
|
|
Msg("system", self.system_prompt, "system"),
|
|
Msg("user", user_text, "user"),
|
|
])
|
|
|
|
try:
|
|
res = await model(prompt)
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
except Exception:
|
|
res_text = f"[LLM 调用失败,使用模拟输出] 已处理: {user_text[:200]}"
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ToolNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, tool_name: str = ""):
|
|
super().__init__()
|
|
self.name = f"Tool_{node_id}"
|
|
self.tool_name = tool_name
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
output = f"[工具 {self.tool_name}] 已处理输入,返回结果。"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class MCPNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, server_name: str = ""):
|
|
super().__init__()
|
|
self.name = f"MCP_{node_id}"
|
|
self.server_name = server_name
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
output = f"[MCP {self.server_name}] 调用完成,返回数据。"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class WeComNotifyAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"WeComNotify_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
template = self.config.get("message_template", "通知: 任务处理完成")
|
|
target = self.config.get("target", "")
|
|
result = f"[企微通知] 已向 {target or '用户'} 推送消息: {template[:100]}"
|
|
return Msg(self.name, result, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ConditionNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, condition: str = ""):
|
|
super().__init__()
|
|
self.name = f"Condition_{node_id}"
|
|
self.condition = condition
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class RAGNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"RAG_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
output = f"[RAG检索] 已从知识库检索相关内容。"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class OutputNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Output_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
|
|
result = template
|
|
import re
|
|
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
|
|
for placeholder in placeholders:
|
|
parts = placeholder.strip().split(".")
|
|
value = ""
|
|
if parts[0] == "trigger":
|
|
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
|
|
elif parts[0] in context.get("_node_results", {}):
|
|
node_result = context.get("_node_results", {}).get(parts[0], {})
|
|
if len(parts) > 1:
|
|
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
|
|
else:
|
|
value = str(node_result.get("output", ""))
|
|
result = result.replace("{{" + placeholder + "}}", value)
|
|
return result
|