Round 3 fixes: cancelled polling, aggregator invalid_count, filter state, scheduler atomicity, HTTP exception handler, tests

This commit is contained in:
祀梦
2026-04-05 10:20:23 +08:00
parent 49e440cb41
commit dc5f050683
32 changed files with 321 additions and 163 deletions

View File

@@ -5,6 +5,7 @@ from app.services.plugin_service import PluginService
from app.services.scheduler_service import SchedulerService
from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, AsyncWorkerPool
from app.repositories.settings_repo import SettingsRepository
def get_proxy_service() -> ProxyService:
@@ -15,6 +16,10 @@ def get_plugin_service() -> PluginService:
return PluginService()
def get_settings_repo() -> SettingsRepository:
return SettingsRepository()
def get_scheduler_service(request: Request) -> SchedulerService:
return request.app.state.scheduler

View File

@@ -1,6 +1,7 @@
"""统一异常处理"""
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import ValidationError
from app.core.exceptions import ProxyPoolException
from app.core.log import logger
@@ -13,6 +14,13 @@ async def proxy_pool_exception_handler(request: Request, exc: ProxyPoolException
)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": exc.detail, "data": None},
)
async def pydantic_validation_handler(request: Request, exc: ValidationError):
logger.error(f"Validation error: {exc}")
return JSONResponse(

View File

@@ -80,6 +80,7 @@ async def lifespan(app: FastAPI):
# 调度器
scheduler = SchedulerService(
executor=executor,
worker_pool=worker_pool,
interval_minutes=db_settings.get("validate_interval_minutes", 30),
)

View File

@@ -3,9 +3,10 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.lifespan import lifespan
from app.api.routes import api_router
from app.api.errors import proxy_pool_exception_handler, pydantic_validation_handler, general_exception_handler
from app.api.errors import proxy_pool_exception_handler, http_exception_handler, pydantic_validation_handler, general_exception_handler
from app.core.exceptions import ProxyPoolException
from pydantic import ValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.core.config import settings as app_settings
# 导入并注册所有插件(显式注册模式)
@@ -30,6 +31,7 @@ def create_app() -> FastAPI:
# 异常处理
app.add_exception_handler(ProxyPoolException, proxy_pool_exception_handler)
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(ValidationError, pydantic_validation_handler)
app.add_exception_handler(Exception, general_exception_handler)

View File

@@ -7,7 +7,7 @@ from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, CrawlJob
from app.core.exceptions import PluginNotFoundException
from app.api.deps import get_plugin_service, get_plugin_runner, get_executor
from app.api.common import success_response
from app.api.common import success_response, format_plugin
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@@ -114,8 +114,11 @@ def _create_crawl_all_aggregator(job_ids, executor):
class CrawlAllAggregator(Job):
async def run(self):
self._set_running()
# 等待所有子 job 完成(最多等 30 秒)
for _ in range(300):
if self.is_cancelled:
break
all_done = all(
executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled")
for jid in job_ids
@@ -125,24 +128,18 @@ def _create_crawl_all_aggregator(job_ids, executor):
await asyncio.sleep(0.1)
total = 0
valid = 0
invalid = 0
for jid in job_ids:
job = executor.get_job(jid)
if job and job.result:
total += job.result.get("proxy_count", 0)
valid += job.result.get("success_count", 0)
return {"total_crawled": total, "valid_count": valid, "invalid_count": 0}
invalid += job.result.get("failure_count", 0)
result = {"total_crawled": total, "valid_count": valid, "invalid_count": invalid}
if self.is_cancelled:
result["cancelled"] = True
return result
return CrawlAllAggregator()
def format_plugin(plugin) -> dict:
return {
"id": plugin.id,
"name": plugin.display_name,
"display_name": plugin.display_name,
"description": plugin.description,
"enabled": plugin.enabled,
"last_run": plugin.last_run.isoformat() if plugin.last_run else None,
"success_count": plugin.success_count,
"failure_count": plugin.failure_count,
}

View File

@@ -52,7 +52,7 @@ async def list_proxies(
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
proxy = await service.get_random_proxy()
if not proxy:
raise ProxyNotFoundException("", 0)
raise ProxyPoolException("暂无可用代理", 404)
return success_response("获取随机代理成功", format_proxy(proxy))
@@ -68,12 +68,8 @@ async def export_proxies(
media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"}
async def generate():
async for chunk in service.export_proxies(fmt, protocol, limit):
yield chunk
return StreamingResponse(
generate(),
service.export_proxies(fmt, protocol, limit),
media_type=media_types[fmt],
headers={"Content-Disposition": f"attachment; filename=proxies.{fmt}"},
)

View File

@@ -4,37 +4,45 @@ from fastapi import APIRouter, Depends
from app.services.scheduler_service import SchedulerService
from app.repositories.settings_repo import SettingsRepository
from app.core.db import get_db
from app.api.deps import get_scheduler_service
from app.api.deps import get_scheduler_service, get_settings_repo
from app.api.common import success_response
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
settings_repo = SettingsRepository()
async def _save_auto_validate_setting(enabled: bool):
async def _save_auto_validate_setting(enabled: bool, settings_repo: SettingsRepository):
"""保存自动验证设置"""
async with get_db() as db:
settings = await settings_repo.get_all(db)
settings["auto_validate"] = enabled
from app.models.schemas import SettingsSchema
await settings_repo.save(db, SettingsSchema(**settings).model_dump())
await settings_repo.save(db, {"auto_validate": enabled})
@router.post("/start")
async def start_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
async def start_scheduler(
scheduler: SchedulerService = Depends(get_scheduler_service),
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
if scheduler.running:
return success_response("验证调度器已在运行", {"running": True})
try:
await _save_auto_validate_setting(True, settings_repo)
except Exception:
return success_response("启动调度器失败(设置保存异常)", {"running": False})
await scheduler.start()
await _save_auto_validate_setting(True)
return success_response("验证调度器已启动", {"running": True})
@router.post("/stop")
async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
async def stop_scheduler(
scheduler: SchedulerService = Depends(get_scheduler_service),
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
if not scheduler.running:
return success_response("验证调度器未运行", {"running": False})
try:
await _save_auto_validate_setting(False, settings_repo)
except Exception:
return success_response("停止调度器失败(设置保存异常)", {"running": True})
await scheduler.stop()
await _save_auto_validate_setting(False)
return success_response("验证调度器已停止", {"running": False})

View File

@@ -1,28 +1,33 @@
"""设置相关路由"""
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Depends
from app.core.db import get_db
from app.repositories.settings_repo import SettingsRepository
from app.models.schemas import SettingsSchema
from app.api.common import success_response
from app.api.deps import get_settings_repo
from app.core.exceptions import ProxyPoolException
from app.core.log import logger
router = APIRouter(prefix="/api/settings", tags=["settings"])
settings_repo = SettingsRepository()
@router.get("")
async def get_settings():
async def get_settings(settings_repo: SettingsRepository = Depends(get_settings_repo)):
async with get_db() as db:
settings = await settings_repo.get_all(db)
return success_response("获取设置成功", settings)
@router.post("")
async def save_settings(request: SettingsSchema, http_request: Request):
async def save_settings(
request: SettingsSchema,
http_request: Request,
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
async with get_db() as db:
success = await settings_repo.save(db, request.model_dump())
if not success:
raise RuntimeError("保存设置失败")
raise ProxyPoolException("保存设置失败", 500)
# 热更新运行中调度器的间隔时间
scheduler = getattr(http_request.app.state, "scheduler", None)
@@ -47,9 +52,9 @@ async def save_settings(request: SettingsSchema, http_request: Request):
validator._init_max_concurrency = request.default_concurrency
if request.validation_targets:
validator.update_test_urls(request.validation_targets)
# 重新创建 semaphore 和 session
validator._semaphore = None
# 先关闭现有 session再重置 semaphore避免竞态窗口
await validator.close()
validator._semaphore = None
logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}, targets={request.validation_targets}")
return success_response("保存设置成功", request.model_dump())

View File

@@ -26,6 +26,7 @@ class JobExecutor:
self.worker_pool = worker_pool
self.max_concurrent_jobs = max_concurrent_jobs
self._jobs: Dict[str, Job] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._running = False
self._semaphore = asyncio.Semaphore(max_concurrent_jobs)
self._cleanup_interval = cleanup_interval_seconds
@@ -61,27 +62,31 @@ class JobExecutor:
def submit_job(self, job: Job) -> str:
"""提交一个 Job 到后台执行"""
self._jobs[job.id] = job
asyncio.create_task(self._run_job(job))
task = asyncio.create_task(self._run_job(job))
self._tasks[job.id] = task
return job.id
async def _run_job(self, job: Job) -> None:
async with self._semaphore:
try:
if job.is_cancelled:
logger.info(f"Job {job.id} was cancelled before running")
return
result = await job.run()
# 如果子类没有显式设置完成状态,自动设为 completed
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
job._set_completed(result)
logger.info(f"Job {job.id} completed: {result}")
except asyncio.CancelledError:
job.status = JobStatus.CANCELLED
job._touch()
logger.info(f"Job {job.id} cancelled during execution")
except Exception as e:
job._set_failed(str(e))
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
try:
async with self._semaphore:
try:
if job.is_cancelled:
logger.info(f"Job {job.id} was cancelled before running")
return
result = await job.run()
# 如果子类没有显式设置完成状态,自动设为 completed
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
job._set_completed(result)
logger.info(f"Job {job.id} completed: {result}")
except asyncio.CancelledError:
job.status = JobStatus.CANCELLED
job._touch()
logger.info(f"Job {job.id} cancelled during execution")
except Exception as e:
job._set_failed(str(e))
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
finally:
self._tasks.pop(job.id, None)
def get_job(self, job_id: str) -> Optional[Job]:
return self._jobs.get(job_id)
@@ -101,6 +106,9 @@ class JobExecutor:
job = self._jobs.get(job_id)
if not job:
return False
task = self._tasks.get(job_id)
if task and not task.done():
task.cancel()
job.cancel()
return True
@@ -108,6 +116,9 @@ class JobExecutor:
cancelled = 0
for job in list(self._jobs.values()):
if job.status in (JobStatus.PENDING, JobStatus.RUNNING):
task = self._tasks.get(job.id)
if task and not task.done():
task.cancel()
job.cancel()
cancelled += 1
return cancelled

View File

@@ -154,6 +154,10 @@ class ValidateAllJob(Job):
self._touch()
payload = {"total": total, "submitted": submitted}
self._set_completed(payload)
if self.is_cancelled:
self.status = JobStatus.CANCELLED
self._touch()
else:
self._set_completed(payload)
logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies")
return payload

View File

@@ -74,21 +74,36 @@ class AsyncWorkerPool:
await self._queue.join()
async def resize(self, new_worker_count: int) -> None:
"""动态调整 Worker 数量:先全部停止,再按新数量启动"""
"""动态调整 Worker 数量,不丢失队列中的任务"""
if new_worker_count == self.worker_count:
return
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
# 安全做法:先 stop 再 start避免新旧 Worker 竞争 sentinel 导致死锁
await self.stop()
if new_worker_count > self.worker_count:
for i in range(self.worker_count, new_worker_count):
self._workers.append(
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
)
elif new_worker_count < self.worker_count:
for _ in range(self.worker_count - new_worker_count):
await self._queue.put(None)
await asyncio.sleep(0)
still_running = []
for w in self._workers:
if w.done():
try:
await w
except asyncio.CancelledError:
pass
else:
still_running.append(w)
self._workers = still_running
self.worker_count = new_worker_count
await self.start()
async def _worker_loop(self, worker_id: int) -> None:
while True:
item = await self._queue.get()
try:
if item is None or not self._running:
self._queue.task_done()
break
await self.handler(item)
except Exception as e:

View File

@@ -5,7 +5,7 @@ from .schemas import (
ProxyResponse,
PluginResponse,
SettingsSchema,
CrawlResult,
CrawlSummarySchema,
ProxyListRequest,
ProxyDeleteItem,
BatchDeleteRequest,
@@ -21,7 +21,7 @@ __all__ = [
"ProxyResponse",
"PluginResponse",
"SettingsSchema",
"CrawlResult",
"CrawlSummarySchema",
"ProxyListRequest",
"ProxyDeleteItem",
"BatchDeleteRequest",

View File

@@ -59,7 +59,7 @@ class SettingsSchema(BaseModel):
)
class CrawlResult(BaseModel):
class CrawlSummarySchema(BaseModel):
plugin_id: str
proxy_count: int
valid_count: int

View File

@@ -124,9 +124,10 @@ class ProxyRepository:
async def batch_delete(db: aiosqlite.Connection, proxies: List[Tuple[str, int]]) -> int:
if not proxies:
return 0
changes_before = db.total_changes
await db.executemany("DELETE FROM proxies WHERE ip = ? AND port = ?", proxies)
await db.commit()
return len(proxies)
return db.total_changes - changes_before
@staticmethod
async def get_by_ip_port(

View File

@@ -19,7 +19,7 @@ class PluginRunner:
"""
def __init__(self, timeout: Optional[float] = None):
self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30)
self.timeout = timeout if timeout is not None else getattr(app_settings, "crawler_timeout", 30)
async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult:
"""执行单个插件爬取"""
@@ -37,6 +37,10 @@ class PluginRunner:
return result
except Exception as e:
logger.warning(f"Plugin {plugin.name} health check error: {e}")
result.error = f"health check error: {e}"
result.failure_count = 1
await self._save_stats(plugin, result)
return result
# 执行爬取
try:
@@ -66,7 +70,7 @@ class PluginRunner:
seen = set()
unique = []
for p in proxies:
key = (p.ip, p.port)
key = (p.ip, p.port, p.protocol)
if key not in seen:
seen.add(key)
unique.append(p)

View File

@@ -132,7 +132,7 @@ class PluginService:
seen = set()
unique = []
for p in all_results:
key = (p.ip, p.port)
key = (p.ip, p.port, p.protocol)
if key not in seen:
seen.add(key)
unique.append(p)

View File

@@ -5,7 +5,7 @@ import io
from datetime import datetime
from typing import List, Optional, Tuple, AsyncIterator
from app.core.db import get_db, transaction
from app.core.db import get_db
from app.repositories.proxy_repo import ProxyRepository
from app.models.domain import Proxy
from app.core.log import logger

View File

@@ -1,6 +1,6 @@
"""调度器服务 - 定时触发全量验证"""
import asyncio
from typing import Optional
from typing import Optional, Any
from app.core.execution.executor import JobExecutor
from app.core.execution.job import ValidateAllJob
@@ -17,9 +17,11 @@ class SchedulerService:
def __init__(
self,
executor: JobExecutor,
worker_pool: Optional[Any] = None,
interval_minutes: int = 30,
):
self.executor = executor
self.worker_pool = worker_pool
self.interval_minutes = interval_minutes
self.running = False
self._stop_event = asyncio.Event()
@@ -50,7 +52,7 @@ class SchedulerService:
def validate_all_now(self) -> str:
"""立即执行一次全量验证,返回 Job ID"""
job_id = self.executor.submit_job(ValidateAllJob())
job_id = self.executor.submit_job(ValidateAllJob(validator_pool=self.worker_pool))
logger.info(f"ValidateAllJob submitted: {job_id}")
return job_id
@@ -58,7 +60,7 @@ class SchedulerService:
"""定时循环"""
while self.running:
try:
self.executor.submit_job(ValidateAllJob())
self.executor.submit_job(ValidateAllJob(validator_pool=self.worker_pool))
except Exception as e:
logger.error(f"Scheduler loop error: {e}", exc_info=True)
# 等待下一次

View File

@@ -39,9 +39,9 @@ class ValidatorService:
max_concurrency: Optional[int] = None,
):
# 初始化时使用传入值或默认值,但运行期会动态读取 settings
self._init_timeout = timeout or app_settings.validator_timeout
self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout
self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency
self._init_timeout = timeout if timeout is not None else app_settings.validator_timeout
self._init_connect_timeout = connect_timeout if connect_timeout is not None else app_settings.validator_connect_timeout
self._init_max_concurrency = max_concurrency if max_concurrency is not None else app_settings.validator_max_concurrency
self._http_connector: Optional[aiohttp.TCPConnector] = None
self._http_session: Optional[aiohttp.ClientSession] = None
@@ -61,23 +61,26 @@ class ValidatorService:
def max_concurrency(self) -> int:
return int(self._init_max_concurrency)
def _ensure_session(self) -> aiohttp.ClientSession:
async def _ensure_session(self) -> aiohttp.ClientSession:
"""懒加载共享 HTTP session"""
if self._http_session is None or self._http_session.closed:
connector = aiohttp.TCPConnector(
ssl=False,
limit=self.max_concurrency,
limit_per_host=self.max_concurrency,
force_close=False,
)
timeout = aiohttp.ClientTimeout(
total=self.timeout, connect=self.connect_timeout
)
self._http_connector = connector
self._http_session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
async with self._lock:
# 双重检查,避免多个协程在获取锁后重复创建
if self._http_session is None or self._http_session.closed:
connector = aiohttp.TCPConnector(
ssl=False,
limit=self.max_concurrency,
limit_per_host=self.max_concurrency,
force_close=False,
)
timeout = aiohttp.ClientTimeout(
total=self.timeout, connect=self.connect_timeout
)
self._http_connector = connector
self._http_session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
return self._http_session
def _ensure_semaphore(self) -> asyncio.Semaphore:
@@ -119,9 +122,9 @@ class ValidatorService:
return False, 0.0
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
proxy_url = f"http://{ip}:{port}"
proxy_url = f"{protocol}://{ip}:{port}"
test_url = self._get_test_url(protocol)
session = self._ensure_session()
session = await self._ensure_session()
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
if response.status in (200, 301, 302):