This commit is contained in:
2025-11-07 09:05:16 +08:00
parent 49894e4913
commit 9e49737f7c

View File

@@ -0,0 +1,288 @@
"""
Pollinations AI图片生成提供者
"""
import asyncio
import logging
import time
import urllib.parse
import tempfile
from typing import Dict, Any, List, Optional
from pathlib import Path
import aiohttp
from ..models import (
ImageProvider, ImageGenerationRequest, ImageOperationResult,
ImageInfo, ImageFormat, ImageLicense, ImageSourceType, ImageMetadata, ImageTag
)
from .base import ImageGenerationProvider
logger = logging.getLogger(__name__)
class PollinationsProvider(ImageGenerationProvider):
"""Pollinations AI图片生成提供者"""
def __init__(self, config: Dict[str, Any]):
super().__init__(ImageProvider.POLLINATIONS, config)
# API配置
self.api_base = config.get('api_base', 'https://image.pollinations.ai')
self.api_token = config.get('api_token', '')
self.referrer = config.get('referrer', '')
self.model = config.get('model', 'flux')
self.default_width = config.get('default_width', 1024)
self.default_height = config.get('default_height', 1024)
self.default_enhance = config.get('default_enhance', False)
self.default_safe = config.get('default_safe', False)
self.default_nologo = config.get('default_nologo', False)
self.default_private = config.get('default_private', False)
self.default_transparent = config.get('default_transparent', False)
# 速率限制
self.rate_limit_requests = config.get('rate_limit_requests', 60)
self.rate_limit_window = config.get('rate_limit_window', 60)
# 请求历史(用于速率限制)
self._request_history = []
logger.debug(f"Pollinations provider initialized with model: {self.model}")
async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
"""生成图片"""
try:
# 检查速率限制
if not self._check_rate_limit():
return ImageOperationResult(
success=False,
message="Rate limit exceeded. Please try again later."
)
# 记录请求时间
self._request_history.append(time.time())
# 准备API请求
api_url = self._build_api_url(request)
logger.debug(f"Generating image with Pollinations API: {api_url}")
# 准备请求头
headers = {}
if self.api_token:
headers['Authorization'] = f'Bearer {self.api_token}'
# 发送请求
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.get(api_url, headers=headers) as response:
if response.status == 200:
# 读取图片数据
image_data = await response.read()
# 创建图片信息
image_id = f"pollinations_{int(time.time())}"
filename = f"{image_id}.png"
# 创建元数据
metadata = ImageMetadata(
width=request.width or self.default_width,
height=request.height or self.default_height,
format=ImageFormat.PNG,
file_size=len(image_data)
)
# 创建标签
tags = [
ImageTag(name="ai-generated", category="type", source="system"),
ImageTag(name="pollinations", category="provider", source="system"),
ImageTag(name=self.model, category="model", source="system")
]
# 保存图片到临时文件
temp_dir = Path(tempfile.gettempdir()) / "pollinations_images"
temp_dir.mkdir(exist_ok=True)
temp_file_path = temp_dir / filename
with open(temp_file_path, 'wb') as f:
f.write(image_data)
image_info = ImageInfo(
image_id=image_id,
source_type=ImageSourceType.AI_GENERATED,
provider=self.provider,
original_url=api_url,
local_path=str(temp_file_path),
filename=filename,
title=f"Generated: {request.prompt[:50]}...",
description=f"Generated by Pollinations AI using prompt: {request.prompt}",
metadata=metadata,
tags=tags,
license=ImageLicense.CUSTOM
)
return ImageOperationResult(
success=True,
message="Image generated successfully",
image_info=image_info
)
else:
error_text = await response.text()
logger.error(f"Pollinations API error {response.status}: {error_text}")
return ImageOperationResult(
success=False,
message=f"API request failed with status {response.status}: {error_text}"
)
except asyncio.TimeoutError:
logger.error("Pollinations API request timeout")
return ImageOperationResult(
success=False,
message="Request timeout. Please try again."
)
except Exception as e:
logger.error(f"Pollinations generation error: {str(e)}")
return ImageOperationResult(
success=False,
message=f"Generation failed: {str(e)}"
)
def _build_api_url(self, request: ImageGenerationRequest) -> str:
"""构建API请求URL"""
# URL编码提示词
encoded_prompt = urllib.parse.quote(request.prompt, safe='')
# 构建基础URL
url = f"{self.api_base}/prompt/{encoded_prompt}"
# 添加参数
params = []
# 模型参数
if self.model != 'flux': # flux是默认模型
params.append(f"model={self.model}")
# 尺寸参数
width = request.width or self.default_width
height = request.height or self.default_height
if width != 1024 or height != 1024: # 1024x1024是默认尺寸
params.append(f"width={width}")
params.append(f"height={height}")
# 种子参数(如果提供)
if hasattr(request, 'seed') and request.seed:
params.append(f"seed={request.seed}")
# 增强参数
enhance = request.style == "enhanced" if request.style else self.default_enhance
if enhance:
params.append("enhance=true")
# 安全过滤
if self.default_safe:
params.append("safe=true")
# 无logo需要token或referrer认证
if self.default_nologo and (self.api_token or self.referrer):
params.append("nologo=true")
# 私有模式
if self.default_private:
params.append("private=true")
# 透明背景仅gptimage模型支持
if self.default_transparent and self.model == 'gptimage':
params.append("transparent=true")
# 推荐人标识符(用于认证)
if self.referrer:
params.append(f"referrer={urllib.parse.quote(self.referrer)}")
# 添加参数到URL
if params:
url += "?" + "&".join(params)
return url
def _check_rate_limit(self) -> bool:
"""检查速率限制"""
current_time = time.time()
# 清理过期的请求记录
self._request_history = [
req_time for req_time in self._request_history
if current_time - req_time < self.rate_limit_window
]
# 检查是否超过限制
return len(self._request_history) < self.rate_limit_requests
async def get_available_models(self) -> List[Dict[str, Any]]:
"""获取可用模型列表"""
return [
{
"id": "flux",
"name": "Flux",
"description": "High-quality image generation model (default)",
"default": True
},
{
"id": "turbo",
"name": "Turbo",
"description": "Fast image generation with good quality"
},
{
"id": "gptimage",
"name": "GPT Image",
"description": "GPT-based image generation with transparency support"
}
]
async def get_available_styles(self) -> List[Dict[str, Any]]:
"""获取可用样式列表"""
return [
{
"id": "natural",
"name": "Natural",
"description": "Natural style generation"
},
{
"id": "enhanced",
"name": "Enhanced",
"description": "Enhanced prompt with more detail"
}
]
async def health_check(self) -> Dict[str, Any]:
"""健康检查"""
try:
# 准备请求头
headers = {}
if self.api_token:
headers['Authorization'] = f'Bearer {self.api_token}'
# 简单的健康检查 - 尝试访问API基础URL
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session:
# 使用一个简单的测试提示词
test_url = f"{self.api_base}/prompt/test?width=64&height=64"
async with session.head(test_url, headers=headers) as response:
if response.status in [200, 404]: # 404也表示API可访问
return {
"status": "healthy",
"message": "Pollinations API is accessible",
"model": self.model,
"authenticated": bool(self.api_token or self.referrer),
"nologo_enabled": self.default_nologo and bool(self.api_token or self.referrer),
"referrer": self.referrer if self.referrer else None,
"rate_limit": f"{self.rate_limit_requests}/{self.rate_limit_window}s"
}
else:
return {
"status": "unhealthy",
"message": f"API returned status {response.status}",
"model": self.model
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}",
"model": self.model
}