refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool - 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob) - 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计) - 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob - 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源 - 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理 - 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效) - 前端 Store 强制写后重新拉取,消除乐观更新数据不同步 - 删除 queue.py / task_repo.py / task_service.py - 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
@@ -84,22 +84,7 @@ async def init_db():
|
||||
await db.execute("ALTER TABLE plugin_settings ADD COLUMN stats_json TEXT DEFAULT '{}'")
|
||||
logger.info("Migrated: added stats_json column to plugin_settings")
|
||||
|
||||
# 验证任务队列表
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS validation_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ip TEXT NOT NULL,
|
||||
port INTEGER NOT NULL,
|
||||
protocol TEXT DEFAULT 'http',
|
||||
status TEXT DEFAULT 'pending',
|
||||
result TEXT,
|
||||
response_time_ms REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_status ON validation_tasks(status)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_created ON validation_tasks(created_at)")
|
||||
# 验证任务队列表已废弃,不再创建;旧表保留在数据库中不影响功能
|
||||
|
||||
# 系统设置表
|
||||
await db.execute("""
|
||||
@@ -125,3 +110,27 @@ async def get_db() -> AsyncIterator[aiosqlite.Connection]:
|
||||
yield db
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction() -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""获取带有显式事务控制的数据库连接
|
||||
|
||||
用法:
|
||||
async with transaction() as db:
|
||||
await repo.update(db, ...)
|
||||
# 如果抛出异常,自动 rollback
|
||||
"""
|
||||
ensure_db_dir()
|
||||
db = await aiosqlite.connect(DB_PATH)
|
||||
try:
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute("PRAGMA synchronous=NORMAL")
|
||||
await db.execute("BEGIN")
|
||||
yield db
|
||||
await db.commit()
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
13
app/core/execution/__init__.py
Normal file
13
app/core/execution/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""统一后台任务执行框架"""
|
||||
from .job import JobStatus, Job, CrawlJob, ValidateAllJob
|
||||
from .worker_pool import AsyncWorkerPool
|
||||
from .executor import JobExecutor
|
||||
|
||||
__all__ = [
|
||||
"JobStatus",
|
||||
"Job",
|
||||
"CrawlJob",
|
||||
"ValidateAllJob",
|
||||
"AsyncWorkerPool",
|
||||
"JobExecutor",
|
||||
]
|
||||
138
app/core/execution/executor.py
Normal file
138
app/core/execution/executor.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Job 执行器 - 统一管理所有后台 Job 的生命周期"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.execution.job import Job, JobStatus
|
||||
from app.core.execution.worker_pool import AsyncWorkerPool
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class JobExecutor:
|
||||
"""后台任务执行器
|
||||
|
||||
- 维护内存中的 Job 状态表
|
||||
- 限制最大并发 Job 数
|
||||
- 自动清理过期 Job
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_pool: Optional[AsyncWorkerPool] = None,
|
||||
max_concurrent_jobs: int = 10,
|
||||
cleanup_interval_seconds: int = 300,
|
||||
):
|
||||
self.worker_pool = worker_pool
|
||||
self.max_concurrent_jobs = max_concurrent_jobs
|
||||
self._jobs: Dict[str, Job] = {}
|
||||
self._running = False
|
||||
self._semaphore = asyncio.Semaphore(max_concurrent_jobs)
|
||||
self._cleanup_interval = cleanup_interval_seconds
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("JobExecutor started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
await self.cancel_all()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("JobExecutor stopped")
|
||||
|
||||
def submit_job(self, job: Job) -> str:
|
||||
"""提交一个 Job 到后台执行"""
|
||||
self._jobs[job.id] = job
|
||||
asyncio.create_task(self._run_job(job))
|
||||
return job.id
|
||||
|
||||
async def _run_job(self, job: Job) -> None:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if job.is_cancelled:
|
||||
logger.info(f"Job {job.id} was cancelled before running")
|
||||
return
|
||||
result = await job.run()
|
||||
# 如果子类没有显式设置完成状态,自动设为 completed
|
||||
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
job._set_completed(result)
|
||||
logger.info(f"Job {job.id} completed: {result}")
|
||||
except asyncio.CancelledError:
|
||||
job.status = JobStatus.CANCELLED
|
||||
job._touch()
|
||||
logger.info(f"Job {job.id} cancelled during execution")
|
||||
except Exception as e:
|
||||
job._set_failed(str(e))
|
||||
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: Optional[JobStatus] = None,
|
||||
limit: int = 100,
|
||||
) -> List[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||
return jobs[:limit]
|
||||
|
||||
async def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
async def cancel_all(self) -> int:
|
||||
cancelled = 0
|
||||
for job in list(self._jobs.values()):
|
||||
if job.status in (JobStatus.PENDING, JobStatus.RUNNING):
|
||||
job.cancel()
|
||||
cancelled += 1
|
||||
return cancelled
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""定期清理已完成且过期的 Job"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
self._cleanup_old_jobs()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"JobExecutor cleanup error: {e}")
|
||||
|
||||
def _cleanup_old_jobs(self) -> int:
|
||||
cutoff = datetime.now() - timedelta(seconds=self._cleanup_interval)
|
||||
to_remove = [
|
||||
job_id
|
||||
for job_id, job in self._jobs.items()
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and job.updated_at < cutoff
|
||||
]
|
||||
for job_id in to_remove:
|
||||
del self._jobs[job_id]
|
||||
if to_remove:
|
||||
logger.info(f"JobExecutor cleaned up {len(to_remove)} old jobs")
|
||||
return len(to_remove)
|
||||
159
app/core/execution/job.py
Normal file
159
app/core/execution/job.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Job 定义 - 所有后台异步任务的统一抽象"""
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.models.domain import ProxyRaw
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job(ABC):
|
||||
"""后台任务基类"""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
status: JobStatus = JobStatus.PENDING
|
||||
progress: float = 0.0 # 0-100
|
||||
result: Any = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
_cancelled: bool = field(default=False, repr=False)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""请求取消任务"""
|
||||
self._cancelled = True
|
||||
if self.status in (JobStatus.PENDING, JobStatus.RUNNING):
|
||||
self.status = JobStatus.CANCELLED
|
||||
self._touch()
|
||||
|
||||
def _touch(self) -> None:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def _set_running(self) -> None:
|
||||
self.status = JobStatus.RUNNING
|
||||
self._touch()
|
||||
|
||||
def _set_completed(self, result: Any = None) -> None:
|
||||
self.status = JobStatus.COMPLETED
|
||||
self.result = result
|
||||
self.progress = 100.0
|
||||
self._touch()
|
||||
|
||||
def _set_failed(self, error: str) -> None:
|
||||
self.status = JobStatus.FAILED
|
||||
self.error = error
|
||||
self._touch()
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> Any:
|
||||
"""执行任务的核心逻辑,子类必须实现"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"status": self.status.value,
|
||||
"progress": round(self.progress, 2),
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlJob(Job):
|
||||
"""插件爬取任务"""
|
||||
|
||||
plugin_id: str = ""
|
||||
plugin_runner: Any = field(repr=False, default=None)
|
||||
proxy_service: Any = field(repr=False, default=None)
|
||||
validator_pool: Any = field(repr=False, default=None)
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.core.log import logger
|
||||
|
||||
self._set_running()
|
||||
if not self.plugin_runner:
|
||||
raise RuntimeError("plugin_runner is not set")
|
||||
|
||||
plugin_service = self.proxy_service or PluginService()
|
||||
plugin = plugin_service.get_plugin(self.plugin_id)
|
||||
if not plugin:
|
||||
raise ValueError(f"Plugin '{self.plugin_id}' not found")
|
||||
|
||||
result = await self.plugin_runner.run(plugin)
|
||||
proxies: List[ProxyRaw] = result.proxies if result else []
|
||||
|
||||
if proxies and self.validator_pool:
|
||||
await self.validator_pool.submit(proxies)
|
||||
logger.info(f"CrawlJob {self.id}: submitted {len(proxies)} proxies for validation")
|
||||
|
||||
payload = {
|
||||
"plugin_id": self.plugin_id,
|
||||
"proxy_count": len(proxies),
|
||||
}
|
||||
if result:
|
||||
payload["success_count"] = result.success_count
|
||||
payload["failure_count"] = result.failure_count
|
||||
self._set_completed(payload)
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidateAllJob(Job):
|
||||
"""全量验证任务 - 验证数据库中所有存量代理"""
|
||||
|
||||
proxy_repo: Any = field(repr=False, default=None)
|
||||
validator_pool: Any = field(repr=False, default=None)
|
||||
batch_size: int = 100
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.core.db import get_db
|
||||
from app.core.log import logger
|
||||
|
||||
self._set_running()
|
||||
repo = self.proxy_repo or ProxyRepository()
|
||||
|
||||
async with get_db() as db:
|
||||
proxies = await repo.list_all(db)
|
||||
|
||||
if not proxies:
|
||||
self._set_completed({"total": 0, "submitted": 0})
|
||||
return self.result
|
||||
|
||||
total = len(proxies)
|
||||
submitted = 0
|
||||
for i in range(0, total, self.batch_size):
|
||||
if self.is_cancelled:
|
||||
logger.info(f"ValidateAllJob {self.id}: cancelled")
|
||||
break
|
||||
batch = proxies[i : i + self.batch_size]
|
||||
raws = [ProxyRaw(p.ip, p.port, p.protocol) for p in batch]
|
||||
if self.validator_pool:
|
||||
await self.validator_pool.submit(raws)
|
||||
submitted += len(raws)
|
||||
self.progress = min(100.0, (submitted / total) * 100)
|
||||
self._touch()
|
||||
|
||||
payload = {"total": total, "submitted": submitted}
|
||||
self._set_completed(payload)
|
||||
logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies")
|
||||
return payload
|
||||
97
app/core/execution/worker_pool.py
Normal file
97
app/core/execution/worker_pool.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""轻量级纯内存异步 Worker Pool"""
|
||||
import asyncio
|
||||
from typing import Callable, Coroutine, List, TypeVar
|
||||
|
||||
from app.core.log import logger
|
||||
|
||||
T = TypeVar("T")
|
||||
Handler = Callable[[T], Coroutine[None, None, None]]
|
||||
|
||||
|
||||
class AsyncWorkerPool:
|
||||
"""纯内存异步工作池
|
||||
|
||||
职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理
|
||||
不包含任何持久化逻辑,也不维护复杂的状态机。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_count: int,
|
||||
handler: Handler,
|
||||
queue_maxsize: int = 10000,
|
||||
name: str = "WorkerPool",
|
||||
):
|
||||
self.worker_count = worker_count
|
||||
self.handler = handler
|
||||
self.name = name
|
||||
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self._workers: List[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._shutdown_event.clear()
|
||||
self._workers = [
|
||||
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
|
||||
for i in range(self.worker_count)
|
||||
]
|
||||
logger.info(f"{self.name} started with {self.worker_count} workers")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
self._shutdown_event.set()
|
||||
# 发送足够数量的 sentinel,确保所有 Worker 都能收到
|
||||
for _ in range(len(self._workers) + self._queue.qsize()):
|
||||
try:
|
||||
self._queue.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
break
|
||||
if self._workers:
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info(f"{self.name} stopped")
|
||||
|
||||
async def submit(self, items: List[T]) -> None:
|
||||
"""提交一批任务到队列(阻塞直到有空位,天然背压)"""
|
||||
for item in items:
|
||||
await self._queue.put(item)
|
||||
|
||||
async def drain(self) -> None:
|
||||
"""等待队列中所有任务被消费完毕"""
|
||||
await self._queue.join()
|
||||
|
||||
async def resize(self, new_worker_count: int) -> None:
|
||||
"""动态调整 Worker 数量:先全部停止,再按新数量启动"""
|
||||
if new_worker_count == self.worker_count:
|
||||
return
|
||||
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
|
||||
# 安全做法:先 stop 再 start,避免新旧 Worker 竞争 sentinel 导致死锁
|
||||
await self.stop()
|
||||
self.worker_count = new_worker_count
|
||||
await self.start()
|
||||
|
||||
async def _worker_loop(self, worker_id: int) -> None:
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
try:
|
||||
if item is None or not self._running:
|
||||
self._queue.task_done()
|
||||
break
|
||||
await self.handler(item)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
@@ -45,6 +45,11 @@ class PluginRegistry:
|
||||
def get_plugin_names(self) -> List[str]:
|
||||
return list(self._plugins.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有已注册插件(主要用于测试)"""
|
||||
self._plugins.clear()
|
||||
self._instances.clear()
|
||||
|
||||
def auto_discover(self, package_name: str):
|
||||
"""自动扫描指定包下的所有模块并注册其中的插件类。
|
||||
注意:为了类型安全和可控性,推荐显式注册。auto_discover 仅作为兼容。"""
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""验证任务队列 - 解耦爬取与验证,支持背压控制和持久化"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from app.models.domain import ProxyRaw
|
||||
from app.repositories.task_repo import ValidationTaskRepository
|
||||
from app.core.db import get_db
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class ValidationQueue:
|
||||
"""代理验证队列(支持持久化到 SQLite)
|
||||
|
||||
工作流程:
|
||||
1. 爬虫将原始代理 submit() 到队列(写入数据库 + 内存信号)
|
||||
2. Worker 池从数据库消费并验证
|
||||
3. 验证通过的代理写入数据库
|
||||
4. 服务重启时自动恢复未完成的 pending 任务
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validator,
|
||||
proxy_repo,
|
||||
worker_count: int = 50,
|
||||
score_valid: int = 10,
|
||||
score_invalid: int = -5,
|
||||
score_min: int = 0,
|
||||
score_max: int = 100,
|
||||
):
|
||||
self.validator = validator
|
||||
self.proxy_repo = proxy_repo
|
||||
self.task_repo = ValidationTaskRepository()
|
||||
self.worker_count = worker_count
|
||||
self.score_valid = score_valid
|
||||
self.score_invalid = score_invalid
|
||||
self.score_min = score_min
|
||||
self.score_max = score_max
|
||||
|
||||
self._signal: asyncio.Queue[None] = asyncio.Queue()
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._pending_count = 0
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
# 统计
|
||||
self.valid_count = 0
|
||||
self.invalid_count = 0
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
# 恢复之前中断的 processing 任务
|
||||
async with get_db() as db:
|
||||
recovered = await self.task_repo.reset_processing(db)
|
||||
pending = await self.task_repo.get_pending_count(db)
|
||||
if pending > 1000:
|
||||
logger.warning(f"ValidationQueue has {pending} pending tasks, cleaning up all pending tasks...")
|
||||
await db.execute("DELETE FROM validation_tasks WHERE status = 'pending'")
|
||||
await db.commit()
|
||||
pending = await self.task_repo.get_pending_count(db)
|
||||
logger.info(f"ValidationQueue cleaned up pending tasks, remaining: {pending}")
|
||||
if recovered:
|
||||
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
||||
if pending:
|
||||
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
||||
async with self._condition:
|
||||
self._pending_count = pending
|
||||
|
||||
for i in range(self.worker_count):
|
||||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||||
|
||||
# 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理)
|
||||
if pending:
|
||||
for _ in range(self.worker_count):
|
||||
self._signal.put_nowait(None)
|
||||
|
||||
logger.info(f"ValidationQueue started with {self.worker_count} workers")
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
for _ in self._workers:
|
||||
self._signal.put_nowait(None) # sentinel
|
||||
if self._workers:
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("ValidationQueue stopped")
|
||||
|
||||
async def submit(self, proxies: list[ProxyRaw]):
|
||||
"""提交代理到验证队列(持久化 + 唤醒 Worker)"""
|
||||
async with get_db() as db:
|
||||
inserted = await self.task_repo.insert_batch(db, proxies)
|
||||
if inserted:
|
||||
async with self._condition:
|
||||
self._pending_count += inserted
|
||||
self._condition.notify_all()
|
||||
for _ in range(min(inserted, self.worker_count)):
|
||||
self._signal.put_nowait(None)
|
||||
|
||||
async def submit_one(self, proxy: ProxyRaw):
|
||||
await self.submit([proxy])
|
||||
|
||||
async def drain(self):
|
||||
"""等待队列中当前所有 pending 任务处理完毕"""
|
||||
async with self._condition:
|
||||
if self._pending_count > 0:
|
||||
await self._condition.wait_for(lambda: self._pending_count == 0)
|
||||
|
||||
async def _worker_loop(self, worker_id: int):
|
||||
while True:
|
||||
await self._signal.get()
|
||||
self._signal.task_done()
|
||||
if not self._running:
|
||||
break
|
||||
# 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死
|
||||
while self._running:
|
||||
processed = await self._process_one_task(worker_id)
|
||||
if not processed:
|
||||
break
|
||||
|
||||
async def _process_one_task(self, worker_id: int) -> bool:
|
||||
"""从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。"""
|
||||
async with get_db() as db:
|
||||
task = await self.task_repo.acquire_pending(db)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
|
||||
try:
|
||||
is_valid, latency = await self.validator.validate(
|
||||
proxy.ip, proxy.port, proxy.protocol
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True)
|
||||
is_valid, latency = False, 0.0
|
||||
|
||||
if is_valid:
|
||||
await self.proxy_repo.insert_or_update(
|
||||
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
|
||||
)
|
||||
if latency:
|
||||
await self.proxy_repo.update_response_time(
|
||||
db, proxy.ip, proxy.port, latency
|
||||
)
|
||||
await self.task_repo.complete_task(db, task["id"], True, latency)
|
||||
self.valid_count += 1
|
||||
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
|
||||
else:
|
||||
# 对已有代理扣分,分数<=0时自动删除
|
||||
await self.proxy_repo.update_score(
|
||||
db, proxy.ip, proxy.port, self.score_invalid,
|
||||
self.score_min, self.score_max
|
||||
)
|
||||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||||
self.invalid_count += 1
|
||||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||||
async with self._condition:
|
||||
self._pending_count = max(0, self._pending_count - 1)
|
||||
if self._pending_count == 0:
|
||||
self._condition.notify_all()
|
||||
return True
|
||||
|
||||
def reset_stats(self):
|
||||
self.valid_count = 0
|
||||
self.invalid_count = 0
|
||||
Reference in New Issue
Block a user