diff --git a/src/summeryanyfile/generators/chains.py b/src/summeryanyfile/generators/chains.py new file mode 100644 index 0000000..677f048 --- /dev/null +++ b/src/summeryanyfile/generators/chains.py @@ -0,0 +1,220 @@ +""" +处理链管理器 - 管理LangChain处理链 +""" + +from typing import Dict, Any +import logging +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable + +from ..config.prompts import PromptTemplates +from ..utils.logger import LoggerMixin + +logger = logging.getLogger(__name__) + + +class ChainManager(LoggerMixin): + """处理链管理器""" + + def __init__(self, llm: BaseChatModel): + self.llm = llm + self.prompt_templates = PromptTemplates() + self._chains: Dict[str, Runnable] = {} + self._setup_chains() + + def _setup_chains(self): + """设置所有处理链""" + self.logger.debug("正在设置处理链...") + + # 文档结构分析链 + self._chains["structure_analysis"] = ( + self.prompt_templates.get_structure_analysis_prompt() + | self.llm + | StrOutputParser() + ) + + # 初始PPT框架生成链 + self._chains["initial_outline"] = ( + self.prompt_templates.get_initial_outline_prompt() + | self.llm + | StrOutputParser() + ) + + # 内容细化链 + self._chains["refine_outline"] = ( + self.prompt_templates.get_refine_outline_prompt() + | self.llm + | StrOutputParser() + ) + + # 错误恢复链 + self._chains["error_recovery"] = ( + self.prompt_templates.get_error_recovery_prompt() + | self.llm + | StrOutputParser() + ) + + self.logger.debug(f"已设置 {len(self._chains)} 个处理链") + + def get_chain(self, chain_name: str) -> Runnable: + """ + 获取指定的处理链 + + Args: + chain_name: 链名称 + + Returns: + 处理链实例 + + Raises: + KeyError: 链不存在 + """ + if chain_name not in self._chains: + raise KeyError(f"处理链不存在: {chain_name}") + + return self._chains[chain_name] + + async def invoke_chain( + self, + chain_name: str, + inputs: Dict[str, Any], + config: Dict[str, Any] = None + ) -> str: + """ + 调用指定的处理链 + + Args: + chain_name: 链名称 + inputs: 输入参数 + config: 运行配置 + + Returns: + 处理结果 + """ + chain = self.get_chain(chain_name) + + try: + self.logger.debug(f"调用处理链: {chain_name}") + result = await chain.ainvoke(inputs, config or {}) + self.logger.debug(f"处理链 {chain_name} 执行成功") + return result + except Exception as e: + self.logger.error(f"处理链 {chain_name} 执行失败: {e}") + raise + + def list_chains(self) -> list: + """列出所有可用的处理链""" + return list(self._chains.keys()) + + def update_llm(self, llm: BaseChatModel): + """更新LLM并重新设置处理链""" + self.logger.info("更新LLM并重新设置处理链") + self.llm = llm + self._setup_chains() + + def add_custom_chain(self, name: str, chain: Runnable): + """ + 添加自定义处理链 + + Args: + name: 链名称 + chain: 处理链实例 + """ + self._chains[name] = chain + self.logger.info(f"已添加自定义处理链: {name}") + + def remove_chain(self, name: str): + """ + 移除处理链 + + Args: + name: 链名称 + """ + if name in self._chains: + del self._chains[name] + self.logger.info(f"已移除处理链: {name}") + else: + self.logger.warning(f"尝试移除不存在的处理链: {name}") + + +class ChainExecutor: + """处理链执行器,提供重试和错误处理功能""" + + def __init__(self, chain_manager: ChainManager, max_retries: int = 3): + self.chain_manager = chain_manager + self.max_retries = max_retries + self.logger = logging.getLogger(self.__class__.__name__) + + async def execute_with_retry( + self, + chain_name: str, + inputs: Dict[str, Any], + config: Dict[str, Any] = None + ) -> str: + """ + 带重试的链执行 + + Args: + chain_name: 链名称 + inputs: 输入参数 + config: 运行配置 + + Returns: + 处理结果 + + Raises: + Exception: 所有重试都失败后抛出最后一个异常 + """ + last_exception = None + + for attempt in range(self.max_retries): + try: + result = await self.chain_manager.invoke_chain(chain_name, inputs, config) + if attempt > 0: + self.logger.info(f"处理链 {chain_name} 在第 {attempt + 1} 次尝试后成功") + return result + except Exception as e: + last_exception = e + self.logger.warning(f"处理链 {chain_name} 第 {attempt + 1} 次尝试失败: {e}") + + if attempt < self.max_retries - 1: + # 可以在这里添加退避策略 + import asyncio + await asyncio.sleep(1 * (attempt + 1)) # 简单的线性退避 + + self.logger.error(f"处理链 {chain_name} 在 {self.max_retries} 次尝试后仍然失败") + raise last_exception + + async def execute_with_fallback( + self, + primary_chain: str, + fallback_chain: str, + inputs: Dict[str, Any], + fallback_inputs: Dict[str, Any] = None, + config: Dict[str, Any] = None + ) -> str: + """ + 带回退的链执行 + + Args: + primary_chain: 主要处理链 + fallback_chain: 回退处理链 + inputs: 主要链的输入参数 + fallback_inputs: 回退链的输入参数 + config: 运行配置 + + Returns: + 处理结果 + """ + try: + return await self.execute_with_retry(primary_chain, inputs, config) + except Exception as e: + self.logger.warning(f"主要处理链 {primary_chain} 失败,尝试回退链 {fallback_chain}: {e}") + + fallback_inputs = fallback_inputs or inputs + try: + return await self.execute_with_retry(fallback_chain, fallback_inputs, config) + except Exception as fallback_e: + self.logger.error(f"回退处理链 {fallback_chain} 也失败了: {fallback_e}") + raise fallback_e