Add File
This commit is contained in:
220
src/summeryanyfile/generators/chains.py
Normal file
220
src/summeryanyfile/generators/chains.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
处理链管理器 - 管理LangChain处理链
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
import logging
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
|
from ..config.prompts import PromptTemplates
|
||||||
|
from ..utils.logger import LoggerMixin
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChainManager(LoggerMixin):
|
||||||
|
"""处理链管理器"""
|
||||||
|
|
||||||
|
def __init__(self, llm: BaseChatModel):
|
||||||
|
self.llm = llm
|
||||||
|
self.prompt_templates = PromptTemplates()
|
||||||
|
self._chains: Dict[str, Runnable] = {}
|
||||||
|
self._setup_chains()
|
||||||
|
|
||||||
|
def _setup_chains(self):
|
||||||
|
"""设置所有处理链"""
|
||||||
|
self.logger.debug("正在设置处理链...")
|
||||||
|
|
||||||
|
# 文档结构分析链
|
||||||
|
self._chains["structure_analysis"] = (
|
||||||
|
self.prompt_templates.get_structure_analysis_prompt()
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始PPT框架生成链
|
||||||
|
self._chains["initial_outline"] = (
|
||||||
|
self.prompt_templates.get_initial_outline_prompt()
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 内容细化链
|
||||||
|
self._chains["refine_outline"] = (
|
||||||
|
self.prompt_templates.get_refine_outline_prompt()
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 错误恢复链
|
||||||
|
self._chains["error_recovery"] = (
|
||||||
|
self.prompt_templates.get_error_recovery_prompt()
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(f"已设置 {len(self._chains)} 个处理链")
|
||||||
|
|
||||||
|
def get_chain(self, chain_name: str) -> Runnable:
|
||||||
|
"""
|
||||||
|
获取指定的处理链
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain_name: 链名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理链实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 链不存在
|
||||||
|
"""
|
||||||
|
if chain_name not in self._chains:
|
||||||
|
raise KeyError(f"处理链不存在: {chain_name}")
|
||||||
|
|
||||||
|
return self._chains[chain_name]
|
||||||
|
|
||||||
|
async def invoke_chain(
|
||||||
|
self,
|
||||||
|
chain_name: str,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
config: Dict[str, Any] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
调用指定的处理链
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain_name: 链名称
|
||||||
|
inputs: 输入参数
|
||||||
|
config: 运行配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
chain = self.get_chain(chain_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.debug(f"调用处理链: {chain_name}")
|
||||||
|
result = await chain.ainvoke(inputs, config or {})
|
||||||
|
self.logger.debug(f"处理链 {chain_name} 执行成功")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"处理链 {chain_name} 执行失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def list_chains(self) -> list:
|
||||||
|
"""列出所有可用的处理链"""
|
||||||
|
return list(self._chains.keys())
|
||||||
|
|
||||||
|
def update_llm(self, llm: BaseChatModel):
|
||||||
|
"""更新LLM并重新设置处理链"""
|
||||||
|
self.logger.info("更新LLM并重新设置处理链")
|
||||||
|
self.llm = llm
|
||||||
|
self._setup_chains()
|
||||||
|
|
||||||
|
def add_custom_chain(self, name: str, chain: Runnable):
|
||||||
|
"""
|
||||||
|
添加自定义处理链
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 链名称
|
||||||
|
chain: 处理链实例
|
||||||
|
"""
|
||||||
|
self._chains[name] = chain
|
||||||
|
self.logger.info(f"已添加自定义处理链: {name}")
|
||||||
|
|
||||||
|
def remove_chain(self, name: str):
|
||||||
|
"""
|
||||||
|
移除处理链
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 链名称
|
||||||
|
"""
|
||||||
|
if name in self._chains:
|
||||||
|
del self._chains[name]
|
||||||
|
self.logger.info(f"已移除处理链: {name}")
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"尝试移除不存在的处理链: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ChainExecutor:
|
||||||
|
"""处理链执行器,提供重试和错误处理功能"""
|
||||||
|
|
||||||
|
def __init__(self, chain_manager: ChainManager, max_retries: int = 3):
|
||||||
|
self.chain_manager = chain_manager
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
async def execute_with_retry(
|
||||||
|
self,
|
||||||
|
chain_name: str,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
config: Dict[str, Any] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
带重试的链执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain_name: 链名称
|
||||||
|
inputs: 输入参数
|
||||||
|
config: 运行配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 所有重试都失败后抛出最后一个异常
|
||||||
|
"""
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
result = await self.chain_manager.invoke_chain(chain_name, inputs, config)
|
||||||
|
if attempt > 0:
|
||||||
|
self.logger.info(f"处理链 {chain_name} 在第 {attempt + 1} 次尝试后成功")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
self.logger.warning(f"处理链 {chain_name} 第 {attempt + 1} 次尝试失败: {e}")
|
||||||
|
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
# 可以在这里添加退避策略
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(1 * (attempt + 1)) # 简单的线性退避
|
||||||
|
|
||||||
|
self.logger.error(f"处理链 {chain_name} 在 {self.max_retries} 次尝试后仍然失败")
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
async def execute_with_fallback(
|
||||||
|
self,
|
||||||
|
primary_chain: str,
|
||||||
|
fallback_chain: str,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
fallback_inputs: Dict[str, Any] = None,
|
||||||
|
config: Dict[str, Any] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
带回退的链执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primary_chain: 主要处理链
|
||||||
|
fallback_chain: 回退处理链
|
||||||
|
inputs: 主要链的输入参数
|
||||||
|
fallback_inputs: 回退链的输入参数
|
||||||
|
config: 运行配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self.execute_with_retry(primary_chain, inputs, config)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"主要处理链 {primary_chain} 失败,尝试回退链 {fallback_chain}: {e}")
|
||||||
|
|
||||||
|
fallback_inputs = fallback_inputs or inputs
|
||||||
|
try:
|
||||||
|
return await self.execute_with_retry(fallback_chain, fallback_inputs, config)
|
||||||
|
except Exception as fallback_e:
|
||||||
|
self.logger.error(f"回退处理链 {fallback_chain} 也失败了: {fallback_e}")
|
||||||
|
raise fallback_e
|
||||||
Reference in New Issue
Block a user