Files
ProxyPool/app/core/execution/worker_pool.py

106 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""轻量级纯内存异步 Worker Pool"""
import asyncio
from typing import Callable, Coroutine, List, TypeVar
from app.core.log import logger
T = TypeVar("T")
Handler = Callable[[T], Coroutine[None, None, None]]
class AsyncWorkerPool:
"""纯内存异步工作池
职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理
不包含任何持久化逻辑,也不维护复杂的状态机。
"""
def __init__(
self,
worker_count: int,
handler: Handler,
queue_maxsize: int = 10000,
name: str = "WorkerPool",
):
self.worker_count = worker_count
self.handler = handler
self.name = name
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
self._workers: List[asyncio.Task] = []
self._running = False
self._shutdown_event = asyncio.Event()
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
async def start(self) -> None:
if self._running:
return
self._running = True
self._shutdown_event.clear()
self._workers = [
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
for i in range(self.worker_count)
]
logger.info(f"{self.name} started with {self.worker_count} workers")
async def stop(self) -> None:
if not self._running:
return
self._running = False
self._shutdown_event.set()
# 发送足够数量的 sentinel确保所有 Worker 都能收到
for _ in range(len(self._workers) + self._queue.qsize()):
try:
self._queue.put_nowait(None)
except asyncio.QueueFull:
break
if self._workers:
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info(f"{self.name} stopped")
async def submit(self, items: List[T]) -> None:
"""提交一批任务到队列(阻塞直到有空位,天然背压)"""
for item in items:
await self._queue.put(item)
async def drain(self) -> None:
"""等待队列中所有任务被消费完毕"""
await self._queue.join()
async def resize(self, new_worker_count: int) -> None:
"""动态调整 Worker 数量,不丢失队列中的任务"""
if new_worker_count == self.worker_count:
return
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
if new_worker_count > self.worker_count:
for i in range(self.worker_count, new_worker_count):
self._workers.append(
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
)
elif new_worker_count < self.worker_count:
excess_workers = self._workers[new_worker_count:]
self._workers = self._workers[:new_worker_count]
for w in excess_workers:
w.cancel()
if excess_workers:
await asyncio.gather(*excess_workers, return_exceptions=True)
self.worker_count = new_worker_count
async def _worker_loop(self, worker_id: int) -> None:
while True:
item = await self._queue.get()
try:
if item is None or not self._running:
break
await self.handler(item)
except Exception as e:
logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True)
finally:
self._queue.task_done()