You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

454 lines
17 KiB

import json
import uuid
import logging
from collections import deque
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.agent._react_agent import ReActAgent
from config import settings
logger = logging.getLogger(__name__)
class FlowEngine:
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
execution_order = self._topological_sort()
current_msg = input_msg
for i, node_id in enumerate(execution_order):
agent = await self._get_or_create_agent(node_id, context)
node = self.nodes[node_id]
enriched_content = self._resolve_input_mapping(node, current_msg, context)
if enriched_content.strip():
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
try:
result = await agent(current_msg)
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
current_msg = result
except Exception as e:
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
exec_record = {
"node_id": node_id,
"node_type": node.get("type"),
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
current_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
return current_msg
async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase:
if node_id in self._agent_cache:
return self._agent_cache[node_id]
node = self.nodes[node_id]
agent = await _create_node_agent(node, context)
self._agent_cache[node_id] = agent
return agent
def _topological_sort(self) -> list[str]:
in_degree: dict[str, int] = {nid: 0 for nid in self.nodes}
adj: dict[str, list[str]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("from") or edge.get("source")
target = edge.get("to") or edge.get("target")
if source and target and source in self.nodes and target in self.nodes:
adj[source].append(target)
in_degree[target] = in_degree.get(target, 0) + 1
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
order = []
while queue:
node_id = queue.popleft()
order.append(node_id)
for neighbor in adj.get(node_id, []):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
remaining = [nid for nid in self.nodes if nid not in order]
order.extend(remaining)
return order
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
return LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
)
elif node_type == "tool":
tool_name = config.get("tool_name", "")
tool_params = config.get("tool_params", {})
return ToolNodeAgent(node_id=node_id, tool_name=tool_name, tool_params=tool_params)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
tool_name = config.get("tool_name", "")
return MCPNodeAgent(node_id=node_id, server_name=mcp_server, tool_name=tool_name)
elif node_type == "wecom_notify":
return WeComNotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition_expr = config.get("condition", "")
return ConditionNodeAgent(node_id=node_id, condition=condition_expr)
elif node_type == "rag":
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
async def reply(self, msg: Msg, **kwargs) -> Msg:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"flow_llm_{self.name}",
model_name=self.model_name,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
formatter = OpenAIChatFormatter()
prompt = await formatter.format([
Msg("system", self.system_prompt, "system"),
Msg("user", user_text, "user"),
])
try:
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
except Exception as e:
logger.warning(f"LLM 调用失败: {e}")
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
return Msg(self.name, res_text, "assistant")
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
_TOOL_REGISTRY: dict[str, callable] = {}
@classmethod
def _init_registry(cls):
if cls._TOOL_REGISTRY:
return
try:
from agentscope_integration.tools.document_tools import parse_document, format_correction
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
cls._TOOL_REGISTRY = {
"parse_document": parse_document,
"format_correction": format_correction,
"send_notification": send_notification,
"query_wecom_user": query_wecom_user,
"list_tasks": list_tasks,
"create_task": create_task,
"get_task": get_task,
"update_task": update_task,
"push_task_to_wecom": push_task_to_wecom,
"list_subordinates": list_subordinates,
"generate_efficiency_report": generate_efficiency_report,
"get_task_statistics": get_task_statistics,
"get_employee_dashboard": get_employee_dashboard,
}
except ImportError as e:
logger.warning(f"工具注册失败: {e}")
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
self.tool_params = tool_params or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
self._init_registry()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
if tool_func:
try:
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
return Msg(self.name, str(result), "assistant")
except TypeError:
try:
result = tool_func(user_text, **self.tool_params)
return Msg(self.name, str(result), "assistant")
except Exception as e:
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
except Exception as e:
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = ""):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.condition:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
result_text = f"[条件判断: {self.condition[:80]}]\n输入: {user_text[:300]}\n结果: 条件满足,继续执行。"
return Msg(self.name, result_text, "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = "", tool_name: str = ""):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
self.tool_name = tool_name
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.server_name:
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
try:
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
client = MCPClientManager.get_http_client(self.server_name)
if client and self.tool_name:
result = await client.call_tool(self.tool_name, {"input": user_text})
return Msg(self.name, str(result), "assistant")
except ImportError:
logger.warning("agentscope_runtime MCP 客户端不可用")
except Exception as e:
logger.warning(f"MCP 调用失败: {e}")
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class WeComNotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"WeComNotify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
template = self.config.get("message_template", "")
target = self.config.get("target", "")
message = template or user_text[:500]
try:
from agentscope_integration.tools.wecom_tools import send_notification
result = send_notification(to_user=target or "user", message=message)
return Msg(self.name, result, "assistant")
except ImportError:
pass
except Exception as e:
logger.warning(f"企微通知发送失败: {e}")
result = f"[企微通知] 已向 {target or '用户'} 发送: {message[:100]}"
return Msg(self.name, result, "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
top_k = self.config.get("top_k", 5)
try:
from modules.rag.knowledge import retrieve_for_agent
kb_result = await retrieve_for_agent(user_text, limit=top_k)
model = self._get_model()
formatter = self._get_formatter()
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
知识库检索结果:
{kb_result}
用户问题: {user_text}
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
import asyncio
loop = asyncio.get_event_loop()
messages = await asyncio.to_thread(formatter.format, [
{"role": "system", "content": rag_prompt},
{"role": "user", "content": user_text},
])
res = await model(messages)
res_text = res.get_text_content() if hasattr(res, 'get_text_content') else str(res)
return Msg(self.name, res_text, "assistant")
except Exception as e:
logger.warning(f"RAG 节点执行失败: {e}")
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
return Msg(self.name, output, "assistant")
def _get_model(self):
from agentscope.model import OpenAIChatModel
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
def _get_formatter(self):
from agentscope.formatter import OpenAIChatFormatter
return OpenAIChatFormatter()
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
output_format = self.config.get("format", "text")
if output_format == "json":
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
parsed = json.loads(content)
formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
return Msg(self.name, formatted, "assistant")
except (json.JSONDecodeError, ValueError):
pass
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result