This commit is contained in:
2025-11-07 09:05:41 +08:00
parent 6f4841f86d
commit 162ba8a7dd

View File

@@ -0,0 +1,220 @@
"""
处理链管理器 - 管理LangChain处理链
"""
from typing import Dict, Any
import logging
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable
from ..config.prompts import PromptTemplates
from ..utils.logger import LoggerMixin
logger = logging.getLogger(__name__)
class ChainManager(LoggerMixin):
"""处理链管理器"""
def __init__(self, llm: BaseChatModel):
self.llm = llm
self.prompt_templates = PromptTemplates()
self._chains: Dict[str, Runnable] = {}
self._setup_chains()
def _setup_chains(self):
"""设置所有处理链"""
self.logger.debug("正在设置处理链...")
# 文档结构分析链
self._chains["structure_analysis"] = (
self.prompt_templates.get_structure_analysis_prompt()
| self.llm
| StrOutputParser()
)
# 初始PPT框架生成链
self._chains["initial_outline"] = (
self.prompt_templates.get_initial_outline_prompt()
| self.llm
| StrOutputParser()
)
# 内容细化链
self._chains["refine_outline"] = (
self.prompt_templates.get_refine_outline_prompt()
| self.llm
| StrOutputParser()
)
# 错误恢复链
self._chains["error_recovery"] = (
self.prompt_templates.get_error_recovery_prompt()
| self.llm
| StrOutputParser()
)
self.logger.debug(f"已设置 {len(self._chains)} 个处理链")
def get_chain(self, chain_name: str) -> Runnable:
"""
获取指定的处理链
Args:
chain_name: 链名称
Returns:
处理链实例
Raises:
KeyError: 链不存在
"""
if chain_name not in self._chains:
raise KeyError(f"处理链不存在: {chain_name}")
return self._chains[chain_name]
async def invoke_chain(
self,
chain_name: str,
inputs: Dict[str, Any],
config: Dict[str, Any] = None
) -> str:
"""
调用指定的处理链
Args:
chain_name: 链名称
inputs: 输入参数
config: 运行配置
Returns:
处理结果
"""
chain = self.get_chain(chain_name)
try:
self.logger.debug(f"调用处理链: {chain_name}")
result = await chain.ainvoke(inputs, config or {})
self.logger.debug(f"处理链 {chain_name} 执行成功")
return result
except Exception as e:
self.logger.error(f"处理链 {chain_name} 执行失败: {e}")
raise
def list_chains(self) -> list:
"""列出所有可用的处理链"""
return list(self._chains.keys())
def update_llm(self, llm: BaseChatModel):
"""更新LLM并重新设置处理链"""
self.logger.info("更新LLM并重新设置处理链")
self.llm = llm
self._setup_chains()
def add_custom_chain(self, name: str, chain: Runnable):
"""
添加自定义处理链
Args:
name: 链名称
chain: 处理链实例
"""
self._chains[name] = chain
self.logger.info(f"已添加自定义处理链: {name}")
def remove_chain(self, name: str):
"""
移除处理链
Args:
name: 链名称
"""
if name in self._chains:
del self._chains[name]
self.logger.info(f"已移除处理链: {name}")
else:
self.logger.warning(f"尝试移除不存在的处理链: {name}")
class ChainExecutor:
"""处理链执行器,提供重试和错误处理功能"""
def __init__(self, chain_manager: ChainManager, max_retries: int = 3):
self.chain_manager = chain_manager
self.max_retries = max_retries
self.logger = logging.getLogger(self.__class__.__name__)
async def execute_with_retry(
self,
chain_name: str,
inputs: Dict[str, Any],
config: Dict[str, Any] = None
) -> str:
"""
带重试的链执行
Args:
chain_name: 链名称
inputs: 输入参数
config: 运行配置
Returns:
处理结果
Raises:
Exception: 所有重试都失败后抛出最后一个异常
"""
last_exception = None
for attempt in range(self.max_retries):
try:
result = await self.chain_manager.invoke_chain(chain_name, inputs, config)
if attempt > 0:
self.logger.info(f"处理链 {chain_name} 在第 {attempt + 1} 次尝试后成功")
return result
except Exception as e:
last_exception = e
self.logger.warning(f"处理链 {chain_name}{attempt + 1} 次尝试失败: {e}")
if attempt < self.max_retries - 1:
# 可以在这里添加退避策略
import asyncio
await asyncio.sleep(1 * (attempt + 1)) # 简单的线性退避
self.logger.error(f"处理链 {chain_name}{self.max_retries} 次尝试后仍然失败")
raise last_exception
async def execute_with_fallback(
self,
primary_chain: str,
fallback_chain: str,
inputs: Dict[str, Any],
fallback_inputs: Dict[str, Any] = None,
config: Dict[str, Any] = None
) -> str:
"""
带回退的链执行
Args:
primary_chain: 主要处理链
fallback_chain: 回退处理链
inputs: 主要链的输入参数
fallback_inputs: 回退链的输入参数
config: 运行配置
Returns:
处理结果
"""
try:
return await self.execute_with_retry(primary_chain, inputs, config)
except Exception as e:
self.logger.warning(f"主要处理链 {primary_chain} 失败,尝试回退链 {fallback_chain}: {e}")
fallback_inputs = fallback_inputs or inputs
try:
return await self.execute_with_retry(fallback_chain, fallback_inputs, config)
except Exception as fallback_e:
self.logger.error(f"回退处理链 {fallback_chain} 也失败了: {fallback_e}")
raise fallback_e