""" Pollinations AI图片生成提供者 """ import asyncio import logging import time import urllib.parse import tempfile from typing import Dict, Any, List, Optional from pathlib import Path import aiohttp from ..models import ( ImageProvider, ImageGenerationRequest, ImageOperationResult, ImageInfo, ImageFormat, ImageLicense, ImageSourceType, ImageMetadata, ImageTag ) from .base import ImageGenerationProvider logger = logging.getLogger(__name__) class PollinationsProvider(ImageGenerationProvider): """Pollinations AI图片生成提供者""" def __init__(self, config: Dict[str, Any]): super().__init__(ImageProvider.POLLINATIONS, config) # API配置 self.api_base = config.get('api_base', 'https://image.pollinations.ai') self.api_token = config.get('api_token', '') self.referrer = config.get('referrer', '') self.model = config.get('model', 'flux') self.default_width = config.get('default_width', 1024) self.default_height = config.get('default_height', 1024) self.default_enhance = config.get('default_enhance', False) self.default_safe = config.get('default_safe', False) self.default_nologo = config.get('default_nologo', False) self.default_private = config.get('default_private', False) self.default_transparent = config.get('default_transparent', False) # 速率限制 self.rate_limit_requests = config.get('rate_limit_requests', 60) self.rate_limit_window = config.get('rate_limit_window', 60) # 请求历史(用于速率限制) self._request_history = [] logger.debug(f"Pollinations provider initialized with model: {self.model}") async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult: """生成图片""" try: # 检查速率限制 if not self._check_rate_limit(): return ImageOperationResult( success=False, message="Rate limit exceeded. Please try again later." ) # 记录请求时间 self._request_history.append(time.time()) # 准备API请求 api_url = self._build_api_url(request) logger.debug(f"Generating image with Pollinations API: {api_url}") # 准备请求头 headers = {} if self.api_token: headers['Authorization'] = f'Bearer {self.api_token}' # 发送请求 async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: async with session.get(api_url, headers=headers) as response: if response.status == 200: # 读取图片数据 image_data = await response.read() # 创建图片信息 image_id = f"pollinations_{int(time.time())}" filename = f"{image_id}.png" # 创建元数据 metadata = ImageMetadata( width=request.width or self.default_width, height=request.height or self.default_height, format=ImageFormat.PNG, file_size=len(image_data) ) # 创建标签 tags = [ ImageTag(name="ai-generated", category="type", source="system"), ImageTag(name="pollinations", category="provider", source="system"), ImageTag(name=self.model, category="model", source="system") ] # 保存图片到临时文件 temp_dir = Path(tempfile.gettempdir()) / "pollinations_images" temp_dir.mkdir(exist_ok=True) temp_file_path = temp_dir / filename with open(temp_file_path, 'wb') as f: f.write(image_data) image_info = ImageInfo( image_id=image_id, source_type=ImageSourceType.AI_GENERATED, provider=self.provider, original_url=api_url, local_path=str(temp_file_path), filename=filename, title=f"Generated: {request.prompt[:50]}...", description=f"Generated by Pollinations AI using prompt: {request.prompt}", metadata=metadata, tags=tags, license=ImageLicense.CUSTOM ) return ImageOperationResult( success=True, message="Image generated successfully", image_info=image_info ) else: error_text = await response.text() logger.error(f"Pollinations API error {response.status}: {error_text}") return ImageOperationResult( success=False, message=f"API request failed with status {response.status}: {error_text}" ) except asyncio.TimeoutError: logger.error("Pollinations API request timeout") return ImageOperationResult( success=False, message="Request timeout. Please try again." ) except Exception as e: logger.error(f"Pollinations generation error: {str(e)}") return ImageOperationResult( success=False, message=f"Generation failed: {str(e)}" ) def _build_api_url(self, request: ImageGenerationRequest) -> str: """构建API请求URL""" # URL编码提示词 encoded_prompt = urllib.parse.quote(request.prompt, safe='') # 构建基础URL url = f"{self.api_base}/prompt/{encoded_prompt}" # 添加参数 params = [] # 模型参数 if self.model != 'flux': # flux是默认模型 params.append(f"model={self.model}") # 尺寸参数 width = request.width or self.default_width height = request.height or self.default_height if width != 1024 or height != 1024: # 1024x1024是默认尺寸 params.append(f"width={width}") params.append(f"height={height}") # 种子参数(如果提供) if hasattr(request, 'seed') and request.seed: params.append(f"seed={request.seed}") # 增强参数 enhance = request.style == "enhanced" if request.style else self.default_enhance if enhance: params.append("enhance=true") # 安全过滤 if self.default_safe: params.append("safe=true") # 无logo(需要token或referrer认证) if self.default_nologo and (self.api_token or self.referrer): params.append("nologo=true") # 私有模式 if self.default_private: params.append("private=true") # 透明背景(仅gptimage模型支持) if self.default_transparent and self.model == 'gptimage': params.append("transparent=true") # 推荐人标识符(用于认证) if self.referrer: params.append(f"referrer={urllib.parse.quote(self.referrer)}") # 添加参数到URL if params: url += "?" + "&".join(params) return url def _check_rate_limit(self) -> bool: """检查速率限制""" current_time = time.time() # 清理过期的请求记录 self._request_history = [ req_time for req_time in self._request_history if current_time - req_time < self.rate_limit_window ] # 检查是否超过限制 return len(self._request_history) < self.rate_limit_requests async def get_available_models(self) -> List[Dict[str, Any]]: """获取可用模型列表""" return [ { "id": "flux", "name": "Flux", "description": "High-quality image generation model (default)", "default": True }, { "id": "turbo", "name": "Turbo", "description": "Fast image generation with good quality" }, { "id": "gptimage", "name": "GPT Image", "description": "GPT-based image generation with transparency support" } ] async def get_available_styles(self) -> List[Dict[str, Any]]: """获取可用样式列表""" return [ { "id": "natural", "name": "Natural", "description": "Natural style generation" }, { "id": "enhanced", "name": "Enhanced", "description": "Enhanced prompt with more detail" } ] async def health_check(self) -> Dict[str, Any]: """健康检查""" try: # 准备请求头 headers = {} if self.api_token: headers['Authorization'] = f'Bearer {self.api_token}' # 简单的健康检查 - 尝试访问API基础URL async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session: # 使用一个简单的测试提示词 test_url = f"{self.api_base}/prompt/test?width=64&height=64" async with session.head(test_url, headers=headers) as response: if response.status in [200, 404]: # 404也表示API可访问 return { "status": "healthy", "message": "Pollinations API is accessible", "model": self.model, "authenticated": bool(self.api_token or self.referrer), "nologo_enabled": self.default_nologo and bool(self.api_token or self.referrer), "referrer": self.referrer if self.referrer else None, "rate_limit": f"{self.rate_limit_requests}/{self.rate_limit_window}s" } else: return { "status": "unhealthy", "message": f"API returned status {response.status}", "model": self.model } except Exception as e: return { "status": "unhealthy", "message": f"Health check failed: {str(e)}", "model": self.model }