Add File
This commit is contained in:
288
src/landppt/services/image/providers/pollinations_provider.py
Normal file
288
src/landppt/services/image/providers/pollinations_provider.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Pollinations AI图片生成提供者
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
import tempfile
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
import aiohttp
|
||||
|
||||
from ..models import (
|
||||
ImageProvider, ImageGenerationRequest, ImageOperationResult,
|
||||
ImageInfo, ImageFormat, ImageLicense, ImageSourceType, ImageMetadata, ImageTag
|
||||
)
|
||||
from .base import ImageGenerationProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PollinationsProvider(ImageGenerationProvider):
|
||||
"""Pollinations AI图片生成提供者"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(ImageProvider.POLLINATIONS, config)
|
||||
|
||||
# API配置
|
||||
self.api_base = config.get('api_base', 'https://image.pollinations.ai')
|
||||
self.api_token = config.get('api_token', '')
|
||||
self.referrer = config.get('referrer', '')
|
||||
self.model = config.get('model', 'flux')
|
||||
self.default_width = config.get('default_width', 1024)
|
||||
self.default_height = config.get('default_height', 1024)
|
||||
self.default_enhance = config.get('default_enhance', False)
|
||||
self.default_safe = config.get('default_safe', False)
|
||||
self.default_nologo = config.get('default_nologo', False)
|
||||
self.default_private = config.get('default_private', False)
|
||||
self.default_transparent = config.get('default_transparent', False)
|
||||
|
||||
# 速率限制
|
||||
self.rate_limit_requests = config.get('rate_limit_requests', 60)
|
||||
self.rate_limit_window = config.get('rate_limit_window', 60)
|
||||
|
||||
# 请求历史(用于速率限制)
|
||||
self._request_history = []
|
||||
|
||||
logger.debug(f"Pollinations provider initialized with model: {self.model}")
|
||||
|
||||
async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult:
|
||||
"""生成图片"""
|
||||
try:
|
||||
# 检查速率限制
|
||||
if not self._check_rate_limit():
|
||||
return ImageOperationResult(
|
||||
success=False,
|
||||
message="Rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
# 记录请求时间
|
||||
self._request_history.append(time.time())
|
||||
|
||||
# 准备API请求
|
||||
api_url = self._build_api_url(request)
|
||||
|
||||
logger.debug(f"Generating image with Pollinations API: {api_url}")
|
||||
|
||||
# 准备请求头
|
||||
headers = {}
|
||||
if self.api_token:
|
||||
headers['Authorization'] = f'Bearer {self.api_token}'
|
||||
|
||||
# 发送请求
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# 读取图片数据
|
||||
image_data = await response.read()
|
||||
|
||||
# 创建图片信息
|
||||
image_id = f"pollinations_{int(time.time())}"
|
||||
filename = f"{image_id}.png"
|
||||
|
||||
# 创建元数据
|
||||
metadata = ImageMetadata(
|
||||
width=request.width or self.default_width,
|
||||
height=request.height or self.default_height,
|
||||
format=ImageFormat.PNG,
|
||||
file_size=len(image_data)
|
||||
)
|
||||
|
||||
# 创建标签
|
||||
tags = [
|
||||
ImageTag(name="ai-generated", category="type", source="system"),
|
||||
ImageTag(name="pollinations", category="provider", source="system"),
|
||||
ImageTag(name=self.model, category="model", source="system")
|
||||
]
|
||||
|
||||
# 保存图片到临时文件
|
||||
temp_dir = Path(tempfile.gettempdir()) / "pollinations_images"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
temp_file_path = temp_dir / filename
|
||||
|
||||
with open(temp_file_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
image_info = ImageInfo(
|
||||
image_id=image_id,
|
||||
source_type=ImageSourceType.AI_GENERATED,
|
||||
provider=self.provider,
|
||||
original_url=api_url,
|
||||
local_path=str(temp_file_path),
|
||||
filename=filename,
|
||||
title=f"Generated: {request.prompt[:50]}...",
|
||||
description=f"Generated by Pollinations AI using prompt: {request.prompt}",
|
||||
metadata=metadata,
|
||||
tags=tags,
|
||||
license=ImageLicense.CUSTOM
|
||||
)
|
||||
|
||||
return ImageOperationResult(
|
||||
success=True,
|
||||
message="Image generated successfully",
|
||||
image_info=image_info
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Pollinations API error {response.status}: {error_text}")
|
||||
return ImageOperationResult(
|
||||
success=False,
|
||||
message=f"API request failed with status {response.status}: {error_text}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Pollinations API request timeout")
|
||||
return ImageOperationResult(
|
||||
success=False,
|
||||
message="Request timeout. Please try again."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Pollinations generation error: {str(e)}")
|
||||
return ImageOperationResult(
|
||||
success=False,
|
||||
message=f"Generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
def _build_api_url(self, request: ImageGenerationRequest) -> str:
|
||||
"""构建API请求URL"""
|
||||
# URL编码提示词
|
||||
encoded_prompt = urllib.parse.quote(request.prompt, safe='')
|
||||
|
||||
# 构建基础URL
|
||||
url = f"{self.api_base}/prompt/{encoded_prompt}"
|
||||
|
||||
# 添加参数
|
||||
params = []
|
||||
|
||||
# 模型参数
|
||||
if self.model != 'flux': # flux是默认模型
|
||||
params.append(f"model={self.model}")
|
||||
|
||||
# 尺寸参数
|
||||
width = request.width or self.default_width
|
||||
height = request.height or self.default_height
|
||||
if width != 1024 or height != 1024: # 1024x1024是默认尺寸
|
||||
params.append(f"width={width}")
|
||||
params.append(f"height={height}")
|
||||
|
||||
# 种子参数(如果提供)
|
||||
if hasattr(request, 'seed') and request.seed:
|
||||
params.append(f"seed={request.seed}")
|
||||
|
||||
# 增强参数
|
||||
enhance = request.style == "enhanced" if request.style else self.default_enhance
|
||||
if enhance:
|
||||
params.append("enhance=true")
|
||||
|
||||
# 安全过滤
|
||||
if self.default_safe:
|
||||
params.append("safe=true")
|
||||
|
||||
# 无logo(需要token或referrer认证)
|
||||
if self.default_nologo and (self.api_token or self.referrer):
|
||||
params.append("nologo=true")
|
||||
|
||||
# 私有模式
|
||||
if self.default_private:
|
||||
params.append("private=true")
|
||||
|
||||
# 透明背景(仅gptimage模型支持)
|
||||
if self.default_transparent and self.model == 'gptimage':
|
||||
params.append("transparent=true")
|
||||
|
||||
# 推荐人标识符(用于认证)
|
||||
if self.referrer:
|
||||
params.append(f"referrer={urllib.parse.quote(self.referrer)}")
|
||||
|
||||
# 添加参数到URL
|
||||
if params:
|
||||
url += "?" + "&".join(params)
|
||||
|
||||
return url
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""检查速率限制"""
|
||||
current_time = time.time()
|
||||
|
||||
# 清理过期的请求记录
|
||||
self._request_history = [
|
||||
req_time for req_time in self._request_history
|
||||
if current_time - req_time < self.rate_limit_window
|
||||
]
|
||||
|
||||
# 检查是否超过限制
|
||||
return len(self._request_history) < self.rate_limit_requests
|
||||
|
||||
async def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""获取可用模型列表"""
|
||||
return [
|
||||
{
|
||||
"id": "flux",
|
||||
"name": "Flux",
|
||||
"description": "High-quality image generation model (default)",
|
||||
"default": True
|
||||
},
|
||||
{
|
||||
"id": "turbo",
|
||||
"name": "Turbo",
|
||||
"description": "Fast image generation with good quality"
|
||||
},
|
||||
{
|
||||
"id": "gptimage",
|
||||
"name": "GPT Image",
|
||||
"description": "GPT-based image generation with transparency support"
|
||||
}
|
||||
]
|
||||
|
||||
async def get_available_styles(self) -> List[Dict[str, Any]]:
|
||||
"""获取可用样式列表"""
|
||||
return [
|
||||
{
|
||||
"id": "natural",
|
||||
"name": "Natural",
|
||||
"description": "Natural style generation"
|
||||
},
|
||||
{
|
||||
"id": "enhanced",
|
||||
"name": "Enhanced",
|
||||
"description": "Enhanced prompt with more detail"
|
||||
}
|
||||
]
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""健康检查"""
|
||||
try:
|
||||
# 准备请求头
|
||||
headers = {}
|
||||
if self.api_token:
|
||||
headers['Authorization'] = f'Bearer {self.api_token}'
|
||||
|
||||
# 简单的健康检查 - 尝试访问API基础URL
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session:
|
||||
# 使用一个简单的测试提示词
|
||||
test_url = f"{self.api_base}/prompt/test?width=64&height=64"
|
||||
async with session.head(test_url, headers=headers) as response:
|
||||
if response.status in [200, 404]: # 404也表示API可访问
|
||||
return {
|
||||
"status": "healthy",
|
||||
"message": "Pollinations API is accessible",
|
||||
"model": self.model,
|
||||
"authenticated": bool(self.api_token or self.referrer),
|
||||
"nologo_enabled": self.default_nologo and bool(self.api_token or self.referrer),
|
||||
"referrer": self.referrer if self.referrer else None,
|
||||
"rate_limit": f"{self.rate_limit_requests}/{self.rate_limit_window}s"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"API returned status {response.status}",
|
||||
"model": self.model
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}",
|
||||
"model": self.model
|
||||
}
|
||||
Reference in New Issue
Block a user