Round 3 fixes: cancelled polling, aggregator invalid_count, filter state, scheduler atomicity, HTTP exception handler, tests

This commit is contained in:
祀梦
2026-04-05 10:20:23 +08:00
parent 49e440cb41
commit dc5f050683
32 changed files with 321 additions and 163 deletions

View File

@@ -1,11 +1,10 @@
import { ref } from 'vue' import { ref } from 'vue'
import { schedulerService } from '../services/schedulerService' import { schedulerService } from '../services/schedulerService'
const schedulerRunning = ref(false)
const schedulerLoading = ref(false)
const validating = ref(false)
export function useScheduler() { export function useScheduler() {
const schedulerRunning = ref(false)
const schedulerLoading = ref(false)
const validating = ref(false)
async function fetchStatus() { async function fetchStatus() {
try { try {
const response = await schedulerService.getStatus() const response = await schedulerService.getStatus()

View File

@@ -11,13 +11,18 @@ const MAX_POLL_ATTEMPTS = 30
export async function pollTaskStatus(taskId) { export async function pollTaskStatus(taskId) {
for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) { for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) {
await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL)) await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL))
const response = await tasksAPI.getTaskStatus(taskId) try {
if (response.code !== 200) { const response = await tasksAPI.getTaskStatus(taskId)
continue if (response.code !== 200) {
} continue
const status = response.data.status }
if (status === 'completed' || status === 'failed') { const status = response.data.status
return response if (status === 'completed' || status === 'failed' || status === 'cancelled') {
return response
}
} catch (error) {
// 网络异常时继续轮询,不中断
console.warn('轮询任务状态失败:', error)
} }
} }
return { return {

View File

@@ -26,25 +26,21 @@ export const pluginService = {
const finalRes = await pollTaskStatus(startRes.data.task_id) const finalRes = await pollTaskStatus(startRes.data.task_id)
return { return {
code: finalRes.code, code: finalRes.code,
message: finalRes.data?.message || finalRes.message, message: finalRes.message,
data: finalRes.data?.data || finalRes.data data: finalRes.data?.result
} }
}, },
async crawlAll() { async crawlAll() {
const startRes = await pluginsAPI.crawlAll() const startRes = await pluginsAPI.crawlAll()
if (startRes.code !== 200 || !startRes.data?.task_ids?.length) { if (startRes.code !== 200 || !startRes.data?.task_id) {
return startRes return startRes
} }
// 批量轮询所有任务,取最后一个完成的结果 const finalRes = await pollTaskStatus(startRes.data.task_id)
const results = await Promise.all(
startRes.data.task_ids.map(tid => pollTaskStatus(tid))
)
const last = results[results.length - 1]
return { return {
code: last.code, code: finalRes.code,
message: last.data?.message || last.message, message: finalRes.message,
data: last.data?.data || last.data data: finalRes.data?.result
} }
} }
} }

View File

@@ -77,17 +77,17 @@ export const useProxyStore = defineStore('proxy', () => {
* @param {number|string} port * @param {number|string} port
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
async function deleteProxy(ip, port) { async function deleteProxy(ip, port, page = 1, pageSize = 20, filters = {}) {
try { try {
const response = await proxyService.deleteProxy(ip, port) const response = await proxyService.deleteProxy(ip, port)
if (response.code === 200) { if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 await fetchProxies({ page, page_size: pageSize, ...filters }) // 刷新列表
return true return true
} }
} catch (error) { } catch (error) {
console.error('删除代理失败:', error) console.error('删除代理失败:', error)
return false
} }
return false
} }
/** /**
@@ -95,13 +95,13 @@ export const useProxyStore = defineStore('proxy', () => {
* @param {Array<{ip: string, port: number}>} proxyList * @param {Array<{ip: string, port: number}>} proxyList
* @returns {Promise<number>} 实际删除的数量 * @returns {Promise<number>} 实际删除的数量
*/ */
async function batchDeleteProxies(proxyList) { async function batchDeleteProxies(proxyList, page = 1, pageSize = 20, filters = {}) {
if (!proxyList?.length) return 0 if (!proxyList?.length) return 0
try { try {
const response = await proxyService.batchDelete(proxyList) const response = await proxyService.batchDelete(proxyList)
if (response.code === 200) { if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 await fetchProxies({ page, page_size: pageSize, ...filters }) // 刷新列表
return response.data.deleted_count return response.data.deleted_count
} }
} catch (error) { } catch (error) {
@@ -114,17 +114,17 @@ export const useProxyStore = defineStore('proxy', () => {
* 清理无效代理 * 清理无效代理
* @returns {Promise<number>} 删除的数量 * @returns {Promise<number>} 删除的数量
*/ */
async function cleanInvalidProxies() { async function cleanInvalidProxies(page = 1, pageSize = 20, filters = {}) {
try { try {
const response = await proxyService.cleanInvalid() const response = await proxyService.cleanInvalid()
if (response.code === 200) { if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 await fetchProxies({ page, page_size: pageSize, ...filters }) // 刷新列表
return response.data.deleted_count return response.data.deleted_count
} }
} catch (error) { } catch (error) {
console.error('清理无效代理失败:', error) console.error('清理无效代理失败:', error)
} }
return 0 return -1
} }
/** /**

View File

@@ -149,9 +149,13 @@ async function handleClean() {
) )
const deletedCount = await proxyStore.cleanInvalidProxies() const deletedCount = await proxyStore.cleanInvalidProxies()
if (deletedCount >= 0) { if (deletedCount > 0) {
ElMessage.success(`已清理 ${deletedCount} 个无效代理`) ElMessage.success(`已清理 ${deletedCount} 个无效代理`)
await proxyStore.fetchStats() await proxyStore.fetchStats()
} else if (deletedCount === 0) {
ElMessage.info('没有需要清理的无效代理')
} else if (deletedCount === -1) {
ElMessage.error('清理无效代理失败')
} }
} catch { } catch {
// 用户取消 // 用户取消

View File

@@ -101,11 +101,11 @@
<el-icon v-if="crawlResults[row.id].type === 'success'" class="result-icon success"><CircleCheck /></el-icon> <el-icon v-if="crawlResults[row.id].type === 'success'" class="result-icon success"><CircleCheck /></el-icon>
<el-icon v-else class="result-icon failed"><CircleClose /></el-icon> <el-icon v-else class="result-icon failed"><CircleClose /></el-icon>
<span class="result-text">{{ crawlResults[row.id].message }}</span> <span class="result-text">{{ crawlResults[row.id].message }}</span>
<span v-if="crawlResults[row.id].data?.valid_count !== undefined" class="result-count valid"> <span v-if="crawlResults[row.id].data?.success_count !== undefined" class="result-count valid">
有效 {{ crawlResults[row.id].data.valid_count }} 有效 {{ crawlResults[row.id].data.success_count }}
</span> </span>
<span v-if="crawlResults[row.id].data?.invalid_count !== undefined" class="result-count invalid"> <span v-if="crawlResults[row.id].data?.failure_count !== undefined" class="result-count invalid">
无效 {{ crawlResults[row.id].data.invalid_count }} 无效 {{ crawlResults[row.id].data.failure_count }}
</span> </span>
<el-icon class="result-close" @click="clearCrawlResult(row.id)"><Close /></el-icon> <el-icon class="result-close" @click="clearCrawlResult(row.id)"><Close /></el-icon>
</div> </div>
@@ -134,9 +134,7 @@
<span v-if="allCrawlResult.data.total_crawled !== undefined"> <span v-if="allCrawlResult.data.total_crawled !== undefined">
爬取: {{ allCrawlResult.data.total_crawled }} 爬取: {{ allCrawlResult.data.total_crawled }}
</span> </span>
<span v-if="allCrawlResult.data.proxy_count !== undefined">
爬取: {{ allCrawlResult.data.proxy_count }}
</span>
<span v-if="allCrawlResult.data.valid_count !== undefined" class="valid-count"> <span v-if="allCrawlResult.data.valid_count !== undefined" class="valid-count">
有效: {{ allCrawlResult.data.valid_count }} 有效: {{ allCrawlResult.data.valid_count }}
</span> </span>
@@ -235,14 +233,18 @@ async function handleToggle(pluginId, enabled) {
} }
async function handleOpenConfig(row) { async function handleOpenConfig(row) {
currentPlugin.value = row try {
const response = await pluginService.getPluginConfig(row.id) currentPlugin.value = row
if (response.code === 200) { const response = await pluginService.getPluginConfig(row.id)
Object.keys(configForm).forEach(key => delete configForm[key]) if (response.code === 200) {
Object.assign(configForm, response.data.config || {}) Object.keys(configForm).forEach(key => delete configForm[key])
configDialogVisible.value = true Object.assign(configForm, response.data.config || {})
} else { configDialogVisible.value = true
ElMessage.error('获取插件配置失败') } else {
ElMessage.error('获取插件配置失败')
}
} catch (error) {
ElMessage.error('获取插件配置出错')
} }
} }
@@ -321,11 +323,13 @@ async function handleCrawlAll() {
if (response.code === 200) { if (response.code === 200) {
allCrawlResult.value = { allCrawlResult.value = {
type: 'success', type: response.data?.cancelled ? 'info' : 'success',
message: response.message, message: response.message,
data: response.data data: response.data
} }
ElMessage.success('批量爬取完成') if (!response.data?.cancelled) {
ElMessage.success('批量爬取完成')
}
} else { } else {
allCrawlResult.value = { allCrawlResult.value = {
type: 'error', type: 'error',

View File

@@ -190,18 +190,27 @@ async function fetchProxies() {
} }
abortController = new AbortController() abortController = new AbortController()
const success = await proxyStore.fetchProxies({ try {
page: currentPage.value, const success = await proxyStore.fetchProxies({
page_size: pageSize.value, page: currentPage.value,
protocol: filterForm.protocol || null, page_size: pageSize.value,
min_score: filterForm.minScore, protocol: filterForm.protocol || null,
sort_by: filterForm.sortBy, min_score: filterForm.minScore,
sort_order: filterForm.sortOrder sort_by: filterForm.sortBy,
}, abortController.signal) sort_order: filterForm.sortOrder
}, abortController.signal)
abortController = null
if (!success) { if (!success) {
ElMessage.error('获取代理列表失败') ElMessage.error('获取代理列表失败')
}
} catch (error) {
if (error.name === 'AbortError') {
// 用户主动取消,不提示错误
return
}
throw error
} finally {
abortController = null
} }
} }
@@ -223,10 +232,15 @@ async function handleDelete(proxy) {
const confirmed = await confirmDelete(`代理 ${proxy.ip}:${proxy.port}`) const confirmed = await confirmDelete(`代理 ${proxy.ip}:${proxy.port}`)
if (!confirmed) return if (!confirmed) return
const success = await proxyStore.deleteProxy(proxy.ip, proxy.port) const filters = {
protocol: filterForm.protocol || null,
min_score: filterForm.minScore,
sort_by: filterForm.sortBy,
sort_order: filterForm.sortOrder
}
const success = await proxyStore.deleteProxy(proxy.ip, proxy.port, currentPage.value, pageSize.value, filters)
if (success) { if (success) {
ElMessage.success('删除成功') ElMessage.success('删除成功')
fetchProxies()
} }
} }
@@ -237,11 +251,16 @@ async function handleBatchDelete() {
const confirmed = await confirmBatchDelete(count, '代理') const confirmed = await confirmBatchDelete(count, '代理')
if (!confirmed) return if (!confirmed) return
const deletedCount = await proxyStore.batchDeleteProxies(selectedProxies.value) const filters = {
protocol: filterForm.protocol || null,
min_score: filterForm.minScore,
sort_by: filterForm.sortBy,
sort_order: filterForm.sortOrder
}
const deletedCount = await proxyStore.batchDeleteProxies(selectedProxies.value, currentPage.value, pageSize.value, filters)
if (deletedCount > 0) { if (deletedCount > 0) {
ElMessage.success(`已删除 ${deletedCount} 个代理`) ElMessage.success(`已删除 ${deletedCount} 个代理`)
selectedProxies.value = [] selectedProxies.value = []
fetchProxies()
} }
} }

View File

@@ -5,6 +5,7 @@ from app.services.plugin_service import PluginService
from app.services.scheduler_service import SchedulerService from app.services.scheduler_service import SchedulerService
from app.services.plugin_runner import PluginRunner from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, AsyncWorkerPool from app.core.execution import JobExecutor, AsyncWorkerPool
from app.repositories.settings_repo import SettingsRepository
def get_proxy_service() -> ProxyService: def get_proxy_service() -> ProxyService:
@@ -15,6 +16,10 @@ def get_plugin_service() -> PluginService:
return PluginService() return PluginService()
def get_settings_repo() -> SettingsRepository:
return SettingsRepository()
def get_scheduler_service(request: Request) -> SchedulerService: def get_scheduler_service(request: Request) -> SchedulerService:
return request.app.state.scheduler return request.app.state.scheduler

View File

@@ -1,6 +1,7 @@
"""统一异常处理""" """统一异常处理"""
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import ValidationError from pydantic import ValidationError
from app.core.exceptions import ProxyPoolException from app.core.exceptions import ProxyPoolException
from app.core.log import logger from app.core.log import logger
@@ -13,6 +14,13 @@ async def proxy_pool_exception_handler(request: Request, exc: ProxyPoolException
) )
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": exc.detail, "data": None},
)
async def pydantic_validation_handler(request: Request, exc: ValidationError): async def pydantic_validation_handler(request: Request, exc: ValidationError):
logger.error(f"Validation error: {exc}") logger.error(f"Validation error: {exc}")
return JSONResponse( return JSONResponse(

View File

@@ -80,6 +80,7 @@ async def lifespan(app: FastAPI):
# 调度器 # 调度器
scheduler = SchedulerService( scheduler = SchedulerService(
executor=executor, executor=executor,
worker_pool=worker_pool,
interval_minutes=db_settings.get("validate_interval_minutes", 30), interval_minutes=db_settings.get("validate_interval_minutes", 30),
) )

View File

@@ -3,9 +3,10 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.lifespan import lifespan from app.api.lifespan import lifespan
from app.api.routes import api_router from app.api.routes import api_router
from app.api.errors import proxy_pool_exception_handler, pydantic_validation_handler, general_exception_handler from app.api.errors import proxy_pool_exception_handler, http_exception_handler, pydantic_validation_handler, general_exception_handler
from app.core.exceptions import ProxyPoolException from app.core.exceptions import ProxyPoolException
from pydantic import ValidationError from pydantic import ValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.core.config import settings as app_settings from app.core.config import settings as app_settings
# 导入并注册所有插件(显式注册模式) # 导入并注册所有插件(显式注册模式)
@@ -30,6 +31,7 @@ def create_app() -> FastAPI:
# 异常处理 # 异常处理
app.add_exception_handler(ProxyPoolException, proxy_pool_exception_handler) app.add_exception_handler(ProxyPoolException, proxy_pool_exception_handler)
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(ValidationError, pydantic_validation_handler) app.add_exception_handler(ValidationError, pydantic_validation_handler)
app.add_exception_handler(Exception, general_exception_handler) app.add_exception_handler(Exception, general_exception_handler)

View File

@@ -7,7 +7,7 @@ from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, CrawlJob from app.core.execution import JobExecutor, CrawlJob
from app.core.exceptions import PluginNotFoundException from app.core.exceptions import PluginNotFoundException
from app.api.deps import get_plugin_service, get_plugin_runner, get_executor from app.api.deps import get_plugin_service, get_plugin_runner, get_executor
from app.api.common import success_response from app.api.common import success_response, format_plugin
router = APIRouter(prefix="/api/plugins", tags=["plugins"]) router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@@ -114,8 +114,11 @@ def _create_crawl_all_aggregator(job_ids, executor):
class CrawlAllAggregator(Job): class CrawlAllAggregator(Job):
async def run(self): async def run(self):
self._set_running()
# 等待所有子 job 完成(最多等 30 秒) # 等待所有子 job 完成(最多等 30 秒)
for _ in range(300): for _ in range(300):
if self.is_cancelled:
break
all_done = all( all_done = all(
executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled") executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled")
for jid in job_ids for jid in job_ids
@@ -125,24 +128,18 @@ def _create_crawl_all_aggregator(job_ids, executor):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
total = 0 total = 0
valid = 0 valid = 0
invalid = 0
for jid in job_ids: for jid in job_ids:
job = executor.get_job(jid) job = executor.get_job(jid)
if job and job.result: if job and job.result:
total += job.result.get("proxy_count", 0) total += job.result.get("proxy_count", 0)
valid += job.result.get("success_count", 0) valid += job.result.get("success_count", 0)
return {"total_crawled": total, "valid_count": valid, "invalid_count": 0} invalid += job.result.get("failure_count", 0)
result = {"total_crawled": total, "valid_count": valid, "invalid_count": invalid}
if self.is_cancelled:
result["cancelled"] = True
return result
return CrawlAllAggregator() return CrawlAllAggregator()
def format_plugin(plugin) -> dict:
return {
"id": plugin.id,
"name": plugin.display_name,
"display_name": plugin.display_name,
"description": plugin.description,
"enabled": plugin.enabled,
"last_run": plugin.last_run.isoformat() if plugin.last_run else None,
"success_count": plugin.success_count,
"failure_count": plugin.failure_count,
}

View File

@@ -52,7 +52,7 @@ async def list_proxies(
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)): async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
proxy = await service.get_random_proxy() proxy = await service.get_random_proxy()
if not proxy: if not proxy:
raise ProxyNotFoundException("", 0) raise ProxyPoolException("暂无可用代理", 404)
return success_response("获取随机代理成功", format_proxy(proxy)) return success_response("获取随机代理成功", format_proxy(proxy))
@@ -68,12 +68,8 @@ async def export_proxies(
media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"} media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"}
async def generate():
async for chunk in service.export_proxies(fmt, protocol, limit):
yield chunk
return StreamingResponse( return StreamingResponse(
generate(), service.export_proxies(fmt, protocol, limit),
media_type=media_types[fmt], media_type=media_types[fmt],
headers={"Content-Disposition": f"attachment; filename=proxies.{fmt}"}, headers={"Content-Disposition": f"attachment; filename=proxies.{fmt}"},
) )

View File

@@ -4,37 +4,45 @@ from fastapi import APIRouter, Depends
from app.services.scheduler_service import SchedulerService from app.services.scheduler_service import SchedulerService
from app.repositories.settings_repo import SettingsRepository from app.repositories.settings_repo import SettingsRepository
from app.core.db import get_db from app.core.db import get_db
from app.api.deps import get_scheduler_service from app.api.deps import get_scheduler_service, get_settings_repo
from app.api.common import success_response from app.api.common import success_response
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"]) router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
settings_repo = SettingsRepository()
async def _save_auto_validate_setting(enabled: bool): async def _save_auto_validate_setting(enabled: bool, settings_repo: SettingsRepository):
"""保存自动验证设置""" """保存自动验证设置"""
async with get_db() as db: async with get_db() as db:
settings = await settings_repo.get_all(db) await settings_repo.save(db, {"auto_validate": enabled})
settings["auto_validate"] = enabled
from app.models.schemas import SettingsSchema
await settings_repo.save(db, SettingsSchema(**settings).model_dump())
@router.post("/start") @router.post("/start")
async def start_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)): async def start_scheduler(
scheduler: SchedulerService = Depends(get_scheduler_service),
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
if scheduler.running: if scheduler.running:
return success_response("验证调度器已在运行", {"running": True}) return success_response("验证调度器已在运行", {"running": True})
try:
await _save_auto_validate_setting(True, settings_repo)
except Exception:
return success_response("启动调度器失败(设置保存异常)", {"running": False})
await scheduler.start() await scheduler.start()
await _save_auto_validate_setting(True)
return success_response("验证调度器已启动", {"running": True}) return success_response("验证调度器已启动", {"running": True})
@router.post("/stop") @router.post("/stop")
async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)): async def stop_scheduler(
scheduler: SchedulerService = Depends(get_scheduler_service),
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
if not scheduler.running: if not scheduler.running:
return success_response("验证调度器未运行", {"running": False}) return success_response("验证调度器未运行", {"running": False})
try:
await _save_auto_validate_setting(False, settings_repo)
except Exception:
return success_response("停止调度器失败(设置保存异常)", {"running": True})
await scheduler.stop() await scheduler.stop()
await _save_auto_validate_setting(False)
return success_response("验证调度器已停止", {"running": False}) return success_response("验证调度器已停止", {"running": False})

View File

@@ -1,28 +1,33 @@
"""设置相关路由""" """设置相关路由"""
from fastapi import APIRouter, Request from fastapi import APIRouter, Request, Depends
from app.core.db import get_db from app.core.db import get_db
from app.repositories.settings_repo import SettingsRepository from app.repositories.settings_repo import SettingsRepository
from app.models.schemas import SettingsSchema from app.models.schemas import SettingsSchema
from app.api.common import success_response from app.api.common import success_response
from app.api.deps import get_settings_repo
from app.core.exceptions import ProxyPoolException
from app.core.log import logger from app.core.log import logger
router = APIRouter(prefix="/api/settings", tags=["settings"]) router = APIRouter(prefix="/api/settings", tags=["settings"])
settings_repo = SettingsRepository()
@router.get("") @router.get("")
async def get_settings(): async def get_settings(settings_repo: SettingsRepository = Depends(get_settings_repo)):
async with get_db() as db: async with get_db() as db:
settings = await settings_repo.get_all(db) settings = await settings_repo.get_all(db)
return success_response("获取设置成功", settings) return success_response("获取设置成功", settings)
@router.post("") @router.post("")
async def save_settings(request: SettingsSchema, http_request: Request): async def save_settings(
request: SettingsSchema,
http_request: Request,
settings_repo: SettingsRepository = Depends(get_settings_repo),
):
async with get_db() as db: async with get_db() as db:
success = await settings_repo.save(db, request.model_dump()) success = await settings_repo.save(db, request.model_dump())
if not success: if not success:
raise RuntimeError("保存设置失败") raise ProxyPoolException("保存设置失败", 500)
# 热更新运行中调度器的间隔时间 # 热更新运行中调度器的间隔时间
scheduler = getattr(http_request.app.state, "scheduler", None) scheduler = getattr(http_request.app.state, "scheduler", None)
@@ -47,9 +52,9 @@ async def save_settings(request: SettingsSchema, http_request: Request):
validator._init_max_concurrency = request.default_concurrency validator._init_max_concurrency = request.default_concurrency
if request.validation_targets: if request.validation_targets:
validator.update_test_urls(request.validation_targets) validator.update_test_urls(request.validation_targets)
# 重新创建 semaphore 和 session # 先关闭现有 session再重置 semaphore避免竞态窗口
validator._semaphore = None
await validator.close() await validator.close()
validator._semaphore = None
logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}, targets={request.validation_targets}") logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}, targets={request.validation_targets}")
return success_response("保存设置成功", request.model_dump()) return success_response("保存设置成功", request.model_dump())

View File

@@ -26,6 +26,7 @@ class JobExecutor:
self.worker_pool = worker_pool self.worker_pool = worker_pool
self.max_concurrent_jobs = max_concurrent_jobs self.max_concurrent_jobs = max_concurrent_jobs
self._jobs: Dict[str, Job] = {} self._jobs: Dict[str, Job] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._running = False self._running = False
self._semaphore = asyncio.Semaphore(max_concurrent_jobs) self._semaphore = asyncio.Semaphore(max_concurrent_jobs)
self._cleanup_interval = cleanup_interval_seconds self._cleanup_interval = cleanup_interval_seconds
@@ -61,27 +62,31 @@ class JobExecutor:
def submit_job(self, job: Job) -> str: def submit_job(self, job: Job) -> str:
"""提交一个 Job 到后台执行""" """提交一个 Job 到后台执行"""
self._jobs[job.id] = job self._jobs[job.id] = job
asyncio.create_task(self._run_job(job)) task = asyncio.create_task(self._run_job(job))
self._tasks[job.id] = task
return job.id return job.id
async def _run_job(self, job: Job) -> None: async def _run_job(self, job: Job) -> None:
async with self._semaphore: try:
try: async with self._semaphore:
if job.is_cancelled: try:
logger.info(f"Job {job.id} was cancelled before running") if job.is_cancelled:
return logger.info(f"Job {job.id} was cancelled before running")
result = await job.run() return
# 如果子类没有显式设置完成状态,自动设为 completed result = await job.run()
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): # 如果子类没有显式设置完成状态,自动设为 completed
job._set_completed(result) if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
logger.info(f"Job {job.id} completed: {result}") job._set_completed(result)
except asyncio.CancelledError: logger.info(f"Job {job.id} completed: {result}")
job.status = JobStatus.CANCELLED except asyncio.CancelledError:
job._touch() job.status = JobStatus.CANCELLED
logger.info(f"Job {job.id} cancelled during execution") job._touch()
except Exception as e: logger.info(f"Job {job.id} cancelled during execution")
job._set_failed(str(e)) except Exception as e:
logger.error(f"Job {job.id} failed: {e}", exc_info=True) job._set_failed(str(e))
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
finally:
self._tasks.pop(job.id, None)
def get_job(self, job_id: str) -> Optional[Job]: def get_job(self, job_id: str) -> Optional[Job]:
return self._jobs.get(job_id) return self._jobs.get(job_id)
@@ -101,6 +106,9 @@ class JobExecutor:
job = self._jobs.get(job_id) job = self._jobs.get(job_id)
if not job: if not job:
return False return False
task = self._tasks.get(job_id)
if task and not task.done():
task.cancel()
job.cancel() job.cancel()
return True return True
@@ -108,6 +116,9 @@ class JobExecutor:
cancelled = 0 cancelled = 0
for job in list(self._jobs.values()): for job in list(self._jobs.values()):
if job.status in (JobStatus.PENDING, JobStatus.RUNNING): if job.status in (JobStatus.PENDING, JobStatus.RUNNING):
task = self._tasks.get(job.id)
if task and not task.done():
task.cancel()
job.cancel() job.cancel()
cancelled += 1 cancelled += 1
return cancelled return cancelled

View File

@@ -154,6 +154,10 @@ class ValidateAllJob(Job):
self._touch() self._touch()
payload = {"total": total, "submitted": submitted} payload = {"total": total, "submitted": submitted}
self._set_completed(payload) if self.is_cancelled:
self.status = JobStatus.CANCELLED
self._touch()
else:
self._set_completed(payload)
logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies") logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies")
return payload return payload

View File

@@ -74,21 +74,36 @@ class AsyncWorkerPool:
await self._queue.join() await self._queue.join()
async def resize(self, new_worker_count: int) -> None: async def resize(self, new_worker_count: int) -> None:
"""动态调整 Worker 数量:先全部停止,再按新数量启动""" """动态调整 Worker 数量,不丢失队列中的任务"""
if new_worker_count == self.worker_count: if new_worker_count == self.worker_count:
return return
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}") logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
# 安全做法:先 stop 再 start避免新旧 Worker 竞争 sentinel 导致死锁 if new_worker_count > self.worker_count:
await self.stop() for i in range(self.worker_count, new_worker_count):
self._workers.append(
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
)
elif new_worker_count < self.worker_count:
for _ in range(self.worker_count - new_worker_count):
await self._queue.put(None)
await asyncio.sleep(0)
still_running = []
for w in self._workers:
if w.done():
try:
await w
except asyncio.CancelledError:
pass
else:
still_running.append(w)
self._workers = still_running
self.worker_count = new_worker_count self.worker_count = new_worker_count
await self.start()
async def _worker_loop(self, worker_id: int) -> None: async def _worker_loop(self, worker_id: int) -> None:
while True: while True:
item = await self._queue.get() item = await self._queue.get()
try: try:
if item is None or not self._running: if item is None or not self._running:
self._queue.task_done()
break break
await self.handler(item) await self.handler(item)
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,7 @@ from .schemas import (
ProxyResponse, ProxyResponse,
PluginResponse, PluginResponse,
SettingsSchema, SettingsSchema,
CrawlResult, CrawlSummarySchema,
ProxyListRequest, ProxyListRequest,
ProxyDeleteItem, ProxyDeleteItem,
BatchDeleteRequest, BatchDeleteRequest,
@@ -21,7 +21,7 @@ __all__ = [
"ProxyResponse", "ProxyResponse",
"PluginResponse", "PluginResponse",
"SettingsSchema", "SettingsSchema",
"CrawlResult", "CrawlSummarySchema",
"ProxyListRequest", "ProxyListRequest",
"ProxyDeleteItem", "ProxyDeleteItem",
"BatchDeleteRequest", "BatchDeleteRequest",

View File

@@ -59,7 +59,7 @@ class SettingsSchema(BaseModel):
) )
class CrawlResult(BaseModel): class CrawlSummarySchema(BaseModel):
plugin_id: str plugin_id: str
proxy_count: int proxy_count: int
valid_count: int valid_count: int

View File

@@ -124,9 +124,10 @@ class ProxyRepository:
async def batch_delete(db: aiosqlite.Connection, proxies: List[Tuple[str, int]]) -> int: async def batch_delete(db: aiosqlite.Connection, proxies: List[Tuple[str, int]]) -> int:
if not proxies: if not proxies:
return 0 return 0
changes_before = db.total_changes
await db.executemany("DELETE FROM proxies WHERE ip = ? AND port = ?", proxies) await db.executemany("DELETE FROM proxies WHERE ip = ? AND port = ?", proxies)
await db.commit() await db.commit()
return len(proxies) return db.total_changes - changes_before
@staticmethod @staticmethod
async def get_by_ip_port( async def get_by_ip_port(

View File

@@ -19,7 +19,7 @@ class PluginRunner:
""" """
def __init__(self, timeout: Optional[float] = None): def __init__(self, timeout: Optional[float] = None):
self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30) self.timeout = timeout if timeout is not None else getattr(app_settings, "crawler_timeout", 30)
async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult: async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult:
"""执行单个插件爬取""" """执行单个插件爬取"""
@@ -37,6 +37,10 @@ class PluginRunner:
return result return result
except Exception as e: except Exception as e:
logger.warning(f"Plugin {plugin.name} health check error: {e}") logger.warning(f"Plugin {plugin.name} health check error: {e}")
result.error = f"health check error: {e}"
result.failure_count = 1
await self._save_stats(plugin, result)
return result
# 执行爬取 # 执行爬取
try: try:
@@ -66,7 +70,7 @@ class PluginRunner:
seen = set() seen = set()
unique = [] unique = []
for p in proxies: for p in proxies:
key = (p.ip, p.port) key = (p.ip, p.port, p.protocol)
if key not in seen: if key not in seen:
seen.add(key) seen.add(key)
unique.append(p) unique.append(p)

View File

@@ -132,7 +132,7 @@ class PluginService:
seen = set() seen = set()
unique = [] unique = []
for p in all_results: for p in all_results:
key = (p.ip, p.port) key = (p.ip, p.port, p.protocol)
if key not in seen: if key not in seen:
seen.add(key) seen.add(key)
unique.append(p) unique.append(p)

View File

@@ -5,7 +5,7 @@ import io
from datetime import datetime from datetime import datetime
from typing import List, Optional, Tuple, AsyncIterator from typing import List, Optional, Tuple, AsyncIterator
from app.core.db import get_db, transaction from app.core.db import get_db
from app.repositories.proxy_repo import ProxyRepository from app.repositories.proxy_repo import ProxyRepository
from app.models.domain import Proxy from app.models.domain import Proxy
from app.core.log import logger from app.core.log import logger

View File

@@ -1,6 +1,6 @@
"""调度器服务 - 定时触发全量验证""" """调度器服务 - 定时触发全量验证"""
import asyncio import asyncio
from typing import Optional from typing import Optional, Any
from app.core.execution.executor import JobExecutor from app.core.execution.executor import JobExecutor
from app.core.execution.job import ValidateAllJob from app.core.execution.job import ValidateAllJob
@@ -17,9 +17,11 @@ class SchedulerService:
def __init__( def __init__(
self, self,
executor: JobExecutor, executor: JobExecutor,
worker_pool: Optional[Any] = None,
interval_minutes: int = 30, interval_minutes: int = 30,
): ):
self.executor = executor self.executor = executor
self.worker_pool = worker_pool
self.interval_minutes = interval_minutes self.interval_minutes = interval_minutes
self.running = False self.running = False
self._stop_event = asyncio.Event() self._stop_event = asyncio.Event()
@@ -50,7 +52,7 @@ class SchedulerService:
def validate_all_now(self) -> str: def validate_all_now(self) -> str:
"""立即执行一次全量验证,返回 Job ID""" """立即执行一次全量验证,返回 Job ID"""
job_id = self.executor.submit_job(ValidateAllJob()) job_id = self.executor.submit_job(ValidateAllJob(validator_pool=self.worker_pool))
logger.info(f"ValidateAllJob submitted: {job_id}") logger.info(f"ValidateAllJob submitted: {job_id}")
return job_id return job_id
@@ -58,7 +60,7 @@ class SchedulerService:
"""定时循环""" """定时循环"""
while self.running: while self.running:
try: try:
self.executor.submit_job(ValidateAllJob()) self.executor.submit_job(ValidateAllJob(validator_pool=self.worker_pool))
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop error: {e}", exc_info=True) logger.error(f"Scheduler loop error: {e}", exc_info=True)
# 等待下一次 # 等待下一次

View File

@@ -39,9 +39,9 @@ class ValidatorService:
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
): ):
# 初始化时使用传入值或默认值,但运行期会动态读取 settings # 初始化时使用传入值或默认值,但运行期会动态读取 settings
self._init_timeout = timeout or app_settings.validator_timeout self._init_timeout = timeout if timeout is not None else app_settings.validator_timeout
self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout self._init_connect_timeout = connect_timeout if connect_timeout is not None else app_settings.validator_connect_timeout
self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency self._init_max_concurrency = max_concurrency if max_concurrency is not None else app_settings.validator_max_concurrency
self._http_connector: Optional[aiohttp.TCPConnector] = None self._http_connector: Optional[aiohttp.TCPConnector] = None
self._http_session: Optional[aiohttp.ClientSession] = None self._http_session: Optional[aiohttp.ClientSession] = None
@@ -61,23 +61,26 @@ class ValidatorService:
def max_concurrency(self) -> int: def max_concurrency(self) -> int:
return int(self._init_max_concurrency) return int(self._init_max_concurrency)
def _ensure_session(self) -> aiohttp.ClientSession: async def _ensure_session(self) -> aiohttp.ClientSession:
"""懒加载共享 HTTP session""" """懒加载共享 HTTP session"""
if self._http_session is None or self._http_session.closed: if self._http_session is None or self._http_session.closed:
connector = aiohttp.TCPConnector( async with self._lock:
ssl=False, # 双重检查,避免多个协程在获取锁后重复创建
limit=self.max_concurrency, if self._http_session is None or self._http_session.closed:
limit_per_host=self.max_concurrency, connector = aiohttp.TCPConnector(
force_close=False, ssl=False,
) limit=self.max_concurrency,
timeout = aiohttp.ClientTimeout( limit_per_host=self.max_concurrency,
total=self.timeout, connect=self.connect_timeout force_close=False,
) )
self._http_connector = connector timeout = aiohttp.ClientTimeout(
self._http_session = aiohttp.ClientSession( total=self.timeout, connect=self.connect_timeout
connector=connector, )
timeout=timeout, self._http_connector = connector
) self._http_session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
return self._http_session return self._http_session
def _ensure_semaphore(self) -> asyncio.Semaphore: def _ensure_semaphore(self) -> asyncio.Semaphore:
@@ -119,9 +122,9 @@ class ValidatorService:
return False, 0.0 return False, 0.0
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]: async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
proxy_url = f"http://{ip}:{port}" proxy_url = f"{protocol}://{ip}:{port}"
test_url = self._get_test_url(protocol) test_url = self._get_test_url(protocol)
session = self._ensure_session() session = await self._ensure_session()
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response: async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
if response.status in (200, 301, 302): if response.status in (200, 301, 302):

View File

@@ -85,12 +85,14 @@ async def sample_proxy(db, proxy_repo):
@pytest_asyncio.fixture(autouse=True) @pytest_asyncio.fixture(autouse=True)
async def mock_external_requests(monkeypatch): async def mock_external_requests(monkeypatch, request):
""" """
自动在所有测试中 mock 外部网络请求: 自动在集成/E2E 测试中 mock 外部网络请求:
1. 插件爬取返回固定测试代理,避免真实 HTTP 请求 1. 插件爬取返回固定测试代理,避免真实 HTTP 请求
2. 代理验证瞬间成功,避免连接超时等待 2. 代理验证瞬间成功,避免连接超时等待
""" """
if "/unit/" in request.node.nodeid:
return
from app.services.plugin_runner import PluginRunner from app.services.plugin_runner import PluginRunner
from app.services.validator_service import ValidatorService from app.services.validator_service import ValidatorService

View File

@@ -45,3 +45,13 @@ class TestHealthAPI:
assert isinstance(data["database"], str) assert isinstance(data["database"], str)
assert isinstance(data["scheduler"], str) assert isinstance(data["scheduler"], str)
assert isinstance(data["version"], str) assert isinstance(data["version"], str)
@pytest.mark.asyncio
async def test_404_not_found_unified_format(self, client):
"""测试 404 返回统一格式"""
response = await client.get("/api/not-exist")
assert response.status_code == 404
data = response.json()
assert data["code"] == 404
assert "message" in data
assert data["data"] is None

View File

@@ -140,11 +140,11 @@ class TestPluginsAPI:
res = await client.get(f"/api/tasks/{task_id}") res = await client.get(f"/api/tasks/{task_id}")
assert res.status_code == 200 assert res.status_code == 200
task_data = res.json()["data"] task_data = res.json()["data"]
if task_data["status"] in ("completed", "failed"): if task_data["status"] in ("completed", "failed", "cancelled"):
break break
assert task_data is not None assert task_data is not None
assert task_data["status"] == "completed" assert task_data["status"] in ("completed", "cancelled")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_crawl_nonexistent_plugin(self, client): async def test_crawl_nonexistent_plugin(self, client):
@@ -170,8 +170,8 @@ class TestPluginsAPI:
res = await client.get(f"/api/tasks/{task_id}") res = await client.get(f"/api/tasks/{task_id}")
assert res.status_code == 200 assert res.status_code == 200
task_data = res.json()["data"] task_data = res.json()["data"]
if task_data["status"] in ("completed", "failed"): if task_data["status"] in ("completed", "failed", "cancelled"):
break break
assert task_data is not None assert task_data is not None
assert task_data["status"] == "completed" assert task_data["status"] in ("completed", "cancelled")

View File

@@ -137,6 +137,15 @@ class TestProxiesAPI:
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "application/json" assert response.headers["content-type"] == "application/json"
@pytest.mark.asyncio
async def test_export_proxies_json_empty_database(self, client):
"""测试 GET /api/proxies/export/json - 空数据库"""
response = await client.get("/api/proxies/export/json")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
# 空数据库应返回空列表 JSON
assert response.content.strip() == b"[]"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_export_proxies_invalid_format(self, client): async def test_export_proxies_invalid_format(self, client):
"""测试 GET /api/proxies/export/invalid - 无效格式""" """测试 GET /api/proxies/export/invalid - 无效格式"""

View File

@@ -73,6 +73,25 @@ class TestSchedulerAPI:
assert data["code"] == 200 assert data["code"] == 200
assert data["data"]["started"] is True assert data["data"]["started"] is True
@pytest.mark.asyncio
async def test_validate_now_returns_valid_job(self, client):
"""测试 POST /api/scheduler/validate-now 返回有效 job_id"""
await client.post("/api/scheduler/start")
response = await client.post("/api/scheduler/validate-now")
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
job_id = data["data"]["job_id"]
assert isinstance(job_id, str) and len(job_id) > 0
# 通过应用状态验证 job 已被提交到 executor
from app.api.main import create_app
# 使用 client 的 app 实例
app = client._transport.app
executor = app.state.executor
job = executor.get_job(job_id)
assert job is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scheduler_full_workflow(self, client): async def test_scheduler_full_workflow(self, client):
"""测试调度器完整工作流""" """测试调度器完整工作流"""

View File

@@ -135,3 +135,32 @@ class TestSettingsAPI:
# 验证一致性 # 验证一致性
for key, value in test_settings.items(): for key, value in test_settings.items():
assert saved_settings[key] == value, f"设置项 {key} 不一致" assert saved_settings[key] == value, f"设置项 {key} 不一致"
@pytest.mark.asyncio
async def test_settings_roundtrip_with_validation_targets(self, client):
"""测试设置读写一致性 - 包含数组类型的 validation_targets"""
test_settings = {
"crawl_timeout": 30,
"validation_timeout": 10,
"max_retries": 3,
"default_concurrency": 50,
"min_proxy_score": 0,
"proxy_expiry_days": 7,
"auto_validate": True,
"validate_interval_minutes": 30,
"validation_targets": [
"http://example.com/1",
"https://example.com/2",
],
}
# 写入设置
response = await client.post("/api/settings", json=test_settings)
assert response.status_code == 200
data = response.json()
assert data["data"]["validation_targets"] == test_settings["validation_targets"]
# 读取设置
response = await client.get("/api/settings")
saved_settings = response.json()["data"]
assert saved_settings["validation_targets"] == test_settings["validation_targets"]