From aabe11fdbd079063e4f0cc468906db1d1a8a3c29 Mon Sep 17 00:00:00 2001 From: 13315423919 <13315423919@qq.com> Date: Fri, 7 Nov 2025 09:05:16 +0800 Subject: [PATCH] Add File --- .../image/providers/dalle_provider.py | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 src/landppt/services/image/providers/dalle_provider.py diff --git a/src/landppt/services/image/providers/dalle_provider.py b/src/landppt/services/image/providers/dalle_provider.py new file mode 100644 index 0000000..aa201a0 --- /dev/null +++ b/src/landppt/services/image/providers/dalle_provider.py @@ -0,0 +1,348 @@ +""" +DALL-E 图片生成提供者 +""" + +import asyncio +import logging +import time +from typing import Dict, Any, Optional, List +from pathlib import Path +import aiohttp +import json + +from .base import ImageGenerationProvider +from ..models import ( + ImageInfo, ImageGenerationRequest, ImageOperationResult, + ImageProvider, ImageSourceType, ImageFormat, ImageMetadata, ImageTag +) + +logger = logging.getLogger(__name__) + + +class DalleProvider(ImageGenerationProvider): + """DALL-E 图片生成提供者""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(ImageProvider.DALLE, config) + + # API配置 + self.api_key = config.get('api_key') + self.api_base = config.get('api_base', 'https://api.openai.com/v1') + self.model = config.get('model', 'dall-e-3') + self.default_size = config.get('default_size', '1024x1024') + self.default_quality = config.get('default_quality', 'standard') + self.default_style = config.get('default_style', 'vivid') + + # 速率限制 + self.rate_limit_requests = config.get('rate_limit_requests', 50) + self.rate_limit_window = config.get('rate_limit_window', 60) + + # 请求历史(用于速率限制) + self._request_history = [] + + if not self.api_key: + logger.warning("DALL-E API key not configured") + + async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult: + """生成图片""" + if not self.api_key: + return ImageOperationResult( + success=False, + message="DALL-E API key not configured", + error_code="api_key_missing" + ) + + try: + # 检查速率限制 + if not await self._check_rate_limit(): + return ImageOperationResult( + success=False, + message="Rate limit exceeded", + error_code="rate_limit_exceeded" + ) + + # 准备API请求 + api_request = self._prepare_api_request(request) + + # 调用DALL-E API + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.api_base}/images/generations", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + }, + json=api_request, + timeout=aiohttp.ClientTimeout(total=120) # 2分钟超时 + ) as response: + + if response.status != 200: + error_text = await response.text() + logger.error(f"DALL-E API error {response.status}: {error_text}") + return ImageOperationResult( + success=False, + message=f"DALL-E API error: {response.status}", + error_code="api_error" + ) + + result_data = await response.json() + + # 处理API响应 + return await self._process_api_response(result_data, request) + + except asyncio.TimeoutError: + logger.error("DALL-E API request timeout") + return ImageOperationResult( + success=False, + message="Request timeout", + error_code="timeout" + ) + except Exception as e: + logger.error(f"DALL-E generation failed: {e}") + return ImageOperationResult( + success=False, + message=f"Generation failed: {str(e)}", + error_code="generation_error" + ) + + def _prepare_api_request(self, request: ImageGenerationRequest) -> Dict[str, Any]: + """准备API请求""" + # 将width和height转换为DALL-E期望的size格式 + size = f"{request.width}x{request.height}" + + api_request = { + "model": self.model, + "prompt": request.prompt, + "n": 1, # DALL-E 3只支持生成1张图片 + "size": size, + "quality": request.quality or self.default_quality, + "response_format": "url" + } + + # DALL-E 3支持style参数 + if self.model == "dall-e-3": + api_request["style"] = request.style or self.default_style + + return api_request + + async def _process_api_response(self, + response_data: Dict[str, Any], + request: ImageGenerationRequest) -> ImageOperationResult: + """处理API响应""" + try: + if 'data' not in response_data or not response_data['data']: + return ImageOperationResult( + success=False, + message="No image data in response", + error_code="no_data" + ) + + image_data = response_data['data'][0] + image_url = image_data.get('url') + revised_prompt = image_data.get('revised_prompt', request.prompt) + + if not image_url: + return ImageOperationResult( + success=False, + message="No image URL in response", + error_code="no_url" + ) + + # 下载图片 + image_path, image_size = await self._download_image(image_url, request) + + # 创建图片信息 + image_info = self._create_image_info( + image_path, image_size, request, revised_prompt + ) + + return ImageOperationResult( + success=True, + message="Image generated successfully", + image_info=image_info + ) + + except Exception as e: + logger.error(f"Failed to process DALL-E response: {e}") + return ImageOperationResult( + success=False, + message=f"Failed to process response: {str(e)}", + error_code="response_processing_error" + ) + + async def _download_image(self, + image_url: str, + request: ImageGenerationRequest) -> tuple[Path, int]: + """下载生成的图片""" + # 生成文件名 + timestamp = int(time.time()) + filename = f"dalle_{timestamp}_{hash(request.prompt) % 10000}.png" + + # 创建保存路径 + save_dir = Path("temp/images_cache/ai_generated/dalle") + save_dir.mkdir(parents=True, exist_ok=True) + image_path = save_dir / filename + + # 下载图片 + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + if response.status != 200: + raise Exception(f"Failed to download image: {response.status}") + + image_data = await response.read() + + # 保存图片 + with open(image_path, 'wb') as f: + f.write(image_data) + + return image_path, len(image_data) + + def _create_image_info(self, + image_path: Path, + image_size: int, + request: ImageGenerationRequest, + revised_prompt: str) -> ImageInfo: + """创建图片信息""" + # 生成图片ID + image_id = f"dalle_{int(time.time())}_{hash(request.prompt) % 10000}" + + # 使用请求中的尺寸 + width, height = request.width, request.height + + # 创建元数据 + metadata = ImageMetadata( + width=width, + height=height, + format=ImageFormat.PNG, + file_size=image_size, + color_mode="RGB", + has_transparency=True + ) + + # 创建标签(基于提示词) + tags = self._generate_tags_from_prompt(request.prompt) + + return ImageInfo( + image_id=image_id, + filename=image_path.name, + title=f"AI Generated: {request.prompt[:50]}...", + description=f"Generated by DALL-E with prompt: {revised_prompt}", + alt_text=request.prompt, + source_type=ImageSourceType.AI_GENERATED, + provider=ImageProvider.DALLE, + original_url="", + local_path=str(image_path), + metadata=metadata, + tags=tags, + keywords=self._extract_keywords_from_prompt(request.prompt), + usage_count=0, + created_at=time.time(), + updated_at=time.time() + ) + + def _generate_tags_from_prompt(self, prompt: str) -> List[ImageTag]: + """从提示词生成标签""" + # 简单的关键词提取和标签生成 + keywords = prompt.lower().split() + + # 过滤常见词汇 + stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'} + keywords = [word for word in keywords if word not in stop_words and len(word) > 2] + + # 生成标签 + tags = [] + for i, keyword in enumerate(keywords[:10]): # 最多10个标签 + confidence = max(0.5, 1.0 - i * 0.1) # 递减的置信度 + tags.append(ImageTag( + name=keyword, + confidence=confidence, + category="ai_generated" + )) + + return tags + + def _extract_keywords_from_prompt(self, prompt: str) -> List[str]: + """从提示词提取关键词""" + # 简单的关键词提取 + words = prompt.lower().split() + stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'} + keywords = [word for word in words if word not in stop_words and len(word) > 2] + + return keywords[:20] # 最多20个关键词 + + async def _check_rate_limit(self) -> bool: + """检查速率限制""" + current_time = time.time() + + # 清理过期的请求记录 + self._request_history = [ + req_time for req_time in self._request_history + if current_time - req_time < self.rate_limit_window + ] + + # 检查是否超过限制 + if len(self._request_history) >= self.rate_limit_requests: + return False + + # 记录当前请求 + self._request_history.append(current_time) + return True + + async def health_check(self) -> Dict[str, Any]: + """健康检查""" + if not self.api_key: + return { + 'status': 'unhealthy', + 'message': 'API key not configured', + 'provider': self.provider.value + } + + try: + # 简单的API连通性检查 + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.api_base}/models", + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=aiohttp.ClientTimeout(total=10) + ) as response: + + if response.status == 200: + return { + 'status': 'healthy', + 'message': 'API accessible', + 'provider': self.provider.value, + 'model': self.model, + 'rate_limit_remaining': self.rate_limit_requests - len(self._request_history) + } + else: + return { + 'status': 'unhealthy', + 'message': f'API error: {response.status}', + 'provider': self.provider.value + } + + except Exception as e: + return { + 'status': 'unhealthy', + 'message': f'Health check failed: {str(e)}', + 'provider': self.provider.value + } + + async def get_available_models(self) -> List[Dict[str, Any]]: + """获取可用模型列表""" + return [ + { + 'id': 'dall-e-3', + 'name': 'DALL-E 3', + 'description': 'Latest DALL-E model with improved quality and understanding', + 'max_resolution': '1792x1024', + 'supported_styles': ['natural', 'vivid'] + }, + { + 'id': 'dall-e-2', + 'name': 'DALL-E 2', + 'description': 'Previous generation DALL-E model', + 'max_resolution': '1024x1024', + 'supported_styles': ['natural'] + } + ]