refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool - 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob) - 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计) - 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob - 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源 - 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理 - 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效) - 前端 Store 强制写后重新拉取,消除乐观更新数据不同步 - 删除 queue.py / task_repo.py / task_service.py - 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
90
app/services/plugin_runner.py
Normal file
90
app/services/plugin_runner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""插件统一执行器 - 封装超时、重试、健康检查、错误捕获"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
from app.models.domain import CrawlResult, ProxyRaw
|
||||
|
||||
|
||||
class PluginRunner:
|
||||
"""统一插件执行器
|
||||
|
||||
- 超时控制(从 settings 读取 crawl_timeout)
|
||||
- 异常捕获和统计更新
|
||||
- 可选的健康检查前置
|
||||
- 结果去重
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30)
|
||||
|
||||
async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult:
|
||||
"""执行单个插件爬取"""
|
||||
result = CrawlResult(plugin_name=plugin.name)
|
||||
|
||||
# 健康检查(可选)
|
||||
try:
|
||||
healthy = await asyncio.wait_for(
|
||||
plugin.health_check(), timeout=5.0
|
||||
)
|
||||
if not healthy:
|
||||
result.error = "health check failed"
|
||||
result.failure_count = 1
|
||||
await self._save_stats(plugin, result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Plugin {plugin.name} health check error: {e}")
|
||||
|
||||
# 执行爬取
|
||||
try:
|
||||
proxies = await asyncio.wait_for(
|
||||
plugin.crawl(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
result.proxies = self._dedup(proxies)
|
||||
result.success_count = 1 if result.proxies else 0
|
||||
logger.info(
|
||||
f"Plugin {plugin.name} crawled {len(result.proxies)} unique proxies"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
result.error = f"crawl timeout after {self.timeout}s"
|
||||
result.failure_count = 1
|
||||
logger.error(f"Plugin {plugin.name} crawl timeout")
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
result.failure_count = 1
|
||||
logger.error(f"Plugin {plugin.name} crawl failed: {e}", exc_info=True)
|
||||
|
||||
await self._save_stats(plugin, result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dedup(proxies: list[ProxyRaw]) -> list[ProxyRaw]:
|
||||
seen = set()
|
||||
unique = []
|
||||
for p in proxies:
|
||||
key = (p.ip, p.port)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
async def _save_stats(self, plugin: BaseCrawlerPlugin, result: CrawlResult) -> None:
|
||||
"""将爬取统计持久化到数据库"""
|
||||
from app.core.db import get_db
|
||||
from app.repositories.settings_repo import PluginSettingsRepository
|
||||
|
||||
repo = PluginSettingsRepository()
|
||||
payload = {
|
||||
"success_count": result.success_count,
|
||||
"failure_count": result.failure_count,
|
||||
"last_run": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
async with get_db() as db:
|
||||
await repo.set_stats(db, plugin.name, payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save stats for {plugin.name}: {e}")
|
||||
@@ -1,21 +1,21 @@
|
||||
"""插件业务服务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.db import get_db
|
||||
from app.core.plugin_system.registry import registry
|
||||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||||
from app.core.exceptions import PluginNotFoundException
|
||||
from app.repositories.settings_repo import PluginSettingsRepository
|
||||
from app.models.domain import PluginInfo, ProxyRaw
|
||||
from app.models.domain import PluginInfo, ProxyRaw, CrawlResult
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class PluginService:
|
||||
"""插件业务服务:管理插件生命周期、执行爬取、配置管理"""
|
||||
"""插件业务服务:管理插件生命周期、配置管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_settings_repo = PluginSettingsRepository()
|
||||
self._stats: dict[str, dict] = {}
|
||||
|
||||
async def list_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有插件信息(合并持久化状态和配置)"""
|
||||
@@ -24,42 +24,40 @@ class PluginService:
|
||||
|
||||
result = []
|
||||
for plugin in registry.list_plugins():
|
||||
# 合并持久化状态(不修改全局实例,避免并发竞争)
|
||||
state = db_states.get(plugin.name, {})
|
||||
enabled = state.get("enabled", plugin.enabled)
|
||||
if "config" in state and isinstance(state["config"], dict):
|
||||
plugin.update_config(state["config"])
|
||||
|
||||
# 合并数据库统计与内存统计(内存优先)
|
||||
db_stat = state.get("stats", {})
|
||||
stat = {
|
||||
"success_count": db_stat.get("success_count", 0),
|
||||
"failure_count": db_stat.get("failure_count", 0),
|
||||
"last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None,
|
||||
}
|
||||
mem_stat = self._stats.get(plugin.name, {})
|
||||
if mem_stat:
|
||||
stat["success_count"] = mem_stat.get("success_count", stat["success_count"])
|
||||
stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"])
|
||||
if mem_stat.get("last_run"):
|
||||
stat["last_run"] = mem_stat["last_run"]
|
||||
|
||||
result.append(PluginInfo(
|
||||
id=plugin.name,
|
||||
name=plugin.name,
|
||||
display_name=plugin.display_name or plugin.name,
|
||||
description=plugin.description or f"从 {plugin.name} 爬取代理",
|
||||
enabled=enabled,
|
||||
last_run=stat.get("last_run"),
|
||||
success_count=stat.get("success_count", 0),
|
||||
failure_count=stat.get("failure_count", 0),
|
||||
last_run=self._parse_datetime(db_stat.get("last_run")),
|
||||
success_count=db_stat.get("success_count", 0),
|
||||
failure_count=db_stat.get("failure_count", 0),
|
||||
))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value) -> Optional:
|
||||
from datetime import datetime
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool:
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled)
|
||||
if success:
|
||||
@@ -68,10 +66,9 @@ class PluginService:
|
||||
return success
|
||||
|
||||
async def get_plugin_config(self, plugin_id: str) -> Optional[dict]:
|
||||
"""获取插件当前配置(合并默认值和持久化值)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return None
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
saved = await self.plugin_settings_repo.get_config(db, plugin_id)
|
||||
config = dict(plugin.default_config)
|
||||
@@ -80,11 +77,9 @@ class PluginService:
|
||||
return config
|
||||
|
||||
async def update_plugin_config(self, plugin_id: str, config: dict) -> bool:
|
||||
"""更新插件配置(只保存已存在于 default_config 中的键)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
# 过滤非法键
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
safe_config = {k: v for k, v in config.items() if k in plugin.default_config}
|
||||
if not safe_config:
|
||||
return False
|
||||
@@ -95,44 +90,45 @@ class PluginService:
|
||||
def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]:
|
||||
return registry.get(plugin_id)
|
||||
|
||||
async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]:
|
||||
"""执行单个插件爬取"""
|
||||
def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin:
|
||||
plugin = self.get_plugin(plugin_id)
|
||||
if not plugin:
|
||||
raise ValueError(f"Plugin {plugin_id} not found")
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
return plugin
|
||||
|
||||
async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult:
|
||||
"""执行单个插件爬取(通过 PluginRunner)"""
|
||||
plugin = self.get_plugin_or_raise(plugin_id)
|
||||
if not plugin.enabled:
|
||||
logger.warning(f"Plugin {plugin_id} is disabled, skip crawl")
|
||||
return []
|
||||
return CrawlResult(plugin_name=plugin_id)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
try:
|
||||
results = await plugin.crawl()
|
||||
self._record_stat(plugin_id, success=len(results))
|
||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||
return results
|
||||
except Exception as e:
|
||||
self._record_stat(plugin_id, failure=1)
|
||||
logger.error(f"Plugin {plugin_id} crawl failed: {e}")
|
||||
return []
|
||||
finally:
|
||||
await self._save_stats(plugin_id)
|
||||
|
||||
async def run_all_plugins(self) -> List[ProxyRaw]:
|
||||
async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]:
|
||||
"""执行所有启用插件的爬取,限制并发数以避免触发目标站反爬"""
|
||||
all_results: List[ProxyRaw] = []
|
||||
semaphore = asyncio.Semaphore(5)
|
||||
|
||||
async def _run_with_limit(plugin_name: str):
|
||||
plugin = self.get_plugin_or_raise(plugin_name)
|
||||
if not plugin.enabled:
|
||||
return CrawlResult(plugin_name=plugin_name)
|
||||
async with semaphore:
|
||||
return await self.run_plugin(plugin_name)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
|
||||
tasks = [
|
||||
_run_with_limit(plugin.name)
|
||||
for plugin in registry.list_plugins() if plugin.enabled
|
||||
]
|
||||
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for results in results_list:
|
||||
if isinstance(results, Exception):
|
||||
logger.error(f"Run all plugins error: {results}")
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Run all plugins error: {result}")
|
||||
continue
|
||||
all_results.extend(results)
|
||||
# 去重(与数据库 UNIQUE(ip, port) 约束保持一致)
|
||||
if isinstance(result, CrawlResult):
|
||||
all_results.extend(result.proxies)
|
||||
|
||||
# 去重
|
||||
seen = set()
|
||||
unique = []
|
||||
for p in all_results:
|
||||
@@ -141,26 +137,3 @@ class PluginService:
|
||||
seen.add(key)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0):
|
||||
if plugin_id not in self._stats:
|
||||
self._stats[plugin_id] = {
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"last_run": None,
|
||||
}
|
||||
self._stats[plugin_id]["success_count"] += success
|
||||
self._stats[plugin_id]["failure_count"] += failure
|
||||
if success or failure:
|
||||
self._stats[plugin_id]["last_run"] = datetime.now()
|
||||
|
||||
async def _save_stats(self, plugin_id: str):
|
||||
"""将内存中的统计持久化到数据库"""
|
||||
stats = self._stats.get(plugin_id, {})
|
||||
payload = {
|
||||
"success_count": stats.get("success_count", 0),
|
||||
"failure_count": stats.get("failure_count", 0),
|
||||
"last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None,
|
||||
}
|
||||
async with get_db() as db:
|
||||
await self.plugin_settings_repo.set_stats(db, plugin_id, payload)
|
||||
|
||||
@@ -4,7 +4,8 @@ import json
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, AsyncIterator
|
||||
from app.core.db import get_db
|
||||
|
||||
from app.core.db import get_db, transaction
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.models.domain import Proxy
|
||||
from app.core.log import logger
|
||||
|
||||
@@ -1,43 +1,42 @@
|
||||
"""调度器服务 - 定时验证存量代理"""
|
||||
"""调度器服务 - 定时触发全量验证"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from app.core.db import get_db
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.repositories.task_repo import ValidationTaskRepository
|
||||
from app.core.tasks.queue import ValidationQueue
|
||||
from app.core.config import settings as app_settings
|
||||
from typing import Optional
|
||||
|
||||
from app.core.execution.executor import JobExecutor
|
||||
from app.core.execution.job import ValidateAllJob
|
||||
from app.core.log import logger
|
||||
from app.models.domain import ProxyRaw
|
||||
from app.services.task_service import task_service
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""代理验证调度器"""
|
||||
"""代理验证调度器
|
||||
|
||||
职责单一:定时循环,触发 ValidateAllJob。
|
||||
不再直接持有验证队列或 ValidatorService。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validation_queue: ValidationQueue,
|
||||
proxy_repo: ProxyRepository = ProxyRepository(),
|
||||
executor: JobExecutor,
|
||||
interval_minutes: int = 30,
|
||||
):
|
||||
self.validation_queue = validation_queue
|
||||
self.proxy_repo = proxy_repo
|
||||
self.interval_minutes = 30
|
||||
self.executor = executor
|
||||
self.interval_minutes = interval_minutes
|
||||
self.running = False
|
||||
self._stop_event = asyncio.Event()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._validate_task: asyncio.Task | None = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self.running = True
|
||||
await self.validation_queue.start()
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
logger.info("Scheduler started")
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
if self._task:
|
||||
@@ -47,72 +46,26 @@ class SchedulerService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
await self.validation_queue.stop()
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
def cancel_validate_task(self):
|
||||
"""取消正在执行的全量验证后台任务"""
|
||||
if self._validate_task and not self._validate_task.done():
|
||||
self._validate_task.cancel()
|
||||
def validate_all_now(self) -> str:
|
||||
"""立即执行一次全量验证,返回 Job ID"""
|
||||
job_id = self.executor.submit_job(ValidateAllJob())
|
||||
logger.info(f"ValidateAllJob submitted: {job_id}")
|
||||
return job_id
|
||||
|
||||
async def validate_all_now(self):
|
||||
"""立即执行一次全量验证(后台运行,不阻塞)"""
|
||||
if self._validate_task and not self._validate_task.done():
|
||||
return
|
||||
self._validate_task = asyncio.create_task(self._do_validate_all(from_loop=False))
|
||||
|
||||
async def _run_loop(self):
|
||||
async def _run_loop(self) -> None:
|
||||
"""定时循环"""
|
||||
while self.running:
|
||||
try:
|
||||
# 清理过期任务,防止内存无限增长
|
||||
task_service.cleanup_old_tasks()
|
||||
await self._do_validate_all(from_loop=True)
|
||||
self.executor.submit_job(ValidateAllJob())
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler loop error: {e}", exc_info=True)
|
||||
# 等待下一次
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
||||
await asyncio.wait_for(
|
||||
self._stop_event.wait(),
|
||||
timeout=self.interval_minutes * 60,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _do_validate_all(self, from_loop: bool = True):
|
||||
"""验证数据库中所有存量代理"""
|
||||
queue_started_here = False
|
||||
try:
|
||||
logger.info("Starting scheduled validation for all proxies")
|
||||
|
||||
# 如果队列未运行,临时启动它(适用于 validate_all_now 在调度器停止时调用)
|
||||
if not self.validation_queue._running:
|
||||
await self.validation_queue.start()
|
||||
queue_started_here = True
|
||||
|
||||
async with get_db() as db:
|
||||
# 清理 7 天前的验证任务记录,防止表无限增长
|
||||
cleaned = await ValidationTaskRepository.cleanup_old(db, days=7)
|
||||
if cleaned:
|
||||
logger.info(f"Cleaned up {cleaned} old validation tasks")
|
||||
proxies = await self.proxy_repo.list_all(db)
|
||||
if not proxies:
|
||||
logger.info("No proxies to validate")
|
||||
return
|
||||
|
||||
logger.info(f"Validating {len(proxies)} proxies from database")
|
||||
# 批量提交到验证队列,不再阻塞等待 drain
|
||||
batch_size = 100
|
||||
total_batches = (len(proxies) - 1) // batch_size + 1
|
||||
for i in range(0, len(proxies), batch_size):
|
||||
if from_loop and not self.running:
|
||||
break
|
||||
batch = proxies[i : i + batch_size]
|
||||
await self.validation_queue.submit([
|
||||
ProxyRaw(p.ip, p.port, p.protocol) for p in batch
|
||||
])
|
||||
logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}")
|
||||
|
||||
logger.info("Scheduled validation batches submitted")
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled validation error: {e}", exc_info=True)
|
||||
finally:
|
||||
if queue_started_here:
|
||||
await self.validation_queue.stop()
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class TaskService:
|
||||
"""异步任务状态管理器"""
|
||||
|
||||
def __init__(self, max_age_seconds: int = 3600):
|
||||
self._tasks: Dict[str, dict] = {}
|
||||
self._max_age_seconds = max_age_seconds
|
||||
|
||||
def create_task(self, task_type: str) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now()
|
||||
self._tasks[task_id] = {
|
||||
"id": task_id,
|
||||
"type": task_type,
|
||||
"status": "pending",
|
||||
"message": "",
|
||||
"data": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
return task_id
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
data: Optional[dict] = None,
|
||||
) -> bool:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return False
|
||||
if status is not None:
|
||||
task["status"] = status
|
||||
if message is not None:
|
||||
task["message"] = message
|
||||
if data is not None:
|
||||
task["data"].update(data)
|
||||
task["updated_at"] = datetime.now()
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def cleanup_old_tasks(self) -> int:
|
||||
cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds)
|
||||
to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff]
|
||||
for tid in to_remove:
|
||||
del self._tasks[tid]
|
||||
if to_remove:
|
||||
logger.info(f"TaskService cleaned up {len(to_remove)} old tasks")
|
||||
return len(to_remove)
|
||||
|
||||
|
||||
# 全局任务服务实例
|
||||
task_service = TaskService()
|
||||
@@ -4,12 +4,17 @@ import random
|
||||
import time
|
||||
import aiohttp
|
||||
import aiohttp_socks
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class ValidatorService:
|
||||
"""代理验证器"""
|
||||
"""代理验证器
|
||||
|
||||
支持动态读取配置,实现设置热更新。
|
||||
"""
|
||||
|
||||
# 测试 URL
|
||||
TEST_URLS = {
|
||||
@@ -19,40 +24,66 @@ class ValidatorService:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 5.0,
|
||||
connect_timeout: float = 3.0,
|
||||
max_concurrency: int = 50,
|
||||
timeout: Optional[float] = None,
|
||||
connect_timeout: Optional[float] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.max_concurrency = max_concurrency
|
||||
self.semaphore = asyncio.Semaphore(max_concurrency)
|
||||
# 初始化时使用传入值或默认值,但运行期会动态读取 settings
|
||||
self._init_timeout = timeout or app_settings.validator_timeout
|
||||
self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout
|
||||
self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency
|
||||
|
||||
# 共享 HTTP/HTTPS ClientSession
|
||||
self._http_connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=max_concurrency,
|
||||
limit_per_host=max_concurrency,
|
||||
force_close=False,
|
||||
)
|
||||
self._timeout = aiohttp.ClientTimeout(
|
||||
total=timeout, connect=connect_timeout
|
||||
)
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
connector=self._http_connector,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
self._http_connector: Optional[aiohttp.TCPConnector] = None
|
||||
self._http_session: Optional[aiohttp.ClientSession] = None
|
||||
self._semaphore: Optional[asyncio.Semaphore] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return float(self._init_timeout)
|
||||
|
||||
@property
|
||||
def connect_timeout(self) -> float:
|
||||
return float(self._init_connect_timeout)
|
||||
|
||||
@property
|
||||
def max_concurrency(self) -> int:
|
||||
return int(self._init_max_concurrency)
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
"""懒加载共享 HTTP session"""
|
||||
if self._http_session is None or self._http_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=self.max_concurrency,
|
||||
limit_per_host=self.max_concurrency,
|
||||
force_close=False,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.timeout, connect=self.connect_timeout
|
||||
)
|
||||
self._http_connector = connector
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
)
|
||||
return self._http_session
|
||||
|
||||
def _ensure_semaphore(self) -> asyncio.Semaphore:
|
||||
if self._semaphore is None:
|
||||
self._semaphore = asyncio.Semaphore(self.max_concurrency)
|
||||
return self._semaphore
|
||||
|
||||
def _get_test_url(self, protocol: str) -> str:
|
||||
"""获取测试 URL"""
|
||||
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
|
||||
return random.choice(urls)
|
||||
|
||||
async def validate(self, ip: str, port: int, protocol: str = "http") -> Tuple[bool, float]:
|
||||
"""验证单个代理,返回 (是否有效, 延迟毫秒)"""
|
||||
protocol = protocol.lower()
|
||||
|
||||
async with self.semaphore:
|
||||
semaphore = self._ensure_semaphore()
|
||||
|
||||
async with semaphore:
|
||||
start = time.time()
|
||||
try:
|
||||
if protocol in ("socks4", "socks5"):
|
||||
@@ -67,11 +98,11 @@ class ValidatorService:
|
||||
return False, 0.0
|
||||
|
||||
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 HTTP/HTTPS 代理"""
|
||||
proxy_url = f"http://{ip}:{port}"
|
||||
test_url = self._get_test_url(protocol)
|
||||
session = self._ensure_session()
|
||||
|
||||
async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
if response.status in (200, 301, 302):
|
||||
latency = round((time.time() - start) * 1000, 2)
|
||||
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
||||
@@ -79,7 +110,6 @@ class ValidatorService:
|
||||
return False, 0.0
|
||||
|
||||
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 SOCKS4/SOCKS5 代理"""
|
||||
proxy_type = (
|
||||
aiohttp_socks.ProxyType.SOCKS4
|
||||
if protocol == "socks4"
|
||||
@@ -103,7 +133,9 @@ class ValidatorService:
|
||||
return True, latency
|
||||
return False, 0.0
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭共享的 HTTP ClientSession"""
|
||||
if self._http_session and not self._http_session.closed:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
self._http_connector = None
|
||||
|
||||
Reference in New Issue
Block a user