113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
"""轻量级纯内存异步 Worker Pool"""
|
||
import asyncio
|
||
from typing import Callable, Coroutine, List, TypeVar
|
||
|
||
from app.core.log import logger
|
||
|
||
T = TypeVar("T")
|
||
Handler = Callable[[T], Coroutine[None, None, None]]
|
||
|
||
|
||
class AsyncWorkerPool:
|
||
"""纯内存异步工作池
|
||
|
||
职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理
|
||
不包含任何持久化逻辑,也不维护复杂的状态机。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
worker_count: int,
|
||
handler: Handler,
|
||
queue_maxsize: int = 10000,
|
||
name: str = "WorkerPool",
|
||
):
|
||
self.worker_count = worker_count
|
||
self.handler = handler
|
||
self.name = name
|
||
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||
self._workers: List[asyncio.Task] = []
|
||
self._running = False
|
||
self._shutdown_event = asyncio.Event()
|
||
|
||
async def __aenter__(self):
|
||
await self.start()
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.stop()
|
||
|
||
async def start(self) -> None:
|
||
if self._running:
|
||
return
|
||
self._running = True
|
||
self._shutdown_event.clear()
|
||
self._workers = [
|
||
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
|
||
for i in range(self.worker_count)
|
||
]
|
||
logger.info(f"{self.name} started with {self.worker_count} workers")
|
||
|
||
async def stop(self) -> None:
|
||
if not self._running:
|
||
return
|
||
self._running = False
|
||
self._shutdown_event.set()
|
||
# 发送足够数量的 sentinel,确保所有 Worker 都能收到
|
||
for _ in range(len(self._workers) + self._queue.qsize()):
|
||
try:
|
||
self._queue.put_nowait(None)
|
||
except asyncio.QueueFull:
|
||
break
|
||
if self._workers:
|
||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||
self._workers.clear()
|
||
logger.info(f"{self.name} stopped")
|
||
|
||
async def submit(self, items: List[T]) -> None:
|
||
"""提交一批任务到队列(阻塞直到有空位,天然背压)"""
|
||
for item in items:
|
||
await self._queue.put(item)
|
||
|
||
async def drain(self) -> None:
|
||
"""等待队列中所有任务被消费完毕"""
|
||
await self._queue.join()
|
||
|
||
async def resize(self, new_worker_count: int) -> None:
|
||
"""动态调整 Worker 数量,不丢失队列中的任务"""
|
||
if new_worker_count == self.worker_count:
|
||
return
|
||
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
|
||
if new_worker_count > self.worker_count:
|
||
for i in range(self.worker_count, new_worker_count):
|
||
self._workers.append(
|
||
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
|
||
)
|
||
elif new_worker_count < self.worker_count:
|
||
for _ in range(self.worker_count - new_worker_count):
|
||
await self._queue.put(None)
|
||
await asyncio.sleep(0)
|
||
still_running = []
|
||
for w in self._workers:
|
||
if w.done():
|
||
try:
|
||
await w
|
||
except asyncio.CancelledError:
|
||
pass
|
||
else:
|
||
still_running.append(w)
|
||
self._workers = still_running
|
||
self.worker_count = new_worker_count
|
||
|
||
async def _worker_loop(self, worker_id: int) -> None:
|
||
while True:
|
||
item = await self._queue.get()
|
||
try:
|
||
if item is None or not self._running:
|
||
break
|
||
await self.handler(item)
|
||
except Exception as e:
|
||
logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True)
|
||
finally:
|
||
self._queue.task_done()
|