diff --git a/src/landppt/services/image/providers/base.py b/src/landppt/services/image/providers/base.py new file mode 100644 index 0000000..94a6f23 --- /dev/null +++ b/src/landppt/services/image/providers/base.py @@ -0,0 +1,302 @@ +""" +图片提供者抽象基类 +""" + +from abc import ABC, abstractmethod +from typing import List, Optional, Dict, Any, AsyncGenerator +import logging +import asyncio +from pathlib import Path + +from ..models import ( + ImageInfo, ImageSearchRequest, ImageGenerationRequest, + ImageUploadRequest, ImageSearchResult, ImageOperationResult, + ImageProvider, ImageSourceType +) + +logger = logging.getLogger(__name__) + + +class BaseImageProvider(ABC): + """图片提供者抽象基类""" + + def __init__(self, provider: ImageProvider, config: Dict[str, Any]): + self.provider = provider + self.config = config + self.enabled = config.get('enabled', True) + self.rate_limit = config.get('rate_limit', 60) # 每分钟请求限制 + self.timeout = config.get('timeout', 30) # 请求超时时间 + + # 请求计数器(简单的速率限制) + self._request_count = 0 + self._last_reset = asyncio.get_event_loop().time() + + @property + @abstractmethod + def source_type(self) -> ImageSourceType: + """返回图片来源类型""" + pass + + @property + @abstractmethod + def supported_operations(self) -> List[str]: + """返回支持的操作列表""" + pass + + async def _check_rate_limit(self) -> bool: + """检查速率限制""" + current_time = asyncio.get_event_loop().time() + + # 每分钟重置计数器 + if current_time - self._last_reset > 60: + self._request_count = 0 + self._last_reset = current_time + + if self._request_count >= self.rate_limit: + logger.warning(f"Rate limit exceeded for {self.provider}") + return False + + self._request_count += 1 + return True + + async def _validate_config(self) -> bool: + """验证配置""" + return True + + async def health_check(self) -> Dict[str, Any]: + """健康检查""" + try: + config_valid = await self._validate_config() + rate_limit_ok = await self._check_rate_limit() + + return { + 'provider': self.provider, + 'enabled': self.enabled, + 'config_valid': config_valid, + 'rate_limit_ok': rate_limit_ok, + 'status': 'healthy' if (self.enabled and config_valid) else 'unhealthy' + } + except Exception as e: + logger.error(f"Health check failed for {self.provider}: {e}") + return { + 'provider': self.provider, + 'enabled': False, + 'status': 'error', + 'error': str(e) + } + + +class ImageSearchProvider(BaseImageProvider): + """图片搜索提供者基类""" + + @property + def source_type(self) -> ImageSourceType: + return ImageSourceType.WEB_SEARCH + + @property + def supported_operations(self) -> List[str]: + return ['search', 'get_details', 'download'] + + @abstractmethod + async def search(self, request: ImageSearchRequest) -> ImageSearchResult: + """搜索图片""" + pass + + @abstractmethod + async def get_image_details(self, image_id: str) -> Optional[ImageInfo]: + """获取图片详细信息""" + pass + + @abstractmethod + async def download_image(self, image_info: ImageInfo, save_path: Path) -> ImageOperationResult: + """下载图片到本地""" + pass + + async def batch_search(self, requests: List[ImageSearchRequest]) -> List[ImageSearchResult]: + """批量搜索图片""" + results = [] + for request in requests: + try: + if not await self._check_rate_limit(): + break + result = await self.search(request) + results.append(result) + # 添加延迟避免过快请求 + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"Batch search failed for request {request.query}: {e}") + # 创建空结果 + results.append(ImageSearchResult( + images=[], total_count=0, page=request.page, + per_page=request.per_page, has_next=False, has_prev=False, + search_time=0.0 + )) + return results + + +class ImageGenerationProvider(BaseImageProvider): + """AI图片生成提供者基类""" + + @property + def source_type(self) -> ImageSourceType: + return ImageSourceType.AI_GENERATED + + @property + def supported_operations(self) -> List[str]: + return ['generate', 'get_models', 'get_styles'] + + @abstractmethod + async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult: + """生成图片""" + pass + + @abstractmethod + async def get_available_models(self) -> List[Dict[str, Any]]: + """获取可用模型列表""" + pass + + async def get_available_styles(self) -> List[Dict[str, Any]]: + """获取可用样式列表""" + return [] + + async def batch_generate(self, requests: List[ImageGenerationRequest]) -> List[ImageOperationResult]: + """批量生成图片""" + results = [] + for request in requests: + try: + if not await self._check_rate_limit(): + break + result = await self.generate(request) + results.append(result) + # AI生成通常需要更长时间,添加更长延迟 + await asyncio.sleep(1.0) + except Exception as e: + logger.error(f"Batch generation failed for prompt {request.prompt}: {e}") + results.append(ImageOperationResult( + success=False, + message=f"Generation failed: {str(e)}", + error_code="generation_error" + )) + return results + + +class LocalStorageProvider(BaseImageProvider): + """本地存储提供者基类""" + + @property + def source_type(self) -> ImageSourceType: + return ImageSourceType.LOCAL_STORAGE + + @property + def supported_operations(self) -> List[str]: + return ['upload', 'list', 'get', 'delete', 'update'] + + @abstractmethod + async def upload(self, request: ImageUploadRequest, file_data: bytes) -> ImageOperationResult: + """上传图片""" + pass + + @abstractmethod + async def list_images(self, + page: int = 1, + per_page: int = 20, + category: Optional[str] = None, + tags: Optional[List[str]] = None) -> ImageSearchResult: + """列出图片""" + pass + + @abstractmethod + async def get_image(self, image_id: str) -> Optional[ImageInfo]: + """获取图片信息""" + pass + + @abstractmethod + async def delete_image(self, image_id: str) -> ImageOperationResult: + """删除图片""" + pass + + @abstractmethod + async def update_image(self, image_id: str, updates: Dict[str, Any]) -> ImageOperationResult: + """更新图片信息""" + pass + + async def search_local(self, query: str, **kwargs) -> ImageSearchResult: + """在本地图片中搜索""" + # 默认实现:通过标签和关键词搜索 + tags = kwargs.get('tags', []) + if query: + tags.append(query) + + return await self.list_images( + page=kwargs.get('page', 1), + per_page=kwargs.get('per_page', 20), + tags=tags + ) + + +class ProviderRegistry: + """提供者注册表""" + + def __init__(self): + self._providers: Dict[ImageProvider, BaseImageProvider] = {} + self._search_providers: List[ImageSearchProvider] = [] + self._generation_providers: List[ImageGenerationProvider] = [] + self._storage_providers: List[LocalStorageProvider] = [] + + def register(self, provider: BaseImageProvider): + """注册提供者""" + # 检查是否已经注册了相同的提供者 + if provider.provider in self._providers: + logger.debug(f"Provider {provider.provider} already registered, skipping") + return + + self._providers[provider.provider] = provider + + if isinstance(provider, ImageSearchProvider): + # 检查是否已经存在相同类型的搜索提供者 + existing = [p for p in self._search_providers if p.provider == provider.provider] + if not existing: + self._search_providers.append(provider) + elif isinstance(provider, ImageGenerationProvider): + # 检查是否已经存在相同类型的生成提供者 + existing = [p for p in self._generation_providers if p.provider == provider.provider] + if not existing: + self._generation_providers.append(provider) + elif isinstance(provider, LocalStorageProvider): + # 检查是否已经存在相同类型的存储提供者 + existing = [p for p in self._storage_providers if p.provider == provider.provider] + if not existing: + self._storage_providers.append(provider) + + def get_provider(self, provider: ImageProvider) -> Optional[BaseImageProvider]: + """获取指定提供者""" + return self._providers.get(provider) + + def get_search_providers(self, enabled_only: bool = True) -> List[ImageSearchProvider]: + """获取搜索提供者""" + if enabled_only: + return [p for p in self._search_providers if p.enabled] + return self._search_providers + + def get_generation_providers(self, enabled_only: bool = True) -> List[ImageGenerationProvider]: + """获取生成提供者""" + if enabled_only: + return [p for p in self._generation_providers if p.enabled] + return self._generation_providers + + def get_storage_providers(self, enabled_only: bool = True) -> List[LocalStorageProvider]: + """获取存储提供者""" + if enabled_only: + return [p for p in self._storage_providers if p.enabled] + return self._storage_providers + + async def health_check_all(self) -> Dict[str, Any]: + """检查所有提供者健康状态""" + results = {} + for provider_name, provider in self._providers.items(): + results[provider_name] = await provider.health_check() + return results + + +# 全局提供者注册表 +provider_registry = ProviderRegistry()