You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1495 lines
60 KiB

import json
import uuid
import logging
import re
import asyncio
import httpx
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import Toolkit
from agentscope.agent._react_agent import ReActAgent
from config import settings
logger = logging.getLogger(__name__)
async def _resolve_model_instance(model_instance_id: str) -> dict | None:
try:
from database import AsyncSessionLocal
from sqlalchemy import text
import uuid as _uuid
uid = _uuid.UUID(model_instance_id)
async with AsyncSessionLocal() as db:
result = await db.execute(
text("""
SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key
FROM model_instances mi
JOIN model_providers mp ON mi.provider_id = mp.id
WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true
"""),
{"id": uid},
)
row = result.fetchone()
if row:
return {
"model": row[0],
"base_url": row[2],
"api_key": row[3],
"params": row[1] or {},
}
except Exception:
pass
return None
class FlowSessionMemory:
def __init__(self, session_id: str = "", user_id: str = ""):
self.session_id = session_id
self.user_id = user_id
self._messages: list[dict] = []
def get_history(self, limit: int = 10) -> list[dict]:
return self._messages[-limit * 2:]
def add(self, role: str, content: str):
self._messages.append({"role": role, "content": content})
def to_list(self) -> list[dict]:
return list(self._messages)
class FlowEngine:
MAX_TOTAL_ITERATIONS = 200
FLOW_TIMEOUT_SECONDS = 300
def __init__(self, flow_definition: dict):
self.definition = flow_definition
self.nodes: dict[str, dict] = {}
for node in flow_definition.get("nodes", []):
self.nodes[node["id"]] = node
self.edges: list[dict] = flow_definition.get("edges", [])
self._agent_cache: dict[str, AgentBase] = {}
async def execute(self, input_msg: Msg, context: dict) -> Msg:
graph = self._build_graph()
start_nodes = self._find_start_nodes(graph)
if not start_nodes:
start_nodes = list(self.nodes.keys())[:1]
session_id = context.get("session_id", str(uuid.uuid4()))
user_id = context.get("user_id", "")
memory = FlowSessionMemory(session_id=session_id, user_id=user_id)
context["_memory"] = memory
visited: set[str] = set()
loop_iterations: dict[str, int] = {}
total_iterations = 0
last_result: Msg | None = None
async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None:
nonlocal last_result, total_iterations
total_iterations += 1
if total_iterations > self.MAX_TOTAL_ITERATIONS:
logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止")
last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system")
return
node = self.nodes.get(node_id)
if not node:
return
node_type = node.get("type", "")
is_loop = node_type == "loop"
if not is_loop and node_id in visited:
return
if not is_loop:
visited.add(node_id)
if is_loop:
loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1
config = node.get("config", {})
max_iter = config.get("max_iterations", 10)
if loop_iterations[node_id] > max_iter:
logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}")
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, incoming_msg)
return
if is_loop:
loop_node_config = dict(node.get("config", {}))
loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0)
agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config)
else:
agent = await self._get_or_create_agent(node_id, context)
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
current_msg = incoming_msg
if enriched_content.strip():
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
try:
result = await agent.reply(current_msg, context=context)
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "success",
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
}
context.setdefault("_node_results", {})[node_id] = exec_record
last_result = result
if is_loop:
loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0))
if loop_done:
visited.add(node_id)
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_done":
await traverse(target_id, result)
else:
for target_id, edge_cond in graph.get(node_id, []):
if edge_cond == "loop_body":
body_node = self.nodes.get(target_id)
if body_node:
visited.discard(target_id)
await traverse(target_id, result, {"iteration": loop_iterations[node_id]})
return
is_condition = node_type == "condition"
cond_result = self._parse_condition_result(result)
for target_id, edge_cond in graph.get(node_id, []):
if is_condition:
if edge_cond and edge_cond == cond_result:
await traverse(target_id, result)
elif edge_cond == "loop_body" or edge_cond == "loop_done":
continue
else:
await traverse(target_id, result)
except Exception as e:
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
exec_record = {
"node_id": node_id,
"node_type": node_type,
"label": node.get("label"),
"status": "error",
"error": str(e),
}
context.setdefault("_node_results", {})[node_id] = exec_record
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
last_result = error_msg
if start_nodes:
try:
await asyncio.wait_for(
traverse(start_nodes[0], input_msg),
timeout=self.FLOW_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止")
last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system")
return last_result or input_msg
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
for edge in self.edges:
source = edge.get("source") or edge.get("from")
target = edge.get("target") or edge.get("to")
cond = edge.get("condition") or edge.get("sourceHandle")
if cond == "source":
cond = None
if source and target and source in self.nodes and target in self.nodes:
graph[source].append((target, cond))
return graph
def _find_start_nodes(self, graph: dict) -> list[str]:
target_nodes: set[str] = set()
for targets in graph.values():
for target_id, _ in targets:
target_nodes.add(target_id)
return [nid for nid in self.nodes if nid not in target_nodes]
def _parse_condition_result(self, result: Msg) -> str | None:
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
m = re.search(r'condition:(true|false)', content)
if m:
return m.group(1)
return None
def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool:
config = node.get("config", {})
loop_type = config.get("loop_type", "fixed")
if loop_type == "fixed":
count = config.get("count", 3)
return iteration >= count
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE):
return True
return False
async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase:
if node_id in self._agent_cache and override_config is None:
return self._agent_cache[node_id]
node = dict(self.nodes[node_id])
if override_config is not None:
node["config"] = override_config
agent = await _create_node_agent(node, context)
if override_config is None:
self._agent_cache[node_id] = agent
return agent
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
config = node.get("config", {})
input_mapping = config.get("input_mapping")
if not input_mapping:
return ""
resolved = {}
for key, template in input_mapping.items():
value = template
if isinstance(template, str) and "{{" in template:
value = _resolve_template(template, context, current_msg)
resolved[key] = str(value)
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
node_type = node.get("type", "")
node_id = node.get("id", "")
config = node.get("config", {})
if node_type == "trigger":
return PassThroughAgent(node_id)
elif node_type == "llm":
model_config = config.get("model", settings.LLM_MODEL)
model_instance_id = config.get("model_instance_id")
base_url = settings.LLM_API_BASE
api_key = settings.LLM_API_KEY
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
model_config = resolved.get("model", model_config)
base_url = resolved.get("base_url", base_url)
api_key = resolved.get("api_key", api_key)
temperature = config.get("temperature", 0.7)
system_prompt = config.get("system_prompt", "你是AI助手。")
max_tokens = config.get("max_tokens", 2000)
stream = config.get("stream", True)
stream_cb = context.get("_stream_callback")
agent = LLMNodeAgent(
node_id=node_id,
system_prompt=system_prompt,
model_name=model_config,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
stream_callback=stream_cb,
base_url=base_url,
api_key=api_key,
)
memory = context.get("_memory")
if memory:
agent.set_memory(memory)
return agent
elif node_type == "tool":
tool_name = config.get("tool_name", "")
tool_params = config.get("tool_params", {})
timeout = config.get("timeout", 30)
retry_count = config.get("retry_count", 0)
error_handling = config.get("error_handling", "throw")
return ToolNodeAgent(
node_id=node_id,
tool_name=tool_name,
tool_params=tool_params,
timeout=timeout,
retry_count=retry_count,
error_handling=error_handling,
)
elif node_type == "mcp":
mcp_server = config.get("mcp_server", "")
tool_name = config.get("tool_name", "")
timeout = config.get("timeout", 30)
error_handling = config.get("error_handling", "throw")
return MCPNodeAgent(
node_id=node_id,
server_name=mcp_server,
tool_name=tool_name,
timeout=timeout,
error_handling=error_handling,
)
elif node_type in ("wecom_notify", "notify"):
return NotifyAgent(node_id=node_id, config=config)
elif node_type == "condition":
condition_expr = config.get("condition", "")
condition_type = config.get("condition_type", "expression")
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
elif node_type == "rag":
model_instance_id = config.get("model_instance_id")
if model_instance_id:
resolved = await _resolve_model_instance(model_instance_id)
if resolved:
config = dict(config)
config["_resolved_model"] = resolved
return RAGNodeAgent(node_id=node_id, config=config)
elif node_type == "output":
return OutputNodeAgent(node_id=node_id, config=config)
elif node_type == "merge":
return ParallelMergeNodeAgent(node_id=node_id, config=config)
elif node_type == "loop":
return LoopNodeAgent(node_id=node_id, config=config)
elif node_type == "code":
language = config.get("language", "python")
code = config.get("code", "")
timeout = config.get("timeout", 30)
sandbox = config.get("sandbox", True)
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
elif node_type == "http_request":
return HttpRequestNodeAgent(node_id=node_id, config=config)
elif node_type == "question_classifier":
return QuestionClassifierNodeAgent(node_id=node_id, config=config)
elif node_type == "variable_assigner":
return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context)
elif node_type == "template_transform":
return TemplateTransformNodeAgent(node_id=node_id, config=config)
elif node_type == "iteration":
return IterationNodeAgent(node_id=node_id, config=config)
elif node_type == "question_optimiser":
return QuestionOptimiserNodeAgent(node_id=node_id, config=config)
else:
return PassThroughAgent(node_id)
class PassThroughAgent(AgentBase):
def __init__(self, node_id: str):
super().__init__()
self.name = f"passthrough_{node_id}"
async def reply(self, msg, **kwargs) -> Msg:
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
async def observe(self, msg) -> None:
pass
class HttpRequestNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"HttpRequest_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
method = self.config.get("method", "GET").upper()
url = self.config.get("url", "")
headers = self.config.get("headers", {})
body = self.config.get("body", "")
auth_type = self.config.get("auth_type", "none")
auth_config = self.config.get("auth_config", {})
timeout = self.config.get("timeout", 30)
retry_count = self.config.get("retry_count", 0)
if isinstance(headers, str):
try:
headers = json.loads(headers)
except (json.JSONDecodeError, ValueError):
headers = {}
request_headers = dict(headers)
if auth_type == "bearer" and auth_config.get("token"):
request_headers["Authorization"] = f"Bearer {auth_config['token']}"
elif auth_type == "api_key" and auth_config.get("api_key"):
key_name = auth_config.get("key_name", "X-API-Key")
request_headers[key_name] = auth_config["api_key"]
elif auth_type == "basic" and auth_config.get("username"):
import base64
credentials = f"{auth_config['username']}:{auth_config.get('password', '')}"
request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}"
last_error = None
for attempt in range(retry_count + 1):
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
request_kwargs = {"headers": request_headers}
if method in ("POST", "PUT", "PATCH"):
if isinstance(body, dict):
request_kwargs["json"] = body
elif body and isinstance(body, str):
try:
request_kwargs["json"] = json.loads(body)
except (json.JSONDecodeError, ValueError):
request_kwargs["content"] = body
response = await client.request(method, url, **request_kwargs)
resp_status = response.status_code
resp_text = response.text
try:
resp_body = json.loads(resp_text)
except (json.JSONDecodeError, ValueError):
resp_body = resp_text
result = json.dumps({
"status_code": resp_status,
"body": resp_body,
}, ensure_ascii=False)
return Msg(self.name, result, "assistant")
except Exception as e:
last_error = str(e)
if attempt < retry_count:
await asyncio.sleep(1)
error_result = json.dumps({
"status_code": 0,
"error": last_error or "unknown error",
"body": None,
}, ensure_ascii=False)
return Msg(self.name, error_result, "assistant")
async def observe(self, msg) -> None:
pass
class QuestionClassifierNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Classifier_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
categories = self.config.get("categories", [])
instruction = self.config.get("instruction", "")
model_name = self.config.get("model", settings.LLM_MODEL)
temperature = self.config.get("temperature", 0.3)
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not categories:
return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant")
category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories])
prompt = f"""请对以下用户输入进行意图分类。
分类选项:
{category_desc}
{instruction}
用户输入:{user_text}
请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}"""
try:
import openai
client = openai.AsyncOpenAI(
base_url=settings.LLM_API_BASE,
api_key=settings.LLM_API_KEY,
)
response = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=200,
)
result_text = response.choices[0].message.content.strip()
try:
parsed = json.loads(result_text)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant")
except Exception as e:
logger.error(f"QuestionClassifier error: {e}")
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant")
async def observe(self, msg) -> None:
pass
class VariableAssignerNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None, context: dict = None):
super().__init__()
self.name = f"VarAssign_{node_id}"
self.config = config or {}
self._context = context or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
assignments = self.config.get("assignments", [])
results = {}
for assignment in assignments:
target_var = assignment.get("target_var", "")
source_type = assignment.get("source_type", "constant")
source_value = assignment.get("source_value", "")
if source_type == "constant":
value = source_value
elif source_type == "upstream_output":
value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
elif source_type == "template":
value = _resolve_template(source_value, self._context, msg)
elif source_type == "expression":
try:
safe_locals = {"msg": msg, "context": self._context}
value = eval(source_value, {"__builtins__": {}}, safe_locals)
value = str(value)
except Exception as e:
value = f"[expression error: {e}]"
else:
value = source_value
results[target_var] = value
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class LLMNodeAgent(AgentBase):
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""):
super().__init__()
self.name = f"LLM_{node_id}"
self.system_prompt = system_prompt
self.model_name = model_name or settings.LLM_MODEL
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.stream_callback = stream_callback
self.base_url = base_url or settings.LLM_API_BASE
self.api_key = api_key or settings.LLM_API_KEY
self._memory = None
def set_memory(self, memory):
self._memory = memory
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
context = kwargs.get("context", {})
memory_ctx = context.get("_memory_context", {})
messages = [{"role": "system", "content": self.system_prompt}]
if memory_ctx:
summary = memory_ctx.get("summary", "")
recent = memory_ctx.get("recent_messages", [])
if summary:
messages.append({"role": "system", "content": f"[历史对话摘要]\n{summary}"})
for m in recent[-10:]:
role = m.get("role", "user")
content = m.get("content", "")
if len(content) > 2000:
content = content[:2000]
messages.append({"role": role, "content": content})
elif self._memory:
history = self._memory.get_history(limit=5)
for h in history:
role = h.get("role", "user")
content = h.get("content", "")
if len(content) > 2000:
content = content[:2000]
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": user_text})
try:
if self.stream_callback:
res_text = await self._stream_llm_call(messages)
else:
res_text = await self._blocking_llm_call(messages)
except Exception as e:
logger.warning(f"LLM 调用失败: {e}")
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
if self._memory and not memory_ctx:
self._memory.add("user", user_text)
self._memory.add("assistant", res_text)
return Msg(self.name, res_text, "assistant")
async def _blocking_llm_call(self, messages: list[dict]) -> str:
import httpx
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
resp = await client.post(url, json=body, headers=headers)
data = resp.json()
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception as e:
logger.warning(f"LLM 阻塞调用失败: {e}")
return f"[LLM 调用失败: {e}]"
async def _stream_llm_call(self, messages: list[dict]) -> str:
import httpx
import json
api_base = self.base_url.rstrip("/")
api_key = self.api_key
model_name = self.model_name or settings.LLM_MODEL
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": True,
}
accumulated = ""
try:
timeout = httpx.Timeout(60.0, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", url, json=body, headers=headers) as response:
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
json_str = line[6:].strip()
if json_str == "[DONE]":
break
try:
chunk = json.loads(json_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
token = delta.get("content", "")
if token:
accumulated += token
await self.stream_callback("text_chunk", {"content": token})
except json.JSONDecodeError:
continue
except httpx.TimeoutException:
logger.warning("LLM 流式调用超时")
if not accumulated:
accumulated = "[LLM 超时]"
except Exception as e:
logger.warning(f"LLM 流式调用失败: {e}")
if not accumulated:
accumulated = f"[LLM 调用失败: {e}]"
return accumulated
async def observe(self, msg) -> None:
pass
class ToolNodeAgent(AgentBase):
_TOOL_REGISTRY: dict[str, callable] = {}
_TOOL_SCHEMAS: dict[str, dict] = {}
_CUSTOM_TOOL_DEFS: dict[str, dict] = {}
@classmethod
def _init_registry(cls):
if cls._TOOL_REGISTRY:
return
try:
from agentscope_integration.tools.document_tools import parse_document, format_correction
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
cls._TOOL_REGISTRY = {
"parse_document": parse_document,
"format_correction": format_correction,
"send_notification": send_notification,
"query_wecom_user": query_wecom_user,
"list_tasks": list_tasks,
"create_task": create_task,
"get_task": get_task,
"update_task": update_task,
"push_task_to_wecom": push_task_to_wecom,
"list_subordinates": list_subordinates,
"generate_efficiency_report": generate_efficiency_report,
"get_task_statistics": get_task_statistics,
"get_employee_dashboard": get_employee_dashboard,
}
for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]:
try:
mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"])
if hasattr(mod, "SCHEMAS"):
cls._TOOL_SCHEMAS.update(mod.SCHEMAS)
except Exception:
pass
except ImportError as e:
logger.warning(f"工具注册失败: {e}")
@classmethod
async def load_custom_tools(cls, db):
try:
from sqlalchemy import select
from models import CustomTool
from modules.custom_tool.executor import CustomToolExecutor
result = await db.execute(
select(CustomTool).where(CustomTool.is_active == True)
)
for tool in result.scalars().all():
tool_def = {
"endpoint_url": tool.endpoint_url,
"method": tool.method,
"path": tool.path,
"headers_json": tool.headers_json,
"auth_type": tool.auth_type,
"auth_config": tool.auth_config,
"timeout": 30,
}
cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def
cls._TOOL_SCHEMAS[tool.name] = tool.schema_json
except ImportError:
logger.warning("无法加载自定义工具模块")
except Exception as e:
logger.warning(f"加载自定义工具失败: {e}")
@classmethod
def register_custom_tool(cls, name: str, schema: dict, tool_def: dict):
cls._CUSTOM_TOOL_DEFS[name] = tool_def
cls._TOOL_SCHEMAS[name] = schema
@classmethod
def get_schemas(cls) -> dict:
cls._init_registry()
return dict(cls._TOOL_SCHEMAS)
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"):
super().__init__()
self.name = f"Tool_{node_id}"
self.tool_name = tool_name
self.tool_params = tool_params or {}
self.timeout = timeout
self.retry_count = retry_count
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
self._init_registry()
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name)
if custom_def:
return await self._execute_custom_tool(custom_def, user_text)
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
if not tool_func:
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
def _call_sync():
try:
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
return result
except TypeError:
try:
result = tool_func(user_text, **self.tool_params)
return result
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
except Exception as e:
raise RuntimeError(f"工具执行失败: {e}")
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
asyncio.to_thread(_call_sync),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except RuntimeError as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[{e}]", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant")
async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg:
from modules.custom_tool.executor import CustomToolExecutor
executor = CustomToolExecutor(custom_def)
for attempt in range(self.retry_count + 1):
try:
result = await asyncio.wait_for(
executor.execute(self.tool_params),
timeout=self.timeout,
)
return Msg(self.name, str(result), "assistant")
except asyncio.TimeoutError:
if attempt < self.retry_count:
logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
except Exception as e:
if attempt < self.retry_count:
continue
if self.error_handling == "skip":
return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
if self.error_handling == "default":
return Msg(self.name, "{}", "assistant")
return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant")
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant")
async def observe(self, msg) -> None:
pass
class ConditionNodeAgent(AgentBase):
def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"):
super().__init__()
self.name = f"Condition_{node_id}"
self.condition = condition
self.condition_type = condition_type
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.condition:
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
if self.condition_type == "regex":
try:
matched = bool(re.search(self.condition, user_text))
return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'' if matched else '未命'}", "assistant")
except re.error as e:
return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant")
if self.condition_type == "json_path":
try:
data = json.loads(user_text) if user_text.strip().startswith('{') else {}
parts = self.condition.strip('$.').split('.')
val = data
for p in parts:
if isinstance(val, dict):
val = val.get(p, None)
else:
val = None
return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant")
except:
return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant")
try:
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
model = OpenAIChatModel(
config_name=f"condition_{self.name}",
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
api_base=settings.LLM_API_BASE,
)
formatter = OpenAIChatFormatter()
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
条件表达式: {self.condition}
输入内容:
{user_text[:2000]}
请严格只输出一行 JSON:
{{"result": true/false, "reason": "简要原因"}}"""
prompt = formatter.format([
Msg("system", condition_prompt, "system"),
Msg("user", user_text[:2000], "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
json_match = re.search(r'\{[^}]+\}', res_text)
if json_match:
parsed = json.loads(json_match.group())
matched = parsed.get("result", False)
reason = parsed.get("reason", "")
result_flag = "true" if matched else "false"
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
except Exception as e:
logger.warning(f"条件判断LLM调用失败: {e}")
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
async def observe(self, msg) -> None:
pass
class MCPNodeAgent(AgentBase):
def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"):
super().__init__()
self.name = f"MCP_{node_id}"
self.server_name = server_name
self.tool_name = tool_name
self.timeout = timeout
self.error_handling = error_handling
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if not self.server_name:
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
try:
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
client = MCPClientManager.get_http_client(self.server_name)
if client and self.tool_name:
result = await asyncio.wait_for(
client.call_tool(self.tool_name, {"input": user_text}),
timeout=self.timeout
)
return Msg(self.name, str(result), "assistant")
except ImportError:
logger.warning("agentscope_runtime MCP 客户端不可用")
except asyncio.TimeoutError:
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant")
return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant")
except Exception as e:
logger.warning(f"MCP 调用失败: {e}")
if self.error_handling == "skip":
return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant")
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
return Msg(self.name, output, "assistant")
async def observe(self, msg) -> None:
pass
class NotifyAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Notify_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
channels = self.config.get("channels", {"wecom": True, "web": False})
message_type = self.config.get("message_type", "text")
results = []
if channels.get("wecom", True):
template = self.config.get("message_template", "")
target = self.config.get("target", "")
message = template or user_text[:500]
try:
from agentscope_integration.tools.wecom_tools import send_notification
result = send_notification(to_user=target or "user", message=message)
results.append(f"企微通知: {result[:100]}")
except ImportError:
pass
except Exception as e:
logger.warning(f"企微通知发送失败: {e}")
results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}")
if channels.get("web", False):
web_template = self.config.get("web_template", "")
web_message = web_template or user_text[:500]
try:
from websocket_manager import ws_manager
target_user = self.config.get("target", "")
if target_user:
await ws_manager.send_to_user(target_user, {
"type": "flow_notification",
"message": web_message,
"level": "info",
})
results.append(f"Web通知: 已推送")
except ImportError:
pass
except Exception as e:
logger.warning(f"Web通知推送失败: {e}")
results.append(f"Web通知: 推送失败({e})")
if not results:
results.append("通知已发送")
return Msg(self.name, " | ".join(results), "assistant")
async def observe(self, msg) -> None:
pass
class LoopNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Loop_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
loop_type = self.config.get("loop_type", "fixed")
count = self.config.get("count", 3)
iterator_variable = self.config.get("iterator_variable", "item")
iteration = self.config.get("_engine_iteration", 0) + 1
if loop_type == "array":
items = self.config.get("items", [])
if isinstance(items, str):
try:
items = json.loads(items)
except Exception:
items = [line.strip() for line in items.split("\n") if line.strip()]
if isinstance(items, list) and items:
idx = iteration - 1
if idx >= len(items):
return Msg(self.name, "loop:stop|数组迭代完成", "assistant")
current_item = items[idx]
return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, "loop:stop|无数组数据", "assistant")
if loop_type == "fixed":
if iteration >= count:
return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant")
return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
async def observe(self, msg) -> None:
pass
class CodeNodeAgent(AgentBase):
def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True):
super().__init__()
self.name = f"Code_{node_id}"
self.language = language
self.code = code
self.timeout = timeout
self.sandbox = sandbox
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
import subprocess
import tempfile
import os
import json
if not self.code.strip():
return Msg(self.name, user_text, "assistant")
safe_input = json.dumps(user_text)
code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}"
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
f.write(code_with_input)
temp_path = f.name
proc = await asyncio.create_subprocess_exec(
"python", temp_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
except asyncio.TimeoutError:
proc.kill()
try:
os.unlink(temp_path)
except:
pass
return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant")
try:
os.unlink(temp_path)
except:
pass
if stderr:
return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant")
output = stdout.decode('utf-8', errors='replace').strip()
return Msg(self.name, output or "[代码执行完成,无输出]", "assistant")
except Exception as e:
return Msg(self.name, f"[代码执行失败: {e}]", "assistant")
async def observe(self, msg) -> None:
pass
class RAGNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"RAG_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
top_k = self.config.get("top_k", 5)
similarity_threshold = self.config.get("similarity_threshold", 0.7)
search_mode = self.config.get("search_mode", "hybrid")
include_metadata = self.config.get("include_metadata", True)
try:
from modules.rag.knowledge import retrieve_for_agent
kb_result = await retrieve_for_agent(user_text, limit=top_k)
model = self._get_model()
formatter = self._get_formatter()
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
知识库检索结果:
{kb_result}
用户问题: {user_text}
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
prompt = formatter.format([
Msg("system", rag_prompt, "system"),
Msg("user", user_text, "user"),
])
res = await model(prompt)
res_text = ""
if isinstance(res, list):
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
elif hasattr(res, 'get_text_content'):
res_text = res.get_text_content()
else:
res_text = str(res)
return Msg(self.name, res_text, "assistant")
except Exception as e:
logger.warning(f"RAG 节点执行失败: {e}")
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
return Msg(self.name, output, "assistant")
def _get_model(self):
from agentscope.model import OpenAIChatModel
resolved = self.config.get("_resolved_model", {})
return OpenAIChatModel(
config_name=f"rag_{self.name}",
model_name=resolved.get("model", settings.LLM_MODEL),
api_key=resolved.get("api_key", settings.LLM_API_KEY),
api_base=resolved.get("base_url", settings.LLM_API_BASE),
)
def _get_formatter(self):
from agentscope.formatter import OpenAIChatFormatter
return OpenAIChatFormatter()
async def observe(self, msg) -> None:
pass
class OutputNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Output_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
output_format = self.config.get("format", "text")
output_template = self.config.get("output_template", "")
truncate = self.config.get("truncate", False)
max_length = self.config.get("max_length", 2000)
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
if truncate and len(content) > max_length:
content = content[:max_length] + "\n\n[内容已截断]"
if output_format == "json":
try:
parsed = json.loads(content)
indent = self.config.get("indent", 2)
formatted = json.dumps(parsed, indent=indent, ensure_ascii=False)
return Msg(self.name, formatted, "assistant")
except (json.JSONDecodeError, ValueError):
pass
if output_template:
try:
resolved = output_template.replace("{{output}}", content)
return Msg(self.name, resolved, "assistant")
except:
pass
return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant")
async def observe(self, msg) -> None:
pass
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
result = template
import re
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
for placeholder in placeholders:
parts = placeholder.strip().split(".")
value = ""
if parts[0] == "trigger":
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
elif parts[0] in context.get("_node_results", {}):
node_result = context.get("_node_results", {}).get(parts[0], {})
if len(parts) > 1:
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
else:
value = str(node_result.get("output", ""))
result = result.replace("{{" + placeholder + "}}", value)
return result
class TemplateTransformNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Template_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
template = self.config.get("template", "")
output_type = self.config.get("output_type", "string")
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
rendered = _resolve_template(template, context, msg)
if not rendered:
rendered = template
rendered = rendered.replace("{{input}}", user_text)
except Exception:
rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text
if output_type == "json":
try:
parsed = json.loads(rendered)
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
except (json.JSONDecodeError, ValueError):
pass
return Msg(self.name, rendered, "assistant")
async def observe(self, msg) -> None:
pass
class IterationNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Iteration_{node_id}"
self.config = config or {}
self._results: list[str] = []
async def reply(self, msg: Msg, **kwargs) -> Msg:
input_array_source = self.config.get("input_array_source", "")
max_iterations = self.config.get("max_iterations", 20)
context = kwargs.get("context", {})
items = []
if input_array_source:
resolved = _resolve_template(input_array_source, context, msg)
try:
items = json.loads(resolved)
if not isinstance(items, list):
items = [resolved]
except (json.JSONDecodeError, ValueError):
items = [resolved] if resolved else []
else:
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
try:
items = json.loads(user_text)
if not isinstance(items, list):
items = [user_text]
except (json.JSONDecodeError, ValueError):
items = [item.strip() for item in user_text.split("\n") if item.strip()]
if not items:
items = [user_text]
items = items[:max_iterations]
results = []
for i, item in enumerate(items):
results.append({
"index": i,
"item": item,
})
output = json.dumps(results, ensure_ascii=False)
return Msg(self.name, output, "assistant")
def get_iteration_items(self) -> list:
return self._results
async def observe(self, msg) -> None:
pass
class QuestionOptimiserNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"QOpt_{node_id}"
self.config = config or {}
async def reply(self, msg: Msg, **kwargs) -> Msg:
optimization_type = self.config.get("optimization_type", "rewrite")
model_name = self.config.get("model", settings.LLM_MODEL)
context = kwargs.get("context", {})
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
history = context.get("_memory", {}).get("recent_messages", [])
persona = context.get("_memory", {}).get("persona", {})
atoms = context.get("_memory", {}).get("atoms", [])
persona_text = persona.get("raw_text", "")
atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else ""
history_text = "\n".join(
f"{m['role']}: {m['content'][:200]}" for m in history[-6:]
) if history else ""
context_parts = []
if persona_text:
context_parts.append(f"用户画像: {persona_text}")
if atoms_text:
context_parts.append(f"已知信息: {atoms_text}")
if history_text:
context_parts.append(f"近期对话: {history_text}")
context_block = "\n".join(context_parts) if context_parts else ""
if optimization_type == "rewrite":
prompt = f"""{context_block}
原始问题: {user_text}
请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。
只返回优化后的问题,不要其他内容。"""
elif optimization_type == "expand":
prompt = f"""{context_block}
简短问题: {user_text}
请将以上问题扩展为更详细的版本,添加必要的背景和细节。
只返回扩展后的问题,不要其他内容。"""
else:
return Msg(self.name, user_text, "assistant")
try:
import httpx
api_base = settings.LLM_API_BASE.rstrip("/")
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{api_base}/chat/completions",
json={
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.3,
},
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
)
data = resp.json()
result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text)
return Msg(self.name, result.strip(), "assistant")
except Exception as e:
logger.warning(f"问题优化失败: {e}")
return Msg(self.name, user_text, "assistant")
async def observe(self, msg) -> None:
pass
class ParallelMergeNodeAgent(AgentBase):
def __init__(self, node_id: str, config: dict = None):
super().__init__()
self.name = f"Merge_{node_id}"
self.config = config or {}
self._received: dict[str, str] = {}
self.merge_type = self.config.get("merge_type", "concat")
async def reply(self, msg: Msg, **kwargs) -> Msg:
source_id = kwargs.get("source_node_id", "")
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
self._received[source_id] = content
expected_count = self.config.get("expected_branches", 0)
if expected_count <= 0 or len(self._received) >= expected_count:
return self._merge()
return Msg(self.name, "", "assistant")
def _merge(self) -> Msg:
if self.merge_type == "json":
merged = json.dumps(self._received, ensure_ascii=False)
elif self.merge_type == "first_non_empty":
merged = ""
for v in self._received.values():
if v.strip():
merged = v
break
else:
parts = []
for k, v in self._received.items():
parts.append(v)
merged = "\n\n---\n\n".join(parts)
return Msg(self.name, merged, "assistant")
async def observe(self, msg) -> None:
pass