"""轻量级纯内存异步 Worker Pool""" import asyncio from typing import Callable, Coroutine, List, TypeVar from app.core.log import logger T = TypeVar("T") Handler = Callable[[T], Coroutine[None, None, None]] class AsyncWorkerPool: """纯内存异步工作池 职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理 不包含任何持久化逻辑,也不维护复杂的状态机。 """ def __init__( self, worker_count: int, handler: Handler, queue_maxsize: int = 10000, name: str = "WorkerPool", ): self.worker_count = worker_count self.handler = handler self.name = name self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize) self._workers: List[asyncio.Task] = [] self._running = False self._shutdown_event = asyncio.Event() async def __aenter__(self): await self.start() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop() async def start(self) -> None: if self._running: return self._running = True self._shutdown_event.clear() self._workers = [ asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}") for i in range(self.worker_count) ] logger.info(f"{self.name} started with {self.worker_count} workers") async def stop(self) -> None: if not self._running: return self._running = False self._shutdown_event.set() # 发送足够数量的 sentinel,确保所有 Worker 都能收到 for _ in range(len(self._workers) + self._queue.qsize()): try: self._queue.put_nowait(None) except asyncio.QueueFull: break if self._workers: await asyncio.gather(*self._workers, return_exceptions=True) self._workers.clear() logger.info(f"{self.name} stopped") async def submit(self, items: List[T]) -> None: """提交一批任务到队列(优先 put_nowait,队列满时再 await put)""" for item in items: try: self._queue.put_nowait(item) except asyncio.QueueFull: await self._queue.put(item) async def drain(self) -> None: """等待队列中所有任务被消费完毕""" await self._queue.join() async def resize(self, new_worker_count: int) -> None: """动态调整 Worker 数量,不丢失队列中的任务""" if new_worker_count == self.worker_count: return logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}") if new_worker_count > self.worker_count: for i in range(self.worker_count, new_worker_count): self._workers.append( asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}") ) elif new_worker_count < self.worker_count: excess_workers = self._workers[new_worker_count:] self._workers = self._workers[:new_worker_count] for w in excess_workers: w.cancel() if excess_workers: await asyncio.gather(*excess_workers, return_exceptions=True) self.worker_count = new_worker_count async def _worker_loop(self, worker_id: int) -> None: while True: item = await self._queue.get() try: if item is None or not self._running: break await self.handler(item) except Exception as e: logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True) finally: self._queue.task_done()