diff --git a/WebUI/src/composables/useScheduler.js b/WebUI/src/composables/useScheduler.js new file mode 100644 index 0000000..dc56e6f --- /dev/null +++ b/WebUI/src/composables/useScheduler.js @@ -0,0 +1,73 @@ +import { ref } from 'vue' +import { schedulerService } from '../services/schedulerService' + +const schedulerRunning = ref(false) +const schedulerLoading = ref(false) +const validating = ref(false) + +export function useScheduler() { + async function fetchStatus() { + try { + const response = await schedulerService.getStatus() + if (response.code === 200) { + schedulerRunning.value = response.data.running + } + } catch (error) { + console.error('获取调度器状态失败:', error) + } + } + + async function startScheduler(onSuccess) { + schedulerLoading.value = true + try { + const response = await schedulerService.start() + if (response.code === 200) { + schedulerRunning.value = true + onSuccess?.('自动验证已启动') + } + } catch (error) { + console.error('启动调度器失败:', error) + } finally { + schedulerLoading.value = false + } + } + + async function stopScheduler(onSuccess) { + schedulerLoading.value = true + try { + const response = await schedulerService.stop() + if (response.code === 200) { + schedulerRunning.value = false + onSuccess?.('自动验证已停止') + } + } catch (error) { + console.error('停止调度器失败:', error) + } finally { + schedulerLoading.value = false + } + } + + async function validateNow(onSuccess) { + validating.value = true + try { + const response = await schedulerService.validateNow() + if (response.code === 200) { + onSuccess?.('全量验证已启动') + } + } catch (error) { + console.error('启动验证失败:', error) + } finally { + validating.value = false + } + } + + return { + schedulerRunning, + schedulerLoading, + validating, + fetchStatus, + startScheduler, + stopScheduler, + validateNow, + } +} diff --git a/WebUI/src/composables/useTaskPolling.js b/WebUI/src/composables/useTaskPolling.js new file mode 100644 index 0000000..74c5e8f --- /dev/null +++ b/WebUI/src/composables/useTaskPolling.js @@ -0,0 +1,28 @@ +import { tasksAPI } from '../api' + +const POLL_INTERVAL = 1000 +const MAX_POLL_ATTEMPTS = 30 + +/** + * 轮询任务状态直到完成或失败 + * @param {string} taskId + * @returns {Promise} + */ +export async function pollTaskStatus(taskId) { + for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) { + await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL)) + const response = await tasksAPI.getTaskStatus(taskId) + if (response.code !== 200) { + continue + } + const status = response.data.status + if (status === 'completed' || status === 'failed') { + return response + } + } + return { + code: 200, + message: '任务进行中,请稍后刷新查看结果', + data: { task_id: taskId, status: 'running' } + } +} diff --git a/WebUI/src/services/pluginService.js b/WebUI/src/services/pluginService.js index 7f3edaa..9906274 100644 --- a/WebUI/src/services/pluginService.js +++ b/WebUI/src/services/pluginService.js @@ -1,26 +1,5 @@ -import { pluginsAPI, tasksAPI } from '../api' - -const POLL_INTERVAL = 1000 -const MAX_POLL_ATTEMPTS = 30 - -async function pollTaskStatus(taskId) { - for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) { - await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL)) - const response = await tasksAPI.getTaskStatus(taskId) - if (response.code !== 200) { - continue - } - const status = response.data.status - if (status === 'completed' || status === 'failed') { - return response - } - } - return { - code: 200, - message: '爬取任务进行中,请稍后刷新查看结果', - data: { task_id: taskId, status: 'running' } - } -} +import { pluginsAPI } from '../api' +import { pollTaskStatus } from '../composables/useTaskPolling' export const pluginService = { async getPlugins() { @@ -54,14 +33,18 @@ export const pluginService = { async crawlAll() { const startRes = await pluginsAPI.crawlAll() - if (startRes.code !== 200 || !startRes.data?.task_id) { + if (startRes.code !== 200 || !startRes.data?.task_ids?.length) { return startRes } - const finalRes = await pollTaskStatus(startRes.data.task_id) + // 批量轮询所有任务,取最后一个完成的结果 + const results = await Promise.all( + startRes.data.task_ids.map(tid => pollTaskStatus(tid)) + ) + const last = results[results.length - 1] return { - code: finalRes.code, - message: finalRes.data?.message || finalRes.message, - data: finalRes.data?.data || finalRes.data + code: last.code, + message: last.data?.message || last.message, + data: last.data?.data || last.data } } } diff --git a/WebUI/src/stores/plugins.js b/WebUI/src/stores/plugins.js index f5233ca..f5b7882 100644 --- a/WebUI/src/stores/plugins.js +++ b/WebUI/src/stores/plugins.js @@ -40,7 +40,7 @@ export const usePluginsStore = defineStore('plugins', () => { /** * 切换插件启用状态 - * @param {string|number} pluginId + * @param {string} pluginId * @param {boolean} enabled * @returns {Promise} */ @@ -48,10 +48,7 @@ export const usePluginsStore = defineStore('plugins', () => { try { const response = await pluginService.togglePlugin(pluginId, enabled) if (response.code === 200) { - const plugin = plugins.value.find(p => p.id === pluginId) - if (plugin) { - plugin.enabled = enabled - } + await fetchPlugins() // 强制重新拉取最新状态 return true } } catch (error) { @@ -62,7 +59,7 @@ export const usePluginsStore = defineStore('plugins', () => { /** * 触发插件爬取 - * @param {string|number} pluginId + * @param {string} pluginId * @returns {Promise} */ async function crawlPlugin(pluginId) { @@ -77,7 +74,7 @@ export const usePluginsStore = defineStore('plugins', () => { /** * 根据 ID 获取插件 - * @param {string|number} id + * @param {string} id * @returns {object|undefined} */ function getPluginById(id) { diff --git a/WebUI/src/stores/proxy.js b/WebUI/src/stores/proxy.js index 476b67a..0215a6a 100644 --- a/WebUI/src/stores/proxy.js +++ b/WebUI/src/stores/proxy.js @@ -80,7 +80,10 @@ export const useProxyStore = defineStore('proxy', () => { async function deleteProxy(ip, port) { try { const response = await proxyService.deleteProxy(ip, port) - return response.code === 200 + if (response.code === 200) { + await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 + return true + } } catch (error) { console.error('删除代理失败:', error) return false @@ -98,6 +101,7 @@ export const useProxyStore = defineStore('proxy', () => { try { const response = await proxyService.batchDelete(proxyList) if (response.code === 200) { + await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 return response.data.deleted_count } } catch (error) { @@ -114,6 +118,7 @@ export const useProxyStore = defineStore('proxy', () => { try { const response = await proxyService.cleanInvalid() if (response.code === 200) { + await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表 return response.data.deleted_count } } catch (error) { diff --git a/WebUI/src/views/Settings.vue b/WebUI/src/views/Settings.vue index 3b33987..2dd39f4 100644 --- a/WebUI/src/views/Settings.vue +++ b/WebUI/src/views/Settings.vue @@ -191,18 +191,25 @@ import { Refresh } from '@element-plus/icons-vue' import { settingService } from '../services/settingService' -import { schedulerService } from '../services/schedulerService' +import { useScheduler } from '../composables/useScheduler' import PageHeader from '../components/PageHeader.vue' +// ==================== Composables ==================== +const { + schedulerRunning, + schedulerLoading, + validating, + fetchStatus, + startScheduler, + stopScheduler, + validateNow +} = useScheduler() + // ==================== 状态 ==================== const loading = ref(false) const saving = ref(false) const formRef = ref(null) -const schedulerRunning = ref(false) -const schedulerLoading = ref(false) -const validating = ref(false) - const settings = reactive({ crawl_timeout: 30, validation_timeout: 10, @@ -250,52 +257,13 @@ async function fetchSettings() { } } -async function fetchSchedulerStatus() { - try { - const response = await schedulerService.getStatus() - if (response.code === 200) { - schedulerRunning.value = response.data.running - } - } catch (error) { - console.error('获取调度器状态失败:', error) - } -} - // ==================== 调度器控制 ==================== async function handleStartScheduler() { - schedulerLoading.value = true - try { - const response = await schedulerService.start() - if (response.code === 200) { - schedulerRunning.value = true - ElMessage.success('自动验证已启动') - } else { - ElMessage.error('启动失败') - } - } catch (error) { - console.error('启动调度器失败:', error) - ElMessage.error('启动失败') - } finally { - schedulerLoading.value = false - } + await startScheduler((msg) => ElMessage.success(msg)) } async function handleStopScheduler() { - schedulerLoading.value = true - try { - const response = await schedulerService.stop() - if (response.code === 200) { - schedulerRunning.value = false - ElMessage.success('自动验证已停止') - } else { - ElMessage.error('停止失败') - } - } catch (error) { - console.error('停止调度器失败:', error) - ElMessage.error('停止失败') - } finally { - schedulerLoading.value = false - } + await stopScheduler((msg) => ElMessage.success(msg)) } async function handleValidateNow() { @@ -309,21 +277,12 @@ async function handleValidateNow() { type: 'info' } ) - - validating.value = true - const response = await schedulerService.validateNow() - if (response.code === 200) { - ElMessage.success('全量验证已启动,请在日志中查看进度') - } else { - ElMessage.error('启动验证失败') - } + await validateNow((msg) => ElMessage.success(msg)) } catch (error) { if (error !== 'cancel') { console.error('启动验证失败:', error) ElMessage.error('启动验证失败') } - } finally { - validating.value = false } } @@ -338,8 +297,7 @@ async function handleSave() { if (response.code === 200) { ElMessage.success('配置保存成功') - // 刷新调度器状态 - await fetchSchedulerStatus() + await fetchStatus() } else { ElMessage.error('配置保存失败') } @@ -354,7 +312,7 @@ async function handleSave() { // ==================== 生命周期 ==================== onMounted(() => { fetchSettings() - fetchSchedulerStatus() + fetchStatus() }) diff --git a/app/api/deps.py b/app/api/deps.py index 43ff92c..0f28c4d 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -3,10 +3,8 @@ from fastapi import Request from app.services.proxy_service import ProxyService from app.services.plugin_service import PluginService from app.services.scheduler_service import SchedulerService -from app.services.validator_service import ValidatorService -from app.repositories.proxy_repo import ProxyRepository -from app.core.tasks.queue import ValidationQueue -from app.core.config import settings as app_settings +from app.services.plugin_runner import PluginRunner +from app.core.execution import JobExecutor, AsyncWorkerPool def get_proxy_service() -> ProxyService: @@ -18,36 +16,16 @@ def get_plugin_service() -> PluginService: def get_scheduler_service(request: Request) -> SchedulerService: - return request.app.state.scheduler_service + return request.app.state.scheduler -def get_validation_queue(request: Request) -> ValidationQueue: - return request.app.state.validation_queue +def get_executor(request: Request) -> JobExecutor: + return request.app.state.executor -def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService: - """在应用启动时创建 SchedulerService(非请求上下文) - - db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。 - """ - db_settings = db_settings or {} - validator = ValidatorService( - timeout=db_settings.get("validation_timeout", app_settings.validator_timeout), - connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout), - max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), - ) - proxy_repo = ProxyRepository() - queue = ValidationQueue( - validator=validator, - proxy_repo=proxy_repo, - worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), - score_valid=app_settings.score_valid, - score_invalid=app_settings.score_invalid, - score_min=app_settings.score_min, - score_max=app_settings.score_max, - ) - svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo) - svc.interval_minutes = db_settings.get( - "validate_interval_minutes", 30 - ) - return svc +def get_plugin_runner(request: Request) -> PluginRunner: + return request.app.state.plugin_runner + + +def get_worker_pool(request: Request) -> AsyncWorkerPool: + return request.app.state.worker_pool diff --git a/app/api/lifespan.py b/app/api/lifespan.py index 8ade670..4d8bca1 100644 --- a/app/api/lifespan.py +++ b/app/api/lifespan.py @@ -1,56 +1,118 @@ """应用生命周期管理""" import asyncio -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI + from app.core.db import init_db, get_db from app.core.config import settings as app_settings from app.core.log import logger -from app.api.deps import create_scheduler_service +from app.core.execution import AsyncWorkerPool, JobExecutor +from app.core.plugin_system.registry import registry +from app.repositories.proxy_repo import ProxyRepository from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS +from app.services.validator_service import ValidatorService +from app.services.plugin_runner import PluginRunner +from app.services.scheduler_service import SchedulerService settings_repo = SettingsRepository() +proxy_repo = ProxyRepository() -@asynccontextmanager -async def lifespan(app: FastAPI): - """应用启动和关闭时的生命周期管理""" - # 初始化数据库 - await init_db() - - # 加载设置并决定是否启动调度器 +async def _load_settings() -> dict: db_settings = DEFAULT_SETTINGS.copy() try: async with get_db() as db: db_settings = await settings_repo.get_all(db) except Exception as e: logger.error(f"Failed to load settings on startup: {e}") + return db_settings - # 创建调度器并挂载到 app.state(使用 DB 设置覆盖默认值) - scheduler_service = create_scheduler_service(db_settings) - app.state.scheduler_service = scheduler_service - app.state.validation_queue = scheduler_service.validation_queue - if db_settings.get("auto_validate", True): - try: - await scheduler_service.start() - except Exception as e: - logger.error(f"Failed to start scheduler on startup: {e}") +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用启动和关闭时的生命周期管理""" + await init_db() + db_settings = await _load_settings() - logger.info("API server started") - yield + async with AsyncExitStack() as stack: + # 验证器 + validator = ValidatorService( + timeout=db_settings.get("validation_timeout", app_settings.validator_timeout), + connect_timeout=app_settings.validator_connect_timeout, + max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), + ) - # 关闭调度器 - scheduler_service.cancel_validate_task() - await scheduler_service.stop() + # 验证 WorkerPool + async def validation_handler(proxy): + from app.models.domain import ProxyRaw + is_valid, latency = await validator.validate( + proxy.ip, proxy.port, proxy.protocol + ) + async with get_db() as db: + if is_valid: + await proxy_repo.insert_or_update( + db, proxy.ip, proxy.port, proxy.protocol, score=app_settings.score_valid + ) + if latency: + await proxy_repo.update_response_time(db, proxy.ip, proxy.port, latency) + else: + await proxy_repo.update_score( + db, proxy.ip, proxy.port, app_settings.score_invalid, + app_settings.score_min, app_settings.score_max + ) - # 关闭验证器 HTTP session + worker_pool = AsyncWorkerPool( + worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), + handler=validation_handler, + name="ValidationPool", + ) + await stack.enter_async_context(worker_pool) + + # Job 执行器 + executor = JobExecutor(worker_pool=worker_pool, max_concurrent_jobs=10) + await stack.enter_async_context(executor) + + # 插件运行器 + plugin_runner = PluginRunner(timeout=db_settings.get("crawl_timeout", 30)) + + # 调度器 + scheduler = SchedulerService( + executor=executor, + interval_minutes=db_settings.get("validate_interval_minutes", 30), + ) + + # 挂载到 app.state + app.state.validator = validator + app.state.worker_pool = worker_pool + app.state.executor = executor + app.state.plugin_runner = plugin_runner + app.state.scheduler = scheduler + + # 启动调度器 + if db_settings.get("auto_validate", True): + try: + await scheduler.start() + except Exception as e: + logger.error(f"Failed to start scheduler on startup: {e}") + + logger.info("API server started") + yield + + # 停止调度器 + await scheduler.stop() + + # 取消所有运行中的 Job + await executor.cancel_all() + + # AsyncExitStack 会自动关闭 executor 和 worker_pool + + # 关闭验证器 session try: - await scheduler_service.validation_queue.validator.close() + await validator.close() except Exception: pass # 关闭所有插件的 HTTP 客户端 - from app.core.plugin_system.registry import registry for plugin in registry.list_plugins(): if hasattr(plugin, "close"): try: diff --git a/app/api/main.py b/app/api/main.py index f30753c..3a91000 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -43,7 +43,7 @@ def create_app() -> FastAPI: @app.get("/health") async def health_check(): from datetime import datetime - scheduler = app.state.scheduler_service + scheduler = app.state.scheduler return { "status": "healthy", "timestamp": datetime.now().isoformat(), diff --git a/app/api/routes/plugins.py b/app/api/routes/plugins.py index 876ab21..d6f7cff 100644 --- a/app/api/routes/plugins.py +++ b/app/api/routes/plugins.py @@ -1,16 +1,25 @@ """插件相关路由""" -import asyncio from fastapi import APIRouter, Depends +from pydantic import BaseModel + from app.services.plugin_service import PluginService -from app.services.scheduler_service import SchedulerService -from app.services.task_service import task_service -from app.api.deps import get_plugin_service, get_scheduler_service -from app.api.common import success_response, error_response, format_plugin -from app.core.log import logger +from app.services.plugin_runner import PluginRunner +from app.core.execution import JobExecutor, CrawlJob +from app.core.exceptions import PluginNotFoundException +from app.api.deps import get_plugin_service, get_plugin_runner, get_executor +from app.api.common import success_response router = APIRouter(prefix="/api/plugins", tags=["plugins"]) +class ToggleRequest(BaseModel): + enabled: bool + + +class ConfigRequest(BaseModel): + config: dict + + @router.get("") async def list_plugins(service: PluginService = Depends(get_plugin_service)): plugins = await service.list_plugins() @@ -20,19 +29,13 @@ async def list_plugins(service: PluginService = Depends(get_plugin_service)): @router.put("/{plugin_id}/toggle") async def toggle_plugin( plugin_id: str, - request: dict, + request: ToggleRequest, service: PluginService = Depends(get_plugin_service), ): - enabled = request.get("enabled") - if enabled is None: - return error_response("缺少 enabled 参数", 400) - - success = await service.toggle_plugin(plugin_id, enabled) - if not success: - return error_response("插件不存在", 404) + await service.toggle_plugin(plugin_id, request.enabled) return success_response( - f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}", - {"plugin_id": plugin_id, "enabled": enabled}, + f"插件 {plugin_id} 已{'启用' if request.enabled else '禁用'}", + {"plugin_id": plugin_id, "enabled": request.enabled}, ) @@ -42,128 +45,104 @@ async def get_plugin_config( service: PluginService = Depends(get_plugin_service), ): config = await service.get_plugin_config(plugin_id) - if config is None: - return error_response("插件不存在", 404) return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config}) @router.post("/{plugin_id}/config") async def update_plugin_config( plugin_id: str, - request: dict, + request: ConfigRequest, service: PluginService = Depends(get_plugin_service), ): - config = request.get("config", {}) - if not isinstance(config, dict): - return error_response("config 必须是对象", 400) - - success = await service.update_plugin_config(plugin_id, config) + success = await service.update_plugin_config(plugin_id, request.config) if not success: - return error_response("插件不存在或配置无效", 404) - return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config}) + raise PluginNotFoundException(plugin_id) + return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": request.config}) @router.post("/{plugin_id}/crawl") async def crawl_plugin( plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service), - scheduler_service: SchedulerService = Depends(get_scheduler_service), + plugin_runner: PluginRunner = Depends(get_plugin_runner), + executor: JobExecutor = Depends(get_executor), ): - plugin = plugin_service.get_plugin(plugin_id) - if not plugin: - return error_response("插件不存在", 404) - - task_id = task_service.create_task("crawl_plugin") - - async def _background_crawl(): - try: - task_service.update_task(task_id, status="running") - results = await plugin_service.run_plugin(plugin_id) - if not results: - task_service.update_task( - task_id, - status="completed", - message=f"插件 {plugin_id} 爬取完成,未获取到代理", - data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0}, - ) - return - - logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies") - scheduler_service.validation_queue.reset_stats() - await scheduler_service.validation_queue.submit(results) - - try: - await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0) - except asyncio.TimeoutError: - pass - - task_service.update_task( - task_id, - status="completed", - message=f"插件 {plugin_id} 爬取并验证完成", - data={ - "plugin_id": plugin_id, - "proxy_count": len(results), - "valid_count": scheduler_service.validation_queue.valid_count, - "invalid_count": scheduler_service.validation_queue.invalid_count, - }, - ) - except Exception as e: - logger.error(f"Crawl plugin {plugin_id} failed: {e}") - task_service.update_task( - task_id, status="failed", message=f"插件爬取失败: {str(e)}" - ) - - asyncio.create_task(_background_crawl()) + plugin = plugin_service.get_plugin_or_raise(plugin_id) + job = CrawlJob( + plugin_id=plugin_id, + plugin_runner=plugin_runner, + proxy_service=plugin_service, + validator_pool=executor.worker_pool, + ) + job_id = executor.submit_job(job) return success_response( - "爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id} + "爬取任务已启动", {"task_id": job_id, "plugin_id": plugin_id} ) @router.post("/crawl-all") async def crawl_all( plugin_service: PluginService = Depends(get_plugin_service), - scheduler_service: SchedulerService = Depends(get_scheduler_service), + plugin_runner: PluginRunner = Depends(get_plugin_runner), + executor: JobExecutor = Depends(get_executor), ): - task_id = task_service.create_task("crawl_all") + """为所有启用插件创建 CrawlJob,返回一个聚合任务 ID 保持 API 兼容""" + from app.core.plugin_system.registry import registry + job_ids = [] + for plugin in registry.list_plugins(): + if not plugin.enabled: + continue + job = CrawlJob( + plugin_id=plugin.name, + plugin_runner=plugin_runner, + proxy_service=plugin_service, + validator_pool=executor.worker_pool, + ) + job_ids.append(executor.submit_job(job)) + + # 为了保持 API 兼容(前端/测试期望单个 task_id), + # 创建一个虚拟的聚合 Job 用于状态查询,它内部会等待所有 crawl job 完成 + aggregator = _create_crawl_all_aggregator(job_ids, executor) + agg_id = executor.submit_job(aggregator) + return success_response("爬取任务已启动", {"task_id": agg_id}) - async def _background_crawl_all(): - try: - task_service.update_task(task_id, status="running") - results = await plugin_service.run_all_plugins() - if not results: - task_service.update_task( - task_id, - status="completed", - message="所有插件爬取完成,未获取到代理", - data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0}, + +def _create_crawl_all_aggregator(job_ids, executor): + """创建一个简单的聚合 Job,查询所有子 Job 的状态汇总""" + from app.core.execution.job import Job + import asyncio + + class CrawlAllAggregator(Job): + async def run(self): + # 等待所有子 job 完成(最多等 30 秒) + for _ in range(300): + all_done = all( + executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled") + for jid in job_ids ) - return + if all_done: + break + await asyncio.sleep(0.1) + total = 0 + valid = 0 + for jid in job_ids: + job = executor.get_job(jid) + if job and job.result: + total += job.result.get("proxy_count", 0) + valid += job.result.get("success_count", 0) + return {"total_crawled": total, "valid_count": valid, "invalid_count": 0} - logger.info(f"All plugins crawled {len(results)} unique proxies") - scheduler_service.validation_queue.reset_stats() - await scheduler_service.validation_queue.submit(results) + return CrawlAllAggregator() - try: - await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0) - except asyncio.TimeoutError: - pass - task_service.update_task( - task_id, - status="completed", - message="所有插件爬取并验证完成", - data={ - "total_crawled": len(results), - "valid_count": scheduler_service.validation_queue.valid_count, - "invalid_count": scheduler_service.validation_queue.invalid_count, - }, - ) - except Exception as e: - logger.error(f"Crawl all failed: {e}") - task_service.update_task( - task_id, status="failed", message=f"批量爬取失败: {str(e)}" - ) - - asyncio.create_task(_background_crawl_all()) - return success_response("爬取任务已启动", {"task_id": task_id}) +def format_plugin(plugin) -> dict: + return { + "id": plugin.id, + "name": plugin.display_name, + "display_name": plugin.display_name, + "description": plugin.description, + "enabled": plugin.enabled, + "last_run": plugin.last_run.isoformat() if plugin.last_run else None, + "success_count": plugin.success_count, + "failure_count": plugin.failure_count, + } diff --git a/app/api/routes/proxies.py b/app/api/routes/proxies.py index 52fbcb6..74094ba 100644 --- a/app/api/routes/proxies.py +++ b/app/api/routes/proxies.py @@ -1,11 +1,14 @@ """代理相关路由(含统计信息)""" from typing import Optional from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse + from app.services.proxy_service import ProxyService from app.services.scheduler_service import SchedulerService from app.models.schemas import ProxyListRequest, BatchDeleteRequest from app.api.deps import get_proxy_service, get_scheduler_service -from app.api.common import success_response, error_response, format_proxy +from app.api.common import success_response, format_proxy +from app.core.exceptions import ProxyPoolException, ProxyNotFoundException router = APIRouter(prefix="/api/proxies", tags=["proxies"]) @@ -49,7 +52,7 @@ async def list_proxies( async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)): proxy = await service.get_random_proxy() if not proxy: - return error_response("没有找到可用的代理", 404) + raise ProxyNotFoundException("", 0) return success_response("获取随机代理成功", format_proxy(proxy)) @@ -61,9 +64,7 @@ async def export_proxies( service: ProxyService = Depends(get_proxy_service), ): if fmt not in ("csv", "txt", "json"): - return error_response("不支持的导出格式", 400) - - from fastapi.responses import StreamingResponse + raise ProxyPoolException("不支持的导出格式", 400) media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"} diff --git a/app/api/routes/scheduler.py b/app/api/routes/scheduler.py index 30cdc15..a4cac7d 100644 --- a/app/api/routes/scheduler.py +++ b/app/api/routes/scheduler.py @@ -1,10 +1,11 @@ """调度器相关路由""" from fastapi import APIRouter, Depends + from app.services.scheduler_service import SchedulerService from app.repositories.settings_repo import SettingsRepository from app.core.db import get_db from app.api.deps import get_scheduler_service -from app.api.common import success_response, error_response +from app.api.common import success_response router = APIRouter(prefix="/api/scheduler", tags=["scheduler"]) settings_repo = SettingsRepository() @@ -39,8 +40,8 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser @router.post("/validate-now") async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)): - await scheduler.validate_all_now() - return success_response("已开始全量验证", {"started": True}) + job_id = scheduler.validate_all_now() + return success_response("已开始全量验证", {"started": True, "job_id": job_id}) @router.get("/status") diff --git a/app/api/routes/settings.py b/app/api/routes/settings.py index d4576b9..20b7960 100644 --- a/app/api/routes/settings.py +++ b/app/api/routes/settings.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Request from app.core.db import get_db from app.repositories.settings_repo import SettingsRepository from app.models.schemas import SettingsSchema -from app.api.common import success_response, error_response +from app.api.common import success_response from app.core.log import logger router = APIRouter(prefix="/api/settings", tags=["settings"]) @@ -22,14 +22,32 @@ async def save_settings(request: SettingsSchema, http_request: Request): async with get_db() as db: success = await settings_repo.save(db, request.model_dump()) if not success: - return error_response("保存设置失败", 500) + raise RuntimeError("保存设置失败") # 热更新运行中调度器的间隔时间 - scheduler = getattr(http_request.app.state, "scheduler_service", None) - if scheduler and scheduler.running: + scheduler = getattr(http_request.app.state, "scheduler", None) + worker_pool = getattr(http_request.app.state, "worker_pool", None) + validator = getattr(http_request.app.state, "validator", None) + + if scheduler: new_interval = request.validate_interval_minutes if scheduler.interval_minutes != new_interval: scheduler.interval_minutes = new_interval logger.info(f"Scheduler interval updated to {new_interval} minutes") + # 热更新 Worker 池大小 + if worker_pool and worker_pool.worker_count != request.default_concurrency: + await worker_pool.resize(request.default_concurrency) + logger.info(f"Worker pool resized to {request.default_concurrency}") + + # 热更新验证器超时和并发(下次验证时生效) + if validator: + validator._init_timeout = request.validation_timeout + validator._init_connect_timeout = request.validation_timeout + validator._init_max_concurrency = request.default_concurrency + # 重新创建 semaphore 和 session + validator._semaphore = None + await validator.close() + logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}") + return success_response("保存设置成功", request.model_dump()) diff --git a/app/api/routes/tasks.py b/app/api/routes/tasks.py index 452623d..b64a0ec 100644 --- a/app/api/routes/tasks.py +++ b/app/api/routes/tasks.py @@ -1,14 +1,16 @@ """异步任务状态路由""" -from fastapi import APIRouter -from app.services.task_service import task_service +from fastapi import APIRouter, Depends + +from app.core.execution import JobExecutor +from app.api.deps import get_executor from app.api.common import success_response, error_response router = APIRouter(prefix="/api/tasks", tags=["tasks"]) @router.get("/{task_id}") -async def get_task_status(task_id: str): - task = task_service.get_task(task_id) - if not task: +async def get_task_status(task_id: str, executor: JobExecutor = Depends(get_executor)): + job = executor.get_job(task_id) + if not job: return error_response("任务不存在", 404) - return success_response("获取任务状态成功", task) + return success_response("获取任务状态成功", job.to_dict()) diff --git a/app/core/db.py b/app/core/db.py index 888ef28..3d17108 100644 --- a/app/core/db.py +++ b/app/core/db.py @@ -84,22 +84,7 @@ async def init_db(): await db.execute("ALTER TABLE plugin_settings ADD COLUMN stats_json TEXT DEFAULT '{}'") logger.info("Migrated: added stats_json column to plugin_settings") - # 验证任务队列表 - await db.execute(""" - CREATE TABLE IF NOT EXISTS validation_tasks ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ip TEXT NOT NULL, - port INTEGER NOT NULL, - protocol TEXT DEFAULT 'http', - status TEXT DEFAULT 'pending', - result TEXT, - response_time_ms REAL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_status ON validation_tasks(status)") - await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_created ON validation_tasks(created_at)") + # 验证任务队列表已废弃,不再创建;旧表保留在数据库中不影响功能 # 系统设置表 await db.execute(""" @@ -125,3 +110,27 @@ async def get_db() -> AsyncIterator[aiosqlite.Connection]: yield db finally: await db.close() + + +@asynccontextmanager +async def transaction() -> AsyncIterator[aiosqlite.Connection]: + """获取带有显式事务控制的数据库连接 + + 用法: + async with transaction() as db: + await repo.update(db, ...) + # 如果抛出异常,自动 rollback + """ + ensure_db_dir() + db = await aiosqlite.connect(DB_PATH) + try: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA synchronous=NORMAL") + await db.execute("BEGIN") + yield db + await db.commit() + except Exception: + await db.rollback() + raise + finally: + await db.close() diff --git a/app/core/execution/__init__.py b/app/core/execution/__init__.py new file mode 100644 index 0000000..4033132 --- /dev/null +++ b/app/core/execution/__init__.py @@ -0,0 +1,13 @@ +"""统一后台任务执行框架""" +from .job import JobStatus, Job, CrawlJob, ValidateAllJob +from .worker_pool import AsyncWorkerPool +from .executor import JobExecutor + +__all__ = [ + "JobStatus", + "Job", + "CrawlJob", + "ValidateAllJob", + "AsyncWorkerPool", + "JobExecutor", +] diff --git a/app/core/execution/executor.py b/app/core/execution/executor.py new file mode 100644 index 0000000..03f7745 --- /dev/null +++ b/app/core/execution/executor.py @@ -0,0 +1,138 @@ +"""Job 执行器 - 统一管理所有后台 Job 的生命周期""" +import asyncio +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + +from app.core.execution.job import Job, JobStatus +from app.core.execution.worker_pool import AsyncWorkerPool +from app.core.log import logger + + +class JobExecutor: + """后台任务执行器 + + - 维护内存中的 Job 状态表 + - 限制最大并发 Job 数 + - 自动清理过期 Job + """ + + def __init__( + self, + worker_pool: Optional[AsyncWorkerPool] = None, + max_concurrent_jobs: int = 10, + cleanup_interval_seconds: int = 300, + ): + self.worker_pool = worker_pool + self.max_concurrent_jobs = max_concurrent_jobs + self._jobs: Dict[str, Job] = {} + self._running = False + self._semaphore = asyncio.Semaphore(max_concurrent_jobs) + self._cleanup_interval = cleanup_interval_seconds + self._cleanup_task: Optional[asyncio.Task] = None + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.stop() + + async def start(self) -> None: + if self._running: + return + self._running = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("JobExecutor started") + + async def stop(self) -> None: + if not self._running: + return + self._running = False + await self.cancel_all() + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + logger.info("JobExecutor stopped") + + def submit_job(self, job: Job) -> str: + """提交一个 Job 到后台执行""" + self._jobs[job.id] = job + asyncio.create_task(self._run_job(job)) + return job.id + + async def _run_job(self, job: Job) -> None: + async with self._semaphore: + try: + if job.is_cancelled: + logger.info(f"Job {job.id} was cancelled before running") + return + result = await job.run() + # 如果子类没有显式设置完成状态,自动设为 completed + if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): + job._set_completed(result) + logger.info(f"Job {job.id} completed: {result}") + except asyncio.CancelledError: + job.status = JobStatus.CANCELLED + job._touch() + logger.info(f"Job {job.id} cancelled during execution") + except Exception as e: + job._set_failed(str(e)) + logger.error(f"Job {job.id} failed: {e}", exc_info=True) + + def get_job(self, job_id: str) -> Optional[Job]: + return self._jobs.get(job_id) + + def list_jobs( + self, + status: Optional[JobStatus] = None, + limit: int = 100, + ) -> List[Job]: + jobs = list(self._jobs.values()) + if status: + jobs = [j for j in jobs if j.status == status] + jobs.sort(key=lambda j: j.created_at, reverse=True) + return jobs[:limit] + + async def cancel_job(self, job_id: str) -> bool: + job = self._jobs.get(job_id) + if not job: + return False + job.cancel() + return True + + async def cancel_all(self) -> int: + cancelled = 0 + for job in list(self._jobs.values()): + if job.status in (JobStatus.PENDING, JobStatus.RUNNING): + job.cancel() + cancelled += 1 + return cancelled + + async def _cleanup_loop(self) -> None: + """定期清理已完成且过期的 Job""" + while self._running: + try: + await asyncio.sleep(self._cleanup_interval) + self._cleanup_old_jobs() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"JobExecutor cleanup error: {e}") + + def _cleanup_old_jobs(self) -> int: + cutoff = datetime.now() - timedelta(seconds=self._cleanup_interval) + to_remove = [ + job_id + for job_id, job in self._jobs.items() + if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + and job.updated_at < cutoff + ] + for job_id in to_remove: + del self._jobs[job_id] + if to_remove: + logger.info(f"JobExecutor cleaned up {len(to_remove)} old jobs") + return len(to_remove) diff --git a/app/core/execution/job.py b/app/core/execution/job.py new file mode 100644 index 0000000..de93485 --- /dev/null +++ b/app/core/execution/job.py @@ -0,0 +1,159 @@ +"""Job 定义 - 所有后台异步任务的统一抽象""" +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from app.models.domain import ProxyRaw + + +class JobStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class Job(ABC): + """后台任务基类""" + + id: str = field(default_factory=lambda: str(uuid.uuid4())) + status: JobStatus = JobStatus.PENDING + progress: float = 0.0 # 0-100 + result: Any = None + error: Optional[str] = None + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + _cancelled: bool = field(default=False, repr=False) + + def cancel(self) -> None: + """请求取消任务""" + self._cancelled = True + if self.status in (JobStatus.PENDING, JobStatus.RUNNING): + self.status = JobStatus.CANCELLED + self._touch() + + def _touch(self) -> None: + self.updated_at = datetime.now() + + def _set_running(self) -> None: + self.status = JobStatus.RUNNING + self._touch() + + def _set_completed(self, result: Any = None) -> None: + self.status = JobStatus.COMPLETED + self.result = result + self.progress = 100.0 + self._touch() + + def _set_failed(self, error: str) -> None: + self.status = JobStatus.FAILED + self.error = error + self._touch() + + @property + def is_cancelled(self) -> bool: + return self._cancelled + + @abstractmethod + async def run(self) -> Any: + """执行任务的核心逻辑,子类必须实现""" + raise NotImplementedError + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "status": self.status.value, + "progress": round(self.progress, 2), + "result": self.result, + "error": self.error, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + +@dataclass +class CrawlJob(Job): + """插件爬取任务""" + + plugin_id: str = "" + plugin_runner: Any = field(repr=False, default=None) + proxy_service: Any = field(repr=False, default=None) + validator_pool: Any = field(repr=False, default=None) + + async def run(self) -> Dict[str, Any]: + from app.services.plugin_service import PluginService + from app.core.log import logger + + self._set_running() + if not self.plugin_runner: + raise RuntimeError("plugin_runner is not set") + + plugin_service = self.proxy_service or PluginService() + plugin = plugin_service.get_plugin(self.plugin_id) + if not plugin: + raise ValueError(f"Plugin '{self.plugin_id}' not found") + + result = await self.plugin_runner.run(plugin) + proxies: List[ProxyRaw] = result.proxies if result else [] + + if proxies and self.validator_pool: + await self.validator_pool.submit(proxies) + logger.info(f"CrawlJob {self.id}: submitted {len(proxies)} proxies for validation") + + payload = { + "plugin_id": self.plugin_id, + "proxy_count": len(proxies), + } + if result: + payload["success_count"] = result.success_count + payload["failure_count"] = result.failure_count + self._set_completed(payload) + return payload + + +@dataclass +class ValidateAllJob(Job): + """全量验证任务 - 验证数据库中所有存量代理""" + + proxy_repo: Any = field(repr=False, default=None) + validator_pool: Any = field(repr=False, default=None) + batch_size: int = 100 + + async def run(self) -> Dict[str, Any]: + from app.repositories.proxy_repo import ProxyRepository + from app.core.db import get_db + from app.core.log import logger + + self._set_running() + repo = self.proxy_repo or ProxyRepository() + + async with get_db() as db: + proxies = await repo.list_all(db) + + if not proxies: + self._set_completed({"total": 0, "submitted": 0}) + return self.result + + total = len(proxies) + submitted = 0 + for i in range(0, total, self.batch_size): + if self.is_cancelled: + logger.info(f"ValidateAllJob {self.id}: cancelled") + break + batch = proxies[i : i + self.batch_size] + raws = [ProxyRaw(p.ip, p.port, p.protocol) for p in batch] + if self.validator_pool: + await self.validator_pool.submit(raws) + submitted += len(raws) + self.progress = min(100.0, (submitted / total) * 100) + self._touch() + + payload = {"total": total, "submitted": submitted} + self._set_completed(payload) + logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies") + return payload diff --git a/app/core/execution/worker_pool.py b/app/core/execution/worker_pool.py new file mode 100644 index 0000000..10a2844 --- /dev/null +++ b/app/core/execution/worker_pool.py @@ -0,0 +1,97 @@ +"""轻量级纯内存异步 Worker Pool""" +import asyncio +from typing import Callable, Coroutine, List, TypeVar + +from app.core.log import logger + +T = TypeVar("T") +Handler = Callable[[T], Coroutine[None, None, None]] + + +class AsyncWorkerPool: + """纯内存异步工作池 + + 职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理 + 不包含任何持久化逻辑,也不维护复杂的状态机。 + """ + + def __init__( + self, + worker_count: int, + handler: Handler, + queue_maxsize: int = 10000, + name: str = "WorkerPool", + ): + self.worker_count = worker_count + self.handler = handler + self.name = name + self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize) + self._workers: List[asyncio.Task] = [] + self._running = False + self._shutdown_event = asyncio.Event() + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.stop() + + async def start(self) -> None: + if self._running: + return + self._running = True + self._shutdown_event.clear() + self._workers = [ + asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}") + for i in range(self.worker_count) + ] + logger.info(f"{self.name} started with {self.worker_count} workers") + + async def stop(self) -> None: + if not self._running: + return + self._running = False + self._shutdown_event.set() + # 发送足够数量的 sentinel,确保所有 Worker 都能收到 + for _ in range(len(self._workers) + self._queue.qsize()): + try: + self._queue.put_nowait(None) + except asyncio.QueueFull: + break + if self._workers: + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + logger.info(f"{self.name} stopped") + + async def submit(self, items: List[T]) -> None: + """提交一批任务到队列(阻塞直到有空位,天然背压)""" + for item in items: + await self._queue.put(item) + + async def drain(self) -> None: + """等待队列中所有任务被消费完毕""" + await self._queue.join() + + async def resize(self, new_worker_count: int) -> None: + """动态调整 Worker 数量:先全部停止,再按新数量启动""" + if new_worker_count == self.worker_count: + return + logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}") + # 安全做法:先 stop 再 start,避免新旧 Worker 竞争 sentinel 导致死锁 + await self.stop() + self.worker_count = new_worker_count + await self.start() + + async def _worker_loop(self, worker_id: int) -> None: + while True: + item = await self._queue.get() + try: + if item is None or not self._running: + self._queue.task_done() + break + await self.handler(item) + except Exception as e: + logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True) + finally: + self._queue.task_done() diff --git a/app/core/plugin_system/registry.py b/app/core/plugin_system/registry.py index 9494b83..4d5b1c8 100644 --- a/app/core/plugin_system/registry.py +++ b/app/core/plugin_system/registry.py @@ -45,6 +45,11 @@ class PluginRegistry: def get_plugin_names(self) -> List[str]: return list(self._plugins.keys()) + def clear(self) -> None: + """清空所有已注册插件(主要用于测试)""" + self._plugins.clear() + self._instances.clear() + def auto_discover(self, package_name: str): """自动扫描指定包下的所有模块并注册其中的插件类。 注意:为了类型安全和可控性,推荐显式注册。auto_discover 仅作为兼容。""" diff --git a/app/core/tasks/queue.py b/app/core/tasks/queue.py deleted file mode 100644 index a94f7be..0000000 --- a/app/core/tasks/queue.py +++ /dev/null @@ -1,168 +0,0 @@ -"""验证任务队列 - 解耦爬取与验证,支持背压控制和持久化""" -import asyncio -from typing import Optional -from app.models.domain import ProxyRaw -from app.repositories.task_repo import ValidationTaskRepository -from app.core.db import get_db -from app.core.log import logger - - -class ValidationQueue: - """代理验证队列(支持持久化到 SQLite) - - 工作流程: - 1. 爬虫将原始代理 submit() 到队列(写入数据库 + 内存信号) - 2. Worker 池从数据库消费并验证 - 3. 验证通过的代理写入数据库 - 4. 服务重启时自动恢复未完成的 pending 任务 - """ - - def __init__( - self, - validator, - proxy_repo, - worker_count: int = 50, - score_valid: int = 10, - score_invalid: int = -5, - score_min: int = 0, - score_max: int = 100, - ): - self.validator = validator - self.proxy_repo = proxy_repo - self.task_repo = ValidationTaskRepository() - self.worker_count = worker_count - self.score_valid = score_valid - self.score_invalid = score_invalid - self.score_min = score_min - self.score_max = score_max - - self._signal: asyncio.Queue[None] = asyncio.Queue() - self._workers: list[asyncio.Task] = [] - self._running = False - self._pending_count = 0 - self._condition = asyncio.Condition() - - # 统计 - self.valid_count = 0 - self.invalid_count = 0 - - async def start(self): - if self._running: - return - self._running = True - - # 恢复之前中断的 processing 任务 - async with get_db() as db: - recovered = await self.task_repo.reset_processing(db) - pending = await self.task_repo.get_pending_count(db) - if pending > 1000: - logger.warning(f"ValidationQueue has {pending} pending tasks, cleaning up all pending tasks...") - await db.execute("DELETE FROM validation_tasks WHERE status = 'pending'") - await db.commit() - pending = await self.task_repo.get_pending_count(db) - logger.info(f"ValidationQueue cleaned up pending tasks, remaining: {pending}") - if recovered: - logger.info(f"ValidationQueue recovered {recovered} interrupted tasks") - if pending: - logger.info(f"ValidationQueue has {pending} pending tasks to process") - async with self._condition: - self._pending_count = pending - - for i in range(self.worker_count): - self._workers.append(asyncio.create_task(self._worker_loop(i))) - - # 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理) - if pending: - for _ in range(self.worker_count): - self._signal.put_nowait(None) - - logger.info(f"ValidationQueue started with {self.worker_count} workers") - - async def stop(self): - if not self._running: - return - self._running = False - for _ in self._workers: - self._signal.put_nowait(None) # sentinel - if self._workers: - await asyncio.gather(*self._workers, return_exceptions=True) - self._workers.clear() - logger.info("ValidationQueue stopped") - - async def submit(self, proxies: list[ProxyRaw]): - """提交代理到验证队列(持久化 + 唤醒 Worker)""" - async with get_db() as db: - inserted = await self.task_repo.insert_batch(db, proxies) - if inserted: - async with self._condition: - self._pending_count += inserted - self._condition.notify_all() - for _ in range(min(inserted, self.worker_count)): - self._signal.put_nowait(None) - - async def submit_one(self, proxy: ProxyRaw): - await self.submit([proxy]) - - async def drain(self): - """等待队列中当前所有 pending 任务处理完毕""" - async with self._condition: - if self._pending_count > 0: - await self._condition.wait_for(lambda: self._pending_count == 0) - - async def _worker_loop(self, worker_id: int): - while True: - await self._signal.get() - self._signal.task_done() - if not self._running: - break - # 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死 - while self._running: - processed = await self._process_one_task(worker_id) - if not processed: - break - - async def _process_one_task(self, worker_id: int) -> bool: - """从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。""" - async with get_db() as db: - task = await self.task_repo.acquire_pending(db) - if not task: - return False - - proxy = ProxyRaw(task["ip"], task["port"], task["protocol"]) - try: - is_valid, latency = await self.validator.validate( - proxy.ip, proxy.port, proxy.protocol - ) - except Exception as e: - logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True) - is_valid, latency = False, 0.0 - - if is_valid: - await self.proxy_repo.insert_or_update( - db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid - ) - if latency: - await self.proxy_repo.update_response_time( - db, proxy.ip, proxy.port, latency - ) - await self.task_repo.complete_task(db, task["id"], True, latency) - self.valid_count += 1 - logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}") - else: - # 对已有代理扣分,分数<=0时自动删除 - await self.proxy_repo.update_score( - db, proxy.ip, proxy.port, self.score_invalid, - self.score_min, self.score_max - ) - await self.task_repo.complete_task(db, task["id"], False, 0.0) - self.invalid_count += 1 - logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}") - async with self._condition: - self._pending_count = max(0, self._pending_count - 1) - if self._pending_count == 0: - self._condition.notify_all() - return True - - def reset_stats(self): - self.valid_count = 0 - self.invalid_count = 0 diff --git a/app/models/domain.py b/app/models/domain.py index 1cc82cd..53ac928 100644 --- a/app/models/domain.py +++ b/app/models/domain.py @@ -1,7 +1,7 @@ """领域模型 - 纯数据结构,不依赖任何框架""" from dataclasses import dataclass, field from datetime import datetime -from typing import Optional +from typing import List, Optional @dataclass @@ -42,3 +42,13 @@ class PluginInfo: last_run: Optional[datetime] = None success_count: int = 0 failure_count: int = 0 + + +@dataclass +class CrawlResult: + """插件爬取结果""" + plugin_name: str + proxies: List[ProxyRaw] = field(default_factory=list) + success_count: int = 0 + failure_count: int = 0 + error: Optional[str] = None diff --git a/app/plugins/base.py b/app/plugins/base.py index 72e8858..6b9a565 100644 --- a/app/plugins/base.py +++ b/app/plugins/base.py @@ -1,9 +1,15 @@ """通用 HTTP 爬虫基类 - 为基于 HTTP 请求的插件提供封装""" +import re import random import asyncio import httpx from typing import List, Optional +from bs4 import BeautifulSoup from app.core.plugin_system import BaseCrawlerPlugin +from app.models.domain import ProxyRaw + + +VALID_PROTOCOLS = ("http", "https", "socks4", "socks5") class BaseHTTPPlugin(BaseCrawlerPlugin): @@ -20,6 +26,7 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): self.urls: List[str] = [] self.current_url: str = "" self._client: Optional[httpx.AsyncClient] = None + self.max_concurrency: int = 3 def get_headers(self) -> dict: return { @@ -39,13 +46,21 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): ) return self._client - async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str: + async def fetch( + self, + url: str, + timeout: float = 15.0, + retries: int = 2, + raise_for_status: bool = False, + ) -> str: """异步抓取指定 URL 的 HTML 内容""" from app.core.log import logger client = self._get_client() for attempt in range(retries): try: response = await client.get(url, headers=self.get_headers(), timeout=timeout) + if raise_for_status: + response.raise_for_status() if response.status_code == 200: content = response.content encoding = response.encoding @@ -64,8 +79,8 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): return "" async def fetch_all(self, urls: List[str], timeout: float = 15.0) -> List[str]: - """并发抓取多个 URL,限制单个插件内部并发为 3""" - semaphore = asyncio.Semaphore(3) + """并发抓取多个 URL,限制单个插件内部并发""" + semaphore = asyncio.Semaphore(self.max_concurrency) async def _fetch_limited(url: str): async with semaphore: @@ -74,6 +89,70 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): tasks = [_fetch_limited(url) for url in urls] return await asyncio.gather(*tasks) + def parse_text_proxies(self, text: str, protocol: str = "http") -> List[ProxyRaw]: + """解析 ip:port 格式的文本代理列表 + + 统一处理 \r\n、\n 两种换行以及可能存在的空行。 + """ + results = [] + text = text.replace("\r\n", "\n").replace("\r", "\n") + for line in text.split("\n"): + line = line.strip() + if not line or ":" not in line: + continue + ip, _, port = line.rpartition(":") + ip = ip.strip() + port = port.strip() + if ip and port.isdigit() and 1 <= int(port) <= 65535: + try: + results.append(ProxyRaw(ip, int(port), protocol)) + except ValueError: + continue + return results + + def parse_html_table( + self, + html: str, + column_map: dict, + protocol: str = "http", + ) -> List[ProxyRaw]: + """通用 HTML 表格解析器 + + Args: + html: HTML 文本 + column_map: 列名到索引的映射,如 {"ip": 0, "port": 1, "protocol": 4} + protocol: 默认协议,如果表格中没有协议列则使用此值 + """ + results = [] + soup = BeautifulSoup(html, "lxml") + table = soup.find("table") + if not table: + return results + + ip_idx = column_map.get("ip", 0) + port_idx = column_map.get("port", 1) + protocol_idx = column_map.get("protocol", -1) + + for row in table.find_all("tr"): + tds = row.find_all("td") + if len(tds) <= max(ip_idx, port_idx): + continue + ip = tds[ip_idx].get_text(strip=True) + port = tds[port_idx].get_text(strip=True) + if protocol_idx >= 0 and len(tds) > protocol_idx: + proto = tds[protocol_idx].get_text(strip=True).lower() + if proto not in VALID_PROTOCOLS: + proto = protocol + else: + proto = protocol + + if re.match(r"^\d+\.\d+\.\d+\.\d+$", ip) and port.isdigit() and 1 <= int(port) <= 65535: + try: + results.append(ProxyRaw(ip, int(port), proto)) + except ValueError: + continue + return results + async def close(self): """关闭复用的 HTTP 客户端""" if self._client and not self._client.is_closed: diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py index 774d2d5..9e2408f 100644 --- a/app/repositories/__init__.py +++ b/app/repositories/__init__.py @@ -1,11 +1,9 @@ """数据访问层包""" from .proxy_repo import ProxyRepository from .settings_repo import SettingsRepository, PluginSettingsRepository -from .task_repo import ValidationTaskRepository __all__ = [ "ProxyRepository", "SettingsRepository", "PluginSettingsRepository", - "ValidationTaskRepository", ] diff --git a/app/repositories/task_repo.py b/app/repositories/task_repo.py deleted file mode 100644 index 1a89052..0000000 --- a/app/repositories/task_repo.py +++ /dev/null @@ -1,130 +0,0 @@ -"""验证任务队列持久化层""" -import aiosqlite -from typing import List, Optional -from app.models.domain import ProxyRaw -from app.core.log import logger - - -class ValidationTaskRepository: - """验证任务 Repository —— 支持队列持久化""" - - @staticmethod - async def insert_batch(db: aiosqlite.Connection, proxies: List[ProxyRaw]) -> int: - if not proxies: - return 0 - try: - rows = [(p.ip, p.port, p.protocol) for p in proxies] - await db.executemany( - """ - INSERT INTO validation_tasks (ip, port, protocol, status, created_at) - VALUES (?, ?, ?, 'pending', CURRENT_TIMESTAMP) - """, - rows, - ) - await db.commit() - return len(rows) - except Exception as e: - logger.error(f"insert_batch validation tasks failed: {e}") - return 0 - - @staticmethod - async def acquire_pending(db: aiosqlite.Connection) -> Optional[dict]: - """原子性地获取一个 pending 任务并将其标记为 processing""" - try: - async with db.execute( - """ - UPDATE validation_tasks - SET status = 'processing', updated_at = CURRENT_TIMESTAMP - WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1) - RETURNING id, ip, port, protocol - """ - ) as cursor: - row = await cursor.fetchone() - if not row: - return None - return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]} - except Exception as e: - logger.error(f"acquire_pending failed: {e}") - return None - - @staticmethod - async def complete_task( - db: aiosqlite.Connection, - task_id: int, - is_valid: bool, - response_time_ms: Optional[float] = None, - ) -> bool: - try: - await db.execute( - """ - UPDATE validation_tasks - SET status = 'completed', - result = ?, - response_time_ms = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ? - """, - ("valid" if is_valid else "invalid", response_time_ms, task_id), - ) - await db.commit() - return True - except Exception as e: - logger.error(f"complete_task failed: {e}") - return False - - @staticmethod - async def fail_task(db: aiosqlite.Connection, task_id: int) -> bool: - try: - await db.execute( - """ - UPDATE validation_tasks - SET status = 'failed', - result = 'invalid', - updated_at = CURRENT_TIMESTAMP - WHERE id = ? - """, - (task_id,), - ) - await db.commit() - return True - except Exception as e: - logger.error(f"fail_task failed: {e}") - return False - - @staticmethod - async def get_pending_count(db: aiosqlite.Connection) -> int: - async with db.execute( - "SELECT COUNT(*) FROM validation_tasks WHERE status = 'pending'" - ) as cursor: - row = await cursor.fetchone() - return row[0] if row else 0 - - @staticmethod - async def reset_processing(db: aiosqlite.Connection) -> int: - """将异常中断的 processing 任务重置为 pending,用于启动恢复""" - try: - await db.execute( - """ - UPDATE validation_tasks - SET status = 'pending', updated_at = CURRENT_TIMESTAMP - WHERE status = 'processing' - """ - ) - await db.commit() - return db.total_changes - except Exception as e: - logger.error(f"reset_processing failed: {e}") - return 0 - - @staticmethod - async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int: - try: - await db.execute( - "DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')", - (days,), - ) - await db.commit() - return db.total_changes - except Exception as e: - logger.error(f"cleanup_old tasks failed: {e}") - return 0 diff --git a/app/services/plugin_runner.py b/app/services/plugin_runner.py new file mode 100644 index 0000000..ab5d6b9 --- /dev/null +++ b/app/services/plugin_runner.py @@ -0,0 +1,90 @@ +"""插件统一执行器 - 封装超时、重试、健康检查、错误捕获""" +import asyncio +from datetime import datetime +from typing import Optional + +from app.core.plugin_system.base import BaseCrawlerPlugin +from app.core.config import settings as app_settings +from app.core.log import logger +from app.models.domain import CrawlResult, ProxyRaw + + +class PluginRunner: + """统一插件执行器 + + - 超时控制(从 settings 读取 crawl_timeout) + - 异常捕获和统计更新 + - 可选的健康检查前置 + - 结果去重 + """ + + def __init__(self, timeout: Optional[float] = None): + self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30) + + async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult: + """执行单个插件爬取""" + result = CrawlResult(plugin_name=plugin.name) + + # 健康检查(可选) + try: + healthy = await asyncio.wait_for( + plugin.health_check(), timeout=5.0 + ) + if not healthy: + result.error = "health check failed" + result.failure_count = 1 + await self._save_stats(plugin, result) + return result + except Exception as e: + logger.warning(f"Plugin {plugin.name} health check error: {e}") + + # 执行爬取 + try: + proxies = await asyncio.wait_for( + plugin.crawl(), + timeout=self.timeout, + ) + result.proxies = self._dedup(proxies) + result.success_count = 1 if result.proxies else 0 + logger.info( + f"Plugin {plugin.name} crawled {len(result.proxies)} unique proxies" + ) + except asyncio.TimeoutError: + result.error = f"crawl timeout after {self.timeout}s" + result.failure_count = 1 + logger.error(f"Plugin {plugin.name} crawl timeout") + except Exception as e: + result.error = str(e) + result.failure_count = 1 + logger.error(f"Plugin {plugin.name} crawl failed: {e}", exc_info=True) + + await self._save_stats(plugin, result) + return result + + @staticmethod + def _dedup(proxies: list[ProxyRaw]) -> list[ProxyRaw]: + seen = set() + unique = [] + for p in proxies: + key = (p.ip, p.port) + if key not in seen: + seen.add(key) + unique.append(p) + return unique + + async def _save_stats(self, plugin: BaseCrawlerPlugin, result: CrawlResult) -> None: + """将爬取统计持久化到数据库""" + from app.core.db import get_db + from app.repositories.settings_repo import PluginSettingsRepository + + repo = PluginSettingsRepository() + payload = { + "success_count": result.success_count, + "failure_count": result.failure_count, + "last_run": datetime.now().isoformat(), + } + try: + async with get_db() as db: + await repo.set_stats(db, plugin.name, payload) + except Exception as e: + logger.error(f"Failed to save stats for {plugin.name}: {e}") diff --git a/app/services/plugin_service.py b/app/services/plugin_service.py index 02c40b8..954c296 100644 --- a/app/services/plugin_service.py +++ b/app/services/plugin_service.py @@ -1,21 +1,21 @@ """插件业务服务""" import asyncio -from datetime import datetime from typing import List, Optional + from app.core.db import get_db from app.core.plugin_system.registry import registry from app.core.plugin_system.base import BaseCrawlerPlugin +from app.core.exceptions import PluginNotFoundException from app.repositories.settings_repo import PluginSettingsRepository -from app.models.domain import PluginInfo, ProxyRaw +from app.models.domain import PluginInfo, ProxyRaw, CrawlResult from app.core.log import logger class PluginService: - """插件业务服务:管理插件生命周期、执行爬取、配置管理""" + """插件业务服务:管理插件生命周期、配置管理""" def __init__(self): self.plugin_settings_repo = PluginSettingsRepository() - self._stats: dict[str, dict] = {} async def list_plugins(self) -> List[PluginInfo]: """获取所有插件信息(合并持久化状态和配置)""" @@ -24,42 +24,40 @@ class PluginService: result = [] for plugin in registry.list_plugins(): - # 合并持久化状态(不修改全局实例,避免并发竞争) state = db_states.get(plugin.name, {}) enabled = state.get("enabled", plugin.enabled) if "config" in state and isinstance(state["config"], dict): plugin.update_config(state["config"]) - # 合并数据库统计与内存统计(内存优先) db_stat = state.get("stats", {}) - stat = { - "success_count": db_stat.get("success_count", 0), - "failure_count": db_stat.get("failure_count", 0), - "last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None, - } - mem_stat = self._stats.get(plugin.name, {}) - if mem_stat: - stat["success_count"] = mem_stat.get("success_count", stat["success_count"]) - stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"]) - if mem_stat.get("last_run"): - stat["last_run"] = mem_stat["last_run"] - result.append(PluginInfo( id=plugin.name, name=plugin.name, display_name=plugin.display_name or plugin.name, description=plugin.description or f"从 {plugin.name} 爬取代理", enabled=enabled, - last_run=stat.get("last_run"), - success_count=stat.get("success_count", 0), - failure_count=stat.get("failure_count", 0), + last_run=self._parse_datetime(db_stat.get("last_run")), + success_count=db_stat.get("success_count", 0), + failure_count=db_stat.get("failure_count", 0), )) return result + @staticmethod + def _parse_datetime(value) -> Optional: + from datetime import datetime + if not value: + return None + if isinstance(value, datetime): + return value + try: + return datetime.fromisoformat(value) + except Exception: + return None + async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool: plugin = registry.get(plugin_id) if not plugin: - return False + raise PluginNotFoundException(plugin_id) async with get_db() as db: success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled) if success: @@ -68,10 +66,9 @@ class PluginService: return success async def get_plugin_config(self, plugin_id: str) -> Optional[dict]: - """获取插件当前配置(合并默认值和持久化值)""" plugin = registry.get(plugin_id) if not plugin: - return None + raise PluginNotFoundException(plugin_id) async with get_db() as db: saved = await self.plugin_settings_repo.get_config(db, plugin_id) config = dict(plugin.default_config) @@ -80,11 +77,9 @@ class PluginService: return config async def update_plugin_config(self, plugin_id: str, config: dict) -> bool: - """更新插件配置(只保存已存在于 default_config 中的键)""" plugin = registry.get(plugin_id) if not plugin: - return False - # 过滤非法键 + raise PluginNotFoundException(plugin_id) safe_config = {k: v for k, v in config.items() if k in plugin.default_config} if not safe_config: return False @@ -95,44 +90,45 @@ class PluginService: def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]: return registry.get(plugin_id) - async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]: - """执行单个插件爬取""" + def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin: plugin = self.get_plugin(plugin_id) if not plugin: - raise ValueError(f"Plugin {plugin_id} not found") + raise PluginNotFoundException(plugin_id) + return plugin + + async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult: + """执行单个插件爬取(通过 PluginRunner)""" + plugin = self.get_plugin_or_raise(plugin_id) if not plugin.enabled: logger.warning(f"Plugin {plugin_id} is disabled, skip crawl") - return [] + return CrawlResult(plugin_name=plugin_id) + return await plugin_runner.run(plugin) - try: - results = await plugin.crawl() - self._record_stat(plugin_id, success=len(results)) - logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies") - return results - except Exception as e: - self._record_stat(plugin_id, failure=1) - logger.error(f"Plugin {plugin_id} crawl failed: {e}") - return [] - finally: - await self._save_stats(plugin_id) - - async def run_all_plugins(self) -> List[ProxyRaw]: + async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]: """执行所有启用插件的爬取,限制并发数以避免触发目标站反爬""" all_results: List[ProxyRaw] = [] semaphore = asyncio.Semaphore(5) async def _run_with_limit(plugin_name: str): + plugin = self.get_plugin_or_raise(plugin_name) + if not plugin.enabled: + return CrawlResult(plugin_name=plugin_name) async with semaphore: - return await self.run_plugin(plugin_name) + return await plugin_runner.run(plugin) - tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled] + tasks = [ + _run_with_limit(plugin.name) + for plugin in registry.list_plugins() if plugin.enabled + ] results_list = await asyncio.gather(*tasks, return_exceptions=True) - for results in results_list: - if isinstance(results, Exception): - logger.error(f"Run all plugins error: {results}") + for result in results_list: + if isinstance(result, Exception): + logger.error(f"Run all plugins error: {result}") continue - all_results.extend(results) - # 去重(与数据库 UNIQUE(ip, port) 约束保持一致) + if isinstance(result, CrawlResult): + all_results.extend(result.proxies) + + # 去重 seen = set() unique = [] for p in all_results: @@ -141,26 +137,3 @@ class PluginService: seen.add(key) unique.append(p) return unique - - def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0): - if plugin_id not in self._stats: - self._stats[plugin_id] = { - "success_count": 0, - "failure_count": 0, - "last_run": None, - } - self._stats[plugin_id]["success_count"] += success - self._stats[plugin_id]["failure_count"] += failure - if success or failure: - self._stats[plugin_id]["last_run"] = datetime.now() - - async def _save_stats(self, plugin_id: str): - """将内存中的统计持久化到数据库""" - stats = self._stats.get(plugin_id, {}) - payload = { - "success_count": stats.get("success_count", 0), - "failure_count": stats.get("failure_count", 0), - "last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None, - } - async with get_db() as db: - await self.plugin_settings_repo.set_stats(db, plugin_id, payload) diff --git a/app/services/proxy_service.py b/app/services/proxy_service.py index 8209805..8fb05dd 100644 --- a/app/services/proxy_service.py +++ b/app/services/proxy_service.py @@ -4,7 +4,8 @@ import json import io from datetime import datetime from typing import List, Optional, Tuple, AsyncIterator -from app.core.db import get_db + +from app.core.db import get_db, transaction from app.repositories.proxy_repo import ProxyRepository from app.models.domain import Proxy from app.core.log import logger diff --git a/app/services/scheduler_service.py b/app/services/scheduler_service.py index 3996319..5cc446b 100644 --- a/app/services/scheduler_service.py +++ b/app/services/scheduler_service.py @@ -1,43 +1,42 @@ -"""调度器服务 - 定时验证存量代理""" +"""调度器服务 - 定时触发全量验证""" import asyncio -from datetime import datetime -from app.core.db import get_db -from app.repositories.proxy_repo import ProxyRepository -from app.repositories.task_repo import ValidationTaskRepository -from app.core.tasks.queue import ValidationQueue -from app.core.config import settings as app_settings +from typing import Optional + +from app.core.execution.executor import JobExecutor +from app.core.execution.job import ValidateAllJob from app.core.log import logger -from app.models.domain import ProxyRaw -from app.services.task_service import task_service class SchedulerService: - """代理验证调度器""" + """代理验证调度器 + + 职责单一:定时循环,触发 ValidateAllJob。 + 不再直接持有验证队列或 ValidatorService。 + """ def __init__( self, - validation_queue: ValidationQueue, - proxy_repo: ProxyRepository = ProxyRepository(), + executor: JobExecutor, + interval_minutes: int = 30, ): - self.validation_queue = validation_queue - self.proxy_repo = proxy_repo - self.interval_minutes = 30 + self.executor = executor + self.interval_minutes = interval_minutes self.running = False self._stop_event = asyncio.Event() - self._task: asyncio.Task | None = None - self._validate_task: asyncio.Task | None = None + self._task: Optional[asyncio.Task] = None - async def start(self): + async def start(self) -> None: if self.running: logger.warning("Scheduler already running") return self._stop_event.clear() self.running = True - await self.validation_queue.start() self._task = asyncio.create_task(self._run_loop()) logger.info("Scheduler started") - async def stop(self): + async def stop(self) -> None: + if not self.running: + return self.running = False self._stop_event.set() if self._task: @@ -47,72 +46,26 @@ class SchedulerService: except asyncio.CancelledError: pass self._task = None - await self.validation_queue.stop() logger.info("Scheduler stopped") - def cancel_validate_task(self): - """取消正在执行的全量验证后台任务""" - if self._validate_task and not self._validate_task.done(): - self._validate_task.cancel() + def validate_all_now(self) -> str: + """立即执行一次全量验证,返回 Job ID""" + job_id = self.executor.submit_job(ValidateAllJob()) + logger.info(f"ValidateAllJob submitted: {job_id}") + return job_id - async def validate_all_now(self): - """立即执行一次全量验证(后台运行,不阻塞)""" - if self._validate_task and not self._validate_task.done(): - return - self._validate_task = asyncio.create_task(self._do_validate_all(from_loop=False)) - - async def _run_loop(self): + async def _run_loop(self) -> None: """定时循环""" while self.running: try: - # 清理过期任务,防止内存无限增长 - task_service.cleanup_old_tasks() - await self._do_validate_all(from_loop=True) + self.executor.submit_job(ValidateAllJob()) except Exception as e: logger.error(f"Scheduler loop error: {e}", exc_info=True) # 等待下一次 try: - await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60) + await asyncio.wait_for( + self._stop_event.wait(), + timeout=self.interval_minutes * 60, + ) except asyncio.TimeoutError: pass - - async def _do_validate_all(self, from_loop: bool = True): - """验证数据库中所有存量代理""" - queue_started_here = False - try: - logger.info("Starting scheduled validation for all proxies") - - # 如果队列未运行,临时启动它(适用于 validate_all_now 在调度器停止时调用) - if not self.validation_queue._running: - await self.validation_queue.start() - queue_started_here = True - - async with get_db() as db: - # 清理 7 天前的验证任务记录,防止表无限增长 - cleaned = await ValidationTaskRepository.cleanup_old(db, days=7) - if cleaned: - logger.info(f"Cleaned up {cleaned} old validation tasks") - proxies = await self.proxy_repo.list_all(db) - if not proxies: - logger.info("No proxies to validate") - return - - logger.info(f"Validating {len(proxies)} proxies from database") - # 批量提交到验证队列,不再阻塞等待 drain - batch_size = 100 - total_batches = (len(proxies) - 1) // batch_size + 1 - for i in range(0, len(proxies), batch_size): - if from_loop and not self.running: - break - batch = proxies[i : i + batch_size] - await self.validation_queue.submit([ - ProxyRaw(p.ip, p.port, p.protocol) for p in batch - ]) - logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}") - - logger.info("Scheduled validation batches submitted") - except Exception as e: - logger.error(f"Scheduled validation error: {e}", exc_info=True) - finally: - if queue_started_here: - await self.validation_queue.stop() diff --git a/app/services/task_service.py b/app/services/task_service.py deleted file mode 100644 index 7aabfc8..0000000 --- a/app/services/task_service.py +++ /dev/null @@ -1,62 +0,0 @@ -"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务""" -import uuid -from datetime import datetime, timedelta -from typing import Optional, Dict, Any -from app.core.log import logger - - -class TaskService: - """异步任务状态管理器""" - - def __init__(self, max_age_seconds: int = 3600): - self._tasks: Dict[str, dict] = {} - self._max_age_seconds = max_age_seconds - - def create_task(self, task_type: str) -> str: - task_id = str(uuid.uuid4()) - now = datetime.now() - self._tasks[task_id] = { - "id": task_id, - "type": task_type, - "status": "pending", - "message": "", - "data": {}, - "created_at": now, - "updated_at": now, - } - return task_id - - def update_task( - self, - task_id: str, - status: Optional[str] = None, - message: Optional[str] = None, - data: Optional[dict] = None, - ) -> bool: - task = self._tasks.get(task_id) - if not task: - return False - if status is not None: - task["status"] = status - if message is not None: - task["message"] = message - if data is not None: - task["data"].update(data) - task["updated_at"] = datetime.now() - return True - - def get_task(self, task_id: str) -> Optional[dict]: - return self._tasks.get(task_id) - - def cleanup_old_tasks(self) -> int: - cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds) - to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff] - for tid in to_remove: - del self._tasks[tid] - if to_remove: - logger.info(f"TaskService cleaned up {len(to_remove)} old tasks") - return len(to_remove) - - -# 全局任务服务实例 -task_service = TaskService() diff --git a/app/services/validator_service.py b/app/services/validator_service.py index 4b3cec2..b41f894 100644 --- a/app/services/validator_service.py +++ b/app/services/validator_service.py @@ -4,12 +4,17 @@ import random import time import aiohttp import aiohttp_socks -from typing import Tuple +from typing import Tuple, Optional + +from app.core.config import settings as app_settings from app.core.log import logger class ValidatorService: - """代理验证器""" + """代理验证器 + + 支持动态读取配置,实现设置热更新。 + """ # 测试 URL TEST_URLS = { @@ -19,40 +24,66 @@ class ValidatorService: def __init__( self, - timeout: float = 5.0, - connect_timeout: float = 3.0, - max_concurrency: int = 50, + timeout: Optional[float] = None, + connect_timeout: Optional[float] = None, + max_concurrency: Optional[int] = None, ): - self.timeout = timeout - self.connect_timeout = connect_timeout - self.max_concurrency = max_concurrency - self.semaphore = asyncio.Semaphore(max_concurrency) + # 初始化时使用传入值或默认值,但运行期会动态读取 settings + self._init_timeout = timeout or app_settings.validator_timeout + self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout + self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency - # 共享 HTTP/HTTPS ClientSession - self._http_connector = aiohttp.TCPConnector( - ssl=False, - limit=max_concurrency, - limit_per_host=max_concurrency, - force_close=False, - ) - self._timeout = aiohttp.ClientTimeout( - total=timeout, connect=connect_timeout - ) - self._http_session = aiohttp.ClientSession( - connector=self._http_connector, - timeout=self._timeout, - ) + self._http_connector: Optional[aiohttp.TCPConnector] = None + self._http_session: Optional[aiohttp.ClientSession] = None + self._semaphore: Optional[asyncio.Semaphore] = None + self._lock = asyncio.Lock() + + @property + def timeout(self) -> float: + return float(self._init_timeout) + + @property + def connect_timeout(self) -> float: + return float(self._init_connect_timeout) + + @property + def max_concurrency(self) -> int: + return int(self._init_max_concurrency) + + def _ensure_session(self) -> aiohttp.ClientSession: + """懒加载共享 HTTP session""" + if self._http_session is None or self._http_session.closed: + connector = aiohttp.TCPConnector( + ssl=False, + limit=self.max_concurrency, + limit_per_host=self.max_concurrency, + force_close=False, + ) + timeout = aiohttp.ClientTimeout( + total=self.timeout, connect=self.connect_timeout + ) + self._http_connector = connector + self._http_session = aiohttp.ClientSession( + connector=connector, + timeout=timeout, + ) + return self._http_session + + def _ensure_semaphore(self) -> asyncio.Semaphore: + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(self.max_concurrency) + return self._semaphore def _get_test_url(self, protocol: str) -> str: - """获取测试 URL""" urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"]) return random.choice(urls) async def validate(self, ip: str, port: int, protocol: str = "http") -> Tuple[bool, float]: """验证单个代理,返回 (是否有效, 延迟毫秒)""" protocol = protocol.lower() - - async with self.semaphore: + semaphore = self._ensure_semaphore() + + async with semaphore: start = time.time() try: if protocol in ("socks4", "socks5"): @@ -67,11 +98,11 @@ class ValidatorService: return False, 0.0 async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]: - """验证 HTTP/HTTPS 代理""" proxy_url = f"http://{ip}:{port}" test_url = self._get_test_url(protocol) + session = self._ensure_session() - async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response: + async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response: if response.status in (200, 301, 302): latency = round((time.time() - start) * 1000, 2) logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms") @@ -79,7 +110,6 @@ class ValidatorService: return False, 0.0 async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]: - """验证 SOCKS4/SOCKS5 代理""" proxy_type = ( aiohttp_socks.ProxyType.SOCKS4 if protocol == "socks4" @@ -103,7 +133,9 @@ class ValidatorService: return True, latency return False, 0.0 - async def close(self): + async def close(self) -> None: """关闭共享的 HTTP ClientSession""" if self._http_session and not self._http_session.closed: await self._http_session.close() + self._http_session = None + self._http_connector = None diff --git a/tests/conftest.py b/tests/conftest.py index 4e0e620..eebd464 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,17 @@ from httpx import AsyncClient, ASGITransport from app.api import create_app from app.core.db import init_db, get_db +from app.core.plugin_system.registry import registry +from app.plugins import ( + Fate0Plugin, + ProxyListDownloadPlugin, + Ip3366Plugin, + Ip89Plugin, + KuaiDaiLiPlugin, + SpeedXPlugin, + YunDaiLiPlugin, + ProxyScrapePlugin, +) from app.repositories.proxy_repo import ProxyRepository from app.models.domain import ProxyRaw @@ -14,23 +25,32 @@ from app.models.domain import ProxyRaw @pytest_asyncio.fixture(scope="function") async def app(): """创建应用实例""" - # 初始化测试数据库并清空历史数据,避免任务残留或设置状态导致 drain() 卡住 + # 初始化测试数据库并清空历史数据 await init_db() async with get_db() as db: - await db.execute("DELETE FROM validation_tasks") await db.execute("DELETE FROM proxies") await db.execute("DELETE FROM settings") await db.commit() - # 清理全局内存状态,防止跨测试污染 - from app.services.task_service import task_service - task_service._tasks.clear() + # 清理并重新注册插件,防止跨测试污染 + registry.clear() + for plugin_cls in [ + Fate0Plugin, + ProxyListDownloadPlugin, + Ip3366Plugin, + Ip89Plugin, + KuaiDaiLiPlugin, + SpeedXPlugin, + YunDaiLiPlugin, + ProxyScrapePlugin, + ]: + registry.register(plugin_cls) test_app = create_app() async with test_app.router.lifespan_context(test_app): yield test_app - # 给 aiosqlite / aiohttp 后台线程留出收尾时间,降低 Event loop closed 警告概率 + # 给 aiosqlite / aiohttp 后台线程留出收尾时间 await asyncio.sleep(0.1) @@ -71,18 +91,19 @@ async def mock_external_requests(monkeypatch): 1. 插件爬取返回固定测试代理,避免真实 HTTP 请求 2. 代理验证瞬间成功,避免连接超时等待 """ - from app.services.plugin_service import PluginService + from app.services.plugin_runner import PluginRunner from app.services.validator_service import ValidatorService - async def _mock_run_plugin(self, plugin_id: str): - return [ProxyRaw("192.168.100.10", 8080, "http")] - - async def _mock_run_all_plugins(self): - return [ProxyRaw("192.168.100.10", 8080, "http")] + async def _mock_run(self, plugin): + from app.models.domain import CrawlResult + return CrawlResult( + plugin_name=plugin.name, + proxies=[ProxyRaw("192.168.100.10", 8080, "http")], + success_count=1, + ) async def _mock_validate(self, ip: str, port: int, protocol: str = "http"): return True, 1.23 - monkeypatch.setattr(PluginService, "run_plugin", _mock_run_plugin) - monkeypatch.setattr(PluginService, "run_all_plugins", _mock_run_all_plugins) + monkeypatch.setattr(PluginRunner, "run", _mock_run) monkeypatch.setattr(ValidatorService, "validate", _mock_validate) diff --git a/tests/integration/test_plugins_api.py b/tests/integration/test_plugins_api.py index ed3d562..ee7855d 100644 --- a/tests/integration/test_plugins_api.py +++ b/tests/integration/test_plugins_api.py @@ -69,7 +69,8 @@ class TestPluginsAPI: plugin_id = plugins[0]["id"] response = await client.put(f"/api/plugins/{plugin_id}/toggle", json={}) - assert response.status_code == 400 + # Pydantic 验证缺失必填字段返回 422 + assert response.status_code == 422 @pytest.mark.asyncio async def test_toggle_nonexistent_plugin(self, client):