From eca5563f2601f001c324e66e71f5017be204e89f Mon Sep 17 00:00:00 2001 From: 13315423919 <13315423919@qq.com> Date: Fri, 7 Nov 2025 09:05:39 +0800 Subject: [PATCH] Add File --- src/summeryanyfile/graph/nodes.py | 306 ++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 src/summeryanyfile/graph/nodes.py diff --git a/src/summeryanyfile/graph/nodes.py b/src/summeryanyfile/graph/nodes.py new file mode 100644 index 0000000..a42573e --- /dev/null +++ b/src/summeryanyfile/graph/nodes.py @@ -0,0 +1,306 @@ +""" +图节点实现 - 定义LangGraph工作流中的各个节点 +""" + +import json +from typing import Dict, Any, Literal +import logging +from langchain_core.runnables import RunnableConfig + +from ..core.models import PPTState +from ..core.json_parser import JSONParser +from ..generators.chains import ChainManager, ChainExecutor +from ..utils.logger import LoggerMixin + +logger = logging.getLogger(__name__) + + +class GraphNodes(LoggerMixin): + """图节点集合,包含所有工作流节点的实现""" + + def __init__(self, chain_manager: ChainManager, config=None): + self.chain_manager = chain_manager + self.chain_executor = ChainExecutor(chain_manager) + self.json_parser = JSONParser() + self.config = config # 添加配置参数 + + def _get_slides_range_text(self, state: Dict[str, Any]) -> str: + """根据状态中的页数模式生成页数约束文本""" + page_count_mode = state.get("page_count_mode", "ai_decide") + min_pages = state.get("min_pages") + max_pages = state.get("max_pages") + fixed_pages = state.get("fixed_pages") + + if page_count_mode == "fixed" and fixed_pages: + result = f"【强制要求】必须生成恰好{fixed_pages}页的PPT,不能多也不能少" + elif page_count_mode == "custom_range" and min_pages and max_pages: + result = f"【强制要求】必须严格控制在{min_pages}-{max_pages}页范围内,最少{min_pages}页,最多{max_pages}页,不能超出此范围" + else: # ai_decide + result = "根据内容的复杂度、深度和逻辑结构,自主决定最合适的页数,确保内容充实且逻辑清晰" + + return result + + async def analyze_structure(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]: + """ + 分析文档结构节点 + + Args: + state: 当前状态 + config: 运行配置 + + Returns: + 更新的状态字段 + """ + self.logger.info("开始分析文档结构...") + + try: + # 获取第一个文档块 + first_chunk = state["document_chunks"][0] if state["document_chunks"] else "" + + if not first_chunk.strip(): + self.logger.warning("第一个文档块为空,使用默认结构") + structure = { + "title": "文档分析", + "type": "通用文档", + "sections": [], + "key_concepts": [], + "language": "中文", + "complexity": "中等" + } + else: + # 调用结构分析链 + structure_response = await self.chain_executor.execute_with_retry( + "structure_analysis", + { + "content": first_chunk, + "project_topic": state.get("project_topic", ""), + "project_scenario": state.get("project_scenario", "general"), + "project_requirements": state.get("project_requirements", ""), + "target_audience": state.get("target_audience", "普通大众"), + "custom_audience": state.get("custom_audience", ""), + "ppt_style": state.get("ppt_style", "general"), + "custom_style_prompt": state.get("custom_style_prompt", "") + }, + config + ) + + # 解析JSON响应 + structure = self.json_parser.extract_json_from_response(structure_response) + + # 验证结构 + if not isinstance(structure, dict): + raise ValueError("结构分析返回的不是有效的字典") + + self.logger.info(f"文档结构分析完成: {structure.get('title', '未知标题')}") + + return { + "document_structure": structure, + "accumulated_context": first_chunk[:500] # 保留前500字作为上下文 + } + + except Exception as e: + self.logger.error(f"文档结构分析失败: {e}") + # 返回默认结构 + return { + "document_structure": { + "title": "文档分析", + "type": "通用文档", + "sections": [], + "key_concepts": [], + "language": "中文", + "complexity": "中等" + }, + "accumulated_context": first_chunk[:500] if state["document_chunks"] else "" + } + + async def generate_initial_outline(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]: + """ + 生成初始PPT框架节点 + + Args: + state: 当前状态 + config: 运行配置 + + Returns: + 更新的状态字段 + """ + self.logger.info("开始生成初始PPT框架...") + + try: + # 准备输入 + structure_json = json.dumps(state["document_structure"], ensure_ascii=False) + first_chunk = state["document_chunks"][0] if state["document_chunks"] else "" + + # 准备输入参数,包含页数范围、目标语言和项目信息 + chain_inputs = { + "structure": structure_json, + "content": first_chunk, + "project_topic": state.get("project_topic", ""), + "project_scenario": state.get("project_scenario", "general"), + "project_requirements": state.get("project_requirements", ""), + "target_audience": state.get("target_audience", "普通大众"), + "custom_audience": state.get("custom_audience", ""), + "ppt_style": state.get("ppt_style", "general"), + "custom_style_prompt": state.get("custom_style_prompt", "") + } + + # 添加页数范围信息 + slides_range_text = self._get_slides_range_text(state) + chain_inputs["slides_range"] = slides_range_text + if self.config: + chain_inputs["target_language"] = self.config.target_language + else: + chain_inputs["target_language"] = "zh" # 默认中文 + + # 调用初始大纲生成链 + outline_response = await self.chain_executor.execute_with_retry( + "initial_outline", + chain_inputs, + config + ) + + # 解析JSON响应 + outline = self.json_parser.extract_json_from_response(outline_response) + + # 验证和修复大纲结构 + outline = self.json_parser.validate_ppt_structure(outline) + + self.logger.info(f"初始PPT框架生成完成: {outline.get('title', '未知标题')}") + + return { + **state, # 保留所有原始状态 + "ppt_title": outline.get("title", "学术演示"), + "total_pages": outline.get("total_pages", 15), + "page_count_mode": state.get("page_count_mode", "estimated"), # 保持原始页数模式 + "slides": outline.get("slides", []), + "current_index": 1 + } + + except Exception as e: + self.logger.error(f"初始PPT框架生成失败: {e}") + # 返回默认框架 + return { + "ppt_title": "学术演示", + "total_pages": 15, + "page_count_mode": "estimated", + "slides": [ + { + "page_number": 1, + "title": "标题页", + "content_points": ["演示标题", "演示者", "日期"], + "slide_type": "title", + "description": "PPT开场标题页" + } + ], + "current_index": 1 + } + + async def refine_outline(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]: + """ + 细化PPT大纲节点 + + Args: + state: 当前状态 + config: 运行配置 + + Returns: + 更新的状态字段 + """ + current_index = state["current_index"] + total_chunks = len(state["document_chunks"]) + + self.logger.info(f"正在细化PPT大纲 ({current_index + 1}/{total_chunks})...") + + # 检查是否还有内容需要处理 + if current_index >= total_chunks: + self.logger.info("所有文档块已处理完成") + return state + + try: + # 获取当前文档块 + current_content = state["document_chunks"][current_index] + + # 准备现有大纲 + existing_outline = { + "title": state["ppt_title"], + "total_pages": state["total_pages"], + "slides": state["slides"] + } + existing_outline_json = json.dumps(existing_outline, ensure_ascii=False) + + # 准备输入参数,包含页数范围、目标语言和项目信息 + chain_inputs = { + "existing_outline": existing_outline_json, + "new_content": current_content, + "context": state["accumulated_context"], + "project_topic": state.get("project_topic", ""), + "project_scenario": state.get("project_scenario", "general"), + "project_requirements": state.get("project_requirements", ""), + "target_audience": state.get("target_audience", "普通大众"), + "custom_audience": state.get("custom_audience", ""), + "ppt_style": state.get("ppt_style", "general"), + "custom_style_prompt": state.get("custom_style_prompt", "") + } + + # 添加页数范围信息和目标语言 + slides_range_text = self._get_slides_range_text(state) + chain_inputs["slides_range"] = slides_range_text + if self.config: + chain_inputs["target_language"] = self.config.target_language + else: + chain_inputs["target_language"] = "zh" # 默认中文 + + # 调用细化链 + refined_response = await self.chain_executor.execute_with_retry( + "refine_outline", + chain_inputs, + config + ) + + # 解析JSON响应 + refined_outline = self.json_parser.extract_json_from_response(refined_response) + + # 验证和修复结构 + refined_outline = self.json_parser.validate_ppt_structure(refined_outline) + + # 更新累积上下文 + new_context = state["accumulated_context"] + "\n" + current_content[:300] + if len(new_context) > 2000: # 限制上下文长度 + new_context = new_context[-2000:] + + return { + **state, # 保留所有原始状态 + "ppt_title": refined_outline.get("title", state["ppt_title"]), + "total_pages": refined_outline.get("total_pages", state["total_pages"]), + "slides": refined_outline.get("slides", state["slides"]), + "current_index": current_index + 1, + "accumulated_context": new_context + } + + except Exception as e: + self.logger.error(f"PPT大纲细化失败: {e}") + # 继续处理下一个块 + return { + **state, + "current_index": current_index + 1 + } + + def should_continue_refining(self, state: PPTState) -> Literal["refine_outline", "end"]: + """ + 判断是否继续细化的条件函数 + + Args: + state: 当前状态 + + Returns: + 下一个节点名称 + """ + current_index = state["current_index"] + total_chunks = len(state["document_chunks"]) + + if current_index >= total_chunks: + self.logger.info("所有文档块已处理,完成大纲生成") + return "end" + else: + self.logger.debug(f"继续处理文档块 {current_index + 1}/{total_chunks}") + return "refine_outline"