Add File
This commit is contained in:
213
src/landppt/services/speech_script_repository.py
Normal file
213
src/landppt/services/speech_script_repository.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Speech Script Repository
|
||||
数据访问层,处理演讲稿的数据库操作
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_
|
||||
|
||||
from ..database.models import SpeechScript
|
||||
from ..database.database import get_db, SessionLocal
|
||||
|
||||
|
||||
class SpeechScriptRepository:
|
||||
"""演讲稿数据访问层"""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self._should_close_db = db is None
|
||||
if self.db is None:
|
||||
self.db = SessionLocal()
|
||||
|
||||
async def save_speech_script(
|
||||
self,
|
||||
project_id: str,
|
||||
slide_index: int,
|
||||
slide_title: str,
|
||||
script_content: str,
|
||||
generation_params: Dict[str, Any],
|
||||
estimated_duration: Optional[str] = None,
|
||||
speaker_notes: Optional[str] = None
|
||||
) -> SpeechScript:
|
||||
"""保存演讲稿到数据库,如果已存在则覆盖"""
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 0.5
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# 先查找是否已存在该页面的演讲稿
|
||||
existing_script = self.db.query(SpeechScript).filter(
|
||||
and_(
|
||||
SpeechScript.project_id == project_id,
|
||||
SpeechScript.slide_index == slide_index
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_script:
|
||||
# 更新现有记录
|
||||
existing_script.slide_title = slide_title
|
||||
existing_script.script_content = script_content
|
||||
existing_script.estimated_duration = estimated_duration
|
||||
existing_script.speaker_notes = speaker_notes
|
||||
existing_script.generation_type = generation_params.get('generation_type', 'single')
|
||||
existing_script.tone = generation_params.get('tone', 'conversational')
|
||||
existing_script.target_audience = generation_params.get('target_audience', 'general_public')
|
||||
existing_script.custom_audience = generation_params.get('custom_audience')
|
||||
existing_script.language_complexity = generation_params.get('language_complexity', 'moderate')
|
||||
existing_script.speaking_pace = generation_params.get('speaking_pace', 'normal')
|
||||
existing_script.custom_style_prompt = generation_params.get('custom_style_prompt')
|
||||
existing_script.include_transitions = generation_params.get('include_transitions', True)
|
||||
existing_script.include_timing_notes = generation_params.get('include_timing_notes', False)
|
||||
existing_script.updated_at = time.time()
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(existing_script)
|
||||
return existing_script
|
||||
else:
|
||||
# 创建新记录
|
||||
speech_script = SpeechScript(
|
||||
project_id=project_id,
|
||||
slide_index=slide_index,
|
||||
slide_title=slide_title,
|
||||
script_content=script_content,
|
||||
estimated_duration=estimated_duration,
|
||||
speaker_notes=speaker_notes,
|
||||
generation_type=generation_params.get('generation_type', 'single'),
|
||||
tone=generation_params.get('tone', 'conversational'),
|
||||
target_audience=generation_params.get('target_audience', 'general_public'),
|
||||
custom_audience=generation_params.get('custom_audience'),
|
||||
language_complexity=generation_params.get('language_complexity', 'moderate'),
|
||||
speaking_pace=generation_params.get('speaking_pace', 'normal'),
|
||||
custom_style_prompt=generation_params.get('custom_style_prompt'),
|
||||
include_transitions=generation_params.get('include_transitions', True),
|
||||
include_timing_notes=generation_params.get('include_timing_notes', False),
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
|
||||
self.db.add(speech_script)
|
||||
self.db.commit()
|
||||
self.db.refresh(speech_script)
|
||||
return speech_script
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
if "database is locked" in str(e) and attempt < max_retries - 1:
|
||||
import asyncio
|
||||
await asyncio.sleep(retry_delay * (attempt + 1))
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
async def get_speech_scripts_by_project(
|
||||
self,
|
||||
project_id: str,
|
||||
limit: Optional[int] = None
|
||||
) -> List[SpeechScript]:
|
||||
"""获取项目的所有演讲稿"""
|
||||
|
||||
query = self.db.query(SpeechScript).filter(
|
||||
SpeechScript.project_id == project_id
|
||||
).order_by(
|
||||
SpeechScript.slide_index.asc(),
|
||||
SpeechScript.created_at.desc()
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
async def get_speech_scripts_by_slide(
|
||||
self,
|
||||
project_id: str,
|
||||
slide_index: int,
|
||||
limit: Optional[int] = None
|
||||
) -> List[SpeechScript]:
|
||||
"""获取特定幻灯片的演讲稿历史"""
|
||||
|
||||
query = self.db.query(SpeechScript).filter(
|
||||
and_(
|
||||
SpeechScript.project_id == project_id,
|
||||
SpeechScript.slide_index == slide_index
|
||||
)
|
||||
).order_by(desc(SpeechScript.created_at))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
async def get_current_speech_scripts_by_project(
|
||||
self,
|
||||
project_id: str
|
||||
) -> List[SpeechScript]:
|
||||
"""获取项目每个幻灯片的当前演讲稿(每页只有一个)"""
|
||||
|
||||
return self.db.query(SpeechScript).filter(
|
||||
SpeechScript.project_id == project_id
|
||||
).order_by(SpeechScript.slide_index.asc()).all()
|
||||
|
||||
async def get_speech_script_by_slide(
|
||||
self,
|
||||
project_id: str,
|
||||
slide_index: int
|
||||
) -> Optional[SpeechScript]:
|
||||
"""获取特定幻灯片的演讲稿"""
|
||||
|
||||
return self.db.query(SpeechScript).filter(
|
||||
and_(
|
||||
SpeechScript.project_id == project_id,
|
||||
SpeechScript.slide_index == slide_index
|
||||
)
|
||||
).first()
|
||||
|
||||
async def get_speech_script_by_id(self, script_id: int) -> Optional[SpeechScript]:
|
||||
"""根据ID获取演讲稿"""
|
||||
return self.db.query(SpeechScript).filter(SpeechScript.id == script_id).first()
|
||||
|
||||
async def delete_speech_script(self, script_id: int) -> bool:
|
||||
"""删除演讲稿"""
|
||||
script = await self.get_speech_script_by_id(script_id)
|
||||
if script:
|
||||
self.db.delete(script)
|
||||
self.db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def delete_speech_scripts_by_project(self, project_id: str) -> int:
|
||||
"""删除项目的所有演讲稿"""
|
||||
count = self.db.query(SpeechScript).filter(
|
||||
SpeechScript.project_id == project_id
|
||||
).count()
|
||||
|
||||
self.db.query(SpeechScript).filter(
|
||||
SpeechScript.project_id == project_id
|
||||
).delete()
|
||||
|
||||
self.db.commit()
|
||||
return count
|
||||
|
||||
async def get_speech_scripts_grouped_by_slide(
|
||||
self,
|
||||
project_id: str
|
||||
) -> Dict[int, List[SpeechScript]]:
|
||||
"""获取按幻灯片分组的演讲稿"""
|
||||
|
||||
scripts = await self.get_speech_scripts_by_project(project_id)
|
||||
grouped = {}
|
||||
|
||||
for script in scripts:
|
||||
if script.slide_index not in grouped:
|
||||
grouped[script.slide_index] = []
|
||||
grouped[script.slide_index].append(script)
|
||||
|
||||
return grouped
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if self._should_close_db and self.db:
|
||||
self.db.close()
|
||||
Reference in New Issue
Block a user