Add File
This commit is contained in:
637
src/landppt/database/repositories.py
Normal file
637
src/landppt/database/repositories.py
Normal file
@@ -0,0 +1,637 @@
|
||||
"""
|
||||
Repository classes for database operations
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from .models import Project, TodoBoard, TodoStage, ProjectVersion, SlideData, PPTTemplate, GlobalMasterTemplate
|
||||
from ..api.models import PPTProject, TodoBoard as TodoBoardModel, TodoStage as TodoStageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProjectRepository:
|
||||
"""Repository for Project operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, project_data: Dict[str, Any]) -> Project:
|
||||
"""Create a new project"""
|
||||
project = Project(**project_data)
|
||||
self.session.add(project)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(project)
|
||||
return project
|
||||
|
||||
async def get_by_id(self, project_id: str) -> Optional[Project]:
|
||||
"""Get project by ID with all relationships"""
|
||||
stmt = select(Project).where(Project.project_id == project_id).options(
|
||||
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
|
||||
selectinload(Project.versions),
|
||||
selectinload(Project.slides)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_projects(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> List[Project]:
|
||||
"""List projects with pagination"""
|
||||
stmt = select(Project).options(
|
||||
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
|
||||
selectinload(Project.versions),
|
||||
selectinload(Project.slides)
|
||||
)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(Project.status == status)
|
||||
|
||||
stmt = stmt.order_by(Project.updated_at.desc())
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_projects(self, status: Optional[str] = None) -> int:
|
||||
"""Count total projects"""
|
||||
from sqlalchemy import func
|
||||
|
||||
stmt = select(func.count(Project.id))
|
||||
if status:
|
||||
stmt = stmt.where(Project.status == status)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[Project]:
|
||||
"""Update project"""
|
||||
try:
|
||||
# 首先获取项目
|
||||
project = await self.get_by_id(project_id)
|
||||
if not project:
|
||||
logger.warning(f"No project found with ID {project_id} for update")
|
||||
return None
|
||||
|
||||
# 更新项目属性
|
||||
for key, value in update_data.items():
|
||||
if hasattr(project, key):
|
||||
setattr(project, key, value)
|
||||
|
||||
# 设置更新时间
|
||||
project.updated_at = time.time()
|
||||
|
||||
# 提交更改
|
||||
await self.session.commit()
|
||||
await self.session.refresh(project)
|
||||
|
||||
logger.info(f"Successfully updated project {project_id}")
|
||||
return project
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating project {project_id}: {e}")
|
||||
await self.session.rollback()
|
||||
raise
|
||||
|
||||
async def delete(self, project_id: str) -> bool:
|
||||
"""Delete project"""
|
||||
stmt = delete(Project).where(Project.project_id == project_id)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
class TodoBoardRepository:
|
||||
"""Repository for TodoBoard operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, board_data: Dict[str, Any]) -> TodoBoard:
|
||||
"""Create a new todo board"""
|
||||
board = TodoBoard(**board_data)
|
||||
self.session.add(board)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(board)
|
||||
return board
|
||||
|
||||
async def get_by_project_id(self, project_id: str) -> Optional[TodoBoard]:
|
||||
"""Get todo board by project ID"""
|
||||
stmt = select(TodoBoard).where(TodoBoard.project_id == project_id).options(
|
||||
selectinload(TodoBoard.stages)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[TodoBoard]:
|
||||
"""Update todo board"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(TodoBoard).where(TodoBoard.project_id == project_id).values(**update_data)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return await self.get_by_project_id(project_id)
|
||||
|
||||
|
||||
class TodoStageRepository:
|
||||
"""Repository for TodoStage operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_stages(self, stages_data: List[Dict[str, Any]]) -> List[TodoStage]:
|
||||
"""Create multiple stages"""
|
||||
stages = [TodoStage(**stage_data) for stage_data in stages_data]
|
||||
self.session.add_all(stages)
|
||||
await self.session.commit()
|
||||
for stage in stages:
|
||||
await self.session.refresh(stage)
|
||||
return stages
|
||||
|
||||
async def update_stage(self, stage_id: str, update_data: Dict[str, Any]) -> bool:
|
||||
"""Update a specific stage"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(TodoStage).where(TodoStage.stage_id == stage_id).values(**update_data)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def update_stage_by_project_and_stage(self, project_id: str, stage_id: str, update_data: Dict[str, Any]) -> bool:
|
||||
"""Update a specific stage by project_id and stage_id for better performance"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(TodoStage).where(
|
||||
TodoStage.project_id == project_id,
|
||||
TodoStage.stage_id == stage_id
|
||||
).values(**update_data)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def get_stages_by_board_id(self, board_id: int) -> List[TodoStage]:
|
||||
"""Get all stages for a todo board"""
|
||||
stmt = select(TodoStage).where(TodoStage.todo_board_id == board_id).order_by(TodoStage.stage_index)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
class ProjectVersionRepository:
|
||||
"""Repository for ProjectVersion operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, version_data: Dict[str, Any]) -> ProjectVersion:
|
||||
"""Create a new project version"""
|
||||
version = ProjectVersion(**version_data)
|
||||
self.session.add(version)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(version)
|
||||
return version
|
||||
|
||||
async def get_versions_by_project_id(self, project_id: str) -> List[ProjectVersion]:
|
||||
"""Get all versions for a project"""
|
||||
stmt = select(ProjectVersion).where(ProjectVersion.project_id == project_id).order_by(ProjectVersion.version.desc())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
class SlideDataRepository:
|
||||
"""Repository for SlideData operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_slides(self, slides_data: List[Dict[str, Any]]) -> List[SlideData]:
|
||||
"""Create multiple slides"""
|
||||
slides = [SlideData(**slide_data) for slide_data in slides_data]
|
||||
self.session.add_all(slides)
|
||||
await self.session.commit()
|
||||
for slide in slides:
|
||||
await self.session.refresh(slide)
|
||||
return slides
|
||||
|
||||
async def create_single_slide(self, slide_data: Dict[str, Any]) -> SlideData:
|
||||
"""Create a single slide"""
|
||||
slide = SlideData(**slide_data)
|
||||
self.session.add(slide)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(slide)
|
||||
return slide
|
||||
|
||||
async def upsert_slide(self, project_id: str, slide_index: int, slide_data: Dict[str, Any]) -> SlideData:
|
||||
"""Insert or update a single slide"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"🔄 数据库仓库开始upsert幻灯片: 项目ID={project_id}, 索引={slide_index}")
|
||||
|
||||
# Check if slide already exists
|
||||
stmt = select(SlideData).where(
|
||||
SlideData.project_id == project_id,
|
||||
SlideData.slide_index == slide_index
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
existing_slide = result.scalar_one_or_none()
|
||||
|
||||
if existing_slide:
|
||||
# Update existing slide
|
||||
logger.info(f"📝 更新现有幻灯片: 数据库ID={existing_slide.id}, 项目ID={project_id}, 索引={slide_index}")
|
||||
slide_data['updated_at'] = time.time()
|
||||
|
||||
updated_fields = []
|
||||
for key, value in slide_data.items():
|
||||
if hasattr(existing_slide, key):
|
||||
old_value = getattr(existing_slide, key)
|
||||
if old_value != value:
|
||||
setattr(existing_slide, key, value)
|
||||
updated_fields.append(key)
|
||||
|
||||
logger.info(f"📊 更新的字段: {updated_fields}")
|
||||
await self.session.commit()
|
||||
await self.session.refresh(existing_slide)
|
||||
logger.info(f"✅ 幻灯片更新成功: 数据库ID={existing_slide.id}")
|
||||
return existing_slide
|
||||
else:
|
||||
# Create new slide
|
||||
logger.info(f"➕ 创建新幻灯片: 项目ID={project_id}, 索引={slide_index}")
|
||||
slide_data['created_at'] = time.time()
|
||||
slide_data['updated_at'] = time.time()
|
||||
new_slide = await self.create_single_slide(slide_data)
|
||||
logger.info(f"✅ 新幻灯片创建成功: 数据库ID={new_slide.id}")
|
||||
return new_slide
|
||||
|
||||
async def get_slides_by_project_id(self, project_id: str) -> List[SlideData]:
|
||||
"""Get all slides for a project"""
|
||||
stmt = select(SlideData).where(SlideData.project_id == project_id).order_by(SlideData.slide_index)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_slide(self, slide_id: str, update_data: Dict[str, Any]) -> bool:
|
||||
"""Update a specific slide"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(SlideData).where(SlideData.slide_id == slide_id).values(**update_data)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_slides_by_project_id(self, project_id: str) -> bool:
|
||||
"""Delete all slides for a project"""
|
||||
stmt = delete(SlideData).where(SlideData.project_id == project_id)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_slides_after_index(self, project_id: str, start_index: int) -> int:
|
||||
"""Delete slides with index >= start_index for a project"""
|
||||
logger.debug(f"🗑️ 删除项目 {project_id} 中索引 >= {start_index} 的幻灯片")
|
||||
stmt = delete(SlideData).where(
|
||||
and_(
|
||||
SlideData.project_id == project_id,
|
||||
SlideData.slide_index >= start_index
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
deleted_count = result.rowcount
|
||||
logger.debug(f"✅ 成功删除 {deleted_count} 张多余的幻灯片")
|
||||
return deleted_count
|
||||
|
||||
async def batch_upsert_slides(self, project_id: str, slides_data: List[Dict[str, Any]]) -> bool:
|
||||
"""批量插入或更新幻灯片 - 优化版本"""
|
||||
logger.debug(f"🔄 开始批量upsert幻灯片: 项目ID={project_id}, 数量={len(slides_data)}")
|
||||
|
||||
try:
|
||||
# 获取现有幻灯片
|
||||
existing_slides_stmt = select(SlideData).where(SlideData.project_id == project_id)
|
||||
result = await self.session.execute(existing_slides_stmt)
|
||||
existing_slides = {slide.slide_index: slide for slide in result.scalars().all()}
|
||||
|
||||
updated_count = 0
|
||||
created_count = 0
|
||||
current_time = time.time()
|
||||
|
||||
# 批量处理幻灯片
|
||||
for i, slide_data in enumerate(slides_data):
|
||||
slide_index = i
|
||||
|
||||
if slide_index in existing_slides:
|
||||
# 更新现有幻灯片
|
||||
existing_slide = existing_slides[slide_index]
|
||||
slide_data['updated_at'] = current_time
|
||||
|
||||
# 只更新有变化的字段
|
||||
has_changes = False
|
||||
for key, value in slide_data.items():
|
||||
if hasattr(existing_slide, key) and getattr(existing_slide, key) != value:
|
||||
setattr(existing_slide, key, value)
|
||||
has_changes = True
|
||||
|
||||
if has_changes:
|
||||
updated_count += 1
|
||||
else:
|
||||
# 创建新幻灯片
|
||||
slide_data.update({
|
||||
'project_id': project_id,
|
||||
'slide_index': slide_index,
|
||||
'created_at': current_time,
|
||||
'updated_at': current_time
|
||||
})
|
||||
new_slide = SlideData(**slide_data)
|
||||
self.session.add(new_slide)
|
||||
created_count += 1
|
||||
|
||||
# 一次性提交所有更改
|
||||
await self.session.commit()
|
||||
|
||||
logger.debug(f"✅ 批量upsert完成: 更新={updated_count}, 创建={created_count}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量upsert失败: {e}")
|
||||
await self.session.rollback()
|
||||
return False
|
||||
|
||||
async def update_slide_user_edited_status(self, project_id: str, slide_index: int, is_user_edited: bool = True) -> bool:
|
||||
"""Update the user edited status for a specific slide"""
|
||||
stmt = update(SlideData).where(
|
||||
SlideData.project_id == project_id,
|
||||
SlideData.slide_index == slide_index
|
||||
).values(
|
||||
is_user_edited=is_user_edited,
|
||||
updated_at=time.time()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
class PPTTemplateRepository:
|
||||
"""Repository for PPT template operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_template(self, template_data: Dict[str, Any]) -> PPTTemplate:
|
||||
"""Create a new PPT template"""
|
||||
template_data['created_at'] = time.time()
|
||||
template_data['updated_at'] = time.time()
|
||||
template = PPTTemplate(**template_data)
|
||||
self.session.add(template)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(template)
|
||||
return template
|
||||
|
||||
async def get_template_by_id(self, template_id: int) -> Optional[PPTTemplate]:
|
||||
"""Get template by ID"""
|
||||
stmt = select(PPTTemplate).where(PPTTemplate.id == template_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_templates_by_project_id(self, project_id: str) -> List[PPTTemplate]:
|
||||
"""Get all templates for a project"""
|
||||
stmt = select(PPTTemplate).where(PPTTemplate.project_id == project_id).order_by(PPTTemplate.created_at)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_templates_by_type(self, project_id: str, template_type: str) -> List[PPTTemplate]:
|
||||
"""Get templates by type for a project"""
|
||||
stmt = select(PPTTemplate).where(
|
||||
PPTTemplate.project_id == project_id,
|
||||
PPTTemplate.template_type == template_type
|
||||
).order_by(PPTTemplate.created_at)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
|
||||
"""Update a template"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(**update_data)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def increment_usage_count(self, template_id: int) -> bool:
|
||||
"""Increment template usage count"""
|
||||
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(
|
||||
usage_count=PPTTemplate.usage_count + 1,
|
||||
updated_at=time.time()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_template(self, template_id: int) -> bool:
|
||||
"""Delete a template"""
|
||||
stmt = delete(PPTTemplate).where(PPTTemplate.id == template_id)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
class GlobalMasterTemplateRepository:
|
||||
"""Repository for Global Master Template operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_template(self, template_data: Dict[str, Any]) -> GlobalMasterTemplate:
|
||||
"""Create a new global master template"""
|
||||
template_data['created_at'] = time.time()
|
||||
template_data['updated_at'] = time.time()
|
||||
template = GlobalMasterTemplate(**template_data)
|
||||
self.session.add(template)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(template)
|
||||
return template
|
||||
|
||||
async def get_template_by_id(self, template_id: int) -> Optional[GlobalMasterTemplate]:
|
||||
"""Get template by ID"""
|
||||
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_template_by_name(self, template_name: str) -> Optional[GlobalMasterTemplate]:
|
||||
"""Get template by name"""
|
||||
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.template_name == template_name)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_all_templates(self, active_only: bool = True) -> List[GlobalMasterTemplate]:
|
||||
"""Get all global master templates"""
|
||||
stmt = select(GlobalMasterTemplate)
|
||||
if active_only:
|
||||
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
stmt = stmt.order_by(GlobalMasterTemplate.is_default.desc(), GlobalMasterTemplate.usage_count.desc())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_templates_by_tags(self, tags: List[str], active_only: bool = True) -> List[GlobalMasterTemplate]:
|
||||
"""Get templates by tags"""
|
||||
stmt = select(GlobalMasterTemplate)
|
||||
if active_only:
|
||||
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
|
||||
# Filter by tags (any tag matches)
|
||||
for tag in tags:
|
||||
stmt = stmt.where(GlobalMasterTemplate.tags.contains([tag]))
|
||||
|
||||
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_templates_paginated(
|
||||
self,
|
||||
active_only: bool = True,
|
||||
offset: int = 0,
|
||||
limit: int = 6,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[GlobalMasterTemplate], int]:
|
||||
"""Get templates with pagination"""
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
# Base query
|
||||
stmt = select(GlobalMasterTemplate)
|
||||
count_stmt = select(func.count(GlobalMasterTemplate.id))
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
|
||||
# Add search filter
|
||||
if search and search.strip():
|
||||
search_filter = or_(
|
||||
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
|
||||
GlobalMasterTemplate.description.ilike(f"%{search}%")
|
||||
)
|
||||
stmt = stmt.where(search_filter)
|
||||
count_stmt = count_stmt.where(search_filter)
|
||||
|
||||
# Order and paginate
|
||||
stmt = stmt.order_by(
|
||||
GlobalMasterTemplate.is_default.desc(),
|
||||
GlobalMasterTemplate.usage_count.desc()
|
||||
).offset(offset).limit(limit)
|
||||
|
||||
# Execute queries
|
||||
result = await self.session.execute(stmt)
|
||||
count_result = await self.session.execute(count_stmt)
|
||||
|
||||
templates = result.scalars().all()
|
||||
total_count = count_result.scalar()
|
||||
|
||||
return templates, total_count
|
||||
|
||||
async def get_templates_by_tags_paginated(
|
||||
self,
|
||||
tags: List[str],
|
||||
active_only: bool = True,
|
||||
offset: int = 0,
|
||||
limit: int = 6,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[GlobalMasterTemplate], int]:
|
||||
"""Get templates by tags with pagination"""
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
# Base query
|
||||
stmt = select(GlobalMasterTemplate)
|
||||
count_stmt = select(func.count(GlobalMasterTemplate.id))
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
|
||||
|
||||
# Filter by tags (any tag matches)
|
||||
for tag in tags:
|
||||
tag_filter = GlobalMasterTemplate.tags.contains([tag])
|
||||
stmt = stmt.where(tag_filter)
|
||||
count_stmt = count_stmt.where(tag_filter)
|
||||
|
||||
# Add search filter
|
||||
if search and search.strip():
|
||||
search_filter = or_(
|
||||
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
|
||||
GlobalMasterTemplate.description.ilike(f"%{search}%")
|
||||
)
|
||||
stmt = stmt.where(search_filter)
|
||||
count_stmt = count_stmt.where(search_filter)
|
||||
|
||||
# Order and paginate
|
||||
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc()).offset(offset).limit(limit)
|
||||
|
||||
# Execute queries
|
||||
result = await self.session.execute(stmt)
|
||||
count_result = await self.session.execute(count_stmt)
|
||||
|
||||
templates = result.scalars().all()
|
||||
total_count = count_result.scalar()
|
||||
|
||||
return templates, total_count
|
||||
|
||||
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
|
||||
"""Update a global master template"""
|
||||
update_data['updated_at'] = time.time()
|
||||
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(**update_data)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_template(self, template_id: int) -> bool:
|
||||
"""Delete a global master template"""
|
||||
try:
|
||||
stmt = delete(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
rows_affected = result.rowcount
|
||||
logger.info(f"Delete operation for template {template_id}: {rows_affected} rows affected")
|
||||
|
||||
return rows_affected > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting template {template_id}: {e}")
|
||||
await self.session.rollback()
|
||||
raise
|
||||
|
||||
async def increment_usage_count(self, template_id: int) -> bool:
|
||||
"""Increment template usage count"""
|
||||
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
|
||||
usage_count=GlobalMasterTemplate.usage_count + 1,
|
||||
updated_at=time.time()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def set_default_template(self, template_id: int) -> bool:
|
||||
"""Set a template as default (and unset others)"""
|
||||
# First, unset all default templates
|
||||
stmt = update(GlobalMasterTemplate).values(is_default=False, updated_at=time.time())
|
||||
await self.session.execute(stmt)
|
||||
|
||||
# Then set the specified template as default
|
||||
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
|
||||
is_default=True,
|
||||
updated_at=time.time()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def get_default_template(self) -> Optional[GlobalMasterTemplate]:
|
||||
"""Get the default template"""
|
||||
stmt = select(GlobalMasterTemplate).where(
|
||||
GlobalMasterTemplate.is_default == True,
|
||||
GlobalMasterTemplate.is_active == True
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
|
||||
async def delete_templates_by_project_id(self, project_id: str) -> bool:
|
||||
"""Delete all templates for a project"""
|
||||
stmt = delete(PPTTemplate).where(PPTTemplate.project_id == project_id)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
Reference in New Issue
Block a user