This commit is contained in:
2025-11-07 09:05:16 +08:00
parent 9e49737f7c
commit aabe11fdbd

View File

@@ -0,0 +1,348 @@
"""
DALL-E 图片生成提供者
"""
import asyncio
import logging
import time
from typing import Dict, Any, Optional, List
from pathlib import Path
import aiohttp
import json
from .base import ImageGenerationProvider
from ..models import (
ImageInfo, ImageGenerationRequest, ImageOperationResult,
ImageProvider, ImageSourceType, ImageFormat, ImageMetadata, ImageTag
)
logger = logging.getLogger(__name__)
class DalleProvider(ImageGenerationProvider):
"""DALL-E 图片生成提供者"""
def __init__(self, config: Dict[str, Any]):
super().__init__(ImageProvider.DALLE, config)
# API配置
self.api_key = config.get('api_key')
self.api_base = config.get('api_base', 'https://api.openai.com/v1')
self.model = config.get('model', 'dall-e-3')
self.default_size = config.get('default_size', '1024x1024')
self.default_quality = config.get('default_quality', 'standard')
self.default_style = config.get('default_style', 'vivid')
# 速率限制
self.rate_limit_requests = config.get('rate_limit_requests', 50)
self.rate_limit_window = config.get('rate_limit_window', 60)
# 请求历史(用于速率限制)
self._request_history = []
if not self.api_key:
logger.warning("DALL-E API key not configured")
async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
"""生成图片"""
if not self.api_key:
return ImageOperationResult(
success=False,
message="DALL-E API key not configured",
error_code="api_key_missing"
)
try:
# 检查速率限制
if not await self._check_rate_limit():
return ImageOperationResult(
success=False,
message="Rate limit exceeded",
error_code="rate_limit_exceeded"
)
# 准备API请求
api_request = self._prepare_api_request(request)
# 调用DALL-E API
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.api_base}/images/generations",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json=api_request,
timeout=aiohttp.ClientTimeout(total=120) # 2分钟超时
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"DALL-E API error {response.status}: {error_text}")
return ImageOperationResult(
success=False,
message=f"DALL-E API error: {response.status}",
error_code="api_error"
)
result_data = await response.json()
# 处理API响应
return await self._process_api_response(result_data, request)
except asyncio.TimeoutError:
logger.error("DALL-E API request timeout")
return ImageOperationResult(
success=False,
message="Request timeout",
error_code="timeout"
)
except Exception as e:
logger.error(f"DALL-E generation failed: {e}")
return ImageOperationResult(
success=False,
message=f"Generation failed: {str(e)}",
error_code="generation_error"
)
def _prepare_api_request(self, request: ImageGenerationRequest) -> Dict[str, Any]:
"""准备API请求"""
# 将width和height转换为DALL-E期望的size格式
size = f"{request.width}x{request.height}"
api_request = {
"model": self.model,
"prompt": request.prompt,
"n": 1, # DALL-E 3只支持生成1张图片
"size": size,
"quality": request.quality or self.default_quality,
"response_format": "url"
}
# DALL-E 3支持style参数
if self.model == "dall-e-3":
api_request["style"] = request.style or self.default_style
return api_request
async def _process_api_response(self,
response_data: Dict[str, Any],
request: ImageGenerationRequest) -> ImageOperationResult:
"""处理API响应"""
try:
if 'data' not in response_data or not response_data['data']:
return ImageOperationResult(
success=False,
message="No image data in response",
error_code="no_data"
)
image_data = response_data['data'][0]
image_url = image_data.get('url')
revised_prompt = image_data.get('revised_prompt', request.prompt)
if not image_url:
return ImageOperationResult(
success=False,
message="No image URL in response",
error_code="no_url"
)
# 下载图片
image_path, image_size = await self._download_image(image_url, request)
# 创建图片信息
image_info = self._create_image_info(
image_path, image_size, request, revised_prompt
)
return ImageOperationResult(
success=True,
message="Image generated successfully",
image_info=image_info
)
except Exception as e:
logger.error(f"Failed to process DALL-E response: {e}")
return ImageOperationResult(
success=False,
message=f"Failed to process response: {str(e)}",
error_code="response_processing_error"
)
async def _download_image(self,
image_url: str,
request: ImageGenerationRequest) -> tuple[Path, int]:
"""下载生成的图片"""
# 生成文件名
timestamp = int(time.time())
filename = f"dalle_{timestamp}_{hash(request.prompt) % 10000}.png"
# 创建保存路径
save_dir = Path("temp/images_cache/ai_generated/dalle")
save_dir.mkdir(parents=True, exist_ok=True)
image_path = save_dir / filename
# 下载图片
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"Failed to download image: {response.status}")
image_data = await response.read()
# 保存图片
with open(image_path, 'wb') as f:
f.write(image_data)
return image_path, len(image_data)
def _create_image_info(self,
image_path: Path,
image_size: int,
request: ImageGenerationRequest,
revised_prompt: str) -> ImageInfo:
"""创建图片信息"""
# 生成图片ID
image_id = f"dalle_{int(time.time())}_{hash(request.prompt) % 10000}"
# 使用请求中的尺寸
width, height = request.width, request.height
# 创建元数据
metadata = ImageMetadata(
width=width,
height=height,
format=ImageFormat.PNG,
file_size=image_size,
color_mode="RGB",
has_transparency=True
)
# 创建标签(基于提示词)
tags = self._generate_tags_from_prompt(request.prompt)
return ImageInfo(
image_id=image_id,
filename=image_path.name,
title=f"AI Generated: {request.prompt[:50]}...",
description=f"Generated by DALL-E with prompt: {revised_prompt}",
alt_text=request.prompt,
source_type=ImageSourceType.AI_GENERATED,
provider=ImageProvider.DALLE,
original_url="",
local_path=str(image_path),
metadata=metadata,
tags=tags,
keywords=self._extract_keywords_from_prompt(request.prompt),
usage_count=0,
created_at=time.time(),
updated_at=time.time()
)
def _generate_tags_from_prompt(self, prompt: str) -> List[ImageTag]:
"""从提示词生成标签"""
# 简单的关键词提取和标签生成
keywords = prompt.lower().split()
# 过滤常见词汇
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
keywords = [word for word in keywords if word not in stop_words and len(word) > 2]
# 生成标签
tags = []
for i, keyword in enumerate(keywords[:10]): # 最多10个标签
confidence = max(0.5, 1.0 - i * 0.1) # 递减的置信度
tags.append(ImageTag(
name=keyword,
confidence=confidence,
category="ai_generated"
))
return tags
def _extract_keywords_from_prompt(self, prompt: str) -> List[str]:
"""从提示词提取关键词"""
# 简单的关键词提取
words = prompt.lower().split()
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
keywords = [word for word in words if word not in stop_words and len(word) > 2]
return keywords[:20] # 最多20个关键词
async def _check_rate_limit(self) -> bool:
"""检查速率限制"""
current_time = time.time()
# 清理过期的请求记录
self._request_history = [
req_time for req_time in self._request_history
if current_time - req_time < self.rate_limit_window
]
# 检查是否超过限制
if len(self._request_history) >= self.rate_limit_requests:
return False
# 记录当前请求
self._request_history.append(current_time)
return True
async def health_check(self) -> Dict[str, Any]:
"""健康检查"""
if not self.api_key:
return {
'status': 'unhealthy',
'message': 'API key not configured',
'provider': self.provider.value
}
try:
# 简单的API连通性检查
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.api_base}/models",
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=aiohttp.ClientTimeout(total=10)
) as response:
if response.status == 200:
return {
'status': 'healthy',
'message': 'API accessible',
'provider': self.provider.value,
'model': self.model,
'rate_limit_remaining': self.rate_limit_requests - len(self._request_history)
}
else:
return {
'status': 'unhealthy',
'message': f'API error: {response.status}',
'provider': self.provider.value
}
except Exception as e:
return {
'status': 'unhealthy',
'message': f'Health check failed: {str(e)}',
'provider': self.provider.value
}
async def get_available_models(self) -> List[Dict[str, Any]]:
"""获取可用模型列表"""
return [
{
'id': 'dall-e-3',
'name': 'DALL-E 3',
'description': 'Latest DALL-E model with improved quality and understanding',
'max_resolution': '1792x1024',
'supported_styles': ['natural', 'vivid']
},
{
'id': 'dall-e-2',
'name': 'DALL-E 2',
'description': 'Previous generation DALL-E model',
'max_resolution': '1024x1024',
'supported_styles': ['natural']
}
]