refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool - 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob) - 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计) - 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob - 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源 - 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理 - 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效) - 前端 Store 强制写后重新拉取,消除乐观更新数据不同步 - 删除 queue.py / task_repo.py / task_service.py - 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
@@ -3,10 +3,8 @@ from fastapi import Request
|
||||
from app.services.proxy_service import ProxyService
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.services.validator_service import ValidatorService
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.core.tasks.queue import ValidationQueue
|
||||
from app.core.config import settings as app_settings
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.core.execution import JobExecutor, AsyncWorkerPool
|
||||
|
||||
|
||||
def get_proxy_service() -> ProxyService:
|
||||
@@ -18,36 +16,16 @@ def get_plugin_service() -> PluginService:
|
||||
|
||||
|
||||
def get_scheduler_service(request: Request) -> SchedulerService:
|
||||
return request.app.state.scheduler_service
|
||||
return request.app.state.scheduler
|
||||
|
||||
|
||||
def get_validation_queue(request: Request) -> ValidationQueue:
|
||||
return request.app.state.validation_queue
|
||||
def get_executor(request: Request) -> JobExecutor:
|
||||
return request.app.state.executor
|
||||
|
||||
|
||||
def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService:
|
||||
"""在应用启动时创建 SchedulerService(非请求上下文)
|
||||
|
||||
db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。
|
||||
"""
|
||||
db_settings = db_settings or {}
|
||||
validator = ValidatorService(
|
||||
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
|
||||
connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout),
|
||||
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
)
|
||||
proxy_repo = ProxyRepository()
|
||||
queue = ValidationQueue(
|
||||
validator=validator,
|
||||
proxy_repo=proxy_repo,
|
||||
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
score_valid=app_settings.score_valid,
|
||||
score_invalid=app_settings.score_invalid,
|
||||
score_min=app_settings.score_min,
|
||||
score_max=app_settings.score_max,
|
||||
)
|
||||
svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo)
|
||||
svc.interval_minutes = db_settings.get(
|
||||
"validate_interval_minutes", 30
|
||||
)
|
||||
return svc
|
||||
def get_plugin_runner(request: Request) -> PluginRunner:
|
||||
return request.app.state.plugin_runner
|
||||
|
||||
|
||||
def get_worker_pool(request: Request) -> AsyncWorkerPool:
|
||||
return request.app.state.worker_pool
|
||||
|
||||
@@ -1,56 +1,118 @@
|
||||
"""应用生命周期管理"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.db import init_db, get_db
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
from app.api.deps import create_scheduler_service
|
||||
from app.core.execution import AsyncWorkerPool, JobExecutor
|
||||
from app.core.plugin_system.registry import registry
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS
|
||||
from app.services.validator_service import ValidatorService
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
|
||||
settings_repo = SettingsRepository()
|
||||
proxy_repo = ProxyRepository()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用启动和关闭时的生命周期管理"""
|
||||
# 初始化数据库
|
||||
await init_db()
|
||||
|
||||
# 加载设置并决定是否启动调度器
|
||||
async def _load_settings() -> dict:
|
||||
db_settings = DEFAULT_SETTINGS.copy()
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db_settings = await settings_repo.get_all(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load settings on startup: {e}")
|
||||
return db_settings
|
||||
|
||||
# 创建调度器并挂载到 app.state(使用 DB 设置覆盖默认值)
|
||||
scheduler_service = create_scheduler_service(db_settings)
|
||||
app.state.scheduler_service = scheduler_service
|
||||
app.state.validation_queue = scheduler_service.validation_queue
|
||||
|
||||
if db_settings.get("auto_validate", True):
|
||||
try:
|
||||
await scheduler_service.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler on startup: {e}")
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用启动和关闭时的生命周期管理"""
|
||||
await init_db()
|
||||
db_settings = await _load_settings()
|
||||
|
||||
logger.info("API server started")
|
||||
yield
|
||||
async with AsyncExitStack() as stack:
|
||||
# 验证器
|
||||
validator = ValidatorService(
|
||||
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
|
||||
connect_timeout=app_settings.validator_connect_timeout,
|
||||
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
)
|
||||
|
||||
# 关闭调度器
|
||||
scheduler_service.cancel_validate_task()
|
||||
await scheduler_service.stop()
|
||||
# 验证 WorkerPool
|
||||
async def validation_handler(proxy):
|
||||
from app.models.domain import ProxyRaw
|
||||
is_valid, latency = await validator.validate(
|
||||
proxy.ip, proxy.port, proxy.protocol
|
||||
)
|
||||
async with get_db() as db:
|
||||
if is_valid:
|
||||
await proxy_repo.insert_or_update(
|
||||
db, proxy.ip, proxy.port, proxy.protocol, score=app_settings.score_valid
|
||||
)
|
||||
if latency:
|
||||
await proxy_repo.update_response_time(db, proxy.ip, proxy.port, latency)
|
||||
else:
|
||||
await proxy_repo.update_score(
|
||||
db, proxy.ip, proxy.port, app_settings.score_invalid,
|
||||
app_settings.score_min, app_settings.score_max
|
||||
)
|
||||
|
||||
# 关闭验证器 HTTP session
|
||||
worker_pool = AsyncWorkerPool(
|
||||
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
handler=validation_handler,
|
||||
name="ValidationPool",
|
||||
)
|
||||
await stack.enter_async_context(worker_pool)
|
||||
|
||||
# Job 执行器
|
||||
executor = JobExecutor(worker_pool=worker_pool, max_concurrent_jobs=10)
|
||||
await stack.enter_async_context(executor)
|
||||
|
||||
# 插件运行器
|
||||
plugin_runner = PluginRunner(timeout=db_settings.get("crawl_timeout", 30))
|
||||
|
||||
# 调度器
|
||||
scheduler = SchedulerService(
|
||||
executor=executor,
|
||||
interval_minutes=db_settings.get("validate_interval_minutes", 30),
|
||||
)
|
||||
|
||||
# 挂载到 app.state
|
||||
app.state.validator = validator
|
||||
app.state.worker_pool = worker_pool
|
||||
app.state.executor = executor
|
||||
app.state.plugin_runner = plugin_runner
|
||||
app.state.scheduler = scheduler
|
||||
|
||||
# 启动调度器
|
||||
if db_settings.get("auto_validate", True):
|
||||
try:
|
||||
await scheduler.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler on startup: {e}")
|
||||
|
||||
logger.info("API server started")
|
||||
yield
|
||||
|
||||
# 停止调度器
|
||||
await scheduler.stop()
|
||||
|
||||
# 取消所有运行中的 Job
|
||||
await executor.cancel_all()
|
||||
|
||||
# AsyncExitStack 会自动关闭 executor 和 worker_pool
|
||||
|
||||
# 关闭验证器 session
|
||||
try:
|
||||
await scheduler_service.validation_queue.validator.close()
|
||||
await validator.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 关闭所有插件的 HTTP 客户端
|
||||
from app.core.plugin_system.registry import registry
|
||||
for plugin in registry.list_plugins():
|
||||
if hasattr(plugin, "close"):
|
||||
try:
|
||||
|
||||
@@ -43,7 +43,7 @@ def create_app() -> FastAPI:
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
from datetime import datetime
|
||||
scheduler = app.state.scheduler_service
|
||||
scheduler = app.state.scheduler
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
"""插件相关路由"""
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.services.task_service import task_service
|
||||
from app.api.deps import get_plugin_service, get_scheduler_service
|
||||
from app.api.common import success_response, error_response, format_plugin
|
||||
from app.core.log import logger
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.core.execution import JobExecutor, CrawlJob
|
||||
from app.core.exceptions import PluginNotFoundException
|
||||
from app.api.deps import get_plugin_service, get_plugin_runner, get_executor
|
||||
from app.api.common import success_response
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
class ToggleRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
config: dict
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
||||
plugins = await service.list_plugins()
|
||||
@@ -20,19 +29,13 @@ async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
||||
@router.put("/{plugin_id}/toggle")
|
||||
async def toggle_plugin(
|
||||
plugin_id: str,
|
||||
request: dict,
|
||||
request: ToggleRequest,
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
enabled = request.get("enabled")
|
||||
if enabled is None:
|
||||
return error_response("缺少 enabled 参数", 400)
|
||||
|
||||
success = await service.toggle_plugin(plugin_id, enabled)
|
||||
if not success:
|
||||
return error_response("插件不存在", 404)
|
||||
await service.toggle_plugin(plugin_id, request.enabled)
|
||||
return success_response(
|
||||
f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}",
|
||||
{"plugin_id": plugin_id, "enabled": enabled},
|
||||
f"插件 {plugin_id} 已{'启用' if request.enabled else '禁用'}",
|
||||
{"plugin_id": plugin_id, "enabled": request.enabled},
|
||||
)
|
||||
|
||||
|
||||
@@ -42,128 +45,104 @@ async def get_plugin_config(
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
config = await service.get_plugin_config(plugin_id)
|
||||
if config is None:
|
||||
return error_response("插件不存在", 404)
|
||||
return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/config")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: dict,
|
||||
request: ConfigRequest,
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
config = request.get("config", {})
|
||||
if not isinstance(config, dict):
|
||||
return error_response("config 必须是对象", 400)
|
||||
|
||||
success = await service.update_plugin_config(plugin_id, config)
|
||||
success = await service.update_plugin_config(plugin_id, request.config)
|
||||
if not success:
|
||||
return error_response("插件不存在或配置无效", 404)
|
||||
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": request.config})
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/crawl")
|
||||
async def crawl_plugin(
|
||||
plugin_id: str,
|
||||
plugin_service: PluginService = Depends(get_plugin_service),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
plugin_runner: PluginRunner = Depends(get_plugin_runner),
|
||||
executor: JobExecutor = Depends(get_executor),
|
||||
):
|
||||
plugin = plugin_service.get_plugin(plugin_id)
|
||||
if not plugin:
|
||||
return error_response("插件不存在", 404)
|
||||
|
||||
task_id = task_service.create_task("crawl_plugin")
|
||||
|
||||
async def _background_crawl():
|
||||
try:
|
||||
task_service.update_task(task_id, status="running")
|
||||
results = await plugin_service.run_plugin(plugin_id)
|
||||
if not results:
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message=f"插件 {plugin_id} 爬取完成,未获取到代理",
|
||||
data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0},
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||
scheduler_service.validation_queue.reset_stats()
|
||||
await scheduler_service.validation_queue.submit(results)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message=f"插件 {plugin_id} 爬取并验证完成",
|
||||
data={
|
||||
"plugin_id": plugin_id,
|
||||
"proxy_count": len(results),
|
||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl plugin {plugin_id} failed: {e}")
|
||||
task_service.update_task(
|
||||
task_id, status="failed", message=f"插件爬取失败: {str(e)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_background_crawl())
|
||||
plugin = plugin_service.get_plugin_or_raise(plugin_id)
|
||||
job = CrawlJob(
|
||||
plugin_id=plugin_id,
|
||||
plugin_runner=plugin_runner,
|
||||
proxy_service=plugin_service,
|
||||
validator_pool=executor.worker_pool,
|
||||
)
|
||||
job_id = executor.submit_job(job)
|
||||
return success_response(
|
||||
"爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id}
|
||||
"爬取任务已启动", {"task_id": job_id, "plugin_id": plugin_id}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/crawl-all")
|
||||
async def crawl_all(
|
||||
plugin_service: PluginService = Depends(get_plugin_service),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
plugin_runner: PluginRunner = Depends(get_plugin_runner),
|
||||
executor: JobExecutor = Depends(get_executor),
|
||||
):
|
||||
task_id = task_service.create_task("crawl_all")
|
||||
"""为所有启用插件创建 CrawlJob,返回一个聚合任务 ID 保持 API 兼容"""
|
||||
from app.core.plugin_system.registry import registry
|
||||
job_ids = []
|
||||
for plugin in registry.list_plugins():
|
||||
if not plugin.enabled:
|
||||
continue
|
||||
job = CrawlJob(
|
||||
plugin_id=plugin.name,
|
||||
plugin_runner=plugin_runner,
|
||||
proxy_service=plugin_service,
|
||||
validator_pool=executor.worker_pool,
|
||||
)
|
||||
job_ids.append(executor.submit_job(job))
|
||||
|
||||
# 为了保持 API 兼容(前端/测试期望单个 task_id),
|
||||
# 创建一个虚拟的聚合 Job 用于状态查询,它内部会等待所有 crawl job 完成
|
||||
aggregator = _create_crawl_all_aggregator(job_ids, executor)
|
||||
agg_id = executor.submit_job(aggregator)
|
||||
return success_response("爬取任务已启动", {"task_id": agg_id})
|
||||
|
||||
async def _background_crawl_all():
|
||||
try:
|
||||
task_service.update_task(task_id, status="running")
|
||||
results = await plugin_service.run_all_plugins()
|
||||
if not results:
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message="所有插件爬取完成,未获取到代理",
|
||||
data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0},
|
||||
|
||||
def _create_crawl_all_aggregator(job_ids, executor):
|
||||
"""创建一个简单的聚合 Job,查询所有子 Job 的状态汇总"""
|
||||
from app.core.execution.job import Job
|
||||
import asyncio
|
||||
|
||||
class CrawlAllAggregator(Job):
|
||||
async def run(self):
|
||||
# 等待所有子 job 完成(最多等 30 秒)
|
||||
for _ in range(300):
|
||||
all_done = all(
|
||||
executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled")
|
||||
for jid in job_ids
|
||||
)
|
||||
return
|
||||
if all_done:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
total = 0
|
||||
valid = 0
|
||||
for jid in job_ids:
|
||||
job = executor.get_job(jid)
|
||||
if job and job.result:
|
||||
total += job.result.get("proxy_count", 0)
|
||||
valid += job.result.get("success_count", 0)
|
||||
return {"total_crawled": total, "valid_count": valid, "invalid_count": 0}
|
||||
|
||||
logger.info(f"All plugins crawled {len(results)} unique proxies")
|
||||
scheduler_service.validation_queue.reset_stats()
|
||||
await scheduler_service.validation_queue.submit(results)
|
||||
return CrawlAllAggregator()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message="所有插件爬取并验证完成",
|
||||
data={
|
||||
"total_crawled": len(results),
|
||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl all failed: {e}")
|
||||
task_service.update_task(
|
||||
task_id, status="failed", message=f"批量爬取失败: {str(e)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_background_crawl_all())
|
||||
return success_response("爬取任务已启动", {"task_id": task_id})
|
||||
def format_plugin(plugin) -> dict:
|
||||
return {
|
||||
"id": plugin.id,
|
||||
"name": plugin.display_name,
|
||||
"display_name": plugin.display_name,
|
||||
"description": plugin.description,
|
||||
"enabled": plugin.enabled,
|
||||
"last_run": plugin.last_run.isoformat() if plugin.last_run else None,
|
||||
"success_count": plugin.success_count,
|
||||
"failure_count": plugin.failure_count,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""代理相关路由(含统计信息)"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.services.proxy_service import ProxyService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.models.schemas import ProxyListRequest, BatchDeleteRequest
|
||||
from app.api.deps import get_proxy_service, get_scheduler_service
|
||||
from app.api.common import success_response, error_response, format_proxy
|
||||
from app.api.common import success_response, format_proxy
|
||||
from app.core.exceptions import ProxyPoolException, ProxyNotFoundException
|
||||
|
||||
router = APIRouter(prefix="/api/proxies", tags=["proxies"])
|
||||
|
||||
@@ -49,7 +52,7 @@ async def list_proxies(
|
||||
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
|
||||
proxy = await service.get_random_proxy()
|
||||
if not proxy:
|
||||
return error_response("没有找到可用的代理", 404)
|
||||
raise ProxyNotFoundException("", 0)
|
||||
return success_response("获取随机代理成功", format_proxy(proxy))
|
||||
|
||||
|
||||
@@ -61,9 +64,7 @@ async def export_proxies(
|
||||
service: ProxyService = Depends(get_proxy_service),
|
||||
):
|
||||
if fmt not in ("csv", "txt", "json"):
|
||||
return error_response("不支持的导出格式", 400)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
raise ProxyPoolException("不支持的导出格式", 400)
|
||||
|
||||
media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"}
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""调度器相关路由"""
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.repositories.settings_repo import SettingsRepository
|
||||
from app.core.db import get_db
|
||||
from app.api.deps import get_scheduler_service
|
||||
from app.api.common import success_response, error_response
|
||||
from app.api.common import success_response
|
||||
|
||||
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
|
||||
settings_repo = SettingsRepository()
|
||||
@@ -39,8 +40,8 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
|
||||
|
||||
@router.post("/validate-now")
|
||||
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||
await scheduler.validate_all_now()
|
||||
return success_response("已开始全量验证", {"started": True})
|
||||
job_id = scheduler.validate_all_now()
|
||||
return success_response("已开始全量验证", {"started": True, "job_id": job_id})
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Request
|
||||
from app.core.db import get_db
|
||||
from app.repositories.settings_repo import SettingsRepository
|
||||
from app.models.schemas import SettingsSchema
|
||||
from app.api.common import success_response, error_response
|
||||
from app.api.common import success_response
|
||||
from app.core.log import logger
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
@@ -22,14 +22,32 @@ async def save_settings(request: SettingsSchema, http_request: Request):
|
||||
async with get_db() as db:
|
||||
success = await settings_repo.save(db, request.model_dump())
|
||||
if not success:
|
||||
return error_response("保存设置失败", 500)
|
||||
raise RuntimeError("保存设置失败")
|
||||
|
||||
# 热更新运行中调度器的间隔时间
|
||||
scheduler = getattr(http_request.app.state, "scheduler_service", None)
|
||||
if scheduler and scheduler.running:
|
||||
scheduler = getattr(http_request.app.state, "scheduler", None)
|
||||
worker_pool = getattr(http_request.app.state, "worker_pool", None)
|
||||
validator = getattr(http_request.app.state, "validator", None)
|
||||
|
||||
if scheduler:
|
||||
new_interval = request.validate_interval_minutes
|
||||
if scheduler.interval_minutes != new_interval:
|
||||
scheduler.interval_minutes = new_interval
|
||||
logger.info(f"Scheduler interval updated to {new_interval} minutes")
|
||||
|
||||
# 热更新 Worker 池大小
|
||||
if worker_pool and worker_pool.worker_count != request.default_concurrency:
|
||||
await worker_pool.resize(request.default_concurrency)
|
||||
logger.info(f"Worker pool resized to {request.default_concurrency}")
|
||||
|
||||
# 热更新验证器超时和并发(下次验证时生效)
|
||||
if validator:
|
||||
validator._init_timeout = request.validation_timeout
|
||||
validator._init_connect_timeout = request.validation_timeout
|
||||
validator._init_max_concurrency = request.default_concurrency
|
||||
# 重新创建 semaphore 和 session
|
||||
validator._semaphore = None
|
||||
await validator.close()
|
||||
logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}")
|
||||
|
||||
return success_response("保存设置成功", request.model_dump())
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""异步任务状态路由"""
|
||||
from fastapi import APIRouter
|
||||
from app.services.task_service import task_service
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.execution import JobExecutor
|
||||
from app.api.deps import get_executor
|
||||
from app.api.common import success_response, error_response
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
task = task_service.get_task(task_id)
|
||||
if not task:
|
||||
async def get_task_status(task_id: str, executor: JobExecutor = Depends(get_executor)):
|
||||
job = executor.get_job(task_id)
|
||||
if not job:
|
||||
return error_response("任务不存在", 404)
|
||||
return success_response("获取任务状态成功", task)
|
||||
return success_response("获取任务状态成功", job.to_dict())
|
||||
|
||||
@@ -84,22 +84,7 @@ async def init_db():
|
||||
await db.execute("ALTER TABLE plugin_settings ADD COLUMN stats_json TEXT DEFAULT '{}'")
|
||||
logger.info("Migrated: added stats_json column to plugin_settings")
|
||||
|
||||
# 验证任务队列表
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS validation_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ip TEXT NOT NULL,
|
||||
port INTEGER NOT NULL,
|
||||
protocol TEXT DEFAULT 'http',
|
||||
status TEXT DEFAULT 'pending',
|
||||
result TEXT,
|
||||
response_time_ms REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_status ON validation_tasks(status)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_created ON validation_tasks(created_at)")
|
||||
# 验证任务队列表已废弃,不再创建;旧表保留在数据库中不影响功能
|
||||
|
||||
# 系统设置表
|
||||
await db.execute("""
|
||||
@@ -125,3 +110,27 @@ async def get_db() -> AsyncIterator[aiosqlite.Connection]:
|
||||
yield db
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction() -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""获取带有显式事务控制的数据库连接
|
||||
|
||||
用法:
|
||||
async with transaction() as db:
|
||||
await repo.update(db, ...)
|
||||
# 如果抛出异常,自动 rollback
|
||||
"""
|
||||
ensure_db_dir()
|
||||
db = await aiosqlite.connect(DB_PATH)
|
||||
try:
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute("PRAGMA synchronous=NORMAL")
|
||||
await db.execute("BEGIN")
|
||||
yield db
|
||||
await db.commit()
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
13
app/core/execution/__init__.py
Normal file
13
app/core/execution/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""统一后台任务执行框架"""
|
||||
from .job import JobStatus, Job, CrawlJob, ValidateAllJob
|
||||
from .worker_pool import AsyncWorkerPool
|
||||
from .executor import JobExecutor
|
||||
|
||||
__all__ = [
|
||||
"JobStatus",
|
||||
"Job",
|
||||
"CrawlJob",
|
||||
"ValidateAllJob",
|
||||
"AsyncWorkerPool",
|
||||
"JobExecutor",
|
||||
]
|
||||
138
app/core/execution/executor.py
Normal file
138
app/core/execution/executor.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Job 执行器 - 统一管理所有后台 Job 的生命周期"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.execution.job import Job, JobStatus
|
||||
from app.core.execution.worker_pool import AsyncWorkerPool
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class JobExecutor:
|
||||
"""后台任务执行器
|
||||
|
||||
- 维护内存中的 Job 状态表
|
||||
- 限制最大并发 Job 数
|
||||
- 自动清理过期 Job
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_pool: Optional[AsyncWorkerPool] = None,
|
||||
max_concurrent_jobs: int = 10,
|
||||
cleanup_interval_seconds: int = 300,
|
||||
):
|
||||
self.worker_pool = worker_pool
|
||||
self.max_concurrent_jobs = max_concurrent_jobs
|
||||
self._jobs: Dict[str, Job] = {}
|
||||
self._running = False
|
||||
self._semaphore = asyncio.Semaphore(max_concurrent_jobs)
|
||||
self._cleanup_interval = cleanup_interval_seconds
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("JobExecutor started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
await self.cancel_all()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("JobExecutor stopped")
|
||||
|
||||
def submit_job(self, job: Job) -> str:
|
||||
"""提交一个 Job 到后台执行"""
|
||||
self._jobs[job.id] = job
|
||||
asyncio.create_task(self._run_job(job))
|
||||
return job.id
|
||||
|
||||
async def _run_job(self, job: Job) -> None:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if job.is_cancelled:
|
||||
logger.info(f"Job {job.id} was cancelled before running")
|
||||
return
|
||||
result = await job.run()
|
||||
# 如果子类没有显式设置完成状态,自动设为 completed
|
||||
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
job._set_completed(result)
|
||||
logger.info(f"Job {job.id} completed: {result}")
|
||||
except asyncio.CancelledError:
|
||||
job.status = JobStatus.CANCELLED
|
||||
job._touch()
|
||||
logger.info(f"Job {job.id} cancelled during execution")
|
||||
except Exception as e:
|
||||
job._set_failed(str(e))
|
||||
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: Optional[JobStatus] = None,
|
||||
limit: int = 100,
|
||||
) -> List[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||
return jobs[:limit]
|
||||
|
||||
async def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
async def cancel_all(self) -> int:
|
||||
cancelled = 0
|
||||
for job in list(self._jobs.values()):
|
||||
if job.status in (JobStatus.PENDING, JobStatus.RUNNING):
|
||||
job.cancel()
|
||||
cancelled += 1
|
||||
return cancelled
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""定期清理已完成且过期的 Job"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
self._cleanup_old_jobs()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"JobExecutor cleanup error: {e}")
|
||||
|
||||
def _cleanup_old_jobs(self) -> int:
|
||||
cutoff = datetime.now() - timedelta(seconds=self._cleanup_interval)
|
||||
to_remove = [
|
||||
job_id
|
||||
for job_id, job in self._jobs.items()
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and job.updated_at < cutoff
|
||||
]
|
||||
for job_id in to_remove:
|
||||
del self._jobs[job_id]
|
||||
if to_remove:
|
||||
logger.info(f"JobExecutor cleaned up {len(to_remove)} old jobs")
|
||||
return len(to_remove)
|
||||
159
app/core/execution/job.py
Normal file
159
app/core/execution/job.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Job 定义 - 所有后台异步任务的统一抽象"""
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.models.domain import ProxyRaw
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job(ABC):
|
||||
"""后台任务基类"""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
status: JobStatus = JobStatus.PENDING
|
||||
progress: float = 0.0 # 0-100
|
||||
result: Any = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
_cancelled: bool = field(default=False, repr=False)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""请求取消任务"""
|
||||
self._cancelled = True
|
||||
if self.status in (JobStatus.PENDING, JobStatus.RUNNING):
|
||||
self.status = JobStatus.CANCELLED
|
||||
self._touch()
|
||||
|
||||
def _touch(self) -> None:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def _set_running(self) -> None:
|
||||
self.status = JobStatus.RUNNING
|
||||
self._touch()
|
||||
|
||||
def _set_completed(self, result: Any = None) -> None:
|
||||
self.status = JobStatus.COMPLETED
|
||||
self.result = result
|
||||
self.progress = 100.0
|
||||
self._touch()
|
||||
|
||||
def _set_failed(self, error: str) -> None:
|
||||
self.status = JobStatus.FAILED
|
||||
self.error = error
|
||||
self._touch()
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> Any:
|
||||
"""执行任务的核心逻辑,子类必须实现"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"status": self.status.value,
|
||||
"progress": round(self.progress, 2),
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlJob(Job):
|
||||
"""插件爬取任务"""
|
||||
|
||||
plugin_id: str = ""
|
||||
plugin_runner: Any = field(repr=False, default=None)
|
||||
proxy_service: Any = field(repr=False, default=None)
|
||||
validator_pool: Any = field(repr=False, default=None)
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.core.log import logger
|
||||
|
||||
self._set_running()
|
||||
if not self.plugin_runner:
|
||||
raise RuntimeError("plugin_runner is not set")
|
||||
|
||||
plugin_service = self.proxy_service or PluginService()
|
||||
plugin = plugin_service.get_plugin(self.plugin_id)
|
||||
if not plugin:
|
||||
raise ValueError(f"Plugin '{self.plugin_id}' not found")
|
||||
|
||||
result = await self.plugin_runner.run(plugin)
|
||||
proxies: List[ProxyRaw] = result.proxies if result else []
|
||||
|
||||
if proxies and self.validator_pool:
|
||||
await self.validator_pool.submit(proxies)
|
||||
logger.info(f"CrawlJob {self.id}: submitted {len(proxies)} proxies for validation")
|
||||
|
||||
payload = {
|
||||
"plugin_id": self.plugin_id,
|
||||
"proxy_count": len(proxies),
|
||||
}
|
||||
if result:
|
||||
payload["success_count"] = result.success_count
|
||||
payload["failure_count"] = result.failure_count
|
||||
self._set_completed(payload)
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidateAllJob(Job):
|
||||
"""全量验证任务 - 验证数据库中所有存量代理"""
|
||||
|
||||
proxy_repo: Any = field(repr=False, default=None)
|
||||
validator_pool: Any = field(repr=False, default=None)
|
||||
batch_size: int = 100
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.core.db import get_db
|
||||
from app.core.log import logger
|
||||
|
||||
self._set_running()
|
||||
repo = self.proxy_repo or ProxyRepository()
|
||||
|
||||
async with get_db() as db:
|
||||
proxies = await repo.list_all(db)
|
||||
|
||||
if not proxies:
|
||||
self._set_completed({"total": 0, "submitted": 0})
|
||||
return self.result
|
||||
|
||||
total = len(proxies)
|
||||
submitted = 0
|
||||
for i in range(0, total, self.batch_size):
|
||||
if self.is_cancelled:
|
||||
logger.info(f"ValidateAllJob {self.id}: cancelled")
|
||||
break
|
||||
batch = proxies[i : i + self.batch_size]
|
||||
raws = [ProxyRaw(p.ip, p.port, p.protocol) for p in batch]
|
||||
if self.validator_pool:
|
||||
await self.validator_pool.submit(raws)
|
||||
submitted += len(raws)
|
||||
self.progress = min(100.0, (submitted / total) * 100)
|
||||
self._touch()
|
||||
|
||||
payload = {"total": total, "submitted": submitted}
|
||||
self._set_completed(payload)
|
||||
logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies")
|
||||
return payload
|
||||
97
app/core/execution/worker_pool.py
Normal file
97
app/core/execution/worker_pool.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""轻量级纯内存异步 Worker Pool"""
|
||||
import asyncio
|
||||
from typing import Callable, Coroutine, List, TypeVar
|
||||
|
||||
from app.core.log import logger
|
||||
|
||||
T = TypeVar("T")
|
||||
Handler = Callable[[T], Coroutine[None, None, None]]
|
||||
|
||||
|
||||
class AsyncWorkerPool:
|
||||
"""纯内存异步工作池
|
||||
|
||||
职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理
|
||||
不包含任何持久化逻辑,也不维护复杂的状态机。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_count: int,
|
||||
handler: Handler,
|
||||
queue_maxsize: int = 10000,
|
||||
name: str = "WorkerPool",
|
||||
):
|
||||
self.worker_count = worker_count
|
||||
self.handler = handler
|
||||
self.name = name
|
||||
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self._workers: List[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._shutdown_event.clear()
|
||||
self._workers = [
|
||||
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
|
||||
for i in range(self.worker_count)
|
||||
]
|
||||
logger.info(f"{self.name} started with {self.worker_count} workers")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
self._shutdown_event.set()
|
||||
# 发送足够数量的 sentinel,确保所有 Worker 都能收到
|
||||
for _ in range(len(self._workers) + self._queue.qsize()):
|
||||
try:
|
||||
self._queue.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
break
|
||||
if self._workers:
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info(f"{self.name} stopped")
|
||||
|
||||
async def submit(self, items: List[T]) -> None:
|
||||
"""提交一批任务到队列(阻塞直到有空位,天然背压)"""
|
||||
for item in items:
|
||||
await self._queue.put(item)
|
||||
|
||||
async def drain(self) -> None:
|
||||
"""等待队列中所有任务被消费完毕"""
|
||||
await self._queue.join()
|
||||
|
||||
async def resize(self, new_worker_count: int) -> None:
|
||||
"""动态调整 Worker 数量:先全部停止,再按新数量启动"""
|
||||
if new_worker_count == self.worker_count:
|
||||
return
|
||||
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
|
||||
# 安全做法:先 stop 再 start,避免新旧 Worker 竞争 sentinel 导致死锁
|
||||
await self.stop()
|
||||
self.worker_count = new_worker_count
|
||||
await self.start()
|
||||
|
||||
async def _worker_loop(self, worker_id: int) -> None:
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
try:
|
||||
if item is None or not self._running:
|
||||
self._queue.task_done()
|
||||
break
|
||||
await self.handler(item)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
@@ -45,6 +45,11 @@ class PluginRegistry:
|
||||
def get_plugin_names(self) -> List[str]:
|
||||
return list(self._plugins.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有已注册插件(主要用于测试)"""
|
||||
self._plugins.clear()
|
||||
self._instances.clear()
|
||||
|
||||
def auto_discover(self, package_name: str):
|
||||
"""自动扫描指定包下的所有模块并注册其中的插件类。
|
||||
注意:为了类型安全和可控性,推荐显式注册。auto_discover 仅作为兼容。"""
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""验证任务队列 - 解耦爬取与验证,支持背压控制和持久化"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from app.models.domain import ProxyRaw
|
||||
from app.repositories.task_repo import ValidationTaskRepository
|
||||
from app.core.db import get_db
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class ValidationQueue:
|
||||
"""代理验证队列(支持持久化到 SQLite)
|
||||
|
||||
工作流程:
|
||||
1. 爬虫将原始代理 submit() 到队列(写入数据库 + 内存信号)
|
||||
2. Worker 池从数据库消费并验证
|
||||
3. 验证通过的代理写入数据库
|
||||
4. 服务重启时自动恢复未完成的 pending 任务
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validator,
|
||||
proxy_repo,
|
||||
worker_count: int = 50,
|
||||
score_valid: int = 10,
|
||||
score_invalid: int = -5,
|
||||
score_min: int = 0,
|
||||
score_max: int = 100,
|
||||
):
|
||||
self.validator = validator
|
||||
self.proxy_repo = proxy_repo
|
||||
self.task_repo = ValidationTaskRepository()
|
||||
self.worker_count = worker_count
|
||||
self.score_valid = score_valid
|
||||
self.score_invalid = score_invalid
|
||||
self.score_min = score_min
|
||||
self.score_max = score_max
|
||||
|
||||
self._signal: asyncio.Queue[None] = asyncio.Queue()
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._pending_count = 0
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
# 统计
|
||||
self.valid_count = 0
|
||||
self.invalid_count = 0
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
# 恢复之前中断的 processing 任务
|
||||
async with get_db() as db:
|
||||
recovered = await self.task_repo.reset_processing(db)
|
||||
pending = await self.task_repo.get_pending_count(db)
|
||||
if pending > 1000:
|
||||
logger.warning(f"ValidationQueue has {pending} pending tasks, cleaning up all pending tasks...")
|
||||
await db.execute("DELETE FROM validation_tasks WHERE status = 'pending'")
|
||||
await db.commit()
|
||||
pending = await self.task_repo.get_pending_count(db)
|
||||
logger.info(f"ValidationQueue cleaned up pending tasks, remaining: {pending}")
|
||||
if recovered:
|
||||
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
||||
if pending:
|
||||
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
||||
async with self._condition:
|
||||
self._pending_count = pending
|
||||
|
||||
for i in range(self.worker_count):
|
||||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||||
|
||||
# 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理)
|
||||
if pending:
|
||||
for _ in range(self.worker_count):
|
||||
self._signal.put_nowait(None)
|
||||
|
||||
logger.info(f"ValidationQueue started with {self.worker_count} workers")
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
for _ in self._workers:
|
||||
self._signal.put_nowait(None) # sentinel
|
||||
if self._workers:
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("ValidationQueue stopped")
|
||||
|
||||
async def submit(self, proxies: list[ProxyRaw]):
|
||||
"""提交代理到验证队列(持久化 + 唤醒 Worker)"""
|
||||
async with get_db() as db:
|
||||
inserted = await self.task_repo.insert_batch(db, proxies)
|
||||
if inserted:
|
||||
async with self._condition:
|
||||
self._pending_count += inserted
|
||||
self._condition.notify_all()
|
||||
for _ in range(min(inserted, self.worker_count)):
|
||||
self._signal.put_nowait(None)
|
||||
|
||||
async def submit_one(self, proxy: ProxyRaw):
|
||||
await self.submit([proxy])
|
||||
|
||||
async def drain(self):
|
||||
"""等待队列中当前所有 pending 任务处理完毕"""
|
||||
async with self._condition:
|
||||
if self._pending_count > 0:
|
||||
await self._condition.wait_for(lambda: self._pending_count == 0)
|
||||
|
||||
async def _worker_loop(self, worker_id: int):
|
||||
while True:
|
||||
await self._signal.get()
|
||||
self._signal.task_done()
|
||||
if not self._running:
|
||||
break
|
||||
# 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死
|
||||
while self._running:
|
||||
processed = await self._process_one_task(worker_id)
|
||||
if not processed:
|
||||
break
|
||||
|
||||
async def _process_one_task(self, worker_id: int) -> bool:
|
||||
"""从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。"""
|
||||
async with get_db() as db:
|
||||
task = await self.task_repo.acquire_pending(db)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
|
||||
try:
|
||||
is_valid, latency = await self.validator.validate(
|
||||
proxy.ip, proxy.port, proxy.protocol
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True)
|
||||
is_valid, latency = False, 0.0
|
||||
|
||||
if is_valid:
|
||||
await self.proxy_repo.insert_or_update(
|
||||
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
|
||||
)
|
||||
if latency:
|
||||
await self.proxy_repo.update_response_time(
|
||||
db, proxy.ip, proxy.port, latency
|
||||
)
|
||||
await self.task_repo.complete_task(db, task["id"], True, latency)
|
||||
self.valid_count += 1
|
||||
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
|
||||
else:
|
||||
# 对已有代理扣分,分数<=0时自动删除
|
||||
await self.proxy_repo.update_score(
|
||||
db, proxy.ip, proxy.port, self.score_invalid,
|
||||
self.score_min, self.score_max
|
||||
)
|
||||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||||
self.invalid_count += 1
|
||||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||||
async with self._condition:
|
||||
self._pending_count = max(0, self._pending_count - 1)
|
||||
if self._pending_count == 0:
|
||||
self._condition.notify_all()
|
||||
return True
|
||||
|
||||
def reset_stats(self):
|
||||
self.valid_count = 0
|
||||
self.invalid_count = 0
|
||||
@@ -1,7 +1,7 @@
|
||||
"""领域模型 - 纯数据结构,不依赖任何框架"""
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,3 +42,13 @@ class PluginInfo:
|
||||
last_run: Optional[datetime] = None
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlResult:
|
||||
"""插件爬取结果"""
|
||||
plugin_name: str
|
||||
proxies: List[ProxyRaw] = field(default_factory=list)
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
"""通用 HTTP 爬虫基类 - 为基于 HTTP 请求的插件提供封装"""
|
||||
import re
|
||||
import random
|
||||
import asyncio
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from bs4 import BeautifulSoup
|
||||
from app.core.plugin_system import BaseCrawlerPlugin
|
||||
from app.models.domain import ProxyRaw
|
||||
|
||||
|
||||
VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
|
||||
|
||||
|
||||
class BaseHTTPPlugin(BaseCrawlerPlugin):
|
||||
@@ -20,6 +26,7 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
||||
self.urls: List[str] = []
|
||||
self.current_url: str = ""
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self.max_concurrency: int = 3
|
||||
|
||||
def get_headers(self) -> dict:
|
||||
return {
|
||||
@@ -39,13 +46,21 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str:
|
||||
async def fetch(
|
||||
self,
|
||||
url: str,
|
||||
timeout: float = 15.0,
|
||||
retries: int = 2,
|
||||
raise_for_status: bool = False,
|
||||
) -> str:
|
||||
"""异步抓取指定 URL 的 HTML 内容"""
|
||||
from app.core.log import logger
|
||||
client = self._get_client()
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = await client.get(url, headers=self.get_headers(), timeout=timeout)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
encoding = response.encoding
|
||||
@@ -64,8 +79,8 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
||||
return ""
|
||||
|
||||
async def fetch_all(self, urls: List[str], timeout: float = 15.0) -> List[str]:
|
||||
"""并发抓取多个 URL,限制单个插件内部并发为 3"""
|
||||
semaphore = asyncio.Semaphore(3)
|
||||
"""并发抓取多个 URL,限制单个插件内部并发"""
|
||||
semaphore = asyncio.Semaphore(self.max_concurrency)
|
||||
|
||||
async def _fetch_limited(url: str):
|
||||
async with semaphore:
|
||||
@@ -74,6 +89,70 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
||||
tasks = [_fetch_limited(url) for url in urls]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def parse_text_proxies(self, text: str, protocol: str = "http") -> List[ProxyRaw]:
|
||||
"""解析 ip:port 格式的文本代理列表
|
||||
|
||||
统一处理 \r\n、\n 两种换行以及可能存在的空行。
|
||||
"""
|
||||
results = []
|
||||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
ip, _, port = line.rpartition(":")
|
||||
ip = ip.strip()
|
||||
port = port.strip()
|
||||
if ip and port.isdigit() and 1 <= int(port) <= 65535:
|
||||
try:
|
||||
results.append(ProxyRaw(ip, int(port), protocol))
|
||||
except ValueError:
|
||||
continue
|
||||
return results
|
||||
|
||||
def parse_html_table(
|
||||
self,
|
||||
html: str,
|
||||
column_map: dict,
|
||||
protocol: str = "http",
|
||||
) -> List[ProxyRaw]:
|
||||
"""通用 HTML 表格解析器
|
||||
|
||||
Args:
|
||||
html: HTML 文本
|
||||
column_map: 列名到索引的映射,如 {"ip": 0, "port": 1, "protocol": 4}
|
||||
protocol: 默认协议,如果表格中没有协议列则使用此值
|
||||
"""
|
||||
results = []
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
table = soup.find("table")
|
||||
if not table:
|
||||
return results
|
||||
|
||||
ip_idx = column_map.get("ip", 0)
|
||||
port_idx = column_map.get("port", 1)
|
||||
protocol_idx = column_map.get("protocol", -1)
|
||||
|
||||
for row in table.find_all("tr"):
|
||||
tds = row.find_all("td")
|
||||
if len(tds) <= max(ip_idx, port_idx):
|
||||
continue
|
||||
ip = tds[ip_idx].get_text(strip=True)
|
||||
port = tds[port_idx].get_text(strip=True)
|
||||
if protocol_idx >= 0 and len(tds) > protocol_idx:
|
||||
proto = tds[protocol_idx].get_text(strip=True).lower()
|
||||
if proto not in VALID_PROTOCOLS:
|
||||
proto = protocol
|
||||
else:
|
||||
proto = protocol
|
||||
|
||||
if re.match(r"^\d+\.\d+\.\d+\.\d+$", ip) and port.isdigit() and 1 <= int(port) <= 65535:
|
||||
try:
|
||||
results.append(ProxyRaw(ip, int(port), proto))
|
||||
except ValueError:
|
||||
continue
|
||||
return results
|
||||
|
||||
async def close(self):
|
||||
"""关闭复用的 HTTP 客户端"""
|
||||
if self._client and not self._client.is_closed:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""数据访问层包"""
|
||||
from .proxy_repo import ProxyRepository
|
||||
from .settings_repo import SettingsRepository, PluginSettingsRepository
|
||||
from .task_repo import ValidationTaskRepository
|
||||
|
||||
__all__ = [
|
||||
"ProxyRepository",
|
||||
"SettingsRepository",
|
||||
"PluginSettingsRepository",
|
||||
"ValidationTaskRepository",
|
||||
]
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
"""验证任务队列持久化层"""
|
||||
import aiosqlite
|
||||
from typing import List, Optional
|
||||
from app.models.domain import ProxyRaw
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class ValidationTaskRepository:
|
||||
"""验证任务 Repository —— 支持队列持久化"""
|
||||
|
||||
@staticmethod
|
||||
async def insert_batch(db: aiosqlite.Connection, proxies: List[ProxyRaw]) -> int:
|
||||
if not proxies:
|
||||
return 0
|
||||
try:
|
||||
rows = [(p.ip, p.port, p.protocol) for p in proxies]
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO validation_tasks (ip, port, protocol, status, created_at)
|
||||
VALUES (?, ?, ?, 'pending', CURRENT_TIMESTAMP)
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
await db.commit()
|
||||
return len(rows)
|
||||
except Exception as e:
|
||||
logger.error(f"insert_batch validation tasks failed: {e}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def acquire_pending(db: aiosqlite.Connection) -> Optional[dict]:
|
||||
"""原子性地获取一个 pending 任务并将其标记为 processing"""
|
||||
try:
|
||||
async with db.execute(
|
||||
"""
|
||||
UPDATE validation_tasks
|
||||
SET status = 'processing', updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
|
||||
RETURNING id, ip, port, protocol
|
||||
"""
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
|
||||
except Exception as e:
|
||||
logger.error(f"acquire_pending failed: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def complete_task(
|
||||
db: aiosqlite.Connection,
|
||||
task_id: int,
|
||||
is_valid: bool,
|
||||
response_time_ms: Optional[float] = None,
|
||||
) -> bool:
|
||||
try:
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE validation_tasks
|
||||
SET status = 'completed',
|
||||
result = ?,
|
||||
response_time_ms = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
""",
|
||||
("valid" if is_valid else "invalid", response_time_ms, task_id),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"complete_task failed: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def fail_task(db: aiosqlite.Connection, task_id: int) -> bool:
|
||||
try:
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE validation_tasks
|
||||
SET status = 'failed',
|
||||
result = 'invalid',
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
""",
|
||||
(task_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"fail_task failed: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_count(db: aiosqlite.Connection) -> int:
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM validation_tasks WHERE status = 'pending'"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
@staticmethod
|
||||
async def reset_processing(db: aiosqlite.Connection) -> int:
|
||||
"""将异常中断的 processing 任务重置为 pending,用于启动恢复"""
|
||||
try:
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE validation_tasks
|
||||
SET status = 'pending', updated_at = CURRENT_TIMESTAMP
|
||||
WHERE status = 'processing'
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
return db.total_changes
|
||||
except Exception as e:
|
||||
logger.error(f"reset_processing failed: {e}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
|
||||
try:
|
||||
await db.execute(
|
||||
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
|
||||
(days,),
|
||||
)
|
||||
await db.commit()
|
||||
return db.total_changes
|
||||
except Exception as e:
|
||||
logger.error(f"cleanup_old tasks failed: {e}")
|
||||
return 0
|
||||
90
app/services/plugin_runner.py
Normal file
90
app/services/plugin_runner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""插件统一执行器 - 封装超时、重试、健康检查、错误捕获"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
from app.models.domain import CrawlResult, ProxyRaw
|
||||
|
||||
|
||||
class PluginRunner:
|
||||
"""统一插件执行器
|
||||
|
||||
- 超时控制(从 settings 读取 crawl_timeout)
|
||||
- 异常捕获和统计更新
|
||||
- 可选的健康检查前置
|
||||
- 结果去重
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30)
|
||||
|
||||
async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult:
|
||||
"""执行单个插件爬取"""
|
||||
result = CrawlResult(plugin_name=plugin.name)
|
||||
|
||||
# 健康检查(可选)
|
||||
try:
|
||||
healthy = await asyncio.wait_for(
|
||||
plugin.health_check(), timeout=5.0
|
||||
)
|
||||
if not healthy:
|
||||
result.error = "health check failed"
|
||||
result.failure_count = 1
|
||||
await self._save_stats(plugin, result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Plugin {plugin.name} health check error: {e}")
|
||||
|
||||
# 执行爬取
|
||||
try:
|
||||
proxies = await asyncio.wait_for(
|
||||
plugin.crawl(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
result.proxies = self._dedup(proxies)
|
||||
result.success_count = 1 if result.proxies else 0
|
||||
logger.info(
|
||||
f"Plugin {plugin.name} crawled {len(result.proxies)} unique proxies"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
result.error = f"crawl timeout after {self.timeout}s"
|
||||
result.failure_count = 1
|
||||
logger.error(f"Plugin {plugin.name} crawl timeout")
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
result.failure_count = 1
|
||||
logger.error(f"Plugin {plugin.name} crawl failed: {e}", exc_info=True)
|
||||
|
||||
await self._save_stats(plugin, result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dedup(proxies: list[ProxyRaw]) -> list[ProxyRaw]:
|
||||
seen = set()
|
||||
unique = []
|
||||
for p in proxies:
|
||||
key = (p.ip, p.port)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
async def _save_stats(self, plugin: BaseCrawlerPlugin, result: CrawlResult) -> None:
|
||||
"""将爬取统计持久化到数据库"""
|
||||
from app.core.db import get_db
|
||||
from app.repositories.settings_repo import PluginSettingsRepository
|
||||
|
||||
repo = PluginSettingsRepository()
|
||||
payload = {
|
||||
"success_count": result.success_count,
|
||||
"failure_count": result.failure_count,
|
||||
"last_run": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
async with get_db() as db:
|
||||
await repo.set_stats(db, plugin.name, payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save stats for {plugin.name}: {e}")
|
||||
@@ -1,21 +1,21 @@
|
||||
"""插件业务服务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.db import get_db
|
||||
from app.core.plugin_system.registry import registry
|
||||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||||
from app.core.exceptions import PluginNotFoundException
|
||||
from app.repositories.settings_repo import PluginSettingsRepository
|
||||
from app.models.domain import PluginInfo, ProxyRaw
|
||||
from app.models.domain import PluginInfo, ProxyRaw, CrawlResult
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class PluginService:
|
||||
"""插件业务服务:管理插件生命周期、执行爬取、配置管理"""
|
||||
"""插件业务服务:管理插件生命周期、配置管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_settings_repo = PluginSettingsRepository()
|
||||
self._stats: dict[str, dict] = {}
|
||||
|
||||
async def list_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有插件信息(合并持久化状态和配置)"""
|
||||
@@ -24,42 +24,40 @@ class PluginService:
|
||||
|
||||
result = []
|
||||
for plugin in registry.list_plugins():
|
||||
# 合并持久化状态(不修改全局实例,避免并发竞争)
|
||||
state = db_states.get(plugin.name, {})
|
||||
enabled = state.get("enabled", plugin.enabled)
|
||||
if "config" in state and isinstance(state["config"], dict):
|
||||
plugin.update_config(state["config"])
|
||||
|
||||
# 合并数据库统计与内存统计(内存优先)
|
||||
db_stat = state.get("stats", {})
|
||||
stat = {
|
||||
"success_count": db_stat.get("success_count", 0),
|
||||
"failure_count": db_stat.get("failure_count", 0),
|
||||
"last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None,
|
||||
}
|
||||
mem_stat = self._stats.get(plugin.name, {})
|
||||
if mem_stat:
|
||||
stat["success_count"] = mem_stat.get("success_count", stat["success_count"])
|
||||
stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"])
|
||||
if mem_stat.get("last_run"):
|
||||
stat["last_run"] = mem_stat["last_run"]
|
||||
|
||||
result.append(PluginInfo(
|
||||
id=plugin.name,
|
||||
name=plugin.name,
|
||||
display_name=plugin.display_name or plugin.name,
|
||||
description=plugin.description or f"从 {plugin.name} 爬取代理",
|
||||
enabled=enabled,
|
||||
last_run=stat.get("last_run"),
|
||||
success_count=stat.get("success_count", 0),
|
||||
failure_count=stat.get("failure_count", 0),
|
||||
last_run=self._parse_datetime(db_stat.get("last_run")),
|
||||
success_count=db_stat.get("success_count", 0),
|
||||
failure_count=db_stat.get("failure_count", 0),
|
||||
))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value) -> Optional:
|
||||
from datetime import datetime
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool:
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled)
|
||||
if success:
|
||||
@@ -68,10 +66,9 @@ class PluginService:
|
||||
return success
|
||||
|
||||
async def get_plugin_config(self, plugin_id: str) -> Optional[dict]:
|
||||
"""获取插件当前配置(合并默认值和持久化值)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return None
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
saved = await self.plugin_settings_repo.get_config(db, plugin_id)
|
||||
config = dict(plugin.default_config)
|
||||
@@ -80,11 +77,9 @@ class PluginService:
|
||||
return config
|
||||
|
||||
async def update_plugin_config(self, plugin_id: str, config: dict) -> bool:
|
||||
"""更新插件配置(只保存已存在于 default_config 中的键)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
# 过滤非法键
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
safe_config = {k: v for k, v in config.items() if k in plugin.default_config}
|
||||
if not safe_config:
|
||||
return False
|
||||
@@ -95,44 +90,45 @@ class PluginService:
|
||||
def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]:
|
||||
return registry.get(plugin_id)
|
||||
|
||||
async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]:
|
||||
"""执行单个插件爬取"""
|
||||
def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin:
|
||||
plugin = self.get_plugin(plugin_id)
|
||||
if not plugin:
|
||||
raise ValueError(f"Plugin {plugin_id} not found")
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
return plugin
|
||||
|
||||
async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult:
|
||||
"""执行单个插件爬取(通过 PluginRunner)"""
|
||||
plugin = self.get_plugin_or_raise(plugin_id)
|
||||
if not plugin.enabled:
|
||||
logger.warning(f"Plugin {plugin_id} is disabled, skip crawl")
|
||||
return []
|
||||
return CrawlResult(plugin_name=plugin_id)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
try:
|
||||
results = await plugin.crawl()
|
||||
self._record_stat(plugin_id, success=len(results))
|
||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||
return results
|
||||
except Exception as e:
|
||||
self._record_stat(plugin_id, failure=1)
|
||||
logger.error(f"Plugin {plugin_id} crawl failed: {e}")
|
||||
return []
|
||||
finally:
|
||||
await self._save_stats(plugin_id)
|
||||
|
||||
async def run_all_plugins(self) -> List[ProxyRaw]:
|
||||
async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]:
|
||||
"""执行所有启用插件的爬取,限制并发数以避免触发目标站反爬"""
|
||||
all_results: List[ProxyRaw] = []
|
||||
semaphore = asyncio.Semaphore(5)
|
||||
|
||||
async def _run_with_limit(plugin_name: str):
|
||||
plugin = self.get_plugin_or_raise(plugin_name)
|
||||
if not plugin.enabled:
|
||||
return CrawlResult(plugin_name=plugin_name)
|
||||
async with semaphore:
|
||||
return await self.run_plugin(plugin_name)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
|
||||
tasks = [
|
||||
_run_with_limit(plugin.name)
|
||||
for plugin in registry.list_plugins() if plugin.enabled
|
||||
]
|
||||
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for results in results_list:
|
||||
if isinstance(results, Exception):
|
||||
logger.error(f"Run all plugins error: {results}")
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Run all plugins error: {result}")
|
||||
continue
|
||||
all_results.extend(results)
|
||||
# 去重(与数据库 UNIQUE(ip, port) 约束保持一致)
|
||||
if isinstance(result, CrawlResult):
|
||||
all_results.extend(result.proxies)
|
||||
|
||||
# 去重
|
||||
seen = set()
|
||||
unique = []
|
||||
for p in all_results:
|
||||
@@ -141,26 +137,3 @@ class PluginService:
|
||||
seen.add(key)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0):
|
||||
if plugin_id not in self._stats:
|
||||
self._stats[plugin_id] = {
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"last_run": None,
|
||||
}
|
||||
self._stats[plugin_id]["success_count"] += success
|
||||
self._stats[plugin_id]["failure_count"] += failure
|
||||
if success or failure:
|
||||
self._stats[plugin_id]["last_run"] = datetime.now()
|
||||
|
||||
async def _save_stats(self, plugin_id: str):
|
||||
"""将内存中的统计持久化到数据库"""
|
||||
stats = self._stats.get(plugin_id, {})
|
||||
payload = {
|
||||
"success_count": stats.get("success_count", 0),
|
||||
"failure_count": stats.get("failure_count", 0),
|
||||
"last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None,
|
||||
}
|
||||
async with get_db() as db:
|
||||
await self.plugin_settings_repo.set_stats(db, plugin_id, payload)
|
||||
|
||||
@@ -4,7 +4,8 @@ import json
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, AsyncIterator
|
||||
from app.core.db import get_db
|
||||
|
||||
from app.core.db import get_db, transaction
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.models.domain import Proxy
|
||||
from app.core.log import logger
|
||||
|
||||
@@ -1,43 +1,42 @@
|
||||
"""调度器服务 - 定时验证存量代理"""
|
||||
"""调度器服务 - 定时触发全量验证"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from app.core.db import get_db
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.repositories.task_repo import ValidationTaskRepository
|
||||
from app.core.tasks.queue import ValidationQueue
|
||||
from app.core.config import settings as app_settings
|
||||
from typing import Optional
|
||||
|
||||
from app.core.execution.executor import JobExecutor
|
||||
from app.core.execution.job import ValidateAllJob
|
||||
from app.core.log import logger
|
||||
from app.models.domain import ProxyRaw
|
||||
from app.services.task_service import task_service
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""代理验证调度器"""
|
||||
"""代理验证调度器
|
||||
|
||||
职责单一:定时循环,触发 ValidateAllJob。
|
||||
不再直接持有验证队列或 ValidatorService。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validation_queue: ValidationQueue,
|
||||
proxy_repo: ProxyRepository = ProxyRepository(),
|
||||
executor: JobExecutor,
|
||||
interval_minutes: int = 30,
|
||||
):
|
||||
self.validation_queue = validation_queue
|
||||
self.proxy_repo = proxy_repo
|
||||
self.interval_minutes = 30
|
||||
self.executor = executor
|
||||
self.interval_minutes = interval_minutes
|
||||
self.running = False
|
||||
self._stop_event = asyncio.Event()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._validate_task: asyncio.Task | None = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self.running = True
|
||||
await self.validation_queue.start()
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
logger.info("Scheduler started")
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
if self._task:
|
||||
@@ -47,72 +46,26 @@ class SchedulerService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
await self.validation_queue.stop()
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
def cancel_validate_task(self):
|
||||
"""取消正在执行的全量验证后台任务"""
|
||||
if self._validate_task and not self._validate_task.done():
|
||||
self._validate_task.cancel()
|
||||
def validate_all_now(self) -> str:
|
||||
"""立即执行一次全量验证,返回 Job ID"""
|
||||
job_id = self.executor.submit_job(ValidateAllJob())
|
||||
logger.info(f"ValidateAllJob submitted: {job_id}")
|
||||
return job_id
|
||||
|
||||
async def validate_all_now(self):
|
||||
"""立即执行一次全量验证(后台运行,不阻塞)"""
|
||||
if self._validate_task and not self._validate_task.done():
|
||||
return
|
||||
self._validate_task = asyncio.create_task(self._do_validate_all(from_loop=False))
|
||||
|
||||
async def _run_loop(self):
|
||||
async def _run_loop(self) -> None:
|
||||
"""定时循环"""
|
||||
while self.running:
|
||||
try:
|
||||
# 清理过期任务,防止内存无限增长
|
||||
task_service.cleanup_old_tasks()
|
||||
await self._do_validate_all(from_loop=True)
|
||||
self.executor.submit_job(ValidateAllJob())
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler loop error: {e}", exc_info=True)
|
||||
# 等待下一次
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
||||
await asyncio.wait_for(
|
||||
self._stop_event.wait(),
|
||||
timeout=self.interval_minutes * 60,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _do_validate_all(self, from_loop: bool = True):
|
||||
"""验证数据库中所有存量代理"""
|
||||
queue_started_here = False
|
||||
try:
|
||||
logger.info("Starting scheduled validation for all proxies")
|
||||
|
||||
# 如果队列未运行,临时启动它(适用于 validate_all_now 在调度器停止时调用)
|
||||
if not self.validation_queue._running:
|
||||
await self.validation_queue.start()
|
||||
queue_started_here = True
|
||||
|
||||
async with get_db() as db:
|
||||
# 清理 7 天前的验证任务记录,防止表无限增长
|
||||
cleaned = await ValidationTaskRepository.cleanup_old(db, days=7)
|
||||
if cleaned:
|
||||
logger.info(f"Cleaned up {cleaned} old validation tasks")
|
||||
proxies = await self.proxy_repo.list_all(db)
|
||||
if not proxies:
|
||||
logger.info("No proxies to validate")
|
||||
return
|
||||
|
||||
logger.info(f"Validating {len(proxies)} proxies from database")
|
||||
# 批量提交到验证队列,不再阻塞等待 drain
|
||||
batch_size = 100
|
||||
total_batches = (len(proxies) - 1) // batch_size + 1
|
||||
for i in range(0, len(proxies), batch_size):
|
||||
if from_loop and not self.running:
|
||||
break
|
||||
batch = proxies[i : i + batch_size]
|
||||
await self.validation_queue.submit([
|
||||
ProxyRaw(p.ip, p.port, p.protocol) for p in batch
|
||||
])
|
||||
logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}")
|
||||
|
||||
logger.info("Scheduled validation batches submitted")
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled validation error: {e}", exc_info=True)
|
||||
finally:
|
||||
if queue_started_here:
|
||||
await self.validation_queue.stop()
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class TaskService:
|
||||
"""异步任务状态管理器"""
|
||||
|
||||
def __init__(self, max_age_seconds: int = 3600):
|
||||
self._tasks: Dict[str, dict] = {}
|
||||
self._max_age_seconds = max_age_seconds
|
||||
|
||||
def create_task(self, task_type: str) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now()
|
||||
self._tasks[task_id] = {
|
||||
"id": task_id,
|
||||
"type": task_type,
|
||||
"status": "pending",
|
||||
"message": "",
|
||||
"data": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
return task_id
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
data: Optional[dict] = None,
|
||||
) -> bool:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return False
|
||||
if status is not None:
|
||||
task["status"] = status
|
||||
if message is not None:
|
||||
task["message"] = message
|
||||
if data is not None:
|
||||
task["data"].update(data)
|
||||
task["updated_at"] = datetime.now()
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def cleanup_old_tasks(self) -> int:
|
||||
cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds)
|
||||
to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff]
|
||||
for tid in to_remove:
|
||||
del self._tasks[tid]
|
||||
if to_remove:
|
||||
logger.info(f"TaskService cleaned up {len(to_remove)} old tasks")
|
||||
return len(to_remove)
|
||||
|
||||
|
||||
# 全局任务服务实例
|
||||
task_service = TaskService()
|
||||
@@ -4,12 +4,17 @@ import random
|
||||
import time
|
||||
import aiohttp
|
||||
import aiohttp_socks
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class ValidatorService:
|
||||
"""代理验证器"""
|
||||
"""代理验证器
|
||||
|
||||
支持动态读取配置,实现设置热更新。
|
||||
"""
|
||||
|
||||
# 测试 URL
|
||||
TEST_URLS = {
|
||||
@@ -19,40 +24,66 @@ class ValidatorService:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 5.0,
|
||||
connect_timeout: float = 3.0,
|
||||
max_concurrency: int = 50,
|
||||
timeout: Optional[float] = None,
|
||||
connect_timeout: Optional[float] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.max_concurrency = max_concurrency
|
||||
self.semaphore = asyncio.Semaphore(max_concurrency)
|
||||
# 初始化时使用传入值或默认值,但运行期会动态读取 settings
|
||||
self._init_timeout = timeout or app_settings.validator_timeout
|
||||
self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout
|
||||
self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency
|
||||
|
||||
# 共享 HTTP/HTTPS ClientSession
|
||||
self._http_connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=max_concurrency,
|
||||
limit_per_host=max_concurrency,
|
||||
force_close=False,
|
||||
)
|
||||
self._timeout = aiohttp.ClientTimeout(
|
||||
total=timeout, connect=connect_timeout
|
||||
)
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
connector=self._http_connector,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
self._http_connector: Optional[aiohttp.TCPConnector] = None
|
||||
self._http_session: Optional[aiohttp.ClientSession] = None
|
||||
self._semaphore: Optional[asyncio.Semaphore] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return float(self._init_timeout)
|
||||
|
||||
@property
|
||||
def connect_timeout(self) -> float:
|
||||
return float(self._init_connect_timeout)
|
||||
|
||||
@property
|
||||
def max_concurrency(self) -> int:
|
||||
return int(self._init_max_concurrency)
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
"""懒加载共享 HTTP session"""
|
||||
if self._http_session is None or self._http_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=self.max_concurrency,
|
||||
limit_per_host=self.max_concurrency,
|
||||
force_close=False,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.timeout, connect=self.connect_timeout
|
||||
)
|
||||
self._http_connector = connector
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
)
|
||||
return self._http_session
|
||||
|
||||
def _ensure_semaphore(self) -> asyncio.Semaphore:
|
||||
if self._semaphore is None:
|
||||
self._semaphore = asyncio.Semaphore(self.max_concurrency)
|
||||
return self._semaphore
|
||||
|
||||
def _get_test_url(self, protocol: str) -> str:
|
||||
"""获取测试 URL"""
|
||||
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
|
||||
return random.choice(urls)
|
||||
|
||||
async def validate(self, ip: str, port: int, protocol: str = "http") -> Tuple[bool, float]:
|
||||
"""验证单个代理,返回 (是否有效, 延迟毫秒)"""
|
||||
protocol = protocol.lower()
|
||||
|
||||
async with self.semaphore:
|
||||
semaphore = self._ensure_semaphore()
|
||||
|
||||
async with semaphore:
|
||||
start = time.time()
|
||||
try:
|
||||
if protocol in ("socks4", "socks5"):
|
||||
@@ -67,11 +98,11 @@ class ValidatorService:
|
||||
return False, 0.0
|
||||
|
||||
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 HTTP/HTTPS 代理"""
|
||||
proxy_url = f"http://{ip}:{port}"
|
||||
test_url = self._get_test_url(protocol)
|
||||
session = self._ensure_session()
|
||||
|
||||
async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
if response.status in (200, 301, 302):
|
||||
latency = round((time.time() - start) * 1000, 2)
|
||||
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
||||
@@ -79,7 +110,6 @@ class ValidatorService:
|
||||
return False, 0.0
|
||||
|
||||
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 SOCKS4/SOCKS5 代理"""
|
||||
proxy_type = (
|
||||
aiohttp_socks.ProxyType.SOCKS4
|
||||
if protocol == "socks4"
|
||||
@@ -103,7 +133,9 @@ class ValidatorService:
|
||||
return True, latency
|
||||
return False, 0.0
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭共享的 HTTP ClientSession"""
|
||||
if self._http_session and not self._http_session.closed:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
self._http_connector = None
|
||||
|
||||
Reference in New Issue
Block a user