Add File
This commit is contained in:
637
src/landppt/database/repositories.py
Normal file
637
src/landppt/database/repositories.py
Normal file
@@ -0,0 +1,637 @@
|
|||||||
|
"""
|
||||||
|
Repository classes for database operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, update, delete, and_
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from .models import Project, TodoBoard, TodoStage, ProjectVersion, SlideData, PPTTemplate, GlobalMasterTemplate
|
||||||
|
from ..api.models import PPTProject, TodoBoard as TodoBoardModel, TodoStage as TodoStageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectRepository:
|
||||||
|
"""Repository for Project operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create(self, project_data: Dict[str, Any]) -> Project:
|
||||||
|
"""Create a new project"""
|
||||||
|
project = Project(**project_data)
|
||||||
|
self.session.add(project)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
async def get_by_id(self, project_id: str) -> Optional[Project]:
|
||||||
|
"""Get project by ID with all relationships"""
|
||||||
|
stmt = select(Project).where(Project.project_id == project_id).options(
|
||||||
|
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
|
||||||
|
selectinload(Project.versions),
|
||||||
|
selectinload(Project.slides)
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def list_projects(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> List[Project]:
|
||||||
|
"""List projects with pagination"""
|
||||||
|
stmt = select(Project).options(
|
||||||
|
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
|
||||||
|
selectinload(Project.versions),
|
||||||
|
selectinload(Project.slides)
|
||||||
|
)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(Project.status == status)
|
||||||
|
|
||||||
|
stmt = stmt.order_by(Project.updated_at.desc())
|
||||||
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||||
|
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def count_projects(self, status: Optional[str] = None) -> int:
|
||||||
|
"""Count total projects"""
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
stmt = select(func.count(Project.id))
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(Project.status == status)
|
||||||
|
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[Project]:
|
||||||
|
"""Update project"""
|
||||||
|
try:
|
||||||
|
# 首先获取项目
|
||||||
|
project = await self.get_by_id(project_id)
|
||||||
|
if not project:
|
||||||
|
logger.warning(f"No project found with ID {project_id} for update")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 更新项目属性
|
||||||
|
for key, value in update_data.items():
|
||||||
|
if hasattr(project, key):
|
||||||
|
setattr(project, key, value)
|
||||||
|
|
||||||
|
# 设置更新时间
|
||||||
|
project.updated_at = time.time()
|
||||||
|
|
||||||
|
# 提交更改
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(project)
|
||||||
|
|
||||||
|
logger.info(f"Successfully updated project {project_id}")
|
||||||
|
return project
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating project {project_id}: {e}")
|
||||||
|
await self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, project_id: str) -> bool:
|
||||||
|
"""Delete project"""
|
||||||
|
stmt = delete(Project).where(Project.project_id == project_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TodoBoardRepository:
|
||||||
|
"""Repository for TodoBoard operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create(self, board_data: Dict[str, Any]) -> TodoBoard:
|
||||||
|
"""Create a new todo board"""
|
||||||
|
board = TodoBoard(**board_data)
|
||||||
|
self.session.add(board)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(board)
|
||||||
|
return board
|
||||||
|
|
||||||
|
async def get_by_project_id(self, project_id: str) -> Optional[TodoBoard]:
|
||||||
|
"""Get todo board by project ID"""
|
||||||
|
stmt = select(TodoBoard).where(TodoBoard.project_id == project_id).options(
|
||||||
|
selectinload(TodoBoard.stages)
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[TodoBoard]:
|
||||||
|
"""Update todo board"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(TodoBoard).where(TodoBoard.project_id == project_id).values(**update_data)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return await self.get_by_project_id(project_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TodoStageRepository:
|
||||||
|
"""Repository for TodoStage operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_stages(self, stages_data: List[Dict[str, Any]]) -> List[TodoStage]:
|
||||||
|
"""Create multiple stages"""
|
||||||
|
stages = [TodoStage(**stage_data) for stage_data in stages_data]
|
||||||
|
self.session.add_all(stages)
|
||||||
|
await self.session.commit()
|
||||||
|
for stage in stages:
|
||||||
|
await self.session.refresh(stage)
|
||||||
|
return stages
|
||||||
|
|
||||||
|
async def update_stage(self, stage_id: str, update_data: Dict[str, Any]) -> bool:
|
||||||
|
"""Update a specific stage"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(TodoStage).where(TodoStage.stage_id == stage_id).values(**update_data)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def update_stage_by_project_and_stage(self, project_id: str, stage_id: str, update_data: Dict[str, Any]) -> bool:
|
||||||
|
"""Update a specific stage by project_id and stage_id for better performance"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(TodoStage).where(
|
||||||
|
TodoStage.project_id == project_id,
|
||||||
|
TodoStage.stage_id == stage_id
|
||||||
|
).values(**update_data)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def get_stages_by_board_id(self, board_id: int) -> List[TodoStage]:
|
||||||
|
"""Get all stages for a todo board"""
|
||||||
|
stmt = select(TodoStage).where(TodoStage.todo_board_id == board_id).order_by(TodoStage.stage_index)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectVersionRepository:
|
||||||
|
"""Repository for ProjectVersion operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create(self, version_data: Dict[str, Any]) -> ProjectVersion:
|
||||||
|
"""Create a new project version"""
|
||||||
|
version = ProjectVersion(**version_data)
|
||||||
|
self.session.add(version)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(version)
|
||||||
|
return version
|
||||||
|
|
||||||
|
async def get_versions_by_project_id(self, project_id: str) -> List[ProjectVersion]:
|
||||||
|
"""Get all versions for a project"""
|
||||||
|
stmt = select(ProjectVersion).where(ProjectVersion.project_id == project_id).order_by(ProjectVersion.version.desc())
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
class SlideDataRepository:
|
||||||
|
"""Repository for SlideData operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_slides(self, slides_data: List[Dict[str, Any]]) -> List[SlideData]:
|
||||||
|
"""Create multiple slides"""
|
||||||
|
slides = [SlideData(**slide_data) for slide_data in slides_data]
|
||||||
|
self.session.add_all(slides)
|
||||||
|
await self.session.commit()
|
||||||
|
for slide in slides:
|
||||||
|
await self.session.refresh(slide)
|
||||||
|
return slides
|
||||||
|
|
||||||
|
async def create_single_slide(self, slide_data: Dict[str, Any]) -> SlideData:
|
||||||
|
"""Create a single slide"""
|
||||||
|
slide = SlideData(**slide_data)
|
||||||
|
self.session.add(slide)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(slide)
|
||||||
|
return slide
|
||||||
|
|
||||||
|
async def upsert_slide(self, project_id: str, slide_index: int, slide_data: Dict[str, Any]) -> SlideData:
|
||||||
|
"""Insert or update a single slide"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger.info(f"🔄 数据库仓库开始upsert幻灯片: 项目ID={project_id}, 索引={slide_index}")
|
||||||
|
|
||||||
|
# Check if slide already exists
|
||||||
|
stmt = select(SlideData).where(
|
||||||
|
SlideData.project_id == project_id,
|
||||||
|
SlideData.slide_index == slide_index
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
existing_slide = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing_slide:
|
||||||
|
# Update existing slide
|
||||||
|
logger.info(f"📝 更新现有幻灯片: 数据库ID={existing_slide.id}, 项目ID={project_id}, 索引={slide_index}")
|
||||||
|
slide_data['updated_at'] = time.time()
|
||||||
|
|
||||||
|
updated_fields = []
|
||||||
|
for key, value in slide_data.items():
|
||||||
|
if hasattr(existing_slide, key):
|
||||||
|
old_value = getattr(existing_slide, key)
|
||||||
|
if old_value != value:
|
||||||
|
setattr(existing_slide, key, value)
|
||||||
|
updated_fields.append(key)
|
||||||
|
|
||||||
|
logger.info(f"📊 更新的字段: {updated_fields}")
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(existing_slide)
|
||||||
|
logger.info(f"✅ 幻灯片更新成功: 数据库ID={existing_slide.id}")
|
||||||
|
return existing_slide
|
||||||
|
else:
|
||||||
|
# Create new slide
|
||||||
|
logger.info(f"➕ 创建新幻灯片: 项目ID={project_id}, 索引={slide_index}")
|
||||||
|
slide_data['created_at'] = time.time()
|
||||||
|
slide_data['updated_at'] = time.time()
|
||||||
|
new_slide = await self.create_single_slide(slide_data)
|
||||||
|
logger.info(f"✅ 新幻灯片创建成功: 数据库ID={new_slide.id}")
|
||||||
|
return new_slide
|
||||||
|
|
||||||
|
async def get_slides_by_project_id(self, project_id: str) -> List[SlideData]:
|
||||||
|
"""Get all slides for a project"""
|
||||||
|
stmt = select(SlideData).where(SlideData.project_id == project_id).order_by(SlideData.slide_index)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_slide(self, slide_id: str, update_data: Dict[str, Any]) -> bool:
|
||||||
|
"""Update a specific slide"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(SlideData).where(SlideData.slide_id == slide_id).values(**update_data)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_slides_by_project_id(self, project_id: str) -> bool:
|
||||||
|
"""Delete all slides for a project"""
|
||||||
|
stmt = delete(SlideData).where(SlideData.project_id == project_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_slides_after_index(self, project_id: str, start_index: int) -> int:
|
||||||
|
"""Delete slides with index >= start_index for a project"""
|
||||||
|
logger.debug(f"🗑️ 删除项目 {project_id} 中索引 >= {start_index} 的幻灯片")
|
||||||
|
stmt = delete(SlideData).where(
|
||||||
|
and_(
|
||||||
|
SlideData.project_id == project_id,
|
||||||
|
SlideData.slide_index >= start_index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
logger.debug(f"✅ 成功删除 {deleted_count} 张多余的幻灯片")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
async def batch_upsert_slides(self, project_id: str, slides_data: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""批量插入或更新幻灯片 - 优化版本"""
|
||||||
|
logger.debug(f"🔄 开始批量upsert幻灯片: 项目ID={project_id}, 数量={len(slides_data)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取现有幻灯片
|
||||||
|
existing_slides_stmt = select(SlideData).where(SlideData.project_id == project_id)
|
||||||
|
result = await self.session.execute(existing_slides_stmt)
|
||||||
|
existing_slides = {slide.slide_index: slide for slide in result.scalars().all()}
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
created_count = 0
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 批量处理幻灯片
|
||||||
|
for i, slide_data in enumerate(slides_data):
|
||||||
|
slide_index = i
|
||||||
|
|
||||||
|
if slide_index in existing_slides:
|
||||||
|
# 更新现有幻灯片
|
||||||
|
existing_slide = existing_slides[slide_index]
|
||||||
|
slide_data['updated_at'] = current_time
|
||||||
|
|
||||||
|
# 只更新有变化的字段
|
||||||
|
has_changes = False
|
||||||
|
for key, value in slide_data.items():
|
||||||
|
if hasattr(existing_slide, key) and getattr(existing_slide, key) != value:
|
||||||
|
setattr(existing_slide, key, value)
|
||||||
|
has_changes = True
|
||||||
|
|
||||||
|
if has_changes:
|
||||||
|
updated_count += 1
|
||||||
|
else:
|
||||||
|
# 创建新幻灯片
|
||||||
|
slide_data.update({
|
||||||
|
'project_id': project_id,
|
||||||
|
'slide_index': slide_index,
|
||||||
|
'created_at': current_time,
|
||||||
|
'updated_at': current_time
|
||||||
|
})
|
||||||
|
new_slide = SlideData(**slide_data)
|
||||||
|
self.session.add(new_slide)
|
||||||
|
created_count += 1
|
||||||
|
|
||||||
|
# 一次性提交所有更改
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
logger.debug(f"✅ 批量upsert完成: 更新={updated_count}, 创建={created_count}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 批量upsert失败: {e}")
|
||||||
|
await self.session.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def update_slide_user_edited_status(self, project_id: str, slide_index: int, is_user_edited: bool = True) -> bool:
|
||||||
|
"""Update the user edited status for a specific slide"""
|
||||||
|
stmt = update(SlideData).where(
|
||||||
|
SlideData.project_id == project_id,
|
||||||
|
SlideData.slide_index == slide_index
|
||||||
|
).values(
|
||||||
|
is_user_edited=is_user_edited,
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
class PPTTemplateRepository:
|
||||||
|
"""Repository for PPT template operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_template(self, template_data: Dict[str, Any]) -> PPTTemplate:
|
||||||
|
"""Create a new PPT template"""
|
||||||
|
template_data['created_at'] = time.time()
|
||||||
|
template_data['updated_at'] = time.time()
|
||||||
|
template = PPTTemplate(**template_data)
|
||||||
|
self.session.add(template)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(template)
|
||||||
|
return template
|
||||||
|
|
||||||
|
async def get_template_by_id(self, template_id: int) -> Optional[PPTTemplate]:
|
||||||
|
"""Get template by ID"""
|
||||||
|
stmt = select(PPTTemplate).where(PPTTemplate.id == template_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_templates_by_project_id(self, project_id: str) -> List[PPTTemplate]:
|
||||||
|
"""Get all templates for a project"""
|
||||||
|
stmt = select(PPTTemplate).where(PPTTemplate.project_id == project_id).order_by(PPTTemplate.created_at)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_templates_by_type(self, project_id: str, template_type: str) -> List[PPTTemplate]:
|
||||||
|
"""Get templates by type for a project"""
|
||||||
|
stmt = select(PPTTemplate).where(
|
||||||
|
PPTTemplate.project_id == project_id,
|
||||||
|
PPTTemplate.template_type == template_type
|
||||||
|
).order_by(PPTTemplate.created_at)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
|
||||||
|
"""Update a template"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(**update_data)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def increment_usage_count(self, template_id: int) -> bool:
|
||||||
|
"""Increment template usage count"""
|
||||||
|
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(
|
||||||
|
usage_count=PPTTemplate.usage_count + 1,
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_template(self, template_id: int) -> bool:
|
||||||
|
"""Delete a template"""
|
||||||
|
stmt = delete(PPTTemplate).where(PPTTemplate.id == template_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalMasterTemplateRepository:
|
||||||
|
"""Repository for Global Master Template operations"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_template(self, template_data: Dict[str, Any]) -> GlobalMasterTemplate:
|
||||||
|
"""Create a new global master template"""
|
||||||
|
template_data['created_at'] = time.time()
|
||||||
|
template_data['updated_at'] = time.time()
|
||||||
|
template = GlobalMasterTemplate(**template_data)
|
||||||
|
self.session.add(template)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(template)
|
||||||
|
return template
|
||||||
|
|
||||||
|
async def get_template_by_id(self, template_id: int) -> Optional[GlobalMasterTemplate]:
|
||||||
|
"""Get template by ID"""
|
||||||
|
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_template_by_name(self, template_name: str) -> Optional[GlobalMasterTemplate]:
|
||||||
|
"""Get template by name"""
|
||||||
|
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.template_name == template_name)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_all_templates(self, active_only: bool = True) -> List[GlobalMasterTemplate]:
|
||||||
|
"""Get all global master templates"""
|
||||||
|
stmt = select(GlobalMasterTemplate)
|
||||||
|
if active_only:
|
||||||
|
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
stmt = stmt.order_by(GlobalMasterTemplate.is_default.desc(), GlobalMasterTemplate.usage_count.desc())
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_templates_by_tags(self, tags: List[str], active_only: bool = True) -> List[GlobalMasterTemplate]:
|
||||||
|
"""Get templates by tags"""
|
||||||
|
stmt = select(GlobalMasterTemplate)
|
||||||
|
if active_only:
|
||||||
|
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
|
||||||
|
# Filter by tags (any tag matches)
|
||||||
|
for tag in tags:
|
||||||
|
stmt = stmt.where(GlobalMasterTemplate.tags.contains([tag]))
|
||||||
|
|
||||||
|
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc())
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_templates_paginated(
|
||||||
|
self,
|
||||||
|
active_only: bool = True,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 6,
|
||||||
|
search: Optional[str] = None
|
||||||
|
) -> Tuple[List[GlobalMasterTemplate], int]:
|
||||||
|
"""Get templates with pagination"""
|
||||||
|
from sqlalchemy import func, or_
|
||||||
|
|
||||||
|
# Base query
|
||||||
|
stmt = select(GlobalMasterTemplate)
|
||||||
|
count_stmt = select(func.count(GlobalMasterTemplate.id))
|
||||||
|
|
||||||
|
if active_only:
|
||||||
|
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
|
||||||
|
# Add search filter
|
||||||
|
if search and search.strip():
|
||||||
|
search_filter = or_(
|
||||||
|
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
|
||||||
|
GlobalMasterTemplate.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
stmt = stmt.where(search_filter)
|
||||||
|
count_stmt = count_stmt.where(search_filter)
|
||||||
|
|
||||||
|
# Order and paginate
|
||||||
|
stmt = stmt.order_by(
|
||||||
|
GlobalMasterTemplate.is_default.desc(),
|
||||||
|
GlobalMasterTemplate.usage_count.desc()
|
||||||
|
).offset(offset).limit(limit)
|
||||||
|
|
||||||
|
# Execute queries
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
count_result = await self.session.execute(count_stmt)
|
||||||
|
|
||||||
|
templates = result.scalars().all()
|
||||||
|
total_count = count_result.scalar()
|
||||||
|
|
||||||
|
return templates, total_count
|
||||||
|
|
||||||
|
async def get_templates_by_tags_paginated(
|
||||||
|
self,
|
||||||
|
tags: List[str],
|
||||||
|
active_only: bool = True,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 6,
|
||||||
|
search: Optional[str] = None
|
||||||
|
) -> Tuple[List[GlobalMasterTemplate], int]:
|
||||||
|
"""Get templates by tags with pagination"""
|
||||||
|
from sqlalchemy import func, or_
|
||||||
|
|
||||||
|
# Base query
|
||||||
|
stmt = select(GlobalMasterTemplate)
|
||||||
|
count_stmt = select(func.count(GlobalMasterTemplate.id))
|
||||||
|
|
||||||
|
if active_only:
|
||||||
|
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
|
||||||
|
|
||||||
|
# Filter by tags (any tag matches)
|
||||||
|
for tag in tags:
|
||||||
|
tag_filter = GlobalMasterTemplate.tags.contains([tag])
|
||||||
|
stmt = stmt.where(tag_filter)
|
||||||
|
count_stmt = count_stmt.where(tag_filter)
|
||||||
|
|
||||||
|
# Add search filter
|
||||||
|
if search and search.strip():
|
||||||
|
search_filter = or_(
|
||||||
|
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
|
||||||
|
GlobalMasterTemplate.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
stmt = stmt.where(search_filter)
|
||||||
|
count_stmt = count_stmt.where(search_filter)
|
||||||
|
|
||||||
|
# Order and paginate
|
||||||
|
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc()).offset(offset).limit(limit)
|
||||||
|
|
||||||
|
# Execute queries
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
count_result = await self.session.execute(count_stmt)
|
||||||
|
|
||||||
|
templates = result.scalars().all()
|
||||||
|
total_count = count_result.scalar()
|
||||||
|
|
||||||
|
return templates, total_count
|
||||||
|
|
||||||
|
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
|
||||||
|
"""Update a global master template"""
|
||||||
|
update_data['updated_at'] = time.time()
|
||||||
|
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(**update_data)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def delete_template(self, template_id: int) -> bool:
|
||||||
|
"""Delete a global master template"""
|
||||||
|
try:
|
||||||
|
stmt = delete(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
rows_affected = result.rowcount
|
||||||
|
logger.info(f"Delete operation for template {template_id}: {rows_affected} rows affected")
|
||||||
|
|
||||||
|
return rows_affected > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting template {template_id}: {e}")
|
||||||
|
await self.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def increment_usage_count(self, template_id: int) -> bool:
|
||||||
|
"""Increment template usage count"""
|
||||||
|
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
|
||||||
|
usage_count=GlobalMasterTemplate.usage_count + 1,
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def set_default_template(self, template_id: int) -> bool:
|
||||||
|
"""Set a template as default (and unset others)"""
|
||||||
|
# First, unset all default templates
|
||||||
|
stmt = update(GlobalMasterTemplate).values(is_default=False, updated_at=time.time())
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
|
# Then set the specified template as default
|
||||||
|
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
|
||||||
|
is_default=True,
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def get_default_template(self) -> Optional[GlobalMasterTemplate]:
|
||||||
|
"""Get the default template"""
|
||||||
|
stmt = select(GlobalMasterTemplate).where(
|
||||||
|
GlobalMasterTemplate.is_default == True,
|
||||||
|
GlobalMasterTemplate.is_active == True
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_templates_by_project_id(self, project_id: str) -> bool:
|
||||||
|
"""Delete all templates for a project"""
|
||||||
|
stmt = delete(PPTTemplate).where(PPTTemplate.project_id == project_id)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
Reference in New Issue
Block a user