Add File
This commit is contained in:
306
src/summeryanyfile/graph/nodes.py
Normal file
306
src/summeryanyfile/graph/nodes.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
图节点实现 - 定义LangGraph工作流中的各个节点
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, Literal
|
||||
import logging
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from ..core.models import PPTState
|
||||
from ..core.json_parser import JSONParser
|
||||
from ..generators.chains import ChainManager, ChainExecutor
|
||||
from ..utils.logger import LoggerMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphNodes(LoggerMixin):
|
||||
"""图节点集合,包含所有工作流节点的实现"""
|
||||
|
||||
def __init__(self, chain_manager: ChainManager, config=None):
|
||||
self.chain_manager = chain_manager
|
||||
self.chain_executor = ChainExecutor(chain_manager)
|
||||
self.json_parser = JSONParser()
|
||||
self.config = config # 添加配置参数
|
||||
|
||||
def _get_slides_range_text(self, state: Dict[str, Any]) -> str:
|
||||
"""根据状态中的页数模式生成页数约束文本"""
|
||||
page_count_mode = state.get("page_count_mode", "ai_decide")
|
||||
min_pages = state.get("min_pages")
|
||||
max_pages = state.get("max_pages")
|
||||
fixed_pages = state.get("fixed_pages")
|
||||
|
||||
if page_count_mode == "fixed" and fixed_pages:
|
||||
result = f"【强制要求】必须生成恰好{fixed_pages}页的PPT,不能多也不能少"
|
||||
elif page_count_mode == "custom_range" and min_pages and max_pages:
|
||||
result = f"【强制要求】必须严格控制在{min_pages}-{max_pages}页范围内,最少{min_pages}页,最多{max_pages}页,不能超出此范围"
|
||||
else: # ai_decide
|
||||
result = "根据内容的复杂度、深度和逻辑结构,自主决定最合适的页数,确保内容充实且逻辑清晰"
|
||||
|
||||
return result
|
||||
|
||||
async def analyze_structure(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
分析文档结构节点
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
config: 运行配置
|
||||
|
||||
Returns:
|
||||
更新的状态字段
|
||||
"""
|
||||
self.logger.info("开始分析文档结构...")
|
||||
|
||||
try:
|
||||
# 获取第一个文档块
|
||||
first_chunk = state["document_chunks"][0] if state["document_chunks"] else ""
|
||||
|
||||
if not first_chunk.strip():
|
||||
self.logger.warning("第一个文档块为空,使用默认结构")
|
||||
structure = {
|
||||
"title": "文档分析",
|
||||
"type": "通用文档",
|
||||
"sections": [],
|
||||
"key_concepts": [],
|
||||
"language": "中文",
|
||||
"complexity": "中等"
|
||||
}
|
||||
else:
|
||||
# 调用结构分析链
|
||||
structure_response = await self.chain_executor.execute_with_retry(
|
||||
"structure_analysis",
|
||||
{
|
||||
"content": first_chunk,
|
||||
"project_topic": state.get("project_topic", ""),
|
||||
"project_scenario": state.get("project_scenario", "general"),
|
||||
"project_requirements": state.get("project_requirements", ""),
|
||||
"target_audience": state.get("target_audience", "普通大众"),
|
||||
"custom_audience": state.get("custom_audience", ""),
|
||||
"ppt_style": state.get("ppt_style", "general"),
|
||||
"custom_style_prompt": state.get("custom_style_prompt", "")
|
||||
},
|
||||
config
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
structure = self.json_parser.extract_json_from_response(structure_response)
|
||||
|
||||
# 验证结构
|
||||
if not isinstance(structure, dict):
|
||||
raise ValueError("结构分析返回的不是有效的字典")
|
||||
|
||||
self.logger.info(f"文档结构分析完成: {structure.get('title', '未知标题')}")
|
||||
|
||||
return {
|
||||
"document_structure": structure,
|
||||
"accumulated_context": first_chunk[:500] # 保留前500字作为上下文
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"文档结构分析失败: {e}")
|
||||
# 返回默认结构
|
||||
return {
|
||||
"document_structure": {
|
||||
"title": "文档分析",
|
||||
"type": "通用文档",
|
||||
"sections": [],
|
||||
"key_concepts": [],
|
||||
"language": "中文",
|
||||
"complexity": "中等"
|
||||
},
|
||||
"accumulated_context": first_chunk[:500] if state["document_chunks"] else ""
|
||||
}
|
||||
|
||||
async def generate_initial_outline(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
生成初始PPT框架节点
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
config: 运行配置
|
||||
|
||||
Returns:
|
||||
更新的状态字段
|
||||
"""
|
||||
self.logger.info("开始生成初始PPT框架...")
|
||||
|
||||
try:
|
||||
# 准备输入
|
||||
structure_json = json.dumps(state["document_structure"], ensure_ascii=False)
|
||||
first_chunk = state["document_chunks"][0] if state["document_chunks"] else ""
|
||||
|
||||
# 准备输入参数,包含页数范围、目标语言和项目信息
|
||||
chain_inputs = {
|
||||
"structure": structure_json,
|
||||
"content": first_chunk,
|
||||
"project_topic": state.get("project_topic", ""),
|
||||
"project_scenario": state.get("project_scenario", "general"),
|
||||
"project_requirements": state.get("project_requirements", ""),
|
||||
"target_audience": state.get("target_audience", "普通大众"),
|
||||
"custom_audience": state.get("custom_audience", ""),
|
||||
"ppt_style": state.get("ppt_style", "general"),
|
||||
"custom_style_prompt": state.get("custom_style_prompt", "")
|
||||
}
|
||||
|
||||
# 添加页数范围信息
|
||||
slides_range_text = self._get_slides_range_text(state)
|
||||
chain_inputs["slides_range"] = slides_range_text
|
||||
if self.config:
|
||||
chain_inputs["target_language"] = self.config.target_language
|
||||
else:
|
||||
chain_inputs["target_language"] = "zh" # 默认中文
|
||||
|
||||
# 调用初始大纲生成链
|
||||
outline_response = await self.chain_executor.execute_with_retry(
|
||||
"initial_outline",
|
||||
chain_inputs,
|
||||
config
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
outline = self.json_parser.extract_json_from_response(outline_response)
|
||||
|
||||
# 验证和修复大纲结构
|
||||
outline = self.json_parser.validate_ppt_structure(outline)
|
||||
|
||||
self.logger.info(f"初始PPT框架生成完成: {outline.get('title', '未知标题')}")
|
||||
|
||||
return {
|
||||
**state, # 保留所有原始状态
|
||||
"ppt_title": outline.get("title", "学术演示"),
|
||||
"total_pages": outline.get("total_pages", 15),
|
||||
"page_count_mode": state.get("page_count_mode", "estimated"), # 保持原始页数模式
|
||||
"slides": outline.get("slides", []),
|
||||
"current_index": 1
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"初始PPT框架生成失败: {e}")
|
||||
# 返回默认框架
|
||||
return {
|
||||
"ppt_title": "学术演示",
|
||||
"total_pages": 15,
|
||||
"page_count_mode": "estimated",
|
||||
"slides": [
|
||||
{
|
||||
"page_number": 1,
|
||||
"title": "标题页",
|
||||
"content_points": ["演示标题", "演示者", "日期"],
|
||||
"slide_type": "title",
|
||||
"description": "PPT开场标题页"
|
||||
}
|
||||
],
|
||||
"current_index": 1
|
||||
}
|
||||
|
||||
async def refine_outline(self, state: PPTState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
细化PPT大纲节点
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
config: 运行配置
|
||||
|
||||
Returns:
|
||||
更新的状态字段
|
||||
"""
|
||||
current_index = state["current_index"]
|
||||
total_chunks = len(state["document_chunks"])
|
||||
|
||||
self.logger.info(f"正在细化PPT大纲 ({current_index + 1}/{total_chunks})...")
|
||||
|
||||
# 检查是否还有内容需要处理
|
||||
if current_index >= total_chunks:
|
||||
self.logger.info("所有文档块已处理完成")
|
||||
return state
|
||||
|
||||
try:
|
||||
# 获取当前文档块
|
||||
current_content = state["document_chunks"][current_index]
|
||||
|
||||
# 准备现有大纲
|
||||
existing_outline = {
|
||||
"title": state["ppt_title"],
|
||||
"total_pages": state["total_pages"],
|
||||
"slides": state["slides"]
|
||||
}
|
||||
existing_outline_json = json.dumps(existing_outline, ensure_ascii=False)
|
||||
|
||||
# 准备输入参数,包含页数范围、目标语言和项目信息
|
||||
chain_inputs = {
|
||||
"existing_outline": existing_outline_json,
|
||||
"new_content": current_content,
|
||||
"context": state["accumulated_context"],
|
||||
"project_topic": state.get("project_topic", ""),
|
||||
"project_scenario": state.get("project_scenario", "general"),
|
||||
"project_requirements": state.get("project_requirements", ""),
|
||||
"target_audience": state.get("target_audience", "普通大众"),
|
||||
"custom_audience": state.get("custom_audience", ""),
|
||||
"ppt_style": state.get("ppt_style", "general"),
|
||||
"custom_style_prompt": state.get("custom_style_prompt", "")
|
||||
}
|
||||
|
||||
# 添加页数范围信息和目标语言
|
||||
slides_range_text = self._get_slides_range_text(state)
|
||||
chain_inputs["slides_range"] = slides_range_text
|
||||
if self.config:
|
||||
chain_inputs["target_language"] = self.config.target_language
|
||||
else:
|
||||
chain_inputs["target_language"] = "zh" # 默认中文
|
||||
|
||||
# 调用细化链
|
||||
refined_response = await self.chain_executor.execute_with_retry(
|
||||
"refine_outline",
|
||||
chain_inputs,
|
||||
config
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
refined_outline = self.json_parser.extract_json_from_response(refined_response)
|
||||
|
||||
# 验证和修复结构
|
||||
refined_outline = self.json_parser.validate_ppt_structure(refined_outline)
|
||||
|
||||
# 更新累积上下文
|
||||
new_context = state["accumulated_context"] + "\n" + current_content[:300]
|
||||
if len(new_context) > 2000: # 限制上下文长度
|
||||
new_context = new_context[-2000:]
|
||||
|
||||
return {
|
||||
**state, # 保留所有原始状态
|
||||
"ppt_title": refined_outline.get("title", state["ppt_title"]),
|
||||
"total_pages": refined_outline.get("total_pages", state["total_pages"]),
|
||||
"slides": refined_outline.get("slides", state["slides"]),
|
||||
"current_index": current_index + 1,
|
||||
"accumulated_context": new_context
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"PPT大纲细化失败: {e}")
|
||||
# 继续处理下一个块
|
||||
return {
|
||||
**state,
|
||||
"current_index": current_index + 1
|
||||
}
|
||||
|
||||
def should_continue_refining(self, state: PPTState) -> Literal["refine_outline", "end"]:
|
||||
"""
|
||||
判断是否继续细化的条件函数
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称
|
||||
"""
|
||||
current_index = state["current_index"]
|
||||
total_chunks = len(state["document_chunks"])
|
||||
|
||||
if current_index >= total_chunks:
|
||||
self.logger.info("所有文档块已处理,完成大纲生成")
|
||||
return "end"
|
||||
else:
|
||||
self.logger.debug(f"继续处理文档块 {current_index + 1}/{total_chunks}")
|
||||
return "refine_outline"
|
||||
Reference in New Issue
Block a user