From 7c71d0613838369e6779f87018a9b1a36f2c7370 Mon Sep 17 00:00:00 2001 From: 13315423919 <13315423919@qq.com> Date: Fri, 7 Nov 2025 09:05:20 +0800 Subject: [PATCH] Add File --- .../services/speech_script_repository.py | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 src/landppt/services/speech_script_repository.py diff --git a/src/landppt/services/speech_script_repository.py b/src/landppt/services/speech_script_repository.py new file mode 100644 index 0000000..2f086f6 --- /dev/null +++ b/src/landppt/services/speech_script_repository.py @@ -0,0 +1,213 @@ +""" +Speech Script Repository +数据访问层,处理演讲稿的数据库操作 +""" + +import time +from typing import List, Optional, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy import desc, and_ + +from ..database.models import SpeechScript +from ..database.database import get_db, SessionLocal + + +class SpeechScriptRepository: + """演讲稿数据访问层""" + + def __init__(self, db: Session = None): + self.db = db + self._should_close_db = db is None + if self.db is None: + self.db = SessionLocal() + + async def save_speech_script( + self, + project_id: str, + slide_index: int, + slide_title: str, + script_content: str, + generation_params: Dict[str, Any], + estimated_duration: Optional[str] = None, + speaker_notes: Optional[str] = None + ) -> SpeechScript: + """保存演讲稿到数据库,如果已存在则覆盖""" + + max_retries = 3 + retry_delay = 0.5 + + for attempt in range(max_retries): + try: + # 先查找是否已存在该页面的演讲稿 + existing_script = self.db.query(SpeechScript).filter( + and_( + SpeechScript.project_id == project_id, + SpeechScript.slide_index == slide_index + ) + ).first() + + if existing_script: + # 更新现有记录 + existing_script.slide_title = slide_title + existing_script.script_content = script_content + existing_script.estimated_duration = estimated_duration + existing_script.speaker_notes = speaker_notes + existing_script.generation_type = generation_params.get('generation_type', 'single') + existing_script.tone = generation_params.get('tone', 'conversational') + existing_script.target_audience = generation_params.get('target_audience', 'general_public') + existing_script.custom_audience = generation_params.get('custom_audience') + existing_script.language_complexity = generation_params.get('language_complexity', 'moderate') + existing_script.speaking_pace = generation_params.get('speaking_pace', 'normal') + existing_script.custom_style_prompt = generation_params.get('custom_style_prompt') + existing_script.include_transitions = generation_params.get('include_transitions', True) + existing_script.include_timing_notes = generation_params.get('include_timing_notes', False) + existing_script.updated_at = time.time() + + self.db.commit() + self.db.refresh(existing_script) + return existing_script + else: + # 创建新记录 + speech_script = SpeechScript( + project_id=project_id, + slide_index=slide_index, + slide_title=slide_title, + script_content=script_content, + estimated_duration=estimated_duration, + speaker_notes=speaker_notes, + generation_type=generation_params.get('generation_type', 'single'), + tone=generation_params.get('tone', 'conversational'), + target_audience=generation_params.get('target_audience', 'general_public'), + custom_audience=generation_params.get('custom_audience'), + language_complexity=generation_params.get('language_complexity', 'moderate'), + speaking_pace=generation_params.get('speaking_pace', 'normal'), + custom_style_prompt=generation_params.get('custom_style_prompt'), + include_transitions=generation_params.get('include_transitions', True), + include_timing_notes=generation_params.get('include_timing_notes', False), + created_at=time.time(), + updated_at=time.time() + ) + + self.db.add(speech_script) + self.db.commit() + self.db.refresh(speech_script) + return speech_script + + except Exception as e: + self.db.rollback() + if "database is locked" in str(e) and attempt < max_retries - 1: + import asyncio + await asyncio.sleep(retry_delay * (attempt + 1)) + continue + else: + raise + + async def get_speech_scripts_by_project( + self, + project_id: str, + limit: Optional[int] = None + ) -> List[SpeechScript]: + """获取项目的所有演讲稿""" + + query = self.db.query(SpeechScript).filter( + SpeechScript.project_id == project_id + ).order_by( + SpeechScript.slide_index.asc(), + SpeechScript.created_at.desc() + ) + + if limit: + query = query.limit(limit) + + return query.all() + + async def get_speech_scripts_by_slide( + self, + project_id: str, + slide_index: int, + limit: Optional[int] = None + ) -> List[SpeechScript]: + """获取特定幻灯片的演讲稿历史""" + + query = self.db.query(SpeechScript).filter( + and_( + SpeechScript.project_id == project_id, + SpeechScript.slide_index == slide_index + ) + ).order_by(desc(SpeechScript.created_at)) + + if limit: + query = query.limit(limit) + + return query.all() + + async def get_current_speech_scripts_by_project( + self, + project_id: str + ) -> List[SpeechScript]: + """获取项目每个幻灯片的当前演讲稿(每页只有一个)""" + + return self.db.query(SpeechScript).filter( + SpeechScript.project_id == project_id + ).order_by(SpeechScript.slide_index.asc()).all() + + async def get_speech_script_by_slide( + self, + project_id: str, + slide_index: int + ) -> Optional[SpeechScript]: + """获取特定幻灯片的演讲稿""" + + return self.db.query(SpeechScript).filter( + and_( + SpeechScript.project_id == project_id, + SpeechScript.slide_index == slide_index + ) + ).first() + + async def get_speech_script_by_id(self, script_id: int) -> Optional[SpeechScript]: + """根据ID获取演讲稿""" + return self.db.query(SpeechScript).filter(SpeechScript.id == script_id).first() + + async def delete_speech_script(self, script_id: int) -> bool: + """删除演讲稿""" + script = await self.get_speech_script_by_id(script_id) + if script: + self.db.delete(script) + self.db.commit() + return True + return False + + async def delete_speech_scripts_by_project(self, project_id: str) -> int: + """删除项目的所有演讲稿""" + count = self.db.query(SpeechScript).filter( + SpeechScript.project_id == project_id + ).count() + + self.db.query(SpeechScript).filter( + SpeechScript.project_id == project_id + ).delete() + + self.db.commit() + return count + + async def get_speech_scripts_grouped_by_slide( + self, + project_id: str + ) -> Dict[int, List[SpeechScript]]: + """获取按幻灯片分组的演讲稿""" + + scripts = await self.get_speech_scripts_by_project(project_id) + grouped = {} + + for script in scripts: + if script.slide_index not in grouped: + grouped[script.slide_index] = [] + grouped[script.slide_index].append(script) + + return grouped + + def close(self): + """关闭数据库连接""" + if self._should_close_db and self.db: + self.db.close()