diff --git a/WebUI/src/api/index.js b/WebUI/src/api/index.js index 5424323..d9b781b 100644 --- a/WebUI/src/api/index.js +++ b/WebUI/src/api/index.js @@ -97,4 +97,8 @@ export const settingsAPI = { saveSettings: (data) => api.post('/api/settings', data) } +export const tasksAPI = { + getTaskStatus: (taskId) => api.get(`/api/tasks/${taskId}`) +} + export default api diff --git a/WebUI/src/services/pluginService.js b/WebUI/src/services/pluginService.js index 9c2b128..7f3edaa 100644 --- a/WebUI/src/services/pluginService.js +++ b/WebUI/src/services/pluginService.js @@ -1,4 +1,26 @@ -import { pluginsAPI } from '../api' +import { pluginsAPI, tasksAPI } from '../api' + +const POLL_INTERVAL = 1000 +const MAX_POLL_ATTEMPTS = 30 + +async function pollTaskStatus(taskId) { + for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) { + await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL)) + const response = await tasksAPI.getTaskStatus(taskId) + if (response.code !== 200) { + continue + } + const status = response.data.status + if (status === 'completed' || status === 'failed') { + return response + } + } + return { + code: 200, + message: '爬取任务进行中,请稍后刷新查看结果', + data: { task_id: taskId, status: 'running' } + } +} export const pluginService = { async getPlugins() { @@ -18,10 +40,28 @@ export const pluginService = { }, async crawlPlugin(pluginId) { - return pluginsAPI.crawlPlugin(pluginId) + const startRes = await pluginsAPI.crawlPlugin(pluginId) + if (startRes.code !== 200 || !startRes.data?.task_id) { + return startRes + } + const finalRes = await pollTaskStatus(startRes.data.task_id) + return { + code: finalRes.code, + message: finalRes.data?.message || finalRes.message, + data: finalRes.data?.data || finalRes.data + } }, async crawlAll() { - return pluginsAPI.crawlAll() + const startRes = await pluginsAPI.crawlAll() + if (startRes.code !== 200 || !startRes.data?.task_id) { + return startRes + } + const finalRes = await pollTaskStatus(startRes.data.task_id) + return { + code: finalRes.code, + message: finalRes.data?.message || finalRes.message, + data: finalRes.data?.data || finalRes.data + } } } diff --git a/WebUI/src/stores/plugins.js b/WebUI/src/stores/plugins.js index 892c440..f5233ca 100644 --- a/WebUI/src/stores/plugins.js +++ b/WebUI/src/stores/plugins.js @@ -14,6 +14,7 @@ export const usePluginsStore = defineStore('plugins', () => { // ==================== Getters ==================== const enabledCount = computed(() => plugins.value.filter(p => p.enabled).length) const totalCount = computed(() => plugins.value.length) + const isEmpty = computed(() => !loading.value && plugins.value.length === 0) // ==================== Actions ==================== @@ -97,6 +98,7 @@ export const usePluginsStore = defineStore('plugins', () => { // Getters enabledCount, totalCount, + isEmpty, // Actions fetchPlugins, togglePlugin, diff --git a/WebUI/src/stores/proxy.js b/WebUI/src/stores/proxy.js index aba3eaa..970863b 100644 --- a/WebUI/src/stores/proxy.js +++ b/WebUI/src/stores/proxy.js @@ -132,8 +132,8 @@ export const useProxyStore = defineStore('proxy', () => { try { const response = await proxyService.export(format, protocol) - // 创建下载链接 - const url = window.URL.createObjectURL(new Blob([response])) + // response 已经是 Blob(axios 配置了 responseType: 'blob'),直接创建下载链接 + const url = window.URL.createObjectURL(response) const link = document.createElement('a') link.href = url link.setAttribute('download', `proxies.${format}`) diff --git a/WebUI/src/views/ProxyList.vue b/WebUI/src/views/ProxyList.vue index 5bf4c14..178b582 100644 --- a/WebUI/src/views/ProxyList.vue +++ b/WebUI/src/views/ProxyList.vue @@ -23,7 +23,7 @@ @@ -190,7 +190,7 @@ async function fetchProxies() { } abortController = new AbortController() - await proxyStore.fetchProxies({ + const success = await proxyStore.fetchProxies({ page: currentPage.value, page_size: pageSize.value, protocol: filterForm.protocol || null, @@ -200,6 +200,9 @@ async function fetchProxies() { }, abortController.signal) abortController = null + if (!success) { + ElMessage.error('获取代理列表失败') + } } // ==================== 事件处理 ==================== diff --git a/app/api/common.py b/app/api/common.py index fc47fee..796c30e 100644 --- a/app/api/common.py +++ b/app/api/common.py @@ -23,6 +23,7 @@ def format_proxy(proxy) -> dict: "port": proxy.port, "protocol": proxy.protocol, "score": proxy.score, + "response_time_ms": proxy.response_time_ms, "last_check": proxy.last_check.isoformat() if proxy.last_check else None, } diff --git a/app/api/deps.py b/app/api/deps.py index d3127e5..add4572 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -25,21 +25,29 @@ def get_validation_queue(request: Request) -> ValidationQueue: return request.app.state.validation_queue -def create_scheduler_service() -> SchedulerService: - """在应用启动时创建 SchedulerService(非请求上下文)""" +def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService: + """在应用启动时创建 SchedulerService(非请求上下文) + + db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。 + """ + db_settings = db_settings or {} validator = ValidatorService( - timeout=app_settings.validator_timeout, - connect_timeout=app_settings.validator_connect_timeout, - max_concurrency=app_settings.validator_max_concurrency, + timeout=db_settings.get("validation_timeout", app_settings.validator_timeout), + connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout), + max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), ) proxy_repo = ProxyRepository() queue = ValidationQueue( validator=validator, proxy_repo=proxy_repo, - worker_count=app_settings.validator_max_concurrency, + worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), score_valid=app_settings.score_valid, score_invalid=app_settings.score_invalid, score_min=app_settings.score_min, score_max=app_settings.score_max, ) - return SchedulerService(validation_queue=queue, proxy_repo=proxy_repo) + svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo) + svc.interval_minutes = db_settings.get( + "validate_interval_minutes", app_settings.validator_timeout + ) + return svc diff --git a/app/api/lifespan.py b/app/api/lifespan.py index 94226c1..8ade670 100644 --- a/app/api/lifespan.py +++ b/app/api/lifespan.py @@ -6,7 +6,7 @@ from app.core.db import init_db, get_db from app.core.config import settings as app_settings from app.core.log import logger from app.api.deps import create_scheduler_service -from app.repositories.settings_repo import SettingsRepository +from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS settings_repo = SettingsRepository() @@ -17,33 +17,48 @@ async def lifespan(app: FastAPI): # 初始化数据库 await init_db() - # 创建调度器并挂载到 app.state - scheduler_service = create_scheduler_service() + # 加载设置并决定是否启动调度器 + db_settings = DEFAULT_SETTINGS.copy() + try: + async with get_db() as db: + db_settings = await settings_repo.get_all(db) + except Exception as e: + logger.error(f"Failed to load settings on startup: {e}") + + # 创建调度器并挂载到 app.state(使用 DB 设置覆盖默认值) + scheduler_service = create_scheduler_service(db_settings) app.state.scheduler_service = scheduler_service app.state.validation_queue = scheduler_service.validation_queue - # 加载设置并决定是否启动调度器 - try: - async with get_db() as db: - settings = await settings_repo.get_all(db) - scheduler_service.interval_minutes = settings.get( - "validate_interval_minutes", app_settings.validator_timeout - ) - if settings.get("auto_validate", True): + if db_settings.get("auto_validate", True): + try: await scheduler_service.start() - except Exception as e: - logger.error(f"Failed to load settings on startup: {e}") + except Exception as e: + logger.error(f"Failed to start scheduler on startup: {e}") logger.info("API server started") yield # 关闭调度器 - if scheduler_service._validate_task and not scheduler_service._validate_task.done(): - scheduler_service._validate_task.cancel() - try: - await scheduler_service._validate_task - except asyncio.CancelledError: - pass + scheduler_service.cancel_validate_task() await scheduler_service.stop() - await scheduler_service.validation_queue.validator.close() + + # 关闭验证器 HTTP session + try: + await scheduler_service.validation_queue.validator.close() + except Exception: + pass + + # 关闭所有插件的 HTTP 客户端 + from app.core.plugin_system.registry import registry + for plugin in registry.list_plugins(): + if hasattr(plugin, "close"): + try: + await plugin.close() + except Exception: + pass + + # 给 aiosqlite / aiohttp 后台线程留出收尾时间 + await asyncio.sleep(0.1) + logger.info("API server shutdown") diff --git a/app/api/main.py b/app/api/main.py index 37f4dbf..25c8cf4 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -22,7 +22,7 @@ def create_app() -> FastAPI: # CORS app.add_middleware( CORSMiddleware, - allow_origins=app_settings.cors_origins_list, + allow_origins=app_settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/app/api/routes/__init__.py b/app/api/routes/__init__.py index f1c78a7..c4c7319 100644 --- a/app/api/routes/__init__.py +++ b/app/api/routes/__init__.py @@ -1,9 +1,10 @@ """路由包""" from fastapi import APIRouter -from app.api.routes import proxies, plugins, scheduler, settings +from app.api.routes import proxies, plugins, scheduler, settings, tasks api_router = APIRouter() api_router.include_router(proxies.router) api_router.include_router(plugins.router) api_router.include_router(scheduler.router) api_router.include_router(settings.router) +api_router.include_router(tasks.router) diff --git a/app/api/routes/plugins.py b/app/api/routes/plugins.py index 877ea05..876ab21 100644 --- a/app/api/routes/plugins.py +++ b/app/api/routes/plugins.py @@ -3,6 +3,7 @@ import asyncio from fastapi import APIRouter, Depends from app.services.plugin_service import PluginService from app.services.scheduler_service import SchedulerService +from app.services.task_service import task_service from app.api.deps import get_plugin_service, get_scheduler_service from app.api.common import success_response, error_response, format_plugin from app.core.log import logger @@ -12,12 +13,8 @@ router = APIRouter(prefix="/api/plugins", tags=["plugins"]) @router.get("") async def list_plugins(service: PluginService = Depends(get_plugin_service)): - try: - plugins = await service.list_plugins() - return success_response("获取插件列表成功", {"plugins": [format_plugin(p) for p in plugins]}) - except Exception as e: - logger.error(f"List plugins failed: {e}") - return error_response("获取插件列表失败", 500) + plugins = await service.list_plugins() + return success_response("获取插件列表成功", {"plugins": [format_plugin(p) for p in plugins]}) @router.put("/{plugin_id}/toggle") @@ -29,18 +26,14 @@ async def toggle_plugin( enabled = request.get("enabled") if enabled is None: return error_response("缺少 enabled 参数", 400) - - try: - success = await service.toggle_plugin(plugin_id, enabled) - if not success: - return error_response("插件不存在", 404) - return success_response( - f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}", - {"plugin_id": plugin_id, "enabled": enabled}, - ) - except Exception as e: - logger.error(f"Toggle plugin failed: {e}") - return error_response("切换插件状态失败", 500) + + success = await service.toggle_plugin(plugin_id, enabled) + if not success: + return error_response("插件不存在", 404) + return success_response( + f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}", + {"plugin_id": plugin_id, "enabled": enabled}, + ) @router.get("/{plugin_id}/config") @@ -48,14 +41,10 @@ async def get_plugin_config( plugin_id: str, service: PluginService = Depends(get_plugin_service), ): - try: - config = await service.get_plugin_config(plugin_id) - if config is None: - return error_response("插件不存在", 404) - return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config}) - except Exception as e: - logger.error(f"Get plugin config failed: {e}") - return error_response("获取插件配置失败", 500) + config = await service.get_plugin_config(plugin_id) + if config is None: + return error_response("插件不存在", 404) + return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config}) @router.post("/{plugin_id}/config") @@ -67,15 +56,11 @@ async def update_plugin_config( config = request.get("config", {}) if not isinstance(config, dict): return error_response("config 必须是对象", 400) - - try: - success = await service.update_plugin_config(plugin_id, config) - if not success: - return error_response("插件不存在或配置无效", 404) - return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config}) - except Exception as e: - logger.error(f"Update plugin config failed: {e}") - return error_response("保存插件配置失败", 500) + + success = await service.update_plugin_config(plugin_id, config) + if not success: + return error_response("插件不存在或配置无效", 404) + return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config}) @router.post("/{plugin_id}/crawl") @@ -88,35 +73,51 @@ async def crawl_plugin( if not plugin: return error_response("插件不存在", 404) - try: - results = await plugin_service.run_plugin(plugin_id) - if not results: - return success_response( - f"插件 {plugin_id} 爬取完成,未获取到代理", - {"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0}, + task_id = task_service.create_task("crawl_plugin") + + async def _background_crawl(): + try: + task_service.update_task(task_id, status="running") + results = await plugin_service.run_plugin(plugin_id) + if not results: + task_service.update_task( + task_id, + status="completed", + message=f"插件 {plugin_id} 爬取完成,未获取到代理", + data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0}, + ) + return + + logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies") + scheduler_service.validation_queue.reset_stats() + await scheduler_service.validation_queue.submit(results) + + try: + await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0) + except asyncio.TimeoutError: + pass + + task_service.update_task( + task_id, + status="completed", + message=f"插件 {plugin_id} 爬取并验证完成", + data={ + "plugin_id": plugin_id, + "proxy_count": len(results), + "valid_count": scheduler_service.validation_queue.valid_count, + "invalid_count": scheduler_service.validation_queue.invalid_count, + }, + ) + except Exception as e: + logger.error(f"Crawl plugin {plugin_id} failed: {e}") + task_service.update_task( + task_id, status="failed", message=f"插件爬取失败: {str(e)}" ) - logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies") - scheduler_service.validation_queue.reset_stats() - await scheduler_service.validation_queue.submit(results) - - try: - await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=30.0) - except asyncio.TimeoutError: - pass - - return success_response( - f"插件 {plugin_id} 爬取并验证完成", - { - "plugin_id": plugin_id, - "proxy_count": len(results), - "valid_count": scheduler_service.validation_queue.valid_count, - "invalid_count": scheduler_service.validation_queue.invalid_count, - }, - ) - except Exception as e: - logger.error(f"Crawl plugin {plugin_id} failed: {e}") - return error_response(f"插件爬取失败: {str(e)}", 500) + asyncio.create_task(_background_crawl()) + return success_response( + "爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id} + ) @router.post("/crawl-all") @@ -124,31 +125,45 @@ async def crawl_all( plugin_service: PluginService = Depends(get_plugin_service), scheduler_service: SchedulerService = Depends(get_scheduler_service), ): - try: - results = await plugin_service.run_all_plugins() - if not results: - return success_response( - "所有插件爬取完成,未获取到代理", - {"total_crawled": 0, "valid_count": 0, "invalid_count": 0}, + task_id = task_service.create_task("crawl_all") + + async def _background_crawl_all(): + try: + task_service.update_task(task_id, status="running") + results = await plugin_service.run_all_plugins() + if not results: + task_service.update_task( + task_id, + status="completed", + message="所有插件爬取完成,未获取到代理", + data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0}, + ) + return + + logger.info(f"All plugins crawled {len(results)} unique proxies") + scheduler_service.validation_queue.reset_stats() + await scheduler_service.validation_queue.submit(results) + + try: + await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0) + except asyncio.TimeoutError: + pass + + task_service.update_task( + task_id, + status="completed", + message="所有插件爬取并验证完成", + data={ + "total_crawled": len(results), + "valid_count": scheduler_service.validation_queue.valid_count, + "invalid_count": scheduler_service.validation_queue.invalid_count, + }, + ) + except Exception as e: + logger.error(f"Crawl all failed: {e}") + task_service.update_task( + task_id, status="failed", message=f"批量爬取失败: {str(e)}" ) - logger.info(f"All plugins crawled {len(results)} unique proxies") - scheduler_service.validation_queue.reset_stats() - await scheduler_service.validation_queue.submit(results) - - try: - await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=60.0) - except asyncio.TimeoutError: - pass - - return success_response( - "所有插件爬取并验证完成", - { - "total_crawled": len(results), - "valid_count": scheduler_service.validation_queue.valid_count, - "invalid_count": scheduler_service.validation_queue.invalid_count, - }, - ) - except Exception as e: - logger.error(f"Crawl all failed: {e}") - return error_response(f"批量爬取失败: {str(e)}", 500) + asyncio.create_task(_background_crawl_all()) + return success_response("爬取任务已启动", {"task_id": task_id}) diff --git a/app/api/routes/proxies.py b/app/api/routes/proxies.py index a82c62a..52fbcb6 100644 --- a/app/api/routes/proxies.py +++ b/app/api/routes/proxies.py @@ -6,7 +6,6 @@ from app.services.scheduler_service import SchedulerService from app.models.schemas import ProxyListRequest, BatchDeleteRequest from app.api.deps import get_proxy_service, get_scheduler_service from app.api.common import success_response, error_response, format_proxy -from app.core.log import logger router = APIRouter(prefix="/api/proxies", tags=["proxies"]) @@ -16,13 +15,9 @@ async def get_stats( proxy_service: ProxyService = Depends(get_proxy_service), scheduler_service: SchedulerService = Depends(get_scheduler_service), ): - try: - stats = await proxy_service.get_stats() - stats["scheduler_running"] = scheduler_service.running - return success_response("获取统计信息成功", stats) - except Exception as e: - logger.error(f"Get stats failed: {e}") - return error_response("获取统计信息失败", 500) + stats = await proxy_service.get_stats() + stats["scheduler_running"] = scheduler_service.running + return success_response("获取统计信息成功", stats) @router.post("") @@ -30,40 +25,32 @@ async def list_proxies( request: ProxyListRequest, service: ProxyService = Depends(get_proxy_service), ): - try: - proxies, total = await service.list_proxies( - page=request.page, - page_size=request.page_size, - protocol=request.protocol, - min_score=request.min_score, - max_score=request.max_score, - sort_by=request.sort_by, - sort_order=request.sort_order, - ) - return success_response( - "获取代理列表成功", - { - "list": [format_proxy(p) for p in proxies], - "total": total, - "page": request.page, - "page_size": request.page_size, - }, - ) - except Exception as e: - logger.error(f"List proxies failed: {e}") - return error_response("获取代理列表失败", 500) + proxies, total = await service.list_proxies( + page=request.page, + page_size=request.page_size, + protocol=request.protocol, + min_score=request.min_score, + max_score=request.max_score, + sort_by=request.sort_by, + sort_order=request.sort_order, + ) + return success_response( + "获取代理列表成功", + { + "list": [format_proxy(p) for p in proxies], + "total": total, + "page": request.page, + "page_size": request.page_size, + }, + ) @router.get("/random") async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)): - try: - proxy = await service.get_random_proxy() - if not proxy: - return error_response("没有找到可用的代理", 404) - return success_response("获取随机代理成功", format_proxy(proxy)) - except Exception as e: - logger.error(f"Get random proxy failed: {e}") - return error_response("获取随机代理失败", 500) + proxy = await service.get_random_proxy() + if not proxy: + return error_response("没有找到可用的代理", 404) + return success_response("获取随机代理成功", format_proxy(proxy)) @router.get("/export/{fmt}") @@ -93,12 +80,8 @@ async def export_proxies( @router.delete("/{ip}/{port}") async def delete_proxy(ip: str, port: int, service: ProxyService = Depends(get_proxy_service)): - try: - await service.delete_proxy(ip, port) - return success_response("删除代理成功") - except Exception as e: - logger.error(f"Delete proxy failed: {e}") - return error_response("删除代理失败", 500) + await service.delete_proxy(ip, port) + return success_response("删除代理成功") @router.post("/batch-delete") @@ -106,20 +89,12 @@ async def batch_delete( request: BatchDeleteRequest, service: ProxyService = Depends(get_proxy_service), ): - try: - proxies = [(item.ip, item.port) for item in request.proxies] - deleted = await service.batch_delete(proxies) - return success_response(f"批量删除 {deleted} 个代理成功", {"deleted_count": deleted}) - except Exception as e: - logger.error(f"Batch delete failed: {e}") - return error_response("批量删除失败", 500) + proxies = [(item.ip, item.port) for item in request.proxies] + deleted = await service.batch_delete(proxies) + return success_response(f"批量删除 {deleted} 个代理成功", {"deleted_count": deleted}) @router.delete("/clean-invalid") async def clean_invalid(service: ProxyService = Depends(get_proxy_service)): - try: - count = await service.clean_invalid() - return success_response(f"清理了 {count} 个无效代理", {"deleted_count": count}) - except Exception as e: - logger.error(f"Clean invalid failed: {e}") - return error_response("清理无效代理失败", 500) + count = await service.clean_invalid() + return success_response(f"清理了 {count} 个无效代理", {"deleted_count": count}) diff --git a/app/api/routes/scheduler.py b/app/api/routes/scheduler.py index 29a2f54..30cdc15 100644 --- a/app/api/routes/scheduler.py +++ b/app/api/routes/scheduler.py @@ -5,7 +5,6 @@ from app.repositories.settings_repo import SettingsRepository from app.core.db import get_db from app.api.deps import get_scheduler_service from app.api.common import success_response, error_response -from app.core.log import logger router = APIRouter(prefix="/api/scheduler", tags=["scheduler"]) settings_repo = SettingsRepository() @@ -22,38 +21,26 @@ async def _save_auto_validate_setting(enabled: bool): @router.post("/start") async def start_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)): - try: - if scheduler.running: - return success_response("验证调度器已在运行", {"running": True}) - await scheduler.start() - await _save_auto_validate_setting(True) - return success_response("验证调度器已启动", {"running": True}) - except Exception as e: - logger.error(f"Start scheduler failed: {e}") - return error_response(f"启动调度器失败: {str(e)}", 500) + if scheduler.running: + return success_response("验证调度器已在运行", {"running": True}) + await scheduler.start() + await _save_auto_validate_setting(True) + return success_response("验证调度器已启动", {"running": True}) @router.post("/stop") async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)): - try: - if not scheduler.running: - return success_response("验证调度器未运行", {"running": False}) - await scheduler.stop() - await _save_auto_validate_setting(False) - return success_response("验证调度器已停止", {"running": False}) - except Exception as e: - logger.error(f"Stop scheduler failed: {e}") - return error_response(f"停止调度器失败: {str(e)}", 500) + if not scheduler.running: + return success_response("验证调度器未运行", {"running": False}) + await scheduler.stop() + await _save_auto_validate_setting(False) + return success_response("验证调度器已停止", {"running": False}) @router.post("/validate-now") async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)): - try: - await scheduler.validate_all_now() - return success_response("已开始全量验证", {"started": True}) - except Exception as e: - logger.error(f"Validate now failed: {e}") - return error_response(f"启动验证失败: {str(e)}", 500) + await scheduler.validate_all_now() + return success_response("已开始全量验证", {"started": True}) @router.get("/status") diff --git a/app/api/routes/settings.py b/app/api/routes/settings.py index a84eadb..14f8bde 100644 --- a/app/api/routes/settings.py +++ b/app/api/routes/settings.py @@ -4,7 +4,6 @@ from app.core.db import get_db from app.repositories.settings_repo import SettingsRepository from app.models.schemas import SettingsSchema from app.api.common import success_response, error_response -from app.core.log import logger router = APIRouter(prefix="/api/settings", tags=["settings"]) settings_repo = SettingsRepository() @@ -12,23 +11,15 @@ settings_repo = SettingsRepository() @router.get("") async def get_settings(): - try: - async with get_db() as db: - settings = await settings_repo.get_all(db) - return success_response("获取设置成功", settings) - except Exception as e: - logger.error(f"Get settings failed: {e}") - return error_response("获取设置失败", 500) + async with get_db() as db: + settings = await settings_repo.get_all(db) + return success_response("获取设置成功", settings) @router.post("") async def save_settings(request: SettingsSchema): - try: - async with get_db() as db: - success = await settings_repo.save(db, request.model_dump()) - if not success: - return error_response("保存设置失败", 500) - return success_response("保存设置成功", request.model_dump()) - except Exception as e: - logger.error(f"Save settings failed: {e}") - return error_response(f"保存设置失败: {str(e)}", 500) + async with get_db() as db: + success = await settings_repo.save(db, request.model_dump()) + if not success: + return error_response("保存设置失败", 500) + return success_response("保存设置成功", request.model_dump()) diff --git a/app/api/routes/tasks.py b/app/api/routes/tasks.py new file mode 100644 index 0000000..452623d --- /dev/null +++ b/app/api/routes/tasks.py @@ -0,0 +1,14 @@ +"""异步任务状态路由""" +from fastapi import APIRouter +from app.services.task_service import task_service +from app.api.common import success_response, error_response + +router = APIRouter(prefix="/api/tasks", tags=["tasks"]) + + +@router.get("/{task_id}") +async def get_task_status(task_id: str): + task = task_service.get_task(task_id) + if not task: + return error_response("任务不存在", 404) + return success_response("获取任务状态成功", task) diff --git a/app/core/config.py b/app/core/config.py index e3445a2..4baf881 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -43,12 +43,13 @@ class Settings(BaseSettings): # 插件配置 plugins_dir: str = "plugins" - # CORS 配置 - cors_origins: str = "http://localhost:8080,http://localhost:5173,http://127.0.0.1:18081,http://localhost:18081" - - @property - def cors_origins_list(self) -> List[str]: - return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()] + # CORS 配置 - Pydantic v2 会自动将逗号分隔的字符串解析为 List[str] + cors_origins: List[str] = [ + "http://localhost:8080", + "http://localhost:5173", + "http://127.0.0.1:18081", + "http://localhost:18081", + ] @property def base_dir(self) -> str: diff --git a/app/core/log.py b/app/core/log.py index 3e056e7..38bf6fa 100644 --- a/app/core/log.py +++ b/app/core/log.py @@ -1,47 +1,37 @@ import logging import os -from logging.handlers import RotatingFileHandler -from datetime import datetime +from logging.handlers import TimedRotatingFileHandler -class LogHandler(logging.Logger): - def __init__(self, name='ProxyPool', level=logging.INFO): - super().__init__(name, level) - - # 获取项目根目录并创建 logs 目录 - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - log_dir = os.path.join(base_dir, 'logs') - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - # 使用日期作为文件名 - log_filename = f"{datetime.now().strftime('%Y-%m-%d')}.log" - log_file = os.path.join(log_dir, log_filename) - - # 设置格式 - formatter = logging.Formatter( - '[%(asctime)s] %(name)s [%(levelname)s] %(filename)s[line:%(lineno)d]: %(message)s' - ) - - # 文件处理器(使用RotatingFileHandler支持日志轮转) - # 每个日志文件最大10MB,保留5个备份 - file_handler = RotatingFileHandler( - log_file, - maxBytes=10*1024*1024, - backupCount=5, - encoding='utf-8' - ) - file_handler.setFormatter(formatter) - self.addHandler(file_handler) - - # 控制台处理器 - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - self.addHandler(console_handler) +# 获取项目根目录并创建 logs 目录 +base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +log_dir = os.path.join(base_dir, 'logs') +os.makedirs(log_dir, exist_ok=True) +log_file = os.path.join(log_dir, 'app.log') -# 实例化一个默认 logger 供外部直接使用 -logger = LogHandler() +formatter = logging.Formatter( + '[%(asctime)s] %(name)s [%(levelname)s] %(filename)s[line:%(lineno)d]: %(message)s' +) -if __name__ == '__main__': - logger.info('这是一条按日期存储的日志测试') +# 按天切分的文件处理器,保留最近 7 天 +file_handler = TimedRotatingFileHandler( + log_file, + when='midnight', + interval=1, + backupCount=7, + encoding='utf-8' +) +file_handler.setFormatter(formatter) + +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) + +# 获取标准 logger +logger = logging.getLogger('ProxyPool') +logger.setLevel(logging.INFO) + +# 防止重复添加 handler(如模块重导入) +if not logger.handlers: + logger.addHandler(file_handler) + logger.addHandler(console_handler) diff --git a/app/core/tasks/queue.py b/app/core/tasks/queue.py index 067403b..a94f7be 100644 --- a/app/core/tasks/queue.py +++ b/app/core/tasks/queue.py @@ -39,7 +39,6 @@ class ValidationQueue: self._signal: asyncio.Queue[None] = asyncio.Queue() self._workers: list[asyncio.Task] = [] self._running = False - self._db_lock = asyncio.Lock() self._pending_count = 0 self._condition = asyncio.Condition() @@ -72,9 +71,9 @@ class ValidationQueue: for i in range(self.worker_count): self._workers.append(asyncio.create_task(self._worker_loop(i))) - # 唤醒 Worker 处理恢复的 pending 任务 + # 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理) if pending: - for _ in range(min(pending, self.worker_count)): + for _ in range(self.worker_count): self._signal.put_nowait(None) logger.info(f"ValidationQueue started with {self.worker_count} workers") @@ -92,9 +91,8 @@ class ValidationQueue: async def submit(self, proxies: list[ProxyRaw]): """提交代理到验证队列(持久化 + 唤醒 Worker)""" - async with self._db_lock: - async with get_db() as db: - inserted = await self.task_repo.insert_batch(db, proxies) + async with get_db() as db: + inserted = await self.task_repo.insert_batch(db, proxies) if inserted: async with self._condition: self._pending_count += inserted @@ -117,46 +115,53 @@ class ValidationQueue: self._signal.task_done() if not self._running: break - await self._process_one_task(worker_id) + # 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死 + while self._running: + processed = await self._process_one_task(worker_id) + if not processed: + break - async def _process_one_task(self, worker_id: int): - """从数据库取一个任务并验证""" - async with self._db_lock: - async with get_db() as db: - task = await self.task_repo.acquire_pending(db) - if not task: - return + async def _process_one_task(self, worker_id: int) -> bool: + """从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。""" + async with get_db() as db: + task = await self.task_repo.acquire_pending(db) + if not task: + return False - proxy = ProxyRaw(task["ip"], task["port"], task["protocol"]) - try: - is_valid, latency = await self.validator.validate( - proxy.ip, proxy.port, proxy.protocol - ) - except Exception as e: - logger.error(f"Worker {worker_id} validation error: {e}") - is_valid, latency = False, 0.0 + proxy = ProxyRaw(task["ip"], task["port"], task["protocol"]) + try: + is_valid, latency = await self.validator.validate( + proxy.ip, proxy.port, proxy.protocol + ) + except Exception as e: + logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True) + is_valid, latency = False, 0.0 - async with self._db_lock: - async with get_db() as db: - if is_valid: - await self.proxy_repo.insert_or_update( - db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid + if is_valid: + await self.proxy_repo.insert_or_update( + db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid + ) + if latency: + await self.proxy_repo.update_response_time( + db, proxy.ip, proxy.port, latency ) - if latency: - await self.proxy_repo.update_response_time( - db, proxy.ip, proxy.port, latency - ) - await self.task_repo.complete_task(db, task["id"], True, latency) - self.valid_count += 1 - logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}") - else: - await self.task_repo.complete_task(db, task["id"], False, 0.0) - self.invalid_count += 1 - logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}") + await self.task_repo.complete_task(db, task["id"], True, latency) + self.valid_count += 1 + logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}") + else: + # 对已有代理扣分,分数<=0时自动删除 + await self.proxy_repo.update_score( + db, proxy.ip, proxy.port, self.score_invalid, + self.score_min, self.score_max + ) + await self.task_repo.complete_task(db, task["id"], False, 0.0) + self.invalid_count += 1 + logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}") async with self._condition: self._pending_count = max(0, self._pending_count - 1) if self._pending_count == 0: self._condition.notify_all() + return True def reset_stats(self): self.valid_count = 0 diff --git a/app/models/domain.py b/app/models/domain.py index b390b7e..1cc82cd 100644 --- a/app/models/domain.py +++ b/app/models/domain.py @@ -15,6 +15,8 @@ class ProxyRaw: self.protocol = self.protocol.lower().strip() if self.protocol not in ("http", "https", "socks4", "socks5"): self.protocol = "http" + if not isinstance(self.port, int) or not (1 <= self.port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {self.port}") @dataclass diff --git a/app/plugins/base.py b/app/plugins/base.py index 9e3ca41..72e8858 100644 --- a/app/plugins/base.py +++ b/app/plugins/base.py @@ -2,7 +2,7 @@ import random import asyncio import httpx -from typing import List +from typing import List, Optional from app.core.plugin_system import BaseCrawlerPlugin @@ -19,6 +19,7 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): ] self.urls: List[str] = [] self.current_url: str = "" + self._client: Optional[httpx.AsyncClient] = None def get_headers(self) -> dict: return { @@ -28,30 +29,38 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): "Connection": "keep-alive", } + def _get_client(self) -> httpx.AsyncClient: + """获取或创建复用的 AsyncClient""" + if self._client is None or self._client.is_closed: + transport = httpx.AsyncHTTPTransport(retries=0) + self._client = httpx.AsyncClient( + transport=transport, + follow_redirects=True, + ) + return self._client + async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str: """异步抓取指定 URL 的 HTML 内容""" from app.core.log import logger - headers = self.get_headers() - transport = httpx.AsyncHTTPTransport(retries=0) + client = self._get_client() for attempt in range(retries): - async with httpx.AsyncClient(headers=headers, transport=transport, follow_redirects=True) as client: - try: - response = await client.get(url, timeout=timeout) - if response.status_code == 200: - content = response.content - encoding = response.encoding - if encoding == "utf-8" or not encoding: - try: - return content.decode("utf-8") - except UnicodeDecodeError: - return content.decode("gbk", errors="ignore") - return content.decode(encoding, errors="ignore") - else: - logger.warning(f"Fetch {url} returned status {response.status_code}") - except Exception as e: - logger.warning(f"Fetch {url} failed (attempt {attempt + 1}/{retries}): {e}") - if attempt < retries - 1: - await asyncio.sleep(random.uniform(1, 3)) + try: + response = await client.get(url, headers=self.get_headers(), timeout=timeout) + if response.status_code == 200: + content = response.content + encoding = response.encoding + if encoding == "utf-8" or not encoding: + try: + return content.decode("utf-8") + except UnicodeDecodeError: + return content.decode("gbk", errors="ignore") + return content.decode(encoding, errors="ignore") + else: + logger.warning(f"Fetch {url} returned status {response.status_code}") + except Exception as e: + logger.warning(f"Fetch {url} failed (attempt {attempt + 1}/{retries}): {e}") + if attempt < retries - 1: + await asyncio.sleep(random.uniform(1, 3)) return "" async def fetch_all(self, urls: List[str], timeout: float = 15.0) -> List[str]: @@ -64,3 +73,9 @@ class BaseHTTPPlugin(BaseCrawlerPlugin): tasks = [_fetch_limited(url) for url in urls] return await asyncio.gather(*tasks) + + async def close(self): + """关闭复用的 HTTP 客户端""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None diff --git a/app/repositories/proxy_repo.py b/app/repositories/proxy_repo.py index f6adffa..1a9bd6b 100644 --- a/app/repositories/proxy_repo.py +++ b/app/repositories/proxy_repo.py @@ -63,7 +63,7 @@ class ProxyRepository: await db.commit() return True except Exception as e: - logger.error(f"insert_or_update proxy failed: {e}") + logger.error(f"insert_or_update proxy failed: {e}", exc_info=True) return False @staticmethod @@ -89,11 +89,12 @@ class ProxyRepository: (new_score, ip, port), ) if new_score <= 0: - await db.execute("DELETE FROM proxies WHERE score <= 0") + # 只删除当前代理,避免误删其他无效代理 + await db.execute("DELETE FROM proxies WHERE ip = ? AND port = ?", (ip, port)) await db.commit() return True except Exception as e: - logger.error(f"update_score failed: {e}") + logger.error(f"update_score failed: {e}", exc_info=True) return False @staticmethod @@ -111,7 +112,7 @@ class ProxyRepository: await db.commit() return True except Exception as e: - logger.error(f"update_response_time failed: {e}") + logger.error(f"update_response_time failed: {e}", exc_info=True) return False @staticmethod @@ -255,12 +256,12 @@ class ProxyRepository: async def get_today_new_count(db: aiosqlite.Connection) -> int: try: async with db.execute( - "SELECT COUNT(*) FROM proxies WHERE DATE(last_check) = DATE('now', 'localtime')" + "SELECT COUNT(*) FROM proxies WHERE DATE(created_at) = DATE('now', 'localtime')" ) as cursor: row = await cursor.fetchone() return row[0] if row else 0 except Exception as e: - logger.error(f"get_today_new_count failed: {e}") + logger.error(f"get_today_new_count failed: {e}", exc_info=True) return 0 @staticmethod @@ -279,5 +280,5 @@ class ProxyRepository: await db.commit() return db.total_changes except Exception as e: - logger.error(f"clean_expired failed: {e}") + logger.error(f"clean_expired failed: {e}", exc_info=True) return 0 diff --git a/app/services/scheduler_service.py b/app/services/scheduler_service.py index 4d27e0b..1a70516 100644 --- a/app/services/scheduler_service.py +++ b/app/services/scheduler_service.py @@ -3,9 +3,11 @@ import asyncio from datetime import datetime from app.core.db import get_db from app.repositories.proxy_repo import ProxyRepository +from app.repositories.task_repo import ValidationTaskRepository from app.core.tasks.queue import ValidationQueue from app.core.config import settings as app_settings from app.core.log import logger +from app.models.domain import ProxyRaw class SchedulerService: @@ -47,6 +49,11 @@ class SchedulerService: await self.validation_queue.stop() logger.info("Scheduler stopped") + def cancel_validate_task(self): + """取消正在执行的全量验证后台任务""" + if self._validate_task and not self._validate_task.done(): + self._validate_task.cancel() + async def validate_all_now(self): """立即执行一次全量验证(后台运行,不阻塞)""" if self._validate_task and not self._validate_task.done(): @@ -57,9 +64,12 @@ class SchedulerService: """定时循环""" while self.running: try: + # 清理过期任务,防止内存无限增长 + from app.services.task_service import task_service + task_service.cleanup_old_tasks() await self._do_validate_all() except Exception as e: - logger.error(f"Scheduler loop error: {e}") + logger.error(f"Scheduler loop error: {e}", exc_info=True) # 等待下一次 try: await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60) @@ -71,16 +81,19 @@ class SchedulerService: try: logger.info("Starting scheduled validation for all proxies") async with get_db() as db: + # 清理 7 天前的验证任务记录,防止表无限增长 + cleaned = await ValidationTaskRepository.cleanup_old(db, days=7) + if cleaned: + logger.info(f"Cleaned up {cleaned} old validation tasks") proxies = await self.proxy_repo.list_all(db) if not proxies: logger.info("No proxies to validate") return logger.info(f"Validating {len(proxies)} proxies from database") - from app.models.domain import ProxyRaw - - # 批量提交到验证队列 + # 批量提交到验证队列,不再阻塞等待 drain batch_size = 100 + total_batches = (len(proxies) - 1) // batch_size + 1 for i in range(0, len(proxies), batch_size): if not self.running: break @@ -88,10 +101,8 @@ class SchedulerService: await self.validation_queue.submit([ ProxyRaw(p.ip, p.port, p.protocol) for p in batch ]) - # 等待当前批次处理完 - await self.validation_queue.drain() - logger.info(f"Validated batch {i//batch_size + 1}/{(len(proxies)-1)//batch_size + 1}") + logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}") - logger.info("Scheduled validation completed") + logger.info("Scheduled validation batches submitted") except Exception as e: - logger.error(f"Scheduled validation error: {e}") + logger.error(f"Scheduled validation error: {e}", exc_info=True) diff --git a/app/services/task_service.py b/app/services/task_service.py new file mode 100644 index 0000000..7aabfc8 --- /dev/null +++ b/app/services/task_service.py @@ -0,0 +1,62 @@ +"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务""" +import uuid +from datetime import datetime, timedelta +from typing import Optional, Dict, Any +from app.core.log import logger + + +class TaskService: + """异步任务状态管理器""" + + def __init__(self, max_age_seconds: int = 3600): + self._tasks: Dict[str, dict] = {} + self._max_age_seconds = max_age_seconds + + def create_task(self, task_type: str) -> str: + task_id = str(uuid.uuid4()) + now = datetime.now() + self._tasks[task_id] = { + "id": task_id, + "type": task_type, + "status": "pending", + "message": "", + "data": {}, + "created_at": now, + "updated_at": now, + } + return task_id + + def update_task( + self, + task_id: str, + status: Optional[str] = None, + message: Optional[str] = None, + data: Optional[dict] = None, + ) -> bool: + task = self._tasks.get(task_id) + if not task: + return False + if status is not None: + task["status"] = status + if message is not None: + task["message"] = message + if data is not None: + task["data"].update(data) + task["updated_at"] = datetime.now() + return True + + def get_task(self, task_id: str) -> Optional[dict]: + return self._tasks.get(task_id) + + def cleanup_old_tasks(self) -> int: + cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds) + to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff] + for tid in to_remove: + del self._tasks[tid] + if to_remove: + logger.info(f"TaskService cleaned up {len(to_remove)} old tasks") + return len(to_remove) + + +# 全局任务服务实例 +task_service = TaskService() diff --git a/app/services/validator_service.py b/app/services/validator_service.py index 55537ee..3383a9e 100644 --- a/app/services/validator_service.py +++ b/app/services/validator_service.py @@ -104,8 +104,10 @@ class ValidatorService: return True, latency return False, 0.0 finally: - await connector.close() + # ClientSession 的 async with 退出时会自动关闭 connector,无需手动重复关闭 + pass async def close(self): """关闭共享的 HTTP ClientSession""" - await self._http_session.close() + if self._http_session and not self._http_session.closed: + await self._http_session.close() diff --git a/tests/conftest.py b/tests/conftest.py index 5f30404..4e0e620 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """pytest 配置文件和 fixtures""" +import asyncio import pytest import pytest_asyncio from typing import AsyncGenerator @@ -7,16 +8,30 @@ from httpx import AsyncClient, ASGITransport from app.api import create_app from app.core.db import init_db, get_db from app.repositories.proxy_repo import ProxyRepository +from app.models.domain import ProxyRaw @pytest_asyncio.fixture(scope="function") async def app(): """创建应用实例""" - # 初始化测试数据库 + # 初始化测试数据库并清空历史数据,避免任务残留或设置状态导致 drain() 卡住 await init_db() - app = create_app() - async with app.router.lifespan_context(app): - yield app + async with get_db() as db: + await db.execute("DELETE FROM validation_tasks") + await db.execute("DELETE FROM proxies") + await db.execute("DELETE FROM settings") + await db.commit() + + # 清理全局内存状态,防止跨测试污染 + from app.services.task_service import task_service + task_service._tasks.clear() + + test_app = create_app() + async with test_app.router.lifespan_context(test_app): + yield test_app + + # 给 aiosqlite / aiohttp 后台线程留出收尾时间,降低 Event loop closed 警告概率 + await asyncio.sleep(0.1) @pytest_asyncio.fixture @@ -47,3 +62,27 @@ async def sample_proxy(db, proxy_repo): yield {"ip": "192.168.1.1", "port": 8080, "protocol": "http", "score": 50} # 清理 await proxy_repo.delete(db, "192.168.1.1", 8080) + + +@pytest_asyncio.fixture(autouse=True) +async def mock_external_requests(monkeypatch): + """ + 自动在所有测试中 mock 外部网络请求: + 1. 插件爬取返回固定测试代理,避免真实 HTTP 请求 + 2. 代理验证瞬间成功,避免连接超时等待 + """ + from app.services.plugin_service import PluginService + from app.services.validator_service import ValidatorService + + async def _mock_run_plugin(self, plugin_id: str): + return [ProxyRaw("192.168.100.10", 8080, "http")] + + async def _mock_run_all_plugins(self): + return [ProxyRaw("192.168.100.10", 8080, "http")] + + async def _mock_validate(self, ip: str, port: int, protocol: str = "http"): + return True, 1.23 + + monkeypatch.setattr(PluginService, "run_plugin", _mock_run_plugin) + monkeypatch.setattr(PluginService, "run_all_plugins", _mock_run_all_plugins) + monkeypatch.setattr(ValidatorService, "validate", _mock_validate) diff --git a/tests/integration/test_plugins_api.py b/tests/integration/test_plugins_api.py index fa0dde7..ed3d562 100644 --- a/tests/integration/test_plugins_api.py +++ b/tests/integration/test_plugins_api.py @@ -117,19 +117,33 @@ class TestPluginsAPI: @pytest.mark.asyncio async def test_crawl_plugin(self, client): - """测试 POST /api/plugins/{id}/crawl""" + """测试 POST /api/plugins/{id}/crawl - 异步任务模式""" + import asyncio response = await client.get("/api/plugins") plugins = response.json()["data"]["plugins"] if not plugins: pytest.skip("没有可用的插件") plugin_id = plugins[0]["id"] - # 这个测试可能需要较长时间,设置较短的超时 response = await client.post(f"/api/plugins/{plugin_id}/crawl") assert response.status_code == 200 data = response.json() assert data["code"] == 200 - assert "proxy_count" in data["data"] + assert "task_id" in data["data"] + + task_id = data["data"]["task_id"] + # 轮询任务状态 + task_data = None + for _ in range(10): + await asyncio.sleep(0.3) + res = await client.get(f"/api/tasks/{task_id}") + assert res.status_code == 200 + task_data = res.json()["data"] + if task_data["status"] in ("completed", "failed"): + break + + assert task_data is not None + assert task_data["status"] == "completed" @pytest.mark.asyncio async def test_crawl_nonexistent_plugin(self, client): @@ -139,9 +153,24 @@ class TestPluginsAPI: @pytest.mark.asyncio async def test_crawl_all_plugins(self, client): - """测试 POST /api/plugins/crawl-all""" + """测试 POST /api/plugins/crawl-all - 异步任务模式""" + import asyncio response = await client.post("/api/plugins/crawl-all") assert response.status_code == 200 data = response.json() assert data["code"] == 200 - assert "total_crawled" in data["data"] + assert "task_id" in data["data"] + + task_id = data["data"]["task_id"] + # 轮询任务状态 + task_data = None + for _ in range(10): + await asyncio.sleep(0.3) + res = await client.get(f"/api/tasks/{task_id}") + assert res.status_code == 200 + task_data = res.json()["data"] + if task_data["status"] in ("completed", "failed"): + break + + assert task_data is not None + assert task_data["status"] == "completed"