diff --git a/api_server.py b/api_server.py index 9883a6b..579fe27 100644 --- a/api_server.py +++ b/api_server.py @@ -393,8 +393,7 @@ async def start_crawler(request: CrawlerRequest, _permission: str = Depends(requ tasks_manager.set_callbacks(progress_callback, status_callback) db = SQLiteManager() - async with ProxyValidator(max_concurrency=200) as validator: - asyncio.create_task(tasks_manager.start_task(db, validator, request.num_validators)) + asyncio.create_task(tasks_manager.start_task(db, request.num_validators)) return {"code": 200, "message": "爬虫任务开始啦~", "data": None} except Exception as e: diff --git a/tasks_manager.py b/tasks_manager.py index 0929dcd..c4f7c34 100644 --- a/tasks_manager.py +++ b/tasks_manager.py @@ -111,7 +111,7 @@ class TasksManager: await self._notify_status('validating_done', f'验证完成啦,入库 {verified_count} 个代理~') logger.info(f"验证协程完成,入库 {verified_count} 个代理。") - async def start_task(self, db: SQLiteManager, validator: ProxyValidator, num_validators: int = 50): + async def start_task(self, db: SQLiteManager, num_validators: int = 50): if self.is_running: await self._notify_status('error', '任务正在运行中呢~') return False @@ -128,16 +128,17 @@ class TasksManager: await self._notify_status('running', '任务开始啦~') - crawler_task = asyncio.create_task(self.run_crawler()) - self.validator_tasks = [asyncio.create_task(self.run_validator(db, validator)) for _ in range(num_validators)] - - await crawler_task - - for _ in range(num_validators): - await self.proxy_queue.put(None) + async with ProxyValidator(max_concurrency=200) as validator: + crawler_task = asyncio.create_task(self.run_crawler()) + self.validator_tasks = [asyncio.create_task(self.run_validator(db, validator)) for _ in range(num_validators)] - await self.proxy_queue.join() - await asyncio.gather(*self.validator_tasks, return_exceptions=True) + await crawler_task + + for _ in range(num_validators): + await self.proxy_queue.put(None) + + await self.proxy_queue.join() + await asyncio.gather(*self.validator_tasks, return_exceptions=True) total = await db.count_proxies() self.is_running = False