This commit is contained in:
2025-11-07 09:05:27 +08:00
parent f0e7ec06d7
commit ede0150136

View File

@@ -0,0 +1,637 @@
"""
Repository classes for database operations
"""
import time
import logging
from typing import List, Optional, Dict, Any, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, and_
from sqlalchemy.orm import selectinload
from .models import Project, TodoBoard, TodoStage, ProjectVersion, SlideData, PPTTemplate, GlobalMasterTemplate
from ..api.models import PPTProject, TodoBoard as TodoBoardModel, TodoStage as TodoStageModel
logger = logging.getLogger(__name__)
class ProjectRepository:
"""Repository for Project operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, project_data: Dict[str, Any]) -> Project:
"""Create a new project"""
project = Project(**project_data)
self.session.add(project)
await self.session.commit()
await self.session.refresh(project)
return project
async def get_by_id(self, project_id: str) -> Optional[Project]:
"""Get project by ID with all relationships"""
stmt = select(Project).where(Project.project_id == project_id).options(
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
selectinload(Project.versions),
selectinload(Project.slides)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def list_projects(self, page: int = 1, page_size: int = 10, status: Optional[str] = None) -> List[Project]:
"""List projects with pagination"""
stmt = select(Project).options(
selectinload(Project.todo_board).selectinload(TodoBoard.stages),
selectinload(Project.versions),
selectinload(Project.slides)
)
if status:
stmt = stmt.where(Project.status == status)
stmt = stmt.order_by(Project.updated_at.desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await self.session.execute(stmt)
return result.scalars().all()
async def count_projects(self, status: Optional[str] = None) -> int:
"""Count total projects"""
from sqlalchemy import func
stmt = select(func.count(Project.id))
if status:
stmt = stmt.where(Project.status == status)
result = await self.session.execute(stmt)
return result.scalar() or 0
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[Project]:
"""Update project"""
try:
# 首先获取项目
project = await self.get_by_id(project_id)
if not project:
logger.warning(f"No project found with ID {project_id} for update")
return None
# 更新项目属性
for key, value in update_data.items():
if hasattr(project, key):
setattr(project, key, value)
# 设置更新时间
project.updated_at = time.time()
# 提交更改
await self.session.commit()
await self.session.refresh(project)
logger.info(f"Successfully updated project {project_id}")
return project
except Exception as e:
logger.error(f"Error updating project {project_id}: {e}")
await self.session.rollback()
raise
async def delete(self, project_id: str) -> bool:
"""Delete project"""
stmt = delete(Project).where(Project.project_id == project_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
class TodoBoardRepository:
"""Repository for TodoBoard operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, board_data: Dict[str, Any]) -> TodoBoard:
"""Create a new todo board"""
board = TodoBoard(**board_data)
self.session.add(board)
await self.session.commit()
await self.session.refresh(board)
return board
async def get_by_project_id(self, project_id: str) -> Optional[TodoBoard]:
"""Get todo board by project ID"""
stmt = select(TodoBoard).where(TodoBoard.project_id == project_id).options(
selectinload(TodoBoard.stages)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def update(self, project_id: str, update_data: Dict[str, Any]) -> Optional[TodoBoard]:
"""Update todo board"""
update_data['updated_at'] = time.time()
stmt = update(TodoBoard).where(TodoBoard.project_id == project_id).values(**update_data)
await self.session.execute(stmt)
await self.session.commit()
return await self.get_by_project_id(project_id)
class TodoStageRepository:
"""Repository for TodoStage operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_stages(self, stages_data: List[Dict[str, Any]]) -> List[TodoStage]:
"""Create multiple stages"""
stages = [TodoStage(**stage_data) for stage_data in stages_data]
self.session.add_all(stages)
await self.session.commit()
for stage in stages:
await self.session.refresh(stage)
return stages
async def update_stage(self, stage_id: str, update_data: Dict[str, Any]) -> bool:
"""Update a specific stage"""
update_data['updated_at'] = time.time()
stmt = update(TodoStage).where(TodoStage.stage_id == stage_id).values(**update_data)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def update_stage_by_project_and_stage(self, project_id: str, stage_id: str, update_data: Dict[str, Any]) -> bool:
"""Update a specific stage by project_id and stage_id for better performance"""
update_data['updated_at'] = time.time()
stmt = update(TodoStage).where(
TodoStage.project_id == project_id,
TodoStage.stage_id == stage_id
).values(**update_data)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def get_stages_by_board_id(self, board_id: int) -> List[TodoStage]:
"""Get all stages for a todo board"""
stmt = select(TodoStage).where(TodoStage.todo_board_id == board_id).order_by(TodoStage.stage_index)
result = await self.session.execute(stmt)
return result.scalars().all()
class ProjectVersionRepository:
"""Repository for ProjectVersion operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, version_data: Dict[str, Any]) -> ProjectVersion:
"""Create a new project version"""
version = ProjectVersion(**version_data)
self.session.add(version)
await self.session.commit()
await self.session.refresh(version)
return version
async def get_versions_by_project_id(self, project_id: str) -> List[ProjectVersion]:
"""Get all versions for a project"""
stmt = select(ProjectVersion).where(ProjectVersion.project_id == project_id).order_by(ProjectVersion.version.desc())
result = await self.session.execute(stmt)
return result.scalars().all()
class SlideDataRepository:
"""Repository for SlideData operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_slides(self, slides_data: List[Dict[str, Any]]) -> List[SlideData]:
"""Create multiple slides"""
slides = [SlideData(**slide_data) for slide_data in slides_data]
self.session.add_all(slides)
await self.session.commit()
for slide in slides:
await self.session.refresh(slide)
return slides
async def create_single_slide(self, slide_data: Dict[str, Any]) -> SlideData:
"""Create a single slide"""
slide = SlideData(**slide_data)
self.session.add(slide)
await self.session.commit()
await self.session.refresh(slide)
return slide
async def upsert_slide(self, project_id: str, slide_index: int, slide_data: Dict[str, Any]) -> SlideData:
"""Insert or update a single slide"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"🔄 数据库仓库开始upsert幻灯片: 项目ID={project_id}, 索引={slide_index}")
# Check if slide already exists
stmt = select(SlideData).where(
SlideData.project_id == project_id,
SlideData.slide_index == slide_index
)
result = await self.session.execute(stmt)
existing_slide = result.scalar_one_or_none()
if existing_slide:
# Update existing slide
logger.info(f"📝 更新现有幻灯片: 数据库ID={existing_slide.id}, 项目ID={project_id}, 索引={slide_index}")
slide_data['updated_at'] = time.time()
updated_fields = []
for key, value in slide_data.items():
if hasattr(existing_slide, key):
old_value = getattr(existing_slide, key)
if old_value != value:
setattr(existing_slide, key, value)
updated_fields.append(key)
logger.info(f"📊 更新的字段: {updated_fields}")
await self.session.commit()
await self.session.refresh(existing_slide)
logger.info(f"✅ 幻灯片更新成功: 数据库ID={existing_slide.id}")
return existing_slide
else:
# Create new slide
logger.info(f" 创建新幻灯片: 项目ID={project_id}, 索引={slide_index}")
slide_data['created_at'] = time.time()
slide_data['updated_at'] = time.time()
new_slide = await self.create_single_slide(slide_data)
logger.info(f"✅ 新幻灯片创建成功: 数据库ID={new_slide.id}")
return new_slide
async def get_slides_by_project_id(self, project_id: str) -> List[SlideData]:
"""Get all slides for a project"""
stmt = select(SlideData).where(SlideData.project_id == project_id).order_by(SlideData.slide_index)
result = await self.session.execute(stmt)
return result.scalars().all()
async def update_slide(self, slide_id: str, update_data: Dict[str, Any]) -> bool:
"""Update a specific slide"""
update_data['updated_at'] = time.time()
stmt = update(SlideData).where(SlideData.slide_id == slide_id).values(**update_data)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def delete_slides_by_project_id(self, project_id: str) -> bool:
"""Delete all slides for a project"""
stmt = delete(SlideData).where(SlideData.project_id == project_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def delete_slides_after_index(self, project_id: str, start_index: int) -> int:
"""Delete slides with index >= start_index for a project"""
logger.debug(f"🗑️ 删除项目 {project_id} 中索引 >= {start_index} 的幻灯片")
stmt = delete(SlideData).where(
and_(
SlideData.project_id == project_id,
SlideData.slide_index >= start_index
)
)
result = await self.session.execute(stmt)
await self.session.commit()
deleted_count = result.rowcount
logger.debug(f"✅ 成功删除 {deleted_count} 张多余的幻灯片")
return deleted_count
async def batch_upsert_slides(self, project_id: str, slides_data: List[Dict[str, Any]]) -> bool:
"""批量插入或更新幻灯片 - 优化版本"""
logger.debug(f"🔄 开始批量upsert幻灯片: 项目ID={project_id}, 数量={len(slides_data)}")
try:
# 获取现有幻灯片
existing_slides_stmt = select(SlideData).where(SlideData.project_id == project_id)
result = await self.session.execute(existing_slides_stmt)
existing_slides = {slide.slide_index: slide for slide in result.scalars().all()}
updated_count = 0
created_count = 0
current_time = time.time()
# 批量处理幻灯片
for i, slide_data in enumerate(slides_data):
slide_index = i
if slide_index in existing_slides:
# 更新现有幻灯片
existing_slide = existing_slides[slide_index]
slide_data['updated_at'] = current_time
# 只更新有变化的字段
has_changes = False
for key, value in slide_data.items():
if hasattr(existing_slide, key) and getattr(existing_slide, key) != value:
setattr(existing_slide, key, value)
has_changes = True
if has_changes:
updated_count += 1
else:
# 创建新幻灯片
slide_data.update({
'project_id': project_id,
'slide_index': slide_index,
'created_at': current_time,
'updated_at': current_time
})
new_slide = SlideData(**slide_data)
self.session.add(new_slide)
created_count += 1
# 一次性提交所有更改
await self.session.commit()
logger.debug(f"✅ 批量upsert完成: 更新={updated_count}, 创建={created_count}")
return True
except Exception as e:
logger.error(f"❌ 批量upsert失败: {e}")
await self.session.rollback()
return False
async def update_slide_user_edited_status(self, project_id: str, slide_index: int, is_user_edited: bool = True) -> bool:
"""Update the user edited status for a specific slide"""
stmt = update(SlideData).where(
SlideData.project_id == project_id,
SlideData.slide_index == slide_index
).values(
is_user_edited=is_user_edited,
updated_at=time.time()
)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
class PPTTemplateRepository:
"""Repository for PPT template operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_template(self, template_data: Dict[str, Any]) -> PPTTemplate:
"""Create a new PPT template"""
template_data['created_at'] = time.time()
template_data['updated_at'] = time.time()
template = PPTTemplate(**template_data)
self.session.add(template)
await self.session.commit()
await self.session.refresh(template)
return template
async def get_template_by_id(self, template_id: int) -> Optional[PPTTemplate]:
"""Get template by ID"""
stmt = select(PPTTemplate).where(PPTTemplate.id == template_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_templates_by_project_id(self, project_id: str) -> List[PPTTemplate]:
"""Get all templates for a project"""
stmt = select(PPTTemplate).where(PPTTemplate.project_id == project_id).order_by(PPTTemplate.created_at)
result = await self.session.execute(stmt)
return result.scalars().all()
async def get_templates_by_type(self, project_id: str, template_type: str) -> List[PPTTemplate]:
"""Get templates by type for a project"""
stmt = select(PPTTemplate).where(
PPTTemplate.project_id == project_id,
PPTTemplate.template_type == template_type
).order_by(PPTTemplate.created_at)
result = await self.session.execute(stmt)
return result.scalars().all()
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
"""Update a template"""
update_data['updated_at'] = time.time()
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(**update_data)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def increment_usage_count(self, template_id: int) -> bool:
"""Increment template usage count"""
stmt = update(PPTTemplate).where(PPTTemplate.id == template_id).values(
usage_count=PPTTemplate.usage_count + 1,
updated_at=time.time()
)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def delete_template(self, template_id: int) -> bool:
"""Delete a template"""
stmt = delete(PPTTemplate).where(PPTTemplate.id == template_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
class GlobalMasterTemplateRepository:
"""Repository for Global Master Template operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_template(self, template_data: Dict[str, Any]) -> GlobalMasterTemplate:
"""Create a new global master template"""
template_data['created_at'] = time.time()
template_data['updated_at'] = time.time()
template = GlobalMasterTemplate(**template_data)
self.session.add(template)
await self.session.commit()
await self.session.refresh(template)
return template
async def get_template_by_id(self, template_id: int) -> Optional[GlobalMasterTemplate]:
"""Get template by ID"""
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_template_by_name(self, template_name: str) -> Optional[GlobalMasterTemplate]:
"""Get template by name"""
stmt = select(GlobalMasterTemplate).where(GlobalMasterTemplate.template_name == template_name)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_all_templates(self, active_only: bool = True) -> List[GlobalMasterTemplate]:
"""Get all global master templates"""
stmt = select(GlobalMasterTemplate)
if active_only:
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
stmt = stmt.order_by(GlobalMasterTemplate.is_default.desc(), GlobalMasterTemplate.usage_count.desc())
result = await self.session.execute(stmt)
return result.scalars().all()
async def get_templates_by_tags(self, tags: List[str], active_only: bool = True) -> List[GlobalMasterTemplate]:
"""Get templates by tags"""
stmt = select(GlobalMasterTemplate)
if active_only:
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
# Filter by tags (any tag matches)
for tag in tags:
stmt = stmt.where(GlobalMasterTemplate.tags.contains([tag]))
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc())
result = await self.session.execute(stmt)
return result.scalars().all()
async def get_templates_paginated(
self,
active_only: bool = True,
offset: int = 0,
limit: int = 6,
search: Optional[str] = None
) -> Tuple[List[GlobalMasterTemplate], int]:
"""Get templates with pagination"""
from sqlalchemy import func, or_
# Base query
stmt = select(GlobalMasterTemplate)
count_stmt = select(func.count(GlobalMasterTemplate.id))
if active_only:
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
# Add search filter
if search and search.strip():
search_filter = or_(
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
GlobalMasterTemplate.description.ilike(f"%{search}%")
)
stmt = stmt.where(search_filter)
count_stmt = count_stmt.where(search_filter)
# Order and paginate
stmt = stmt.order_by(
GlobalMasterTemplate.is_default.desc(),
GlobalMasterTemplate.usage_count.desc()
).offset(offset).limit(limit)
# Execute queries
result = await self.session.execute(stmt)
count_result = await self.session.execute(count_stmt)
templates = result.scalars().all()
total_count = count_result.scalar()
return templates, total_count
async def get_templates_by_tags_paginated(
self,
tags: List[str],
active_only: bool = True,
offset: int = 0,
limit: int = 6,
search: Optional[str] = None
) -> Tuple[List[GlobalMasterTemplate], int]:
"""Get templates by tags with pagination"""
from sqlalchemy import func, or_
# Base query
stmt = select(GlobalMasterTemplate)
count_stmt = select(func.count(GlobalMasterTemplate.id))
if active_only:
stmt = stmt.where(GlobalMasterTemplate.is_active == True)
count_stmt = count_stmt.where(GlobalMasterTemplate.is_active == True)
# Filter by tags (any tag matches)
for tag in tags:
tag_filter = GlobalMasterTemplate.tags.contains([tag])
stmt = stmt.where(tag_filter)
count_stmt = count_stmt.where(tag_filter)
# Add search filter
if search and search.strip():
search_filter = or_(
GlobalMasterTemplate.template_name.ilike(f"%{search}%"),
GlobalMasterTemplate.description.ilike(f"%{search}%")
)
stmt = stmt.where(search_filter)
count_stmt = count_stmt.where(search_filter)
# Order and paginate
stmt = stmt.order_by(GlobalMasterTemplate.usage_count.desc()).offset(offset).limit(limit)
# Execute queries
result = await self.session.execute(stmt)
count_result = await self.session.execute(count_stmt)
templates = result.scalars().all()
total_count = count_result.scalar()
return templates, total_count
async def update_template(self, template_id: int, update_data: Dict[str, Any]) -> bool:
"""Update a global master template"""
update_data['updated_at'] = time.time()
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(**update_data)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def delete_template(self, template_id: int) -> bool:
"""Delete a global master template"""
try:
stmt = delete(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id)
result = await self.session.execute(stmt)
await self.session.commit()
rows_affected = result.rowcount
logger.info(f"Delete operation for template {template_id}: {rows_affected} rows affected")
return rows_affected > 0
except Exception as e:
logger.error(f"Error deleting template {template_id}: {e}")
await self.session.rollback()
raise
async def increment_usage_count(self, template_id: int) -> bool:
"""Increment template usage count"""
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
usage_count=GlobalMasterTemplate.usage_count + 1,
updated_at=time.time()
)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def set_default_template(self, template_id: int) -> bool:
"""Set a template as default (and unset others)"""
# First, unset all default templates
stmt = update(GlobalMasterTemplate).values(is_default=False, updated_at=time.time())
await self.session.execute(stmt)
# Then set the specified template as default
stmt = update(GlobalMasterTemplate).where(GlobalMasterTemplate.id == template_id).values(
is_default=True,
updated_at=time.time()
)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def get_default_template(self) -> Optional[GlobalMasterTemplate]:
"""Get the default template"""
stmt = select(GlobalMasterTemplate).where(
GlobalMasterTemplate.is_default == True,
GlobalMasterTemplate.is_active == True
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def delete_templates_by_project_id(self, project_id: str) -> bool:
"""Delete all templates for a project"""
stmt = delete(PPTTemplate).where(PPTTemplate.project_id == project_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0