- Fix SQL injection risks in proxy_repo and task_repo - Atomic acquire_pending with UPDATE ... RETURNING - Reuse aiohttp ClientSession in ValidatorService - Replace polling with asyncio.Event in SchedulerService - Optimize ValidationQueue.drain with asyncio.Condition - Concurrent plugin crawling with asyncio.gather - Unify ProxyRaw model import path - Fix test baseline and remove tracked __pycache__ files
158 lines
5.7 KiB
Python
158 lines
5.7 KiB
Python
"""验证任务队列 - 解耦爬取与验证,支持背压控制和持久化"""
|
||
import asyncio
|
||
from typing import Optional
|
||
from app.models.domain import ProxyRaw
|
||
from app.repositories.task_repo import ValidationTaskRepository
|
||
from app.core.db import get_db
|
||
from app.core.log import logger
|
||
|
||
|
||
class ValidationQueue:
|
||
"""代理验证队列(支持持久化到 SQLite)
|
||
|
||
工作流程:
|
||
1. 爬虫将原始代理 submit() 到队列(写入数据库 + 内存信号)
|
||
2. Worker 池从数据库消费并验证
|
||
3. 验证通过的代理写入数据库
|
||
4. 服务重启时自动恢复未完成的 pending 任务
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
validator,
|
||
proxy_repo,
|
||
worker_count: int = 50,
|
||
score_valid: int = 10,
|
||
score_invalid: int = -5,
|
||
score_min: int = 0,
|
||
score_max: int = 100,
|
||
):
|
||
self.validator = validator
|
||
self.proxy_repo = proxy_repo
|
||
self.task_repo = ValidationTaskRepository()
|
||
self.worker_count = worker_count
|
||
self.score_valid = score_valid
|
||
self.score_invalid = score_invalid
|
||
self.score_min = score_min
|
||
self.score_max = score_max
|
||
|
||
self._signal: asyncio.Queue[None] = asyncio.Queue()
|
||
self._workers: list[asyncio.Task] = []
|
||
self._running = False
|
||
self._db_lock = asyncio.Lock()
|
||
self._pending_count = 0
|
||
self._condition = asyncio.Condition()
|
||
|
||
# 统计
|
||
self.valid_count = 0
|
||
self.invalid_count = 0
|
||
|
||
async def start(self):
|
||
if self._running:
|
||
return
|
||
self._running = True
|
||
|
||
# 恢复之前中断的 processing 任务
|
||
async with get_db() as db:
|
||
recovered = await self.task_repo.reset_processing(db)
|
||
pending = await self.task_repo.get_pending_count(db)
|
||
if recovered:
|
||
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
||
if pending:
|
||
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
||
async with self._condition:
|
||
self._pending_count = pending
|
||
|
||
for i in range(self.worker_count):
|
||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||
|
||
# 唤醒 Worker 处理恢复的 pending 任务
|
||
if pending:
|
||
for _ in range(min(pending, self.worker_count)):
|
||
self._signal.put_nowait(None)
|
||
|
||
logger.info(f"ValidationQueue started with {self.worker_count} workers")
|
||
|
||
async def stop(self):
|
||
if not self._running:
|
||
return
|
||
self._running = False
|
||
for _ in self._workers:
|
||
self._signal.put_nowait(None) # sentinel
|
||
if self._workers:
|
||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||
self._workers.clear()
|
||
logger.info("ValidationQueue stopped")
|
||
|
||
async def submit(self, proxies: list[ProxyRaw]):
|
||
"""提交代理到验证队列(持久化 + 唤醒 Worker)"""
|
||
async with self._db_lock:
|
||
async with get_db() as db:
|
||
inserted = await self.task_repo.insert_batch(db, proxies)
|
||
if inserted:
|
||
async with self._condition:
|
||
self._pending_count += inserted
|
||
self._condition.notify_all()
|
||
for _ in range(min(inserted, self.worker_count)):
|
||
self._signal.put_nowait(None)
|
||
|
||
async def submit_one(self, proxy: ProxyRaw):
|
||
await self.submit([proxy])
|
||
|
||
async def drain(self):
|
||
"""等待队列中当前所有 pending 任务处理完毕"""
|
||
async with self._condition:
|
||
if self._pending_count > 0:
|
||
await self._condition.wait_for(lambda: self._pending_count == 0)
|
||
|
||
async def _worker_loop(self, worker_id: int):
|
||
while True:
|
||
await self._signal.get()
|
||
self._signal.task_done()
|
||
if not self._running:
|
||
break
|
||
await self._process_one_task(worker_id)
|
||
|
||
async def _process_one_task(self, worker_id: int):
|
||
"""从数据库取一个任务并验证"""
|
||
async with self._db_lock:
|
||
async with get_db() as db:
|
||
task = await self.task_repo.acquire_pending(db)
|
||
if not task:
|
||
return
|
||
|
||
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
|
||
try:
|
||
is_valid, latency = await self.validator.validate(
|
||
proxy.ip, proxy.port, proxy.protocol
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Worker {worker_id} validation error: {e}")
|
||
is_valid, latency = False, 0.0
|
||
|
||
async with self._db_lock:
|
||
async with get_db() as db:
|
||
if is_valid:
|
||
await self.proxy_repo.insert_or_update(
|
||
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
|
||
)
|
||
if latency:
|
||
await self.proxy_repo.update_response_time(
|
||
db, proxy.ip, proxy.port, latency
|
||
)
|
||
await self.task_repo.complete_task(db, task["id"], True, latency)
|
||
self.valid_count += 1
|
||
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
|
||
else:
|
||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||
self.invalid_count += 1
|
||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||
async with self._condition:
|
||
self._pending_count = max(0, self._pending_count - 1)
|
||
if self._pending_count == 0:
|
||
self._condition.notify_all()
|
||
|
||
def reset_stats(self):
|
||
self.valid_count = 0
|
||
self.invalid_count = 0
|