Files
LandPPT/src/landppt/services/image/providers/pollinations_provider.py
2025-11-07 09:05:16 +08:00

289 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Pollinations AI图片生成提供者
"""
import asyncio
import logging
import time
import urllib.parse
import tempfile
from typing import Dict, Any, List, Optional
from pathlib import Path
import aiohttp
from ..models import (
ImageProvider, ImageGenerationRequest, ImageOperationResult,
ImageInfo, ImageFormat, ImageLicense, ImageSourceType, ImageMetadata, ImageTag
)
from .base import ImageGenerationProvider
logger = logging.getLogger(__name__)
class PollinationsProvider(ImageGenerationProvider):
"""Pollinations AI图片生成提供者"""
def __init__(self, config: Dict[str, Any]):
super().__init__(ImageProvider.POLLINATIONS, config)
# API配置
self.api_base = config.get('api_base', 'https://image.pollinations.ai')
self.api_token = config.get('api_token', '')
self.referrer = config.get('referrer', '')
self.model = config.get('model', 'flux')
self.default_width = config.get('default_width', 1024)
self.default_height = config.get('default_height', 1024)
self.default_enhance = config.get('default_enhance', False)
self.default_safe = config.get('default_safe', False)
self.default_nologo = config.get('default_nologo', False)
self.default_private = config.get('default_private', False)
self.default_transparent = config.get('default_transparent', False)
# 速率限制
self.rate_limit_requests = config.get('rate_limit_requests', 60)
self.rate_limit_window = config.get('rate_limit_window', 60)
# 请求历史(用于速率限制)
self._request_history = []
logger.debug(f"Pollinations provider initialized with model: {self.model}")
async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
"""生成图片"""
try:
# 检查速率限制
if not self._check_rate_limit():
return ImageOperationResult(
success=False,
message="Rate limit exceeded. Please try again later."
)
# 记录请求时间
self._request_history.append(time.time())
# 准备API请求
api_url = self._build_api_url(request)
logger.debug(f"Generating image with Pollinations API: {api_url}")
# 准备请求头
headers = {}
if self.api_token:
headers['Authorization'] = f'Bearer {self.api_token}'
# 发送请求
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
async with session.get(api_url, headers=headers) as response:
if response.status == 200:
# 读取图片数据
image_data = await response.read()
# 创建图片信息
image_id = f"pollinations_{int(time.time())}"
filename = f"{image_id}.png"
# 创建元数据
metadata = ImageMetadata(
width=request.width or self.default_width,
height=request.height or self.default_height,
format=ImageFormat.PNG,
file_size=len(image_data)
)
# 创建标签
tags = [
ImageTag(name="ai-generated", category="type", source="system"),
ImageTag(name="pollinations", category="provider", source="system"),
ImageTag(name=self.model, category="model", source="system")
]
# 保存图片到临时文件
temp_dir = Path(tempfile.gettempdir()) / "pollinations_images"
temp_dir.mkdir(exist_ok=True)
temp_file_path = temp_dir / filename
with open(temp_file_path, 'wb') as f:
f.write(image_data)
image_info = ImageInfo(
image_id=image_id,
source_type=ImageSourceType.AI_GENERATED,
provider=self.provider,
original_url=api_url,
local_path=str(temp_file_path),
filename=filename,
title=f"Generated: {request.prompt[:50]}...",
description=f"Generated by Pollinations AI using prompt: {request.prompt}",
metadata=metadata,
tags=tags,
license=ImageLicense.CUSTOM
)
return ImageOperationResult(
success=True,
message="Image generated successfully",
image_info=image_info
)
else:
error_text = await response.text()
logger.error(f"Pollinations API error {response.status}: {error_text}")
return ImageOperationResult(
success=False,
message=f"API request failed with status {response.status}: {error_text}"
)
except asyncio.TimeoutError:
logger.error("Pollinations API request timeout")
return ImageOperationResult(
success=False,
message="Request timeout. Please try again."
)
except Exception as e:
logger.error(f"Pollinations generation error: {str(e)}")
return ImageOperationResult(
success=False,
message=f"Generation failed: {str(e)}"
)
def _build_api_url(self, request: ImageGenerationRequest) -> str:
"""构建API请求URL"""
# URL编码提示词
encoded_prompt = urllib.parse.quote(request.prompt, safe='')
# 构建基础URL
url = f"{self.api_base}/prompt/{encoded_prompt}"
# 添加参数
params = []
# 模型参数
if self.model != 'flux': # flux是默认模型
params.append(f"model={self.model}")
# 尺寸参数
width = request.width or self.default_width
height = request.height or self.default_height
if width != 1024 or height != 1024: # 1024x1024是默认尺寸
params.append(f"width={width}")
params.append(f"height={height}")
# 种子参数(如果提供)
if hasattr(request, 'seed') and request.seed:
params.append(f"seed={request.seed}")
# 增强参数
enhance = request.style == "enhanced" if request.style else self.default_enhance
if enhance:
params.append("enhance=true")
# 安全过滤
if self.default_safe:
params.append("safe=true")
# 无logo需要token或referrer认证
if self.default_nologo and (self.api_token or self.referrer):
params.append("nologo=true")
# 私有模式
if self.default_private:
params.append("private=true")
# 透明背景仅gptimage模型支持
if self.default_transparent and self.model == 'gptimage':
params.append("transparent=true")
# 推荐人标识符(用于认证)
if self.referrer:
params.append(f"referrer={urllib.parse.quote(self.referrer)}")
# 添加参数到URL
if params:
url += "?" + "&".join(params)
return url
def _check_rate_limit(self) -> bool:
"""检查速率限制"""
current_time = time.time()
# 清理过期的请求记录
self._request_history = [
req_time for req_time in self._request_history
if current_time - req_time < self.rate_limit_window
]
# 检查是否超过限制
return len(self._request_history) < self.rate_limit_requests
async def get_available_models(self) -> List[Dict[str, Any]]:
"""获取可用模型列表"""
return [
{
"id": "flux",
"name": "Flux",
"description": "High-quality image generation model (default)",
"default": True
},
{
"id": "turbo",
"name": "Turbo",
"description": "Fast image generation with good quality"
},
{
"id": "gptimage",
"name": "GPT Image",
"description": "GPT-based image generation with transparency support"
}
]
async def get_available_styles(self) -> List[Dict[str, Any]]:
"""获取可用样式列表"""
return [
{
"id": "natural",
"name": "Natural",
"description": "Natural style generation"
},
{
"id": "enhanced",
"name": "Enhanced",
"description": "Enhanced prompt with more detail"
}
]
async def health_check(self) -> Dict[str, Any]:
"""健康检查"""
try:
# 准备请求头
headers = {}
if self.api_token:
headers['Authorization'] = f'Bearer {self.api_token}'
# 简单的健康检查 - 尝试访问API基础URL
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session:
# 使用一个简单的测试提示词
test_url = f"{self.api_base}/prompt/test?width=64&height=64"
async with session.head(test_url, headers=headers) as response:
if response.status in [200, 404]: # 404也表示API可访问
return {
"status": "healthy",
"message": "Pollinations API is accessible",
"model": self.model,
"authenticated": bool(self.api_token or self.referrer),
"nologo_enabled": self.default_nologo and bool(self.api_token or self.referrer),
"referrer": self.referrer if self.referrer else None,
"rate_limit": f"{self.rate_limit_requests}/{self.rate_limit_window}s"
}
else:
return {
"status": "unhealthy",
"message": f"API returned status {response.status}",
"model": self.model
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}",
"model": self.model
}