You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1510 lines
60 KiB

"""流引擎核心模块。
定义 FlowEngine 流程执行引擎及各类节点 Agent,包括:
- FlowEngine:流程图的解析与遍历执行器
- LLMNodeAgent / ToolNodeAgent / MCPNodeAgent 等各类节点处理器
"""
import json
import uuid
import logging
import re
import asyncio
import httpx
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.agent._react_agent import ReActAgent
from config import settings
logger = logging.getLogger(__name__)
async def _resolve_model_instance(model_instance_id: str) -> dict | None:
"""根据模型实例 ID 从数据库解析模型配置(模型名、base_url、api_key)。"""
try:
from database import AsyncSessionLocal
from sqlalchemy import text
import uuid as _uuid
uid = _uuid.UUID(model_instance_id)
async with AsyncSessionLocal() as db:
result = await db.execute(
text("""
SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key
FROM model_instances mi
JOIN model_providers mp ON mi.provider_id = mp.id
WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true
"""),
{"id": uid},
)
row = result.fetchone()
if row:
return {
"model": row[0],
"base_url": row[2],
"api_key": row[3],
"params": row[1] or {},
}
except Exception:
pass
return None
class FlowSessionMemory:
"""流程会话级短期记忆,存储当前对话轮次的消息列表。"""
def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id
self.user_id = user_id
self._messages: list[dict] = []
def get_history(self, limit: int = 10) -> list[dict]:
"""获取最近的消息历史。"""
return self._messages[-limit * 2:]
def add(self, role: str, content: str):
"""添加一条消息到历史。"""
self._messages.append({"role": role, "content": content})
def to_list(self) -> list[dict]:
"""返回全部消息列表。"""
return list(self._messages)
class FlowEngine:
"""流程执行引擎,解析流程图定义并按拓扑顺序遍历执行各节点。"""
MAX_TOTAL_ITERATIONS = 200 # 全局最大迭代次数(防止死循环)
FLOW_TIMEOUT_SECONDS = 300 # 单次流程执行超时时间
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
graph = self._build_graph()
start_nodes = self._find_start_nodes(graph)
if not start_nodes:
start_nodes = list(self.nodes.keys())[:1]
session_id = context.get("session_id", str(uuid.uuid4()))
user_id = context.get("user_id", "")
memory = FlowSessionMemory(session_id=session_id, user_id=user_id)
context["_memory"] = memory
visited: set[str] = set()
loop_iterations: dict[str, int] = {}
total_iterations = 0
last_result: Msg | None = None
async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None:
nonlocal last_result, total_iterations
total_iterations += 1
if total_iterations > self.MAX_TOTAL_ITERATIONS:
logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止")
last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system")
return
node = self.nodes.get(node_id)
if not node:
return
node_type = node.get("type", "")
is_loop = node_type == "loop"
if not is_loop and node_id in visited:
return
if not is_loop:
visited.add(node_id)
if is_loop:
loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1
config = node.get("config", {})
max_iter = config.get("max_iterations", 10)
if loop_iterations[node_id] > max_iter:
logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}")
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, incoming_msg)
return
if is_loop:
loop_node_config = dict(node.get("config", {}))
loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0)
agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config)
else:
agent = await self._get_or_create_agent(node_id, context)
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
current_msg = incoming_msg
if enriched_content.strip():
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
try:
result = await agent.reply(current_msg, context=context)
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
last_result = result
if is_loop:
loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0))
if loop_done:
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, result)
else:
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_body":
body_node = self.nodes.get(target_id)
if body_node:
visited.discard(target_id)
await traverse(target_id, result, {"iteration": loop_iterations[node_id]})
return
is_condition = node_type == "condition"
cond_result = self._parse_condition_result(result)
for target_id, edge_cond in graph.get(node_id, []):
if is_condition:
if edge_cond and edge_cond == cond_result:
await traverse(target_id, result)
elif edge_cond == "loop_body" or edge_cond == "loop_done":
continue
else:
await traverse(target_id, result)
except Exception as e:
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
last_result = error_msg
if start_nodes:
try:
await asyncio.wait_for(
traverse(start_nodes[0], input_msg),
timeout=self.FLOW_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止")
last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system")
return last_result or input_msg
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("source") or edge.get("from")
target = edge.get("target") or edge.get("to")
cond = edge.get("condition") or edge.get("sourceHandle")
if cond == "source":
cond = None
if source and target and source in self.nodes and target in self.nodes:
graph[source].append((target, cond))
return graph
def _find_start_nodes(self, graph: dict) -> list[str]:
target_nodes: set[str] = set()
for targets in graph.values():
for target_id, _ in targets:
target_nodes.add(target_id)
return [nid for nid in self.nodes if nid not in target_nodes]
def _parse_condition_result(self, result: Msg) -> str | None:
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
m = re.search(r'condition:(true|false)', content)
if m:
return m.group(1)
return None
def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool:
config = node.get("config", {})
loop_type = config.get("loop_type", "fixed")
if loop_type == "fixed":
count = config.get("count", 3)
return iteration >= count
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE):
return True
return False
async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase:
if node_id in self._agent_cache and override_config is None:
return self._agent_cache[node_id]
node = dict(self.nodes[node_id])
if override_config is not None:
node["config"] = override_config
agent = await _create_node_agent(node, context)
if override_config is None:
self._agent_cache[node_id] = agent
return agent
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
model_instance_id = config.get("model_instance_id")
base_url = settings.LLM_API_BASE
api_key = settings.LLM_API_KEY
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
model_config = resolved.get("model", model_config)
base_url = resolved.get("base_url", base_url)
api_key = resolved.get("api_key", api_key)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
max_tokens = config.get("max_tokens", 2000)
stream = config.get("stream", True)
stream_cb = context.get("_stream_callback")
agent = LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
stream_callback=stream_cb,
base_url=base_url,
api_key=api_key,
)
memory = context.get("_memory")
if memory:
agent.set_memory(memory)
return agent
elif node_type == "tool":
tool_name = config.get("tool_name", "")
tool_params = config.get("tool_params", {})
timeout = config.get("timeout", 30)
retry_count = config.get("retry_count", 0)
error_handling = config.get("error_handling", "throw")
return ToolNodeAgent(
node_id=node_id,
tool_name=tool_name,
tool_params=tool_params,
timeout=timeout,
retry_count=retry_count,
error_handling=error_handling,
)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
tool_name = config.get("tool_name", "")
timeout = config.get("timeout", 30)
error_handling = config.get("error_handling", "throw")
return MCPNodeAgent(
node_id=node_id,
server_name=mcp_server,
tool_name=tool_name,
timeout=timeout,
error_handling=error_handling,
)
elif node_type in ("wecom_notify", "notify"):
return NotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition_expr = config.get("condition", "")
condition_type = config.get("condition_type", "expression")
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
elif node_type == "rag":
model_instance_id = config.get("model_instance_id")
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
config = dict(config)
config["_resolved_model"] = resolved
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
elif node_type == "merge":
return ParallelMergeNodeAgent(node_id=node_id, config=config)
elif node_type == "loop":
return LoopNodeAgent(node_id=node_id, config=config)
elif node_type == "code":
language = config.get("language", "python")
code = config.get("code", "")
timeout = config.get("timeout", 30)
sandbox = config.get("sandbox", True)
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
elif node_type == "http_request":
return HttpRequestNodeAgent(node_id=node_id, config=config)
elif node_type == "question_classifier":
return QuestionClassifierNodeAgent(node_id=node_id, config=config)
elif node_type == "variable_assigner":
return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context)
elif node_type == "template_transform":
return TemplateTransformNodeAgent(node_id=node_id, config=config)
elif node_type == "iteration":
return IterationNodeAgent(node_id=node_id, config=config)
elif node_type == "question_optimiser":
return QuestionOptimiserNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class HttpRequestNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"HttpRequest_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
method = self.config.get("method", "GET").upper()
url = self.config.get("url", "")
headers = self.config.get("headers", {})
body = self.config.get("body", "")
auth_type = self.config.get("auth_type", "none")
auth_config = self.config.get("auth_config", {})
timeout = self.config.get("timeout", 30)
retry_count = self.config.get("retry_count", 0)
if isinstance(headers, str):
try:
headers = json.loads(headers)
except (json.JSONDecodeError, ValueError):
headers = {}
request_headers = dict(headers)
if auth_type == "bearer" and auth_config.get("token"):
request_headers["Authorization"] = f"Bearer {auth_config['token']}"
elif auth_type == "api_key" and auth_config.get("api_key"):
key_name = auth_config.get("key_name", "X-API-Key")
request_headers[key_name] = auth_config["api_key"]
elif auth_type == "basic" and auth_config.get("username"):
import base64
credentials = f"{auth_config['username']}:{auth_config.get('password', '')}"
request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}"
last_error = None
for attempt in range(retry_count + 1):
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
request_kwargs = {"headers": request_headers}
if method in ("POST", "PUT", "PATCH"):
if isinstance(body, dict):
request_kwargs["json"] = body
elif body and isinstance(body, str):
try:
request_kwargs["json"] = json.loads(body)
except (json.JSONDecodeError, ValueError):
request_kwargs["content"] = body
response = await client.request(method, url, **request_kwargs)
resp_status = response.status_code
resp_text = response.text
try:
resp_body = json.loads(resp_text)
except (json.JSONDecodeError, ValueError):
resp_body = resp_text
result = json.dumps({
"status_code": resp_status,
"body": resp_body,
}, ensure_ascii=False)
return Msg(self.name, result, "assistant")
except Exception as e:
last_error = str(e)
if attempt < retry_count:
await asyncio.sleep(1)
error_result = json.dumps({
"status_code": 0,
"error": last_error or "unknown error",
"body": None,
}, ensure_ascii=False)
return Msg(self.name, error_result, "assistant")
async def observe(self, msg) -> None:
pass
class QuestionClassifierNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Classifier_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
categories = self.config.get("categories", [])
instruction = self.config.get("instruction", "")
model_name = self.config.get("model", settings.LLM_MODEL)
temperature = self.config.get("temperature", 0.3)
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not categories:
return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant")
category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories])
prompt = f"""请对以下用户输入进行意图分类。
分类选项:
{category_desc}
{instruction}
用户输入:{user_text}
请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}"""
try:
import openai
client = openai.AsyncOpenAI(
base_url=settings.LLM_API_BASE,
api_key=settings.LLM_API_KEY,
)
response = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=200,
)
result_text = response.choices[0].message.content.strip()
try:
parsed = json.loads(result_text)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant")
except Exception as e:
logger.error(f"QuestionClassifier error: {e}")
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant")
async def observe(self, msg) -> None:
pass
class VariableAssignerNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None, context: dict = None):
super().__init__()
self.name = f"VarAssign_{node_id}"
self.config = config or {}
self._context = context or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
assignments = self.config.get("assignments", [])
results = {}
for assignment in assignments:
target_var = assignment.get("target_var", "")
source_type = assignment.get("source_type", "constant")
source_value = assignment.get("source_value", "")
if source_type == "constant":
value = source_value
elif source_type == "upstream_output":
value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
elif source_type == "template":
value = _resolve_template(source_value, self._context, msg)
elif source_type == "expression":
try:
safe_locals = {"msg": msg, "context": self._context}
value = eval(source_value, {"__builtins__": {}}, safe_locals)
value = str(value)
except Exception as e:
value = f"[expression error: {e}]"
else:
value = source_value
results[target_var] = value
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.stream_callback = stream_callback
self.base_url = base_url or settings.LLM_API_BASE
self.api_key = api_key or settings.LLM_API_KEY
self._memory = None
def set_memory(self, memory):
self._memory = memory
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
context = kwargs.get("context", {})
memory_ctx = context.get("_memory_context", {})
messages = [{"role": "system", "content": self.system_prompt}]
if memory_ctx:
summary = memory_ctx.get("summary", "")
recent = memory_ctx.get("recent_messages", [])
if summary:
messages.append({"role": "system", "content": f"[历史对话摘要]\n{summary}"})
for m in recent[-10:]:
role = m.get("role", "user")
content = m.get("content", "")
if len(content) > 2000:
content = content[:2000]
messages.append({"role": role, "content": content})
elif self._memory:
history = self._memory.get_history(limit=5)
for h in history:
role = h.get("role", "user")
content = h.get("content", "")
if len(content) > 2000:
content = content[:2000]
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": user_text})
try:
if self.stream_callback:
res_text = await self._stream_llm_call(messages)
else:
res_text = await self._blocking_llm_call(messages)
except Exception as e:
logger.warning(f"LLM 调用失败: {e}")
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
if self._memory and not memory_ctx:
self._memory.add("user", user_text)
self._memory.add("assistant", res_text)
return Msg(self.name, res_text, "assistant")
async def _blocking_llm_call(self, messages: list[dict]) -> str:
import httpx
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
resp = await client.post(url, json=body, headers=headers)
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception as e:
logger.warning(f"LLM 阻塞调用失败: {e}")
return f"[LLM 调用失败: {e}]"
async def _stream_llm_call(self, messages: list[dict]) -> str:
import httpx
import json
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": True,
}
accumulated = ""
try:
timeout = httpx.Timeout(60.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", url, json=body, headers=headers) as response:
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
json_str = line[6:].strip()
if json_str == "[DONE]":
break
try:
chunk = json.loads(json_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
token = delta.get("content", "")
if token:
accumulated += token
await self.stream_callback("text_chunk", {"content": token})
except json.JSONDecodeError:
continue
except httpx.TimeoutException:
logger.warning("LLM 流式调用超时")
if not accumulated:
accumulated = "[LLM 超时]"
except Exception as e:
logger.warning(f"LLM 流式调用失败: {e}")
if not accumulated:
accumulated = f"[LLM 调用失败: {e}]"
return accumulated
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
_TOOL_REGISTRY: dict[str, callable] = {}
_TOOL_SCHEMAS: dict[str, dict] = {}
_CUSTOM_TOOL_DEFS: dict[str, dict] = {}
@classmethod
def _init_registry(cls):
if cls._TOOL_REGISTRY:
return
try:
from agentscope_integration.tools.document_tools import parse_document, format_correction
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
cls._TOOL_REGISTRY = {
"parse_document": parse_document,
"format_correction": format_correction,
"send_notification": send_notification,
"query_wecom_user": query_wecom_user,
"list_tasks": list_tasks,
"create_task": create_task,
"get_task": get_task,
"update_task": update_task,
"push_task_to_wecom": push_task_to_wecom,
"list_subordinates": list_subordinates,
"generate_efficiency_report": generate_efficiency_report,
"get_task_statistics": get_task_statistics,
"get_employee_dashboard": get_employee_dashboard,
}
for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]:
try:
mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"])
if hasattr(mod, "SCHEMAS"):
cls._TOOL_SCHEMAS.update(mod.SCHEMAS)
except Exception:
pass
except ImportError as e:
logger.warning(f"工具注册失败: {e}")
@classmethod
async def load_custom_tools(cls, db):
try:
from sqlalchemy import select
from models import CustomTool
from modules.custom_tool.executor import CustomToolExecutor
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True)
)
for tool in result.scalars().all():
tool_def = {
"endpoint_url": tool.endpoint_url,
"method": tool.method,
"path": tool.path,
"headers_json": tool.headers_json,
"auth_type": tool.auth_type,
"auth_config": tool.auth_config,
"timeout": 30,
}
cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def
cls._TOOL_SCHEMAS[tool.name] = tool.schema_json
except ImportError:
logger.warning("无法加载自定义工具模块")
except Exception as e:
logger.warning(f"加载自定义工具失败: {e}")
@classmethod
def register_custom_tool(cls, name: str, schema: dict, tool_def: dict):
cls._CUSTOM_TOOL_DEFS[name] = tool_def
cls._TOOL_SCHEMAS[name] = schema
@classmethod
def get_schemas(cls) -> dict:
cls._init_registry()
return dict(cls._TOOL_SCHEMAS)
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
self.tool_params = tool_params or {}
self.timeout = timeout
self.retry_count = retry_count
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
self._init_registry()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name)
if custom_def:
return await self._execute_custom_tool(custom_def, user_text)
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
if not tool_func:
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
def _call_sync():
try:
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
return result
except TypeError:
try:
result = tool_func(user_text, **self.tool_params)
return result
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
asyncio.to_thread(_call_sync),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except RuntimeError as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[{e}]", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant")
async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg:
from modules.custom_tool.executor import CustomToolExecutor
executor = CustomToolExecutor(custom_def)
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
executor.execute(self.tool_params),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
self.condition_type = condition_type
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.condition:
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
if self.condition_type == "regex":
try:
matched = bool(re.search(self.condition, user_text))
return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'' if matched else '未命'}", "assistant")
except re.error as e:
return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant")
if self.condition_type == "json_path":
try:
data = json.loads(user_text) if user_text.strip().startswith('{') else {}
parts = self.condition.strip('$.').split('.')
val = data
for p in parts:
if isinstance(val, dict):
val = val.get(p, None)
else:
val = None
return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant")
except:
return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant")
try:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"condition_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
formatter = OpenAIChatFormatter()
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
条件表达式: {self.condition}
输入内容:
{user_text[:2000]}
请严格只输出一行 JSON:
{{"result": true/false, "reason": "简要原因"}}"""
prompt = formatter.format([
Msg("system", condition_prompt, "system"),
Msg("user", user_text[:2000], "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
json_match = re.search(r'\{[^}]+\}', res_text)
if json_match:
parsed = json.loads(json_match.group())
matched = parsed.get("result", False)
reason = parsed.get("reason", "")
result_flag = "true" if matched else "false"
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
except Exception as e:
logger.warning(f"条件判断LLM调用失败: {e}")
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
self.tool_name = tool_name
self.timeout = timeout
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.server_name:
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
try:
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
client = MCPClientManager.get_http_client(self.server_name)
if client and self.tool_name:
result = await asyncio.wait_for(
client.call_tool(self.tool_name, {"input": user_text}),
timeout=self.timeout
)
return Msg(self.name, str(result), "assistant")
except ImportError:
logger.warning("agentscope_runtime MCP 客户端不可用")
except asyncio.TimeoutError:
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant")
return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant")
except Exception as e:
logger.warning(f"MCP 调用失败: {e}")
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant")
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class NotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Notify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
channels = self.config.get("channels", {"wecom": True, "web": False})
message_type = self.config.get("message_type", "text")
results = []
if channels.get("wecom", True):
template = self.config.get("message_template", "")
target = self.config.get("target", "")
message = template or user_text[:500]
try:
from agentscope_integration.tools.wecom_tools import send_notification
result = send_notification(to_user=target or "user", message=message)
results.append(f"企微通知: {result[:100]}")
except ImportError:
pass
except Exception as e:
logger.warning(f"企微通知发送失败: {e}")
results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}")
if channels.get("web", False):
web_template = self.config.get("web_template", "")
web_message = web_template or user_text[:500]
try:
from websocket_manager import ws_manager
target_user = self.config.get("target", "")
if target_user:
await ws_manager.send_to_user(target_user, {
"type": "flow_notification",
"message": web_message,
"level": "info",
})
results.append(f"Web通知: 已推送")
except ImportError:
pass
except Exception as e:
logger.warning(f"Web通知推送失败: {e}")
results.append(f"Web通知: 推送失败({e})")
if not results:
results.append("通知已发送")
return Msg(self.name, " | ".join(results), "assistant")
async def observe(self, msg) -> None:
pass
class LoopNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Loop_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
loop_type = self.config.get("loop_type", "fixed")
count = self.config.get("count", 3)
iterator_variable = self.config.get("iterator_variable", "item")
iteration = self.config.get("_engine_iteration", 0) + 1
if loop_type == "array":
items = self.config.get("items", [])
if isinstance(items, str):
try:
items = json.loads(items)
except Exception:
items = [line.strip() for line in items.split("\n") if line.strip()]
if isinstance(items, list) and items:
idx = iteration - 1
if idx >= len(items):
return Msg(self.name, "loop:stop|数组迭代完成", "assistant")
current_item = items[idx]
return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, "loop:stop|无数组数据", "assistant")
if loop_type == "fixed":
if iteration >= count:
return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant")
return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
async def observe(self, msg) -> None:
pass
class CodeNodeAgent(AgentBase):
def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True):
super().__init__()
self.name = f"Code_{node_id}"
self.language = language
self.code = code
self.timeout = timeout
self.sandbox = sandbox
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
import subprocess
import tempfile
import os
import json
if not self.code.strip():
return Msg(self.name, user_text, "assistant")
safe_input = json.dumps(user_text)
code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}"
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
f.write(code_with_input)
temp_path = f.name
proc = await asyncio.create_subprocess_exec(
"python", temp_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
except asyncio.TimeoutError:
proc.kill()
try:
os.unlink(temp_path)
except:
pass
return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant")
try:
os.unlink(temp_path)
except:
pass
if stderr:
return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant")
output = stdout.decode('utf-8', errors='replace').strip()
return Msg(self.name, output or "[代码执行完成,无输出]", "assistant")
except Exception as e:
return Msg(self.name, f"[代码执行失败: {e}]", "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
top_k = self.config.get("top_k", 5)
similarity_threshold = self.config.get("similarity_threshold", 0.7)
search_mode = self.config.get("search_mode", "hybrid")
include_metadata = self.config.get("include_metadata", True)
try:
from modules.rag.knowledge import retrieve_for_agent
kb_result = await retrieve_for_agent(user_text, limit=top_k)
model = self._get_model()
formatter = self._get_formatter()
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
知识库检索结果:
{kb_result}
用户问题: {user_text}
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
prompt = formatter.format([
Msg("system", rag_prompt, "system"),
Msg("user", user_text, "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
return Msg(self.name, res_text, "assistant")
except Exception as e:
logger.warning(f"RAG 节点执行失败: {e}")
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
return Msg(self.name, output, "assistant")
def _get_model(self):
from agentscope.model import OpenAIChatModel
resolved = self.config.get("_resolved_model", {})
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=resolved.get("model", settings.LLM_MODEL),
api_key=resolved.get("api_key", settings.LLM_API_KEY),
api_base=resolved.get("base_url", settings.LLM_API_BASE),
)
def _get_formatter(self):
from agentscope.formatter import OpenAIChatFormatter
return OpenAIChatFormatter()
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
output_format = self.config.get("format", "text")
output_template = self.config.get("output_template", "")
truncate = self.config.get("truncate", False)
max_length = self.config.get("max_length", 2000)
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if truncate and len(content) > max_length:
content = content[:max_length] + "\n\n[内容已截断]"
if output_format == "json":
try:
parsed = json.loads(content)
indent = self.config.get("indent", 2)
formatted = json.dumps(parsed, indent=indent, ensure_ascii=False)
return Msg(self.name, formatted, "assistant")
except (json.JSONDecodeError, ValueError):
pass
if output_template:
try:
resolved = output_template.replace("{{output}}", content)
return Msg(self.name, resolved, "assistant")
except:
pass
return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result
class TemplateTransformNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Template_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
template = self.config.get("template", "")
output_type = self.config.get("output_type", "string")
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
rendered = _resolve_template(template, context, msg)
if not rendered:
rendered = template
rendered = rendered.replace("{{input}}", user_text)
except Exception:
rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text
if output_type == "json":
try:
parsed = json.loads(rendered)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
pass
return Msg(self.name, rendered, "assistant")
async def observe(self, msg) -> None:
pass
class IterationNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Iteration_{node_id}"
self.config = config or {}
self._results: list[str] = []
async def reply(self, msg: Msg, **kwargs) -> Msg:
input_array_source = self.config.get("input_array_source", "")
max_iterations = self.config.get("max_iterations", 20)
context = kwargs.get("context", {})
items = []
if input_array_source:
resolved = _resolve_template(input_array_source, context, msg)
try:
items = json.loads(resolved)
if not isinstance(items, list):
items = [resolved]
except (json.JSONDecodeError, ValueError):
items = [resolved] if resolved else []
else:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
items = json.loads(user_text)
if not isinstance(items, list):
items = [user_text]
except (json.JSONDecodeError, ValueError):
items = [item.strip() for item in user_text.split("\n") if item.strip()]
if not items:
items = [user_text]
items = items[:max_iterations]
results = []
for i, item in enumerate(items):
results.append({
"index": i,
"item": item,
})
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
def get_iteration_items(self) -> list:
return self._results
async def observe(self, msg) -> None:
pass
class QuestionOptimiserNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"QOpt_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
optimization_type = self.config.get("optimization_type", "rewrite")
model_name = self.config.get("model", settings.LLM_MODEL)
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
history = context.get("_memory", {}).get("recent_messages", [])
persona = context.get("_memory", {}).get("persona", {})
atoms = context.get("_memory", {}).get("atoms", [])
persona_text = persona.get("raw_text", "")
atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else ""
history_text = "\n".join(
f"{m['role']}: {m['content'][:200]}" for m in history[-6:]
) if history else ""
context_parts = []
if persona_text:
context_parts.append(f"用户画像: {persona_text}")
if atoms_text:
context_parts.append(f"已知信息: {atoms_text}")
if history_text:
context_parts.append(f"近期对话: {history_text}")
context_block = "\n".join(context_parts) if context_parts else ""
if optimization_type == "rewrite":
prompt = f"""{context_block}
原始问题: {user_text}
请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。
只返回优化后的问题,不要其他内容。"""
elif optimization_type == "expand":
prompt = f"""{context_block}
简短问题: {user_text}
请将以上问题扩展为更详细的版本,添加必要的背景和细节。
只返回扩展后的问题,不要其他内容。"""
else:
return Msg(self.name, user_text, "assistant")
try:
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text)
return Msg(self.name, result.strip(), "assistant")
except Exception as e:
logger.warning(f"问题优化失败: {e}")
return Msg(self.name, user_text, "assistant")
async def observe(self, msg) -> None:
pass
class ParallelMergeNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Merge_{node_id}"
self.config = config or {}
self._received: dict[str, str] = {}
self.merge_type = self.config.get("merge_type", "concat")
async def reply(self, msg: Msg, **kwargs) -> Msg:
source_id = kwargs.get("source_node_id", "")
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
self._received[source_id] = content
expected_count = self.config.get("expected_branches", 0)
if expected_count <= 0 or len(self._received) >= expected_count:
return self._merge()
return Msg(self.name, "", "assistant")
def _merge(self) -> Msg:
if self.merge_type == "json":
merged = json.dumps(self._received, ensure_ascii=False)
elif self.merge_type == "first_non_empty":
merged = ""
for v in self._received.values():
if v.strip():
merged = v
break
else:
parts = []
for k, v in self._received.items():
parts.append(v)
merged = "\n\n---\n\n".join(parts)
return Msg(self.name, merged, "assistant")
async def observe(self, msg) -> None:
pass