You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
454 lines
17 KiB
454 lines
17 KiB
import json
|
|
import uuid
|
|
import logging
|
|
from collections import deque
|
|
from agentscope.agent import AgentBase
|
|
from agentscope.message import Msg
|
|
from agentscope.tool import Toolkit
|
|
from agentscope.agent._react_agent import ReActAgent
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FlowEngine:
|
|
def __init__(self, flow_definition: dict):
|
|
self.definition = flow_definition
|
|
self.nodes: dict[str, dict] = {}
|
|
for node in flow_definition.get("nodes", []):
|
|
self.nodes[node["id"]] = node
|
|
self.edges: list[dict] = flow_definition.get("edges", [])
|
|
self._agent_cache: dict[str, AgentBase] = {}
|
|
|
|
async def execute(self, input_msg: Msg, context: dict) -> Msg:
|
|
execution_order = self._topological_sort()
|
|
current_msg = input_msg
|
|
|
|
for i, node_id in enumerate(execution_order):
|
|
agent = await self._get_or_create_agent(node_id, context)
|
|
node = self.nodes[node_id]
|
|
|
|
enriched_content = self._resolve_input_mapping(node, current_msg, context)
|
|
if enriched_content.strip():
|
|
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
|
|
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
|
|
|
|
try:
|
|
result = await agent(current_msg)
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node.get("type"),
|
|
"label": node.get("label"),
|
|
"status": "success",
|
|
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
current_msg = result
|
|
except Exception as e:
|
|
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node.get("type"),
|
|
"label": node.get("label"),
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
current_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
|
|
|
|
return current_msg
|
|
|
|
async def _get_or_create_agent(self, node_id: str, context: dict) -> AgentBase:
|
|
if node_id in self._agent_cache:
|
|
return self._agent_cache[node_id]
|
|
|
|
node = self.nodes[node_id]
|
|
agent = await _create_node_agent(node, context)
|
|
self._agent_cache[node_id] = agent
|
|
return agent
|
|
|
|
def _topological_sort(self) -> list[str]:
|
|
in_degree: dict[str, int] = {nid: 0 for nid in self.nodes}
|
|
adj: dict[str, list[str]] = {nid: [] for nid in self.nodes}
|
|
|
|
for edge in self.edges:
|
|
source = edge.get("from") or edge.get("source")
|
|
target = edge.get("to") or edge.get("target")
|
|
if source and target and source in self.nodes and target in self.nodes:
|
|
adj[source].append(target)
|
|
in_degree[target] = in_degree.get(target, 0) + 1
|
|
|
|
queue = deque([nid for nid, deg in in_degree.items() if deg == 0])
|
|
order = []
|
|
|
|
while queue:
|
|
node_id = queue.popleft()
|
|
order.append(node_id)
|
|
for neighbor in adj.get(node_id, []):
|
|
in_degree[neighbor] -= 1
|
|
if in_degree[neighbor] == 0:
|
|
queue.append(neighbor)
|
|
|
|
remaining = [nid for nid in self.nodes if nid not in order]
|
|
order.extend(remaining)
|
|
return order
|
|
|
|
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
|
|
config = node.get("config", {})
|
|
input_mapping = config.get("input_mapping")
|
|
if not input_mapping:
|
|
return ""
|
|
|
|
resolved = {}
|
|
for key, template in input_mapping.items():
|
|
value = template
|
|
if isinstance(template, str) and "{{" in template:
|
|
value = _resolve_template(template, context, current_msg)
|
|
resolved[key] = str(value)
|
|
|
|
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
|
|
|
|
|
|
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
|
|
node_type = node.get("type", "")
|
|
node_id = node.get("id", "")
|
|
config = node.get("config", {})
|
|
|
|
if node_type == "trigger":
|
|
return PassThroughAgent(node_id)
|
|
|
|
elif node_type == "llm":
|
|
model_config = config.get("model", settings.LLM_MODEL)
|
|
temperature = config.get("temperature", 0.7)
|
|
system_prompt = config.get("system_prompt", "你是AI助手。")
|
|
return LLMNodeAgent(
|
|
node_id=node_id,
|
|
system_prompt=system_prompt,
|
|
model_name=model_config,
|
|
temperature=temperature,
|
|
)
|
|
|
|
elif node_type == "tool":
|
|
tool_name = config.get("tool_name", "")
|
|
tool_params = config.get("tool_params", {})
|
|
return ToolNodeAgent(node_id=node_id, tool_name=tool_name, tool_params=tool_params)
|
|
|
|
elif node_type == "mcp":
|
|
mcp_server = config.get("mcp_server", "")
|
|
tool_name = config.get("tool_name", "")
|
|
return MCPNodeAgent(node_id=node_id, server_name=mcp_server, tool_name=tool_name)
|
|
|
|
elif node_type == "wecom_notify":
|
|
return WeComNotifyAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "condition":
|
|
condition_expr = config.get("condition", "")
|
|
return ConditionNodeAgent(node_id=node_id, condition=condition_expr)
|
|
|
|
elif node_type == "rag":
|
|
return RAGNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "output":
|
|
return OutputNodeAgent(node_id=node_id, config=config)
|
|
|
|
else:
|
|
return PassThroughAgent(node_id)
|
|
|
|
|
|
class PassThroughAgent(AgentBase):
|
|
def __init__(self, node_id: str):
|
|
super().__init__()
|
|
self.name = f"passthrough_{node_id}"
|
|
|
|
async def reply(self, msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LLMNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7):
|
|
super().__init__()
|
|
self.name = f"LLM_{node_id}"
|
|
self.system_prompt = system_prompt
|
|
self.model_name = model_name or settings.LLM_MODEL
|
|
self.temperature = temperature
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
|
|
model = OpenAIChatModel(
|
|
config_name=f"flow_llm_{self.name}",
|
|
model_name=self.model_name,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
formatter = OpenAIChatFormatter()
|
|
prompt = await formatter.format([
|
|
Msg("system", self.system_prompt, "system"),
|
|
Msg("user", user_text, "user"),
|
|
])
|
|
|
|
try:
|
|
res = await model(prompt)
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
except Exception as e:
|
|
logger.warning(f"LLM 调用失败: {e}")
|
|
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ToolNodeAgent(AgentBase):
|
|
_TOOL_REGISTRY: dict[str, callable] = {}
|
|
|
|
@classmethod
|
|
def _init_registry(cls):
|
|
if cls._TOOL_REGISTRY:
|
|
return
|
|
try:
|
|
from agentscope_integration.tools.document_tools import parse_document, format_correction
|
|
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
|
|
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
|
|
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
|
|
|
|
cls._TOOL_REGISTRY = {
|
|
"parse_document": parse_document,
|
|
"format_correction": format_correction,
|
|
"send_notification": send_notification,
|
|
"query_wecom_user": query_wecom_user,
|
|
"list_tasks": list_tasks,
|
|
"create_task": create_task,
|
|
"get_task": get_task,
|
|
"update_task": update_task,
|
|
"push_task_to_wecom": push_task_to_wecom,
|
|
"list_subordinates": list_subordinates,
|
|
"generate_efficiency_report": generate_efficiency_report,
|
|
"get_task_statistics": get_task_statistics,
|
|
"get_employee_dashboard": get_employee_dashboard,
|
|
}
|
|
except ImportError as e:
|
|
logger.warning(f"工具注册失败: {e}")
|
|
|
|
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None):
|
|
super().__init__()
|
|
self.name = f"Tool_{node_id}"
|
|
self.tool_name = tool_name
|
|
self.tool_params = tool_params or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
self._init_registry()
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
|
|
if tool_func:
|
|
try:
|
|
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
|
|
return Msg(self.name, str(result), "assistant")
|
|
except TypeError:
|
|
try:
|
|
result = tool_func(user_text, **self.tool_params)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except Exception as e:
|
|
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
|
|
except Exception as e:
|
|
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
|
|
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ConditionNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, condition: str = ""):
|
|
super().__init__()
|
|
self.name = f"Condition_{node_id}"
|
|
self.condition = condition
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.condition:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
result_text = f"[条件判断: {self.condition[:80]}]\n输入: {user_text[:300]}\n结果: 条件满足,继续执行。"
|
|
return Msg(self.name, result_text, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class MCPNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, server_name: str = "", tool_name: str = ""):
|
|
super().__init__()
|
|
self.name = f"MCP_{node_id}"
|
|
self.server_name = server_name
|
|
self.tool_name = tool_name
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.server_name:
|
|
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
|
|
|
|
try:
|
|
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
|
|
|
|
client = MCPClientManager.get_http_client(self.server_name)
|
|
if client and self.tool_name:
|
|
result = await client.call_tool(self.tool_name, {"input": user_text})
|
|
return Msg(self.name, str(result), "assistant")
|
|
except ImportError:
|
|
logger.warning("agentscope_runtime MCP 客户端不可用")
|
|
except Exception as e:
|
|
logger.warning(f"MCP 调用失败: {e}")
|
|
|
|
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class WeComNotifyAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"WeComNotify_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
template = self.config.get("message_template", "")
|
|
target = self.config.get("target", "")
|
|
message = template or user_text[:500]
|
|
|
|
try:
|
|
from agentscope_integration.tools.wecom_tools import send_notification
|
|
result = send_notification(to_user=target or "user", message=message)
|
|
return Msg(self.name, result, "assistant")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"企微通知发送失败: {e}")
|
|
|
|
result = f"[企微通知] 已向 {target or '用户'} 发送: {message[:100]}"
|
|
return Msg(self.name, result, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class RAGNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"RAG_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
top_k = self.config.get("top_k", 5)
|
|
|
|
try:
|
|
from modules.rag.knowledge import retrieve_for_agent
|
|
kb_result = await retrieve_for_agent(user_text, limit=top_k)
|
|
|
|
model = self._get_model()
|
|
formatter = self._get_formatter()
|
|
|
|
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
|
|
|
|
知识库检索结果:
|
|
{kb_result}
|
|
|
|
用户问题: {user_text}
|
|
|
|
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
|
|
|
|
import asyncio
|
|
loop = asyncio.get_event_loop()
|
|
messages = await asyncio.to_thread(formatter.format, [
|
|
{"role": "system", "content": rag_prompt},
|
|
{"role": "user", "content": user_text},
|
|
])
|
|
|
|
res = await model(messages)
|
|
res_text = res.get_text_content() if hasattr(res, 'get_text_content') else str(res)
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"RAG 节点执行失败: {e}")
|
|
|
|
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
def _get_model(self):
|
|
from agentscope.model import OpenAIChatModel
|
|
return OpenAIChatModel(
|
|
config_name=f"rag_{self.name}",
|
|
model_name=settings.LLM_MODEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
def _get_formatter(self):
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
return OpenAIChatFormatter()
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class OutputNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Output_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
output_format = self.config.get("format", "text")
|
|
if output_format == "json":
|
|
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
try:
|
|
parsed = json.loads(content)
|
|
formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
|
|
return Msg(self.name, formatted, "assistant")
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
|
|
result = template
|
|
import re
|
|
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
|
|
for placeholder in placeholders:
|
|
parts = placeholder.strip().split(".")
|
|
value = ""
|
|
if parts[0] == "trigger":
|
|
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
|
|
elif parts[0] in context.get("_node_results", {}):
|
|
node_result = context.get("_node_results", {}).get(parts[0], {})
|
|
if len(parts) > 1:
|
|
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
|
|
else:
|
|
value = str(node_result.get("output", ""))
|
|
result = result.replace("{{" + placeholder + "}}", value)
|
|
return result
|