You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1495 lines
60 KiB
1495 lines
60 KiB
import json
|
|
import uuid
|
|
import logging
|
|
import re
|
|
import asyncio
|
|
import httpx
|
|
from agentscope.agent import AgentBase
|
|
from agentscope.message import Msg
|
|
from agentscope.tool import Toolkit
|
|
from agentscope.agent._react_agent import ReActAgent
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _resolve_model_instance(model_instance_id: str) -> dict | None:
|
|
try:
|
|
from database import AsyncSessionLocal
|
|
from sqlalchemy import text
|
|
import uuid as _uuid
|
|
uid = _uuid.UUID(model_instance_id)
|
|
async with AsyncSessionLocal() as db:
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT mi.model_name, mi.default_params, mp.base_url, mp.api_key
|
|
FROM model_instances mi
|
|
JOIN model_providers mp ON mi.provider_id = mp.id
|
|
WHERE mi.id = :id AND mi.is_active = true AND mp.is_active = true
|
|
"""),
|
|
{"id": uid},
|
|
)
|
|
row = result.fetchone()
|
|
if row:
|
|
return {
|
|
"model": row[0],
|
|
"base_url": row[2],
|
|
"api_key": row[3],
|
|
"params": row[1] or {},
|
|
}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
class FlowSessionMemory:
|
|
def __init__(self, session_id: str = "", user_id: str = ""):
|
|
self.session_id = session_id
|
|
self.user_id = user_id
|
|
self._messages: list[dict] = []
|
|
|
|
def get_history(self, limit: int = 10) -> list[dict]:
|
|
return self._messages[-limit * 2:]
|
|
|
|
def add(self, role: str, content: str):
|
|
self._messages.append({"role": role, "content": content})
|
|
|
|
def to_list(self) -> list[dict]:
|
|
return list(self._messages)
|
|
|
|
|
|
class FlowEngine:
|
|
MAX_TOTAL_ITERATIONS = 200
|
|
FLOW_TIMEOUT_SECONDS = 300
|
|
|
|
def __init__(self, flow_definition: dict):
|
|
self.definition = flow_definition
|
|
self.nodes: dict[str, dict] = {}
|
|
for node in flow_definition.get("nodes", []):
|
|
self.nodes[node["id"]] = node
|
|
self.edges: list[dict] = flow_definition.get("edges", [])
|
|
self._agent_cache: dict[str, AgentBase] = {}
|
|
|
|
async def execute(self, input_msg: Msg, context: dict) -> Msg:
|
|
graph = self._build_graph()
|
|
start_nodes = self._find_start_nodes(graph)
|
|
if not start_nodes:
|
|
start_nodes = list(self.nodes.keys())[:1]
|
|
|
|
session_id = context.get("session_id", str(uuid.uuid4()))
|
|
user_id = context.get("user_id", "")
|
|
memory = FlowSessionMemory(session_id=session_id, user_id=user_id)
|
|
context["_memory"] = memory
|
|
|
|
visited: set[str] = set()
|
|
loop_iterations: dict[str, int] = {}
|
|
total_iterations = 0
|
|
last_result: Msg | None = None
|
|
|
|
async def traverse(node_id: str, incoming_msg: Msg, loop_context: dict = None) -> None:
|
|
nonlocal last_result, total_iterations
|
|
|
|
total_iterations += 1
|
|
if total_iterations > self.MAX_TOTAL_ITERATIONS:
|
|
logger.warning(f"流执行超过最大总迭代次数 {self.MAX_TOTAL_ITERATIONS},强制终止")
|
|
last_result = Msg(name="system", content="[流执行超限: 超过最大迭代次数,已强制终止]", role="system")
|
|
return
|
|
|
|
node = self.nodes.get(node_id)
|
|
if not node:
|
|
return
|
|
|
|
node_type = node.get("type", "")
|
|
is_loop = node_type == "loop"
|
|
|
|
if not is_loop and node_id in visited:
|
|
return
|
|
|
|
if not is_loop:
|
|
visited.add(node_id)
|
|
|
|
if is_loop:
|
|
loop_iterations[node_id] = loop_iterations.get(node_id, 0) + 1
|
|
config = node.get("config", {})
|
|
max_iter = config.get("max_iterations", 10)
|
|
if loop_iterations[node_id] > max_iter:
|
|
logger.warning(f"循环节点 {node.get('label', node_id)} 超过最大迭代次数 {max_iter}")
|
|
visited.add(node_id)
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_done":
|
|
await traverse(target_id, incoming_msg)
|
|
return
|
|
|
|
if is_loop:
|
|
loop_node_config = dict(node.get("config", {}))
|
|
loop_node_config["_engine_iteration"] = loop_iterations.get(node_id, 0)
|
|
agent = await self._get_or_create_agent(node_id, context, override_config=loop_node_config)
|
|
else:
|
|
agent = await self._get_or_create_agent(node_id, context)
|
|
enriched_content = self._resolve_input_mapping(node, incoming_msg, context)
|
|
current_msg = incoming_msg
|
|
if enriched_content.strip():
|
|
user_text = current_msg.get_text_content() if hasattr(current_msg, 'get_text_content') else str(current_msg)
|
|
current_msg = Msg(name="user", content=f"{enriched_content}\n\n---\n{user_text}", role="user")
|
|
|
|
try:
|
|
result = await agent.reply(current_msg, context=context)
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node_type,
|
|
"label": node.get("label"),
|
|
"status": "success",
|
|
"output": result.get_text_content()[:500] if hasattr(result, 'get_text_content') else str(result)[:500],
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
last_result = result
|
|
|
|
if is_loop:
|
|
loop_done = self._check_loop_done(result, node, loop_iterations.get(node_id, 0))
|
|
if loop_done:
|
|
visited.add(node_id)
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_done":
|
|
await traverse(target_id, result)
|
|
else:
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if edge_cond == "loop_body":
|
|
body_node = self.nodes.get(target_id)
|
|
if body_node:
|
|
visited.discard(target_id)
|
|
await traverse(target_id, result, {"iteration": loop_iterations[node_id]})
|
|
return
|
|
|
|
is_condition = node_type == "condition"
|
|
cond_result = self._parse_condition_result(result)
|
|
|
|
for target_id, edge_cond in graph.get(node_id, []):
|
|
if is_condition:
|
|
if edge_cond and edge_cond == cond_result:
|
|
await traverse(target_id, result)
|
|
elif edge_cond == "loop_body" or edge_cond == "loop_done":
|
|
continue
|
|
else:
|
|
await traverse(target_id, result)
|
|
except Exception as e:
|
|
logger.error(f"节点 {node.get('label', node_id)} 执行失败: {e}")
|
|
exec_record = {
|
|
"node_id": node_id,
|
|
"node_type": node_type,
|
|
"label": node.get("label"),
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
context.setdefault("_node_results", {})[node_id] = exec_record
|
|
error_msg = Msg(name="system", content=f"[节点 {node.get('label', node_id)} 执行失败: {e}]", role="system")
|
|
last_result = error_msg
|
|
|
|
if start_nodes:
|
|
try:
|
|
await asyncio.wait_for(
|
|
traverse(start_nodes[0], input_msg),
|
|
timeout=self.FLOW_TIMEOUT_SECONDS,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"流执行超时 ({self.FLOW_TIMEOUT_SECONDS}s),强制终止")
|
|
last_result = Msg(name="system", content=f"[流执行超时: 超过{self.FLOW_TIMEOUT_SECONDS}秒限制]", role="system")
|
|
|
|
return last_result or input_msg
|
|
|
|
def _build_graph(self) -> dict[str, list[tuple[str, str | None]]]:
|
|
graph: dict[str, list[tuple[str, str | None]]] = {nid: [] for nid in self.nodes}
|
|
for edge in self.edges:
|
|
source = edge.get("source") or edge.get("from")
|
|
target = edge.get("target") or edge.get("to")
|
|
cond = edge.get("condition") or edge.get("sourceHandle")
|
|
if cond == "source":
|
|
cond = None
|
|
if source and target and source in self.nodes and target in self.nodes:
|
|
graph[source].append((target, cond))
|
|
return graph
|
|
|
|
def _find_start_nodes(self, graph: dict) -> list[str]:
|
|
target_nodes: set[str] = set()
|
|
for targets in graph.values():
|
|
for target_id, _ in targets:
|
|
target_nodes.add(target_id)
|
|
return [nid for nid in self.nodes if nid not in target_nodes]
|
|
|
|
def _parse_condition_result(self, result: Msg) -> str | None:
|
|
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
|
|
m = re.search(r'condition:(true|false)', content)
|
|
if m:
|
|
return m.group(1)
|
|
return None
|
|
|
|
def _check_loop_done(self, result: Msg, node: dict, iteration: int) -> bool:
|
|
config = node.get("config", {})
|
|
loop_type = config.get("loop_type", "fixed")
|
|
if loop_type == "fixed":
|
|
count = config.get("count", 3)
|
|
return iteration >= count
|
|
content = result.get_text_content() if hasattr(result, 'get_text_content') else str(result)
|
|
if re.search(r'loop:(stop|done|break)', content, re.IGNORECASE):
|
|
return True
|
|
return False
|
|
|
|
async def _get_or_create_agent(self, node_id: str, context: dict, override_config: dict = None) -> AgentBase:
|
|
if node_id in self._agent_cache and override_config is None:
|
|
return self._agent_cache[node_id]
|
|
|
|
node = dict(self.nodes[node_id])
|
|
if override_config is not None:
|
|
node["config"] = override_config
|
|
agent = await _create_node_agent(node, context)
|
|
if override_config is None:
|
|
self._agent_cache[node_id] = agent
|
|
return agent
|
|
|
|
def _resolve_input_mapping(self, node: dict, current_msg: Msg, context: dict) -> str:
|
|
config = node.get("config", {})
|
|
input_mapping = config.get("input_mapping")
|
|
if not input_mapping:
|
|
return ""
|
|
|
|
resolved = {}
|
|
for key, template in input_mapping.items():
|
|
value = template
|
|
if isinstance(template, str) and "{{" in template:
|
|
value = _resolve_template(template, context, current_msg)
|
|
resolved[key] = str(value)
|
|
|
|
return "\n".join([f"{k}: {v}" for k, v in resolved.items()])
|
|
|
|
|
|
async def _create_node_agent(node: dict, context: dict) -> AgentBase:
|
|
node_type = node.get("type", "")
|
|
node_id = node.get("id", "")
|
|
config = node.get("config", {})
|
|
|
|
if node_type == "trigger":
|
|
return PassThroughAgent(node_id)
|
|
|
|
elif node_type == "llm":
|
|
model_config = config.get("model", settings.LLM_MODEL)
|
|
model_instance_id = config.get("model_instance_id")
|
|
base_url = settings.LLM_API_BASE
|
|
api_key = settings.LLM_API_KEY
|
|
if model_instance_id:
|
|
resolved = await _resolve_model_instance(model_instance_id)
|
|
if resolved:
|
|
model_config = resolved.get("model", model_config)
|
|
base_url = resolved.get("base_url", base_url)
|
|
api_key = resolved.get("api_key", api_key)
|
|
temperature = config.get("temperature", 0.7)
|
|
system_prompt = config.get("system_prompt", "你是AI助手。")
|
|
max_tokens = config.get("max_tokens", 2000)
|
|
stream = config.get("stream", True)
|
|
stream_cb = context.get("_stream_callback")
|
|
agent = LLMNodeAgent(
|
|
node_id=node_id,
|
|
system_prompt=system_prompt,
|
|
model_name=model_config,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=stream,
|
|
stream_callback=stream_cb,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
memory = context.get("_memory")
|
|
if memory:
|
|
agent.set_memory(memory)
|
|
return agent
|
|
|
|
elif node_type == "tool":
|
|
tool_name = config.get("tool_name", "")
|
|
tool_params = config.get("tool_params", {})
|
|
timeout = config.get("timeout", 30)
|
|
retry_count = config.get("retry_count", 0)
|
|
error_handling = config.get("error_handling", "throw")
|
|
return ToolNodeAgent(
|
|
node_id=node_id,
|
|
tool_name=tool_name,
|
|
tool_params=tool_params,
|
|
timeout=timeout,
|
|
retry_count=retry_count,
|
|
error_handling=error_handling,
|
|
)
|
|
|
|
elif node_type == "mcp":
|
|
mcp_server = config.get("mcp_server", "")
|
|
tool_name = config.get("tool_name", "")
|
|
timeout = config.get("timeout", 30)
|
|
error_handling = config.get("error_handling", "throw")
|
|
return MCPNodeAgent(
|
|
node_id=node_id,
|
|
server_name=mcp_server,
|
|
tool_name=tool_name,
|
|
timeout=timeout,
|
|
error_handling=error_handling,
|
|
)
|
|
|
|
elif node_type in ("wecom_notify", "notify"):
|
|
return NotifyAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "condition":
|
|
condition_expr = config.get("condition", "")
|
|
condition_type = config.get("condition_type", "expression")
|
|
return ConditionNodeAgent(node_id=node_id, condition=condition_expr, condition_type=condition_type)
|
|
|
|
elif node_type == "rag":
|
|
model_instance_id = config.get("model_instance_id")
|
|
if model_instance_id:
|
|
resolved = await _resolve_model_instance(model_instance_id)
|
|
if resolved:
|
|
config = dict(config)
|
|
config["_resolved_model"] = resolved
|
|
return RAGNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "output":
|
|
return OutputNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "merge":
|
|
return ParallelMergeNodeAgent(node_id=node_id, config=config)
|
|
elif node_type == "loop":
|
|
return LoopNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "code":
|
|
language = config.get("language", "python")
|
|
code = config.get("code", "")
|
|
timeout = config.get("timeout", 30)
|
|
sandbox = config.get("sandbox", True)
|
|
return CodeNodeAgent(node_id=node_id, language=language, code=code, timeout=timeout, sandbox=sandbox)
|
|
|
|
elif node_type == "http_request":
|
|
return HttpRequestNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "question_classifier":
|
|
return QuestionClassifierNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "variable_assigner":
|
|
return VariableAssignerNodeAgent(node_id=node_id, config=config, context=context)
|
|
|
|
elif node_type == "template_transform":
|
|
return TemplateTransformNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "iteration":
|
|
return IterationNodeAgent(node_id=node_id, config=config)
|
|
|
|
elif node_type == "question_optimiser":
|
|
return QuestionOptimiserNodeAgent(node_id=node_id, config=config)
|
|
|
|
else:
|
|
return PassThroughAgent(node_id)
|
|
|
|
|
|
class PassThroughAgent(AgentBase):
|
|
def __init__(self, node_id: str):
|
|
super().__init__()
|
|
self.name = f"passthrough_{node_id}"
|
|
|
|
async def reply(self, msg, **kwargs) -> Msg:
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(msg), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class HttpRequestNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"HttpRequest_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
method = self.config.get("method", "GET").upper()
|
|
url = self.config.get("url", "")
|
|
headers = self.config.get("headers", {})
|
|
body = self.config.get("body", "")
|
|
auth_type = self.config.get("auth_type", "none")
|
|
auth_config = self.config.get("auth_config", {})
|
|
timeout = self.config.get("timeout", 30)
|
|
retry_count = self.config.get("retry_count", 0)
|
|
|
|
if isinstance(headers, str):
|
|
try:
|
|
headers = json.loads(headers)
|
|
except (json.JSONDecodeError, ValueError):
|
|
headers = {}
|
|
|
|
request_headers = dict(headers)
|
|
if auth_type == "bearer" and auth_config.get("token"):
|
|
request_headers["Authorization"] = f"Bearer {auth_config['token']}"
|
|
elif auth_type == "api_key" and auth_config.get("api_key"):
|
|
key_name = auth_config.get("key_name", "X-API-Key")
|
|
request_headers[key_name] = auth_config["api_key"]
|
|
elif auth_type == "basic" and auth_config.get("username"):
|
|
import base64
|
|
credentials = f"{auth_config['username']}:{auth_config.get('password', '')}"
|
|
request_headers["Authorization"] = f"Basic {base64.b64encode(credentials.encode()).decode()}"
|
|
|
|
last_error = None
|
|
for attempt in range(retry_count + 1):
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
|
|
request_kwargs = {"headers": request_headers}
|
|
if method in ("POST", "PUT", "PATCH"):
|
|
if isinstance(body, dict):
|
|
request_kwargs["json"] = body
|
|
elif body and isinstance(body, str):
|
|
try:
|
|
request_kwargs["json"] = json.loads(body)
|
|
except (json.JSONDecodeError, ValueError):
|
|
request_kwargs["content"] = body
|
|
response = await client.request(method, url, **request_kwargs)
|
|
resp_status = response.status_code
|
|
resp_text = response.text
|
|
try:
|
|
resp_body = json.loads(resp_text)
|
|
except (json.JSONDecodeError, ValueError):
|
|
resp_body = resp_text
|
|
|
|
result = json.dumps({
|
|
"status_code": resp_status,
|
|
"body": resp_body,
|
|
}, ensure_ascii=False)
|
|
return Msg(self.name, result, "assistant")
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
if attempt < retry_count:
|
|
await asyncio.sleep(1)
|
|
|
|
error_result = json.dumps({
|
|
"status_code": 0,
|
|
"error": last_error or "unknown error",
|
|
"body": None,
|
|
}, ensure_ascii=False)
|
|
return Msg(self.name, error_result, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class QuestionClassifierNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Classifier_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
categories = self.config.get("categories", [])
|
|
instruction = self.config.get("instruction", "")
|
|
model_name = self.config.get("model", settings.LLM_MODEL)
|
|
temperature = self.config.get("temperature", 0.3)
|
|
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not categories:
|
|
return Msg(self.name, json.dumps({"category": "default", "confidence": 1.0}), "assistant")
|
|
|
|
category_desc = "\n".join([f'{c.get("name")}: {c.get("description", "")}' for c in categories])
|
|
prompt = f"""请对以下用户输入进行意图分类。
|
|
|
|
分类选项:
|
|
{category_desc}
|
|
|
|
{instruction}
|
|
|
|
用户输入:{user_text}
|
|
|
|
请只返回一个JSON对象,格式为:{{"category": "分类名称", "confidence": 0.0~1.0}}"""
|
|
|
|
try:
|
|
import openai
|
|
client = openai.AsyncOpenAI(
|
|
base_url=settings.LLM_API_BASE,
|
|
api_key=settings.LLM_API_KEY,
|
|
)
|
|
response = await client.chat.completions.create(
|
|
model=model_name,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
temperature=temperature,
|
|
max_tokens=200,
|
|
)
|
|
result_text = response.choices[0].message.content.strip()
|
|
try:
|
|
parsed = json.loads(result_text)
|
|
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
|
|
except (json.JSONDecodeError, ValueError):
|
|
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.5, "raw": result_text}), "assistant")
|
|
except Exception as e:
|
|
logger.error(f"QuestionClassifier error: {e}")
|
|
return Msg(self.name, json.dumps({"category": "default", "confidence": 0.0, "error": str(e)}), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class VariableAssignerNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None, context: dict = None):
|
|
super().__init__()
|
|
self.name = f"VarAssign_{node_id}"
|
|
self.config = config or {}
|
|
self._context = context or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
assignments = self.config.get("assignments", [])
|
|
results = {}
|
|
for assignment in assignments:
|
|
target_var = assignment.get("target_var", "")
|
|
source_type = assignment.get("source_type", "constant")
|
|
source_value = assignment.get("source_value", "")
|
|
|
|
if source_type == "constant":
|
|
value = source_value
|
|
elif source_type == "upstream_output":
|
|
value = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
elif source_type == "template":
|
|
value = _resolve_template(source_value, self._context, msg)
|
|
elif source_type == "expression":
|
|
try:
|
|
safe_locals = {"msg": msg, "context": self._context}
|
|
value = eval(source_value, {"__builtins__": {}}, safe_locals)
|
|
value = str(value)
|
|
except Exception as e:
|
|
value = f"[expression error: {e}]"
|
|
else:
|
|
value = source_value
|
|
|
|
results[target_var] = value
|
|
|
|
output = json.dumps(results, ensure_ascii=False)
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LLMNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, system_prompt: str, model_name: str = "", temperature: float = 0.7, max_tokens: int = 2000, stream: bool = True, stream_callback=None, base_url: str = "", api_key: str = ""):
|
|
super().__init__()
|
|
self.name = f"LLM_{node_id}"
|
|
self.system_prompt = system_prompt
|
|
self.model_name = model_name or settings.LLM_MODEL
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.stream = stream
|
|
self.stream_callback = stream_callback
|
|
self.base_url = base_url or settings.LLM_API_BASE
|
|
self.api_key = api_key or settings.LLM_API_KEY
|
|
self._memory = None
|
|
|
|
def set_memory(self, memory):
|
|
self._memory = memory
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
context = kwargs.get("context", {})
|
|
memory_ctx = context.get("_memory_context", {})
|
|
|
|
messages = [{"role": "system", "content": self.system_prompt}]
|
|
|
|
if memory_ctx:
|
|
summary = memory_ctx.get("summary", "")
|
|
recent = memory_ctx.get("recent_messages", [])
|
|
if summary:
|
|
messages.append({"role": "system", "content": f"[历史对话摘要]\n{summary}"})
|
|
for m in recent[-10:]:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
if len(content) > 2000:
|
|
content = content[:2000]
|
|
messages.append({"role": role, "content": content})
|
|
elif self._memory:
|
|
history = self._memory.get_history(limit=5)
|
|
for h in history:
|
|
role = h.get("role", "user")
|
|
content = h.get("content", "")
|
|
if len(content) > 2000:
|
|
content = content[:2000]
|
|
messages.append({"role": role, "content": content})
|
|
|
|
messages.append({"role": "user", "content": user_text})
|
|
|
|
try:
|
|
if self.stream_callback:
|
|
res_text = await self._stream_llm_call(messages)
|
|
else:
|
|
res_text = await self._blocking_llm_call(messages)
|
|
except Exception as e:
|
|
logger.warning(f"LLM 调用失败: {e}")
|
|
res_text = f"[LLM 调用失败] 已接收输入: {user_text[:200]}"
|
|
|
|
if self._memory and not memory_ctx:
|
|
self._memory.add("user", user_text)
|
|
self._memory.add("assistant", res_text)
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
|
|
async def _blocking_llm_call(self, messages: list[dict]) -> str:
|
|
import httpx
|
|
|
|
api_base = self.base_url.rstrip("/")
|
|
api_key = self.api_key
|
|
model_name = self.model_name or settings.LLM_MODEL
|
|
|
|
url = f"{api_base}/chat/completions"
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
body = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": False,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
|
|
resp = await client.post(url, json=body, headers=headers)
|
|
data = resp.json()
|
|
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
except Exception as e:
|
|
logger.warning(f"LLM 阻塞调用失败: {e}")
|
|
return f"[LLM 调用失败: {e}]"
|
|
|
|
async def _stream_llm_call(self, messages: list[dict]) -> str:
|
|
import httpx
|
|
import json
|
|
|
|
api_base = self.base_url.rstrip("/")
|
|
api_key = self.api_key
|
|
model_name = self.model_name or settings.LLM_MODEL
|
|
|
|
url = f"{api_base}/chat/completions"
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
body = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": True,
|
|
}
|
|
|
|
accumulated = ""
|
|
try:
|
|
timeout = httpx.Timeout(60.0, connect=10.0)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
async with client.stream("POST", url, json=body, headers=headers) as response:
|
|
async for line in response.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
json_str = line[6:].strip()
|
|
if json_str == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(json_str)
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
token = delta.get("content", "")
|
|
if token:
|
|
accumulated += token
|
|
await self.stream_callback("text_chunk", {"content": token})
|
|
except json.JSONDecodeError:
|
|
continue
|
|
except httpx.TimeoutException:
|
|
logger.warning("LLM 流式调用超时")
|
|
if not accumulated:
|
|
accumulated = "[LLM 超时]"
|
|
except Exception as e:
|
|
logger.warning(f"LLM 流式调用失败: {e}")
|
|
if not accumulated:
|
|
accumulated = f"[LLM 调用失败: {e}]"
|
|
|
|
return accumulated
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ToolNodeAgent(AgentBase):
|
|
_TOOL_REGISTRY: dict[str, callable] = {}
|
|
_TOOL_SCHEMAS: dict[str, dict] = {}
|
|
_CUSTOM_TOOL_DEFS: dict[str, dict] = {}
|
|
|
|
@classmethod
|
|
def _init_registry(cls):
|
|
if cls._TOOL_REGISTRY:
|
|
return
|
|
try:
|
|
from agentscope_integration.tools.document_tools import parse_document, format_correction
|
|
from agentscope_integration.tools.wecom_tools import send_notification, query_wecom_user
|
|
from agentscope_integration.tools.task_tools import list_tasks, create_task, get_task, update_task, push_task_to_wecom
|
|
from agentscope_integration.tools.manager_tools import list_subordinates, generate_efficiency_report, get_task_statistics, get_employee_dashboard
|
|
|
|
cls._TOOL_REGISTRY = {
|
|
"parse_document": parse_document,
|
|
"format_correction": format_correction,
|
|
"send_notification": send_notification,
|
|
"query_wecom_user": query_wecom_user,
|
|
"list_tasks": list_tasks,
|
|
"create_task": create_task,
|
|
"get_task": get_task,
|
|
"update_task": update_task,
|
|
"push_task_to_wecom": push_task_to_wecom,
|
|
"list_subordinates": list_subordinates,
|
|
"generate_efficiency_report": generate_efficiency_report,
|
|
"get_task_statistics": get_task_statistics,
|
|
"get_employee_dashboard": get_employee_dashboard,
|
|
}
|
|
|
|
for mod_name in ["task_tools", "manager_tools", "wecom_tools", "document_tools"]:
|
|
try:
|
|
mod = __import__(f"agentscope_integration.tools.{mod_name}", fromlist=["SCHEMAS"])
|
|
if hasattr(mod, "SCHEMAS"):
|
|
cls._TOOL_SCHEMAS.update(mod.SCHEMAS)
|
|
except Exception:
|
|
pass
|
|
except ImportError as e:
|
|
logger.warning(f"工具注册失败: {e}")
|
|
|
|
@classmethod
|
|
async def load_custom_tools(cls, db):
|
|
try:
|
|
from sqlalchemy import select
|
|
from models import CustomTool
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
|
|
result = await db.execute(
|
|
select(CustomTool).where(CustomTool.is_active == True)
|
|
)
|
|
for tool in result.scalars().all():
|
|
tool_def = {
|
|
"endpoint_url": tool.endpoint_url,
|
|
"method": tool.method,
|
|
"path": tool.path,
|
|
"headers_json": tool.headers_json,
|
|
"auth_type": tool.auth_type,
|
|
"auth_config": tool.auth_config,
|
|
"timeout": 30,
|
|
}
|
|
cls._CUSTOM_TOOL_DEFS[tool.name] = tool_def
|
|
cls._TOOL_SCHEMAS[tool.name] = tool.schema_json
|
|
except ImportError:
|
|
logger.warning("无法加载自定义工具模块")
|
|
except Exception as e:
|
|
logger.warning(f"加载自定义工具失败: {e}")
|
|
|
|
@classmethod
|
|
def register_custom_tool(cls, name: str, schema: dict, tool_def: dict):
|
|
cls._CUSTOM_TOOL_DEFS[name] = tool_def
|
|
cls._TOOL_SCHEMAS[name] = schema
|
|
|
|
@classmethod
|
|
def get_schemas(cls) -> dict:
|
|
cls._init_registry()
|
|
return dict(cls._TOOL_SCHEMAS)
|
|
|
|
def __init__(self, node_id: str, tool_name: str = "", tool_params: dict = None, timeout: int = 30, retry_count: int = 0, error_handling: str = "throw"):
|
|
super().__init__()
|
|
self.name = f"Tool_{node_id}"
|
|
self.tool_name = tool_name
|
|
self.tool_params = tool_params or {}
|
|
self.timeout = timeout
|
|
self.retry_count = retry_count
|
|
self.error_handling = error_handling
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
self._init_registry()
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
custom_def = self._CUSTOM_TOOL_DEFS.get(self.tool_name)
|
|
if custom_def:
|
|
return await self._execute_custom_tool(custom_def, user_text)
|
|
|
|
tool_func = self._TOOL_REGISTRY.get(self.tool_name)
|
|
if not tool_func:
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 未找到,已跳过]", "assistant")
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 未找到或在当前节点中不可用", "assistant")
|
|
|
|
def _call_sync():
|
|
try:
|
|
result = tool_func(**self.tool_params) if self.tool_params else tool_func()
|
|
return result
|
|
except TypeError:
|
|
try:
|
|
result = tool_func(user_text, **self.tool_params)
|
|
return result
|
|
except Exception as e:
|
|
raise RuntimeError(f"工具执行失败: {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"工具执行失败: {e}")
|
|
|
|
for attempt in range(self.retry_count + 1):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
asyncio.to_thread(_call_sync),
|
|
timeout=self.timeout,
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except asyncio.TimeoutError:
|
|
if attempt < self.retry_count:
|
|
logger.warning(f"工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 超时,已跳过]", "assistant")
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
|
|
except RuntimeError as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
if self.error_handling == "default":
|
|
return Msg(self.name, "{}", "assistant")
|
|
return Msg(self.name, f"[{e}]", "assistant")
|
|
except Exception as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
return Msg(self.name, f"[工具执行失败: {e}]", "assistant")
|
|
|
|
return Msg(self.name, f"[工具 {self.tool_name}] 执行失败", "assistant")
|
|
|
|
async def _execute_custom_tool(self, custom_def: dict, user_text: str) -> Msg:
|
|
from modules.custom_tool.executor import CustomToolExecutor
|
|
|
|
executor = CustomToolExecutor(custom_def)
|
|
|
|
for attempt in range(self.retry_count + 1):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
executor.execute(self.tool_params),
|
|
timeout=self.timeout,
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except asyncio.TimeoutError:
|
|
if attempt < self.retry_count:
|
|
logger.warning(f"自定义工具 {self.tool_name} 超时,重试 {attempt + 1}/{self.retry_count}")
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name} 超时,已跳过]", "assistant")
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行超时 ({self.timeout}s)", "assistant")
|
|
except Exception as e:
|
|
if attempt < self.retry_count:
|
|
continue
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name} 失败,已跳过: {e}]", "assistant")
|
|
if self.error_handling == "default":
|
|
return Msg(self.name, "{}", "assistant")
|
|
return Msg(self.name, f"[自定义工具执行失败: {e}]", "assistant")
|
|
|
|
return Msg(self.name, f"[自定义工具 {self.tool_name}] 执行失败", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ConditionNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, condition: str = "", condition_type: str = "expression"):
|
|
super().__init__()
|
|
self.name = f"Condition_{node_id}"
|
|
self.condition = condition
|
|
self.condition_type = condition_type
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.condition:
|
|
return Msg(self.name, "condition:true|条件为空,默认通过", "assistant")
|
|
|
|
if self.condition_type == "regex":
|
|
try:
|
|
matched = bool(re.search(self.condition, user_text))
|
|
return Msg(self.name, f"condition:{'true' if matched else 'false'}|正则匹配:{'命' if matched else '未命'}中", "assistant")
|
|
except re.error as e:
|
|
return Msg(self.name, f"condition:true|正则错误:{e},默认通过", "assistant")
|
|
|
|
if self.condition_type == "json_path":
|
|
try:
|
|
data = json.loads(user_text) if user_text.strip().startswith('{') else {}
|
|
parts = self.condition.strip('$.').split('.')
|
|
val = data
|
|
for p in parts:
|
|
if isinstance(val, dict):
|
|
val = val.get(p, None)
|
|
else:
|
|
val = None
|
|
return Msg(self.name, f"condition:{'true' if val else 'false'}|JSON路径:{self.condition}", "assistant")
|
|
except:
|
|
return Msg(self.name, "condition:true|JSON解析失败,默认通过", "assistant")
|
|
|
|
try:
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
|
|
model = OpenAIChatModel(
|
|
config_name=f"condition_{self.name}",
|
|
model_name=settings.LLM_MODEL,
|
|
api_key=settings.LLM_API_KEY,
|
|
api_base=settings.LLM_API_BASE,
|
|
)
|
|
|
|
formatter = OpenAIChatFormatter()
|
|
condition_prompt = f"""你是一个条件判断专家。请判断以下条件表达式是否基于输入内容满足。
|
|
|
|
条件表达式: {self.condition}
|
|
|
|
输入内容:
|
|
{user_text[:2000]}
|
|
|
|
请严格只输出一行 JSON:
|
|
{{"result": true/false, "reason": "简要原因"}}"""
|
|
|
|
prompt = formatter.format([
|
|
Msg("system", condition_prompt, "system"),
|
|
Msg("user", user_text[:2000], "user"),
|
|
])
|
|
res = await model(prompt)
|
|
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
|
|
json_match = re.search(r'\{[^}]+\}', res_text)
|
|
if json_match:
|
|
parsed = json.loads(json_match.group())
|
|
matched = parsed.get("result", False)
|
|
reason = parsed.get("reason", "")
|
|
result_flag = "true" if matched else "false"
|
|
return Msg(self.name, f"condition:{result_flag}|{reason}", "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"条件判断LLM调用失败: {e}")
|
|
|
|
return Msg(self.name, f"condition:true|条件判断失败,默认通过: {self.condition[:80]}", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class MCPNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, server_name: str = "", tool_name: str = "", timeout: int = 30, error_handling: str = "throw"):
|
|
super().__init__()
|
|
self.name = f"MCP_{node_id}"
|
|
self.server_name = server_name
|
|
self.tool_name = tool_name
|
|
self.timeout = timeout
|
|
self.error_handling = error_handling
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if not self.server_name:
|
|
return Msg(self.name, "[MCP] 未指定 MCP 服务名称", "assistant")
|
|
|
|
try:
|
|
from agentscope_runtime.engine.deployers.routing.task_engine_mixin import MCPClientManager
|
|
|
|
client = MCPClientManager.get_http_client(self.server_name)
|
|
if client and self.tool_name:
|
|
result = await asyncio.wait_for(
|
|
client.call_tool(self.tool_name, {"input": user_text}),
|
|
timeout=self.timeout
|
|
)
|
|
return Msg(self.name, str(result), "assistant")
|
|
except ImportError:
|
|
logger.warning("agentscope_runtime MCP 客户端不可用")
|
|
except asyncio.TimeoutError:
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[MCP] {self.server_name} 超时,已跳过", "assistant")
|
|
return Msg(self.name, f"[MCP] {self.server_name} 调用超时 ({self.timeout}s)", "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"MCP 调用失败: {e}")
|
|
if self.error_handling == "skip":
|
|
return Msg(self.name, f"[MCP] {self.server_name} 失败,已跳过", "assistant")
|
|
|
|
output = f"[MCP] 服务 {self.server_name} 调用完成: 已处理输入"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class NotifyAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Notify_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
channels = self.config.get("channels", {"wecom": True, "web": False})
|
|
message_type = self.config.get("message_type", "text")
|
|
|
|
results = []
|
|
|
|
if channels.get("wecom", True):
|
|
template = self.config.get("message_template", "")
|
|
target = self.config.get("target", "")
|
|
message = template or user_text[:500]
|
|
try:
|
|
from agentscope_integration.tools.wecom_tools import send_notification
|
|
result = send_notification(to_user=target or "user", message=message)
|
|
results.append(f"企微通知: {result[:100]}")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"企微通知发送失败: {e}")
|
|
results.append(f"企微通知: 已向 {target or '用户'} 发送: {message[:100]}")
|
|
|
|
if channels.get("web", False):
|
|
web_template = self.config.get("web_template", "")
|
|
web_message = web_template or user_text[:500]
|
|
try:
|
|
from websocket_manager import ws_manager
|
|
target_user = self.config.get("target", "")
|
|
if target_user:
|
|
await ws_manager.send_to_user(target_user, {
|
|
"type": "flow_notification",
|
|
"message": web_message,
|
|
"level": "info",
|
|
})
|
|
results.append(f"Web通知: 已推送")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"Web通知推送失败: {e}")
|
|
results.append(f"Web通知: 推送失败({e})")
|
|
|
|
if not results:
|
|
results.append("通知已发送")
|
|
return Msg(self.name, " | ".join(results), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class LoopNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Loop_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
loop_type = self.config.get("loop_type", "fixed")
|
|
count = self.config.get("count", 3)
|
|
iterator_variable = self.config.get("iterator_variable", "item")
|
|
iteration = self.config.get("_engine_iteration", 0) + 1
|
|
|
|
if loop_type == "array":
|
|
items = self.config.get("items", [])
|
|
if isinstance(items, str):
|
|
try:
|
|
items = json.loads(items)
|
|
except Exception:
|
|
items = [line.strip() for line in items.split("\n") if line.strip()]
|
|
if isinstance(items, list) and items:
|
|
idx = iteration - 1
|
|
if idx >= len(items):
|
|
return Msg(self.name, "loop:stop|数组迭代完成", "assistant")
|
|
current_item = items[idx]
|
|
return Msg(self.name, f"[迭代 {iteration}/{len(items)}] 变量: {iterator_variable}={current_item}\n输入: {user_text[:200]}", "assistant")
|
|
return Msg(self.name, "loop:stop|无数组数据", "assistant")
|
|
|
|
if loop_type == "fixed":
|
|
if iteration >= count:
|
|
return Msg(self.name, f"loop:stop|固定次数循环完成: {iteration}/{count}", "assistant")
|
|
return Msg(self.name, f"[循环 {iteration}/{count}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
|
|
|
|
return Msg(self.name, f"[循环 {iteration}] 变量: {iterator_variable}={iteration}\n输入: {user_text[:200]}", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class CodeNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, language: str = "python", code: str = "", timeout: int = 30, sandbox: bool = True):
|
|
super().__init__()
|
|
self.name = f"Code_{node_id}"
|
|
self.language = language
|
|
self.code = code
|
|
self.timeout = timeout
|
|
self.sandbox = sandbox
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
import subprocess
|
|
import tempfile
|
|
import os
|
|
import json
|
|
|
|
if not self.code.strip():
|
|
return Msg(self.name, user_text, "assistant")
|
|
|
|
safe_input = json.dumps(user_text)
|
|
code_with_input = f"import json\nINPUT_TEXT = json.loads({safe_input})\n\n{self.code}"
|
|
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
|
f.write(code_with_input)
|
|
temp_path = f.name
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"python", temp_path,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
return Msg(self.name, f"[代码执行超时 ({self.timeout}s)]", "assistant")
|
|
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
|
|
if stderr:
|
|
return Msg(self.name, f"[错误] {stderr.decode('utf-8', errors='replace')[:500]}", "assistant")
|
|
|
|
output = stdout.decode('utf-8', errors='replace').strip()
|
|
return Msg(self.name, output or "[代码执行完成,无输出]", "assistant")
|
|
|
|
except Exception as e:
|
|
return Msg(self.name, f"[代码执行失败: {e}]", "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class RAGNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"RAG_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
top_k = self.config.get("top_k", 5)
|
|
similarity_threshold = self.config.get("similarity_threshold", 0.7)
|
|
search_mode = self.config.get("search_mode", "hybrid")
|
|
include_metadata = self.config.get("include_metadata", True)
|
|
|
|
try:
|
|
from modules.rag.knowledge import retrieve_for_agent
|
|
kb_result = await retrieve_for_agent(user_text, limit=top_k)
|
|
|
|
model = self._get_model()
|
|
formatter = self._get_formatter()
|
|
|
|
rag_prompt = f"""你是一个知识检索助手。请基于以下知识库检索结果回答用户问题。
|
|
|
|
知识库检索结果:
|
|
{kb_result}
|
|
|
|
用户问题: {user_text}
|
|
|
|
请基于以上知识库内容给出专业回答。如果知识库中没有相关信息,请诚实说明。"""
|
|
|
|
prompt = formatter.format([
|
|
Msg("system", rag_prompt, "system"),
|
|
Msg("user", user_text, "user"),
|
|
])
|
|
|
|
res = await model(prompt)
|
|
res_text = ""
|
|
if isinstance(res, list):
|
|
res_text = res[0].get_text_content() if hasattr(res[0], 'get_text_content') else str(res[0])
|
|
elif hasattr(res, 'get_text_content'):
|
|
res_text = res.get_text_content()
|
|
else:
|
|
res_text = str(res)
|
|
|
|
return Msg(self.name, res_text, "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"RAG 节点执行失败: {e}")
|
|
|
|
output = f"[RAG检索] 知识库检索:\n查询: {user_text[:200]}\nTopK: {top_k}"
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
def _get_model(self):
|
|
from agentscope.model import OpenAIChatModel
|
|
resolved = self.config.get("_resolved_model", {})
|
|
return OpenAIChatModel(
|
|
config_name=f"rag_{self.name}",
|
|
model_name=resolved.get("model", settings.LLM_MODEL),
|
|
api_key=resolved.get("api_key", settings.LLM_API_KEY),
|
|
api_base=resolved.get("base_url", settings.LLM_API_BASE),
|
|
)
|
|
|
|
def _get_formatter(self):
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
return OpenAIChatFormatter()
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class OutputNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Output_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
output_format = self.config.get("format", "text")
|
|
output_template = self.config.get("output_template", "")
|
|
truncate = self.config.get("truncate", False)
|
|
max_length = self.config.get("max_length", 2000)
|
|
|
|
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
if truncate and len(content) > max_length:
|
|
content = content[:max_length] + "\n\n[内容已截断]"
|
|
|
|
if output_format == "json":
|
|
try:
|
|
parsed = json.loads(content)
|
|
indent = self.config.get("indent", 2)
|
|
formatted = json.dumps(parsed, indent=indent, ensure_ascii=False)
|
|
return Msg(self.name, formatted, "assistant")
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
if output_template:
|
|
try:
|
|
resolved = output_template.replace("{{output}}", content)
|
|
return Msg(self.name, resolved, "assistant")
|
|
except:
|
|
pass
|
|
|
|
return msg if isinstance(msg, Msg) else Msg(self.name, str(content), "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
def _resolve_template(template: str, context: dict, current_msg: Msg) -> str:
|
|
result = template
|
|
import re
|
|
placeholders = re.findall(r'\{\{(.+?)\}\}', template)
|
|
for placeholder in placeholders:
|
|
parts = placeholder.strip().split(".")
|
|
value = ""
|
|
if parts[0] == "trigger":
|
|
value = str(context.get("trigger_data", {}).get(".".join(parts[1:]), ""))
|
|
elif parts[0] in context.get("_node_results", {}):
|
|
node_result = context.get("_node_results", {}).get(parts[0], {})
|
|
if len(parts) > 1:
|
|
value = str(node_result.get("output", "") if parts[1] == "output" else node_result.get(parts[1], ""))
|
|
else:
|
|
value = str(node_result.get("output", ""))
|
|
result = result.replace("{{" + placeholder + "}}", value)
|
|
return result
|
|
|
|
|
|
class TemplateTransformNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Template_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
template = self.config.get("template", "")
|
|
output_type = self.config.get("output_type", "string")
|
|
context = kwargs.get("context", {})
|
|
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
|
|
try:
|
|
rendered = _resolve_template(template, context, msg)
|
|
if not rendered:
|
|
rendered = template
|
|
rendered = rendered.replace("{{input}}", user_text)
|
|
except Exception:
|
|
rendered = template.replace("{{input}}", user_text) if "{{input}}" in template else user_text
|
|
|
|
if output_type == "json":
|
|
try:
|
|
parsed = json.loads(rendered)
|
|
return Msg(self.name, json.dumps(parsed, ensure_ascii=False), "assistant")
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
return Msg(self.name, rendered, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class IterationNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Iteration_{node_id}"
|
|
self.config = config or {}
|
|
self._results: list[str] = []
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
input_array_source = self.config.get("input_array_source", "")
|
|
max_iterations = self.config.get("max_iterations", 20)
|
|
context = kwargs.get("context", {})
|
|
|
|
items = []
|
|
if input_array_source:
|
|
resolved = _resolve_template(input_array_source, context, msg)
|
|
try:
|
|
items = json.loads(resolved)
|
|
if not isinstance(items, list):
|
|
items = [resolved]
|
|
except (json.JSONDecodeError, ValueError):
|
|
items = [resolved] if resolved else []
|
|
else:
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
try:
|
|
items = json.loads(user_text)
|
|
if not isinstance(items, list):
|
|
items = [user_text]
|
|
except (json.JSONDecodeError, ValueError):
|
|
items = [item.strip() for item in user_text.split("\n") if item.strip()]
|
|
if not items:
|
|
items = [user_text]
|
|
|
|
items = items[:max_iterations]
|
|
results = []
|
|
for i, item in enumerate(items):
|
|
results.append({
|
|
"index": i,
|
|
"item": item,
|
|
})
|
|
|
|
output = json.dumps(results, ensure_ascii=False)
|
|
return Msg(self.name, output, "assistant")
|
|
|
|
def get_iteration_items(self) -> list:
|
|
return self._results
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class QuestionOptimiserNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"QOpt_{node_id}"
|
|
self.config = config or {}
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
optimization_type = self.config.get("optimization_type", "rewrite")
|
|
model_name = self.config.get("model", settings.LLM_MODEL)
|
|
context = kwargs.get("context", {})
|
|
|
|
user_text = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
history = context.get("_memory", {}).get("recent_messages", [])
|
|
persona = context.get("_memory", {}).get("persona", {})
|
|
atoms = context.get("_memory", {}).get("atoms", [])
|
|
|
|
persona_text = persona.get("raw_text", "")
|
|
atoms_text = "\n".join([a.get("content", "") for a in atoms[:5]]) if atoms else ""
|
|
history_text = "\n".join(
|
|
f"{m['role']}: {m['content'][:200]}" for m in history[-6:]
|
|
) if history else ""
|
|
|
|
context_parts = []
|
|
if persona_text:
|
|
context_parts.append(f"用户画像: {persona_text}")
|
|
if atoms_text:
|
|
context_parts.append(f"已知信息: {atoms_text}")
|
|
if history_text:
|
|
context_parts.append(f"近期对话: {history_text}")
|
|
|
|
context_block = "\n".join(context_parts) if context_parts else ""
|
|
|
|
if optimization_type == "rewrite":
|
|
prompt = f"""{context_block}
|
|
|
|
原始问题: {user_text}
|
|
|
|
请将以上问题进行优化改写,使其更清晰、具体、完整。补充可能缺失的上下文信息。
|
|
只返回优化后的问题,不要其他内容。"""
|
|
elif optimization_type == "expand":
|
|
prompt = f"""{context_block}
|
|
|
|
简短问题: {user_text}
|
|
|
|
请将以上问题扩展为更详细的版本,添加必要的背景和细节。
|
|
只返回扩展后的问题,不要其他内容。"""
|
|
else:
|
|
return Msg(self.name, user_text, "assistant")
|
|
|
|
try:
|
|
import httpx
|
|
api_base = settings.LLM_API_BASE.rstrip("/")
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{api_base}/chat/completions",
|
|
json={
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 300,
|
|
"temperature": 0.3,
|
|
},
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"},
|
|
)
|
|
data = resp.json()
|
|
result = data.get("choices", [{}])[0].get("message", {}).get("content", user_text)
|
|
return Msg(self.name, result.strip(), "assistant")
|
|
except Exception as e:
|
|
logger.warning(f"问题优化失败: {e}")
|
|
return Msg(self.name, user_text, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|
|
|
|
|
|
class ParallelMergeNodeAgent(AgentBase):
|
|
def __init__(self, node_id: str, config: dict = None):
|
|
super().__init__()
|
|
self.name = f"Merge_{node_id}"
|
|
self.config = config or {}
|
|
self._received: dict[str, str] = {}
|
|
self.merge_type = self.config.get("merge_type", "concat")
|
|
|
|
async def reply(self, msg: Msg, **kwargs) -> Msg:
|
|
source_id = kwargs.get("source_node_id", "")
|
|
content = msg.get_text_content() if hasattr(msg, 'get_text_content') else str(msg)
|
|
self._received[source_id] = content
|
|
|
|
expected_count = self.config.get("expected_branches", 0)
|
|
if expected_count <= 0 or len(self._received) >= expected_count:
|
|
return self._merge()
|
|
return Msg(self.name, "", "assistant")
|
|
|
|
def _merge(self) -> Msg:
|
|
if self.merge_type == "json":
|
|
merged = json.dumps(self._received, ensure_ascii=False)
|
|
elif self.merge_type == "first_non_empty":
|
|
merged = ""
|
|
for v in self._received.values():
|
|
if v.strip():
|
|
merged = v
|
|
break
|
|
else:
|
|
parts = []
|
|
for k, v in self._received.items():
|
|
parts.append(v)
|
|
merged = "\n\n---\n\n".join(parts)
|
|
return Msg(self.name, merged, "assistant")
|
|
|
|
async def observe(self, msg) -> None:
|
|
pass
|