Add File
This commit is contained in:
146
src/landppt/utils/thread_pool.py
Normal file
146
src/landppt/utils/thread_pool.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
线程池工具类,用于将阻塞操作放入线程池中执行
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, TypeVar, Coroutine, Optional, Dict
|
||||
|
||||
# 类型变量
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ThreadPoolManager:
|
||||
"""线程池管理器,提供全局线程池和辅助方法"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ThreadPoolManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化线程池管理器
|
||||
|
||||
Args:
|
||||
max_workers: 最大工作线程数,默认为 None(使用 CPU 核心数 * 5)
|
||||
"""
|
||||
if not self._initialized:
|
||||
# 如果未指定最大工作线程数,则使用 CPU 核心数 * 5
|
||||
if max_workers is None:
|
||||
max_workers = os.cpu_count() * 5 if os.cpu_count() else 20
|
||||
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="landppt_worker"
|
||||
)
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_tasks": 0,
|
||||
"completed_tasks": 0,
|
||||
"failed_tasks": 0,
|
||||
"active_tasks": 0
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"线程池初始化完成,最大工作线程数: {max_workers}")
|
||||
|
||||
async def run_in_thread(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""在线程池中运行同步函数
|
||||
|
||||
Args:
|
||||
func: 要运行的同步函数
|
||||
*args: 传递给函数的位置参数
|
||||
**kwargs: 传递给函数的关键字参数
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
Raises:
|
||||
Exception: 如果函数执行失败,则抛出异常
|
||||
"""
|
||||
self.stats["total_tasks"] += 1
|
||||
self.stats["active_tasks"] += 1
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
self.executor,
|
||||
functools.partial(func, *args, **kwargs)
|
||||
)
|
||||
|
||||
self.stats["completed_tasks"] += 1
|
||||
return result
|
||||
except Exception as e:
|
||||
self.stats["failed_tasks"] += 1
|
||||
logger.error(f"线程池任务执行失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.stats["active_tasks"] -= 1
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取线程池统计信息"""
|
||||
return self.stats.copy()
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""关闭线程池
|
||||
|
||||
Args:
|
||||
wait: 是否等待所有线程完成
|
||||
"""
|
||||
if self._initialized:
|
||||
self.executor.shutdown(wait=wait)
|
||||
self._initialized = False
|
||||
logger.info("线程池已关闭")
|
||||
|
||||
|
||||
# 全局线程池实例
|
||||
thread_pool = ThreadPoolManager()
|
||||
|
||||
|
||||
async def run_blocking_io(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""运行阻塞的 I/O 操作
|
||||
|
||||
这是一个便捷函数,用于将阻塞的 I/O 操作放入线程池中执行
|
||||
|
||||
Args:
|
||||
func: 要运行的阻塞函数
|
||||
*args: 传递给函数的位置参数
|
||||
**kwargs: 传递给函数的关键字参数
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
"""
|
||||
return await thread_pool.run_in_thread(func, *args, **kwargs)
|
||||
|
||||
|
||||
def to_thread(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||
"""装饰器:将同步函数转换为在线程池中运行的异步函数
|
||||
|
||||
用法:
|
||||
@to_thread
|
||||
def blocking_function(arg1, arg2):
|
||||
# 执行阻塞操作
|
||||
return result
|
||||
|
||||
# 调用
|
||||
result = await blocking_function(arg1, arg2)
|
||||
|
||||
Args:
|
||||
func: 要装饰的同步函数
|
||||
|
||||
Returns:
|
||||
异步包装函数
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await thread_pool.run_in_thread(func, *args, **kwargs)
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user