|
|
|
@ -18,9 +18,10 @@ from datetime import datetime, timezone |
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
from redis.asyncio import Redis |
|
|
|
from sqlalchemy import text |
|
|
|
from sqlalchemy import text, select |
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
|
from config import settings |
|
|
|
from models import ModelInstance, ModelProvider |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@ -65,6 +66,62 @@ class MemoryManager: |
|
|
|
self.db_factory = db_factory |
|
|
|
self.redis = redis |
|
|
|
self._extract_tasks: dict[str, asyncio.Task] = {} # 后台提取任务追踪 |
|
|
|
self._llm_config_cache: dict | None = None # LLM 配置缓存 |
|
|
|
|
|
|
|
async def _get_llm_config(self) -> dict: |
|
|
|
"""从 model_instances 表获取 LLM 配置,带缓存。 |
|
|
|
|
|
|
|
优先级:数据库默认 LLM → 环境变量回退。 |
|
|
|
缓存有效期:5分钟(避免每次调用都查库)。 |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict: {model_name, api_key, base_url} |
|
|
|
""" |
|
|
|
if self._llm_config_cache is not None: |
|
|
|
return self._llm_config_cache |
|
|
|
try: |
|
|
|
db = self.db_factory() |
|
|
|
result = await db.execute( |
|
|
|
select(ModelInstance, ModelProvider) |
|
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) |
|
|
|
.where(ModelInstance.model_type == 'llm') |
|
|
|
.where(ModelInstance.is_default == True) |
|
|
|
.where(ModelInstance.is_active == True) |
|
|
|
.limit(1) |
|
|
|
) |
|
|
|
row = result.first() |
|
|
|
if not row: |
|
|
|
result2 = await db.execute( |
|
|
|
select(ModelInstance, ModelProvider) |
|
|
|
.join(ModelProvider, ModelInstance.provider_id == ModelProvider.id) |
|
|
|
.where(ModelInstance.model_type == 'llm') |
|
|
|
.where(ModelInstance.is_active == True) |
|
|
|
.limit(1) |
|
|
|
) |
|
|
|
row = result2.first() |
|
|
|
if row: |
|
|
|
instance, provider = row |
|
|
|
config = { |
|
|
|
"api_base": (provider.base_url or settings.LLM_API_BASE).rstrip("/"), |
|
|
|
"model": instance.model_name, |
|
|
|
"api_key": provider.api_key or settings.LLM_API_KEY, |
|
|
|
} |
|
|
|
else: |
|
|
|
config = { |
|
|
|
"api_base": settings.LLM_API_BASE.rstrip("/"), |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"api_key": settings.LLM_API_KEY, |
|
|
|
} |
|
|
|
self._llm_config_cache = config |
|
|
|
# 5分钟后自动清除缓存 |
|
|
|
asyncio.get_event_loop().call_later(300, lambda: setattr(self, '_llm_config_cache', None)) |
|
|
|
return config |
|
|
|
except Exception: |
|
|
|
return { |
|
|
|
"api_base": settings.LLM_API_BASE.rstrip("/"), |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"api_key": settings.LLM_API_KEY, |
|
|
|
} |
|
|
|
|
|
|
|
async def inject_memory( |
|
|
|
self, |
|
|
|
@ -394,19 +451,19 @@ class MemoryManager: |
|
|
|
) |
|
|
|
|
|
|
|
import httpx |
|
|
|
api_base = settings.LLM_API_BASE.rstrip("/") |
|
|
|
llm = await self._get_llm_config() |
|
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
|
|
resp = await client.post( |
|
|
|
f"{api_base}/chat/completions", |
|
|
|
f"{llm['api_base']}/chat/completions", |
|
|
|
json={ |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"model": llm["model"], |
|
|
|
"messages": [{ |
|
|
|
"role": "user", |
|
|
|
"content": f"请用一段话简要总结以下对话的关键内容。保留人名、任务、决策、时间等关键信息。\n\n{dialogue}" |
|
|
|
}], |
|
|
|
"max_tokens": 200, |
|
|
|
}, |
|
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, |
|
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"}, |
|
|
|
) |
|
|
|
data = resp.json() |
|
|
|
summary = data.get("choices", [{}])[0].get("message", {}).get("content", "") |
|
|
|
@ -515,17 +572,17 @@ class MemoryManager: |
|
|
|
只返回JSON数组,不要其他内容。如果没有可提取的信息返回空数组[]。""" |
|
|
|
|
|
|
|
import httpx |
|
|
|
api_base = settings.LLM_API_BASE.rstrip("/") |
|
|
|
llm = await self._get_llm_config() |
|
|
|
async with httpx.AsyncClient(timeout=60) as client: |
|
|
|
resp = await client.post( |
|
|
|
f"{api_base}/chat/completions", |
|
|
|
f"{llm['api_base']}/chat/completions", |
|
|
|
json={ |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"model": llm["model"], |
|
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
|
"max_tokens": 800, |
|
|
|
"temperature": 0.3, |
|
|
|
}, |
|
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, |
|
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"}, |
|
|
|
) |
|
|
|
data = resp.json() |
|
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") |
|
|
|
@ -704,17 +761,17 @@ class MemoryManager: |
|
|
|
只返回JSON数组,不要其他内容。""" |
|
|
|
|
|
|
|
import httpx |
|
|
|
api_base = settings.LLM_API_BASE.rstrip("/") |
|
|
|
llm = await self._get_llm_config() |
|
|
|
async with httpx.AsyncClient(timeout=60) as client: |
|
|
|
resp = await client.post( |
|
|
|
f"{api_base}/chat/completions", |
|
|
|
f"{llm['api_base']}/chat/completions", |
|
|
|
json={ |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"model": llm["model"], |
|
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
|
"max_tokens": 500, |
|
|
|
"temperature": 0.3, |
|
|
|
}, |
|
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, |
|
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"}, |
|
|
|
) |
|
|
|
data = resp.json() |
|
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "[]") |
|
|
|
@ -817,17 +874,17 @@ class MemoryManager: |
|
|
|
只返回JSON对象,不要其他内容。""" |
|
|
|
|
|
|
|
import httpx |
|
|
|
api_base = settings.LLM_API_BASE.rstrip("/") |
|
|
|
llm = await self._get_llm_config() |
|
|
|
async with httpx.AsyncClient(timeout=60) as client: |
|
|
|
resp = await client.post( |
|
|
|
f"{api_base}/chat/completions", |
|
|
|
f"{llm['api_base']}/chat/completions", |
|
|
|
json={ |
|
|
|
"model": settings.LLM_MODEL, |
|
|
|
"model": llm["model"], |
|
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
|
"max_tokens": 500, |
|
|
|
"temperature": 0.3, |
|
|
|
}, |
|
|
|
headers={"Authorization": f"Bearer {settings.LLM_API_KEY}"}, |
|
|
|
headers={"Authorization": f"Bearer {llm['api_key']}"}, |
|
|
|
) |
|
|
|
data = resp.json() |
|
|
|
result_text = data.get("choices", [{}])[0].get("message", {}).get("content", "{}") |
|
|
|
|