289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""
|
||
Pollinations AI图片生成提供者
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
import urllib.parse
|
||
import tempfile
|
||
from typing import Dict, Any, List, Optional
|
||
from pathlib import Path
|
||
import aiohttp
|
||
|
||
from ..models import (
|
||
ImageProvider, ImageGenerationRequest, ImageOperationResult,
|
||
ImageInfo, ImageFormat, ImageLicense, ImageSourceType, ImageMetadata, ImageTag
|
||
)
|
||
from .base import ImageGenerationProvider
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PollinationsProvider(ImageGenerationProvider):
|
||
"""Pollinations AI图片生成提供者"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
super().__init__(ImageProvider.POLLINATIONS, config)
|
||
|
||
# API配置
|
||
self.api_base = config.get('api_base', 'https://image.pollinations.ai')
|
||
self.api_token = config.get('api_token', '')
|
||
self.referrer = config.get('referrer', '')
|
||
self.model = config.get('model', 'flux')
|
||
self.default_width = config.get('default_width', 1024)
|
||
self.default_height = config.get('default_height', 1024)
|
||
self.default_enhance = config.get('default_enhance', False)
|
||
self.default_safe = config.get('default_safe', False)
|
||
self.default_nologo = config.get('default_nologo', False)
|
||
self.default_private = config.get('default_private', False)
|
||
self.default_transparent = config.get('default_transparent', False)
|
||
|
||
# 速率限制
|
||
self.rate_limit_requests = config.get('rate_limit_requests', 60)
|
||
self.rate_limit_window = config.get('rate_limit_window', 60)
|
||
|
||
# 请求历史(用于速率限制)
|
||
self._request_history = []
|
||
|
||
logger.debug(f"Pollinations provider initialized with model: {self.model}")
|
||
|
||
async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
|
||
"""生成图片"""
|
||
try:
|
||
# 检查速率限制
|
||
if not self._check_rate_limit():
|
||
return ImageOperationResult(
|
||
success=False,
|
||
message="Rate limit exceeded. Please try again later."
|
||
)
|
||
|
||
# 记录请求时间
|
||
self._request_history.append(time.time())
|
||
|
||
# 准备API请求
|
||
api_url = self._build_api_url(request)
|
||
|
||
logger.debug(f"Generating image with Pollinations API: {api_url}")
|
||
|
||
# 准备请求头
|
||
headers = {}
|
||
if self.api_token:
|
||
headers['Authorization'] = f'Bearer {self.api_token}'
|
||
|
||
# 发送请求
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||
async with session.get(api_url, headers=headers) as response:
|
||
if response.status == 200:
|
||
# 读取图片数据
|
||
image_data = await response.read()
|
||
|
||
# 创建图片信息
|
||
image_id = f"pollinations_{int(time.time())}"
|
||
filename = f"{image_id}.png"
|
||
|
||
# 创建元数据
|
||
metadata = ImageMetadata(
|
||
width=request.width or self.default_width,
|
||
height=request.height or self.default_height,
|
||
format=ImageFormat.PNG,
|
||
file_size=len(image_data)
|
||
)
|
||
|
||
# 创建标签
|
||
tags = [
|
||
ImageTag(name="ai-generated", category="type", source="system"),
|
||
ImageTag(name="pollinations", category="provider", source="system"),
|
||
ImageTag(name=self.model, category="model", source="system")
|
||
]
|
||
|
||
# 保存图片到临时文件
|
||
temp_dir = Path(tempfile.gettempdir()) / "pollinations_images"
|
||
temp_dir.mkdir(exist_ok=True)
|
||
temp_file_path = temp_dir / filename
|
||
|
||
with open(temp_file_path, 'wb') as f:
|
||
f.write(image_data)
|
||
|
||
image_info = ImageInfo(
|
||
image_id=image_id,
|
||
source_type=ImageSourceType.AI_GENERATED,
|
||
provider=self.provider,
|
||
original_url=api_url,
|
||
local_path=str(temp_file_path),
|
||
filename=filename,
|
||
title=f"Generated: {request.prompt[:50]}...",
|
||
description=f"Generated by Pollinations AI using prompt: {request.prompt}",
|
||
metadata=metadata,
|
||
tags=tags,
|
||
license=ImageLicense.CUSTOM
|
||
)
|
||
|
||
return ImageOperationResult(
|
||
success=True,
|
||
message="Image generated successfully",
|
||
image_info=image_info
|
||
)
|
||
else:
|
||
error_text = await response.text()
|
||
logger.error(f"Pollinations API error {response.status}: {error_text}")
|
||
return ImageOperationResult(
|
||
success=False,
|
||
message=f"API request failed with status {response.status}: {error_text}"
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error("Pollinations API request timeout")
|
||
return ImageOperationResult(
|
||
success=False,
|
||
message="Request timeout. Please try again."
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Pollinations generation error: {str(e)}")
|
||
return ImageOperationResult(
|
||
success=False,
|
||
message=f"Generation failed: {str(e)}"
|
||
)
|
||
|
||
def _build_api_url(self, request: ImageGenerationRequest) -> str:
|
||
"""构建API请求URL"""
|
||
# URL编码提示词
|
||
encoded_prompt = urllib.parse.quote(request.prompt, safe='')
|
||
|
||
# 构建基础URL
|
||
url = f"{self.api_base}/prompt/{encoded_prompt}"
|
||
|
||
# 添加参数
|
||
params = []
|
||
|
||
# 模型参数
|
||
if self.model != 'flux': # flux是默认模型
|
||
params.append(f"model={self.model}")
|
||
|
||
# 尺寸参数
|
||
width = request.width or self.default_width
|
||
height = request.height or self.default_height
|
||
if width != 1024 or height != 1024: # 1024x1024是默认尺寸
|
||
params.append(f"width={width}")
|
||
params.append(f"height={height}")
|
||
|
||
# 种子参数(如果提供)
|
||
if hasattr(request, 'seed') and request.seed:
|
||
params.append(f"seed={request.seed}")
|
||
|
||
# 增强参数
|
||
enhance = request.style == "enhanced" if request.style else self.default_enhance
|
||
if enhance:
|
||
params.append("enhance=true")
|
||
|
||
# 安全过滤
|
||
if self.default_safe:
|
||
params.append("safe=true")
|
||
|
||
# 无logo(需要token或referrer认证)
|
||
if self.default_nologo and (self.api_token or self.referrer):
|
||
params.append("nologo=true")
|
||
|
||
# 私有模式
|
||
if self.default_private:
|
||
params.append("private=true")
|
||
|
||
# 透明背景(仅gptimage模型支持)
|
||
if self.default_transparent and self.model == 'gptimage':
|
||
params.append("transparent=true")
|
||
|
||
# 推荐人标识符(用于认证)
|
||
if self.referrer:
|
||
params.append(f"referrer={urllib.parse.quote(self.referrer)}")
|
||
|
||
# 添加参数到URL
|
||
if params:
|
||
url += "?" + "&".join(params)
|
||
|
||
return url
|
||
|
||
def _check_rate_limit(self) -> bool:
|
||
"""检查速率限制"""
|
||
current_time = time.time()
|
||
|
||
# 清理过期的请求记录
|
||
self._request_history = [
|
||
req_time for req_time in self._request_history
|
||
if current_time - req_time < self.rate_limit_window
|
||
]
|
||
|
||
# 检查是否超过限制
|
||
return len(self._request_history) < self.rate_limit_requests
|
||
|
||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||
"""获取可用模型列表"""
|
||
return [
|
||
{
|
||
"id": "flux",
|
||
"name": "Flux",
|
||
"description": "High-quality image generation model (default)",
|
||
"default": True
|
||
},
|
||
{
|
||
"id": "turbo",
|
||
"name": "Turbo",
|
||
"description": "Fast image generation with good quality"
|
||
},
|
||
{
|
||
"id": "gptimage",
|
||
"name": "GPT Image",
|
||
"description": "GPT-based image generation with transparency support"
|
||
}
|
||
]
|
||
|
||
async def get_available_styles(self) -> List[Dict[str, Any]]:
|
||
"""获取可用样式列表"""
|
||
return [
|
||
{
|
||
"id": "natural",
|
||
"name": "Natural",
|
||
"description": "Natural style generation"
|
||
},
|
||
{
|
||
"id": "enhanced",
|
||
"name": "Enhanced",
|
||
"description": "Enhanced prompt with more detail"
|
||
}
|
||
]
|
||
|
||
async def health_check(self) -> Dict[str, Any]:
|
||
"""健康检查"""
|
||
try:
|
||
# 准备请求头
|
||
headers = {}
|
||
if self.api_token:
|
||
headers['Authorization'] = f'Bearer {self.api_token}'
|
||
|
||
# 简单的健康检查 - 尝试访问API基础URL
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session:
|
||
# 使用一个简单的测试提示词
|
||
test_url = f"{self.api_base}/prompt/test?width=64&height=64"
|
||
async with session.head(test_url, headers=headers) as response:
|
||
if response.status in [200, 404]: # 404也表示API可访问
|
||
return {
|
||
"status": "healthy",
|
||
"message": "Pollinations API is accessible",
|
||
"model": self.model,
|
||
"authenticated": bool(self.api_token or self.referrer),
|
||
"nologo_enabled": self.default_nologo and bool(self.api_token or self.referrer),
|
||
"referrer": self.referrer if self.referrer else None,
|
||
"rate_limit": f"{self.rate_limit_requests}/{self.rate_limit_window}s"
|
||
}
|
||
else:
|
||
return {
|
||
"status": "unhealthy",
|
||
"message": f"API returned status {response.status}",
|
||
"model": self.model
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"status": "unhealthy",
|
||
"message": f"Health check failed: {str(e)}",
|
||
"model": self.model
|
||
}
|