fix: 修复设置系统脱节、队列计数漂移、资源泄露等全量问题
- 统一设置系统:create_scheduler_service 读取 DB 设置覆盖默认值 - 修复 ProxyRepository.update_score 误删所有无效代理的 SQL - ValidationQueue:修复 Worker 计数漂移与启动恢复任务饿死 - SchedulerService:移除 drain() 阻塞,主循环可正常响应 stop - TaskService:在调度器周期内自动清理过期任务,防止内存泄漏 - lifespan/conftest:规范关闭顺序,消除 Event loop closed 警告 - Repository:异常日志增加 exc_info,今日新增按 created_at 统计 - ValidatorService:防止 HTTP session 重复关闭,移除 SOCKS 多余 close - 前端:补全 pluginsStore.isEmpty,ProxyList 最低分数上限改为 100 - 删除 config.py 中冗余的 cors_origins_list property
This commit is contained in:
@@ -97,4 +97,8 @@ export const settingsAPI = {
|
|||||||
saveSettings: (data) => api.post('/api/settings', data)
|
saveSettings: (data) => api.post('/api/settings', data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const tasksAPI = {
|
||||||
|
getTaskStatus: (taskId) => api.get(`/api/tasks/${taskId}`)
|
||||||
|
}
|
||||||
|
|
||||||
export default api
|
export default api
|
||||||
|
|||||||
@@ -1,4 +1,26 @@
|
|||||||
import { pluginsAPI } from '../api'
|
import { pluginsAPI, tasksAPI } from '../api'
|
||||||
|
|
||||||
|
const POLL_INTERVAL = 1000
|
||||||
|
const MAX_POLL_ATTEMPTS = 30
|
||||||
|
|
||||||
|
async function pollTaskStatus(taskId) {
|
||||||
|
for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) {
|
||||||
|
await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL))
|
||||||
|
const response = await tasksAPI.getTaskStatus(taskId)
|
||||||
|
if (response.code !== 200) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const status = response.data.status
|
||||||
|
if (status === 'completed' || status === 'failed') {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
code: 200,
|
||||||
|
message: '爬取任务进行中,请稍后刷新查看结果',
|
||||||
|
data: { task_id: taskId, status: 'running' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const pluginService = {
|
export const pluginService = {
|
||||||
async getPlugins() {
|
async getPlugins() {
|
||||||
@@ -18,10 +40,28 @@ export const pluginService = {
|
|||||||
},
|
},
|
||||||
|
|
||||||
async crawlPlugin(pluginId) {
|
async crawlPlugin(pluginId) {
|
||||||
return pluginsAPI.crawlPlugin(pluginId)
|
const startRes = await pluginsAPI.crawlPlugin(pluginId)
|
||||||
|
if (startRes.code !== 200 || !startRes.data?.task_id) {
|
||||||
|
return startRes
|
||||||
|
}
|
||||||
|
const finalRes = await pollTaskStatus(startRes.data.task_id)
|
||||||
|
return {
|
||||||
|
code: finalRes.code,
|
||||||
|
message: finalRes.data?.message || finalRes.message,
|
||||||
|
data: finalRes.data?.data || finalRes.data
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
async crawlAll() {
|
async crawlAll() {
|
||||||
return pluginsAPI.crawlAll()
|
const startRes = await pluginsAPI.crawlAll()
|
||||||
|
if (startRes.code !== 200 || !startRes.data?.task_id) {
|
||||||
|
return startRes
|
||||||
|
}
|
||||||
|
const finalRes = await pollTaskStatus(startRes.data.task_id)
|
||||||
|
return {
|
||||||
|
code: finalRes.code,
|
||||||
|
message: finalRes.data?.message || finalRes.message,
|
||||||
|
data: finalRes.data?.data || finalRes.data
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ export const usePluginsStore = defineStore('plugins', () => {
|
|||||||
// ==================== Getters ====================
|
// ==================== Getters ====================
|
||||||
const enabledCount = computed(() => plugins.value.filter(p => p.enabled).length)
|
const enabledCount = computed(() => plugins.value.filter(p => p.enabled).length)
|
||||||
const totalCount = computed(() => plugins.value.length)
|
const totalCount = computed(() => plugins.value.length)
|
||||||
|
const isEmpty = computed(() => !loading.value && plugins.value.length === 0)
|
||||||
|
|
||||||
// ==================== Actions ====================
|
// ==================== Actions ====================
|
||||||
|
|
||||||
@@ -97,6 +98,7 @@ export const usePluginsStore = defineStore('plugins', () => {
|
|||||||
// Getters
|
// Getters
|
||||||
enabledCount,
|
enabledCount,
|
||||||
totalCount,
|
totalCount,
|
||||||
|
isEmpty,
|
||||||
// Actions
|
// Actions
|
||||||
fetchPlugins,
|
fetchPlugins,
|
||||||
togglePlugin,
|
togglePlugin,
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ export const useProxyStore = defineStore('proxy', () => {
|
|||||||
try {
|
try {
|
||||||
const response = await proxyService.export(format, protocol)
|
const response = await proxyService.export(format, protocol)
|
||||||
|
|
||||||
// 创建下载链接
|
// response 已经是 Blob(axios 配置了 responseType: 'blob'),直接创建下载链接
|
||||||
const url = window.URL.createObjectURL(new Blob([response]))
|
const url = window.URL.createObjectURL(response)
|
||||||
const link = document.createElement('a')
|
const link = document.createElement('a')
|
||||||
link.href = url
|
link.href = url
|
||||||
link.setAttribute('download', `proxies.${format}`)
|
link.setAttribute('download', `proxies.${format}`)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
<el-input-number
|
<el-input-number
|
||||||
v-model="filterForm.minScore"
|
v-model="filterForm.minScore"
|
||||||
:min="0"
|
:min="0"
|
||||||
:max="10"
|
:max="100"
|
||||||
style="width: 120px"
|
style="width: 120px"
|
||||||
@change="handleSearch"
|
@change="handleSearch"
|
||||||
/>
|
/>
|
||||||
@@ -190,7 +190,7 @@ async function fetchProxies() {
|
|||||||
}
|
}
|
||||||
abortController = new AbortController()
|
abortController = new AbortController()
|
||||||
|
|
||||||
await proxyStore.fetchProxies({
|
const success = await proxyStore.fetchProxies({
|
||||||
page: currentPage.value,
|
page: currentPage.value,
|
||||||
page_size: pageSize.value,
|
page_size: pageSize.value,
|
||||||
protocol: filterForm.protocol || null,
|
protocol: filterForm.protocol || null,
|
||||||
@@ -200,6 +200,9 @@ async function fetchProxies() {
|
|||||||
}, abortController.signal)
|
}, abortController.signal)
|
||||||
|
|
||||||
abortController = null
|
abortController = null
|
||||||
|
if (!success) {
|
||||||
|
ElMessage.error('获取代理列表失败')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== 事件处理 ====================
|
// ==================== 事件处理 ====================
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def format_proxy(proxy) -> dict:
|
|||||||
"port": proxy.port,
|
"port": proxy.port,
|
||||||
"protocol": proxy.protocol,
|
"protocol": proxy.protocol,
|
||||||
"score": proxy.score,
|
"score": proxy.score,
|
||||||
|
"response_time_ms": proxy.response_time_ms,
|
||||||
"last_check": proxy.last_check.isoformat() if proxy.last_check else None,
|
"last_check": proxy.last_check.isoformat() if proxy.last_check else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,21 +25,29 @@ def get_validation_queue(request: Request) -> ValidationQueue:
|
|||||||
return request.app.state.validation_queue
|
return request.app.state.validation_queue
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler_service() -> SchedulerService:
|
def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService:
|
||||||
"""在应用启动时创建 SchedulerService(非请求上下文)"""
|
"""在应用启动时创建 SchedulerService(非请求上下文)
|
||||||
|
|
||||||
|
db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。
|
||||||
|
"""
|
||||||
|
db_settings = db_settings or {}
|
||||||
validator = ValidatorService(
|
validator = ValidatorService(
|
||||||
timeout=app_settings.validator_timeout,
|
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
|
||||||
connect_timeout=app_settings.validator_connect_timeout,
|
connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout),
|
||||||
max_concurrency=app_settings.validator_max_concurrency,
|
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||||
)
|
)
|
||||||
proxy_repo = ProxyRepository()
|
proxy_repo = ProxyRepository()
|
||||||
queue = ValidationQueue(
|
queue = ValidationQueue(
|
||||||
validator=validator,
|
validator=validator,
|
||||||
proxy_repo=proxy_repo,
|
proxy_repo=proxy_repo,
|
||||||
worker_count=app_settings.validator_max_concurrency,
|
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||||
score_valid=app_settings.score_valid,
|
score_valid=app_settings.score_valid,
|
||||||
score_invalid=app_settings.score_invalid,
|
score_invalid=app_settings.score_invalid,
|
||||||
score_min=app_settings.score_min,
|
score_min=app_settings.score_min,
|
||||||
score_max=app_settings.score_max,
|
score_max=app_settings.score_max,
|
||||||
)
|
)
|
||||||
return SchedulerService(validation_queue=queue, proxy_repo=proxy_repo)
|
svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo)
|
||||||
|
svc.interval_minutes = db_settings.get(
|
||||||
|
"validate_interval_minutes", app_settings.validator_timeout
|
||||||
|
)
|
||||||
|
return svc
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from app.core.db import init_db, get_db
|
|||||||
from app.core.config import settings as app_settings
|
from app.core.config import settings as app_settings
|
||||||
from app.core.log import logger
|
from app.core.log import logger
|
||||||
from app.api.deps import create_scheduler_service
|
from app.api.deps import create_scheduler_service
|
||||||
from app.repositories.settings_repo import SettingsRepository
|
from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS
|
||||||
|
|
||||||
settings_repo = SettingsRepository()
|
settings_repo = SettingsRepository()
|
||||||
|
|
||||||
@@ -17,33 +17,48 @@ async def lifespan(app: FastAPI):
|
|||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
await init_db()
|
await init_db()
|
||||||
|
|
||||||
# 创建调度器并挂载到 app.state
|
# 加载设置并决定是否启动调度器
|
||||||
scheduler_service = create_scheduler_service()
|
db_settings = DEFAULT_SETTINGS.copy()
|
||||||
|
try:
|
||||||
|
async with get_db() as db:
|
||||||
|
db_settings = await settings_repo.get_all(db)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load settings on startup: {e}")
|
||||||
|
|
||||||
|
# 创建调度器并挂载到 app.state(使用 DB 设置覆盖默认值)
|
||||||
|
scheduler_service = create_scheduler_service(db_settings)
|
||||||
app.state.scheduler_service = scheduler_service
|
app.state.scheduler_service = scheduler_service
|
||||||
app.state.validation_queue = scheduler_service.validation_queue
|
app.state.validation_queue = scheduler_service.validation_queue
|
||||||
|
|
||||||
# 加载设置并决定是否启动调度器
|
if db_settings.get("auto_validate", True):
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
|
||||||
settings = await settings_repo.get_all(db)
|
|
||||||
scheduler_service.interval_minutes = settings.get(
|
|
||||||
"validate_interval_minutes", app_settings.validator_timeout
|
|
||||||
)
|
|
||||||
if settings.get("auto_validate", True):
|
|
||||||
await scheduler_service.start()
|
await scheduler_service.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load settings on startup: {e}")
|
logger.error(f"Failed to start scheduler on startup: {e}")
|
||||||
|
|
||||||
logger.info("API server started")
|
logger.info("API server started")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 关闭调度器
|
# 关闭调度器
|
||||||
if scheduler_service._validate_task and not scheduler_service._validate_task.done():
|
scheduler_service.cancel_validate_task()
|
||||||
scheduler_service._validate_task.cancel()
|
|
||||||
try:
|
|
||||||
await scheduler_service._validate_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
await scheduler_service.stop()
|
await scheduler_service.stop()
|
||||||
|
|
||||||
|
# 关闭验证器 HTTP session
|
||||||
|
try:
|
||||||
await scheduler_service.validation_queue.validator.close()
|
await scheduler_service.validation_queue.validator.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 关闭所有插件的 HTTP 客户端
|
||||||
|
from app.core.plugin_system.registry import registry
|
||||||
|
for plugin in registry.list_plugins():
|
||||||
|
if hasattr(plugin, "close"):
|
||||||
|
try:
|
||||||
|
await plugin.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 给 aiosqlite / aiohttp 后台线程留出收尾时间
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
logger.info("API server shutdown")
|
logger.info("API server shutdown")
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def create_app() -> FastAPI:
|
|||||||
# CORS
|
# CORS
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=app_settings.cors_origins_list,
|
allow_origins=app_settings.cors_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""路由包"""
|
"""路由包"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.api.routes import proxies, plugins, scheduler, settings
|
from app.api.routes import proxies, plugins, scheduler, settings, tasks
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(proxies.router)
|
api_router.include_router(proxies.router)
|
||||||
api_router.include_router(plugins.router)
|
api_router.include_router(plugins.router)
|
||||||
api_router.include_router(scheduler.router)
|
api_router.include_router(scheduler.router)
|
||||||
api_router.include_router(settings.router)
|
api_router.include_router(settings.router)
|
||||||
|
api_router.include_router(tasks.router)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import asyncio
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from app.services.plugin_service import PluginService
|
from app.services.plugin_service import PluginService
|
||||||
from app.services.scheduler_service import SchedulerService
|
from app.services.scheduler_service import SchedulerService
|
||||||
|
from app.services.task_service import task_service
|
||||||
from app.api.deps import get_plugin_service, get_scheduler_service
|
from app.api.deps import get_plugin_service, get_scheduler_service
|
||||||
from app.api.common import success_response, error_response, format_plugin
|
from app.api.common import success_response, error_response, format_plugin
|
||||||
from app.core.log import logger
|
from app.core.log import logger
|
||||||
@@ -12,12 +13,8 @@ router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
|||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
||||||
try:
|
|
||||||
plugins = await service.list_plugins()
|
plugins = await service.list_plugins()
|
||||||
return success_response("获取插件列表成功", {"plugins": [format_plugin(p) for p in plugins]})
|
return success_response("获取插件列表成功", {"plugins": [format_plugin(p) for p in plugins]})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"List plugins failed: {e}")
|
|
||||||
return error_response("获取插件列表失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{plugin_id}/toggle")
|
@router.put("/{plugin_id}/toggle")
|
||||||
@@ -30,7 +27,6 @@ async def toggle_plugin(
|
|||||||
if enabled is None:
|
if enabled is None:
|
||||||
return error_response("缺少 enabled 参数", 400)
|
return error_response("缺少 enabled 参数", 400)
|
||||||
|
|
||||||
try:
|
|
||||||
success = await service.toggle_plugin(plugin_id, enabled)
|
success = await service.toggle_plugin(plugin_id, enabled)
|
||||||
if not success:
|
if not success:
|
||||||
return error_response("插件不存在", 404)
|
return error_response("插件不存在", 404)
|
||||||
@@ -38,9 +34,6 @@ async def toggle_plugin(
|
|||||||
f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}",
|
f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}",
|
||||||
{"plugin_id": plugin_id, "enabled": enabled},
|
{"plugin_id": plugin_id, "enabled": enabled},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Toggle plugin failed: {e}")
|
|
||||||
return error_response("切换插件状态失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}/config")
|
@router.get("/{plugin_id}/config")
|
||||||
@@ -48,14 +41,10 @@ async def get_plugin_config(
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
service: PluginService = Depends(get_plugin_service),
|
service: PluginService = Depends(get_plugin_service),
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
config = await service.get_plugin_config(plugin_id)
|
config = await service.get_plugin_config(plugin_id)
|
||||||
if config is None:
|
if config is None:
|
||||||
return error_response("插件不存在", 404)
|
return error_response("插件不存在", 404)
|
||||||
return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config})
|
return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Get plugin config failed: {e}")
|
|
||||||
return error_response("获取插件配置失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{plugin_id}/config")
|
@router.post("/{plugin_id}/config")
|
||||||
@@ -68,14 +57,10 @@ async def update_plugin_config(
|
|||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
return error_response("config 必须是对象", 400)
|
return error_response("config 必须是对象", 400)
|
||||||
|
|
||||||
try:
|
|
||||||
success = await service.update_plugin_config(plugin_id, config)
|
success = await service.update_plugin_config(plugin_id, config)
|
||||||
if not success:
|
if not success:
|
||||||
return error_response("插件不存在或配置无效", 404)
|
return error_response("插件不存在或配置无效", 404)
|
||||||
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config})
|
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Update plugin config failed: {e}")
|
|
||||||
return error_response("保存插件配置失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{plugin_id}/crawl")
|
@router.post("/{plugin_id}/crawl")
|
||||||
@@ -88,26 +73,35 @@ async def crawl_plugin(
|
|||||||
if not plugin:
|
if not plugin:
|
||||||
return error_response("插件不存在", 404)
|
return error_response("插件不存在", 404)
|
||||||
|
|
||||||
|
task_id = task_service.create_task("crawl_plugin")
|
||||||
|
|
||||||
|
async def _background_crawl():
|
||||||
try:
|
try:
|
||||||
|
task_service.update_task(task_id, status="running")
|
||||||
results = await plugin_service.run_plugin(plugin_id)
|
results = await plugin_service.run_plugin(plugin_id)
|
||||||
if not results:
|
if not results:
|
||||||
return success_response(
|
task_service.update_task(
|
||||||
f"插件 {plugin_id} 爬取完成,未获取到代理",
|
task_id,
|
||||||
{"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0},
|
status="completed",
|
||||||
|
message=f"插件 {plugin_id} 爬取完成,未获取到代理",
|
||||||
|
data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0},
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||||
scheduler_service.validation_queue.reset_stats()
|
scheduler_service.validation_queue.reset_stats()
|
||||||
await scheduler_service.validation_queue.submit(results)
|
await scheduler_service.validation_queue.submit(results)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=30.0)
|
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return success_response(
|
task_service.update_task(
|
||||||
f"插件 {plugin_id} 爬取并验证完成",
|
task_id,
|
||||||
{
|
status="completed",
|
||||||
|
message=f"插件 {plugin_id} 爬取并验证完成",
|
||||||
|
data={
|
||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"proxy_count": len(results),
|
"proxy_count": len(results),
|
||||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||||
@@ -116,7 +110,14 @@ async def crawl_plugin(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Crawl plugin {plugin_id} failed: {e}")
|
logger.error(f"Crawl plugin {plugin_id} failed: {e}")
|
||||||
return error_response(f"插件爬取失败: {str(e)}", 500)
|
task_service.update_task(
|
||||||
|
task_id, status="failed", message=f"插件爬取失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_background_crawl())
|
||||||
|
return success_response(
|
||||||
|
"爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/crawl-all")
|
@router.post("/crawl-all")
|
||||||
@@ -124,26 +125,35 @@ async def crawl_all(
|
|||||||
plugin_service: PluginService = Depends(get_plugin_service),
|
plugin_service: PluginService = Depends(get_plugin_service),
|
||||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
):
|
):
|
||||||
|
task_id = task_service.create_task("crawl_all")
|
||||||
|
|
||||||
|
async def _background_crawl_all():
|
||||||
try:
|
try:
|
||||||
|
task_service.update_task(task_id, status="running")
|
||||||
results = await plugin_service.run_all_plugins()
|
results = await plugin_service.run_all_plugins()
|
||||||
if not results:
|
if not results:
|
||||||
return success_response(
|
task_service.update_task(
|
||||||
"所有插件爬取完成,未获取到代理",
|
task_id,
|
||||||
{"total_crawled": 0, "valid_count": 0, "invalid_count": 0},
|
status="completed",
|
||||||
|
message="所有插件爬取完成,未获取到代理",
|
||||||
|
data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0},
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"All plugins crawled {len(results)} unique proxies")
|
logger.info(f"All plugins crawled {len(results)} unique proxies")
|
||||||
scheduler_service.validation_queue.reset_stats()
|
scheduler_service.validation_queue.reset_stats()
|
||||||
await scheduler_service.validation_queue.submit(results)
|
await scheduler_service.validation_queue.submit(results)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=60.0)
|
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return success_response(
|
task_service.update_task(
|
||||||
"所有插件爬取并验证完成",
|
task_id,
|
||||||
{
|
status="completed",
|
||||||
|
message="所有插件爬取并验证完成",
|
||||||
|
data={
|
||||||
"total_crawled": len(results),
|
"total_crawled": len(results),
|
||||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||||
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
||||||
@@ -151,4 +161,9 @@ async def crawl_all(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Crawl all failed: {e}")
|
logger.error(f"Crawl all failed: {e}")
|
||||||
return error_response(f"批量爬取失败: {str(e)}", 500)
|
task_service.update_task(
|
||||||
|
task_id, status="failed", message=f"批量爬取失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_background_crawl_all())
|
||||||
|
return success_response("爬取任务已启动", {"task_id": task_id})
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from app.services.scheduler_service import SchedulerService
|
|||||||
from app.models.schemas import ProxyListRequest, BatchDeleteRequest
|
from app.models.schemas import ProxyListRequest, BatchDeleteRequest
|
||||||
from app.api.deps import get_proxy_service, get_scheduler_service
|
from app.api.deps import get_proxy_service, get_scheduler_service
|
||||||
from app.api.common import success_response, error_response, format_proxy
|
from app.api.common import success_response, error_response, format_proxy
|
||||||
from app.core.log import logger
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/proxies", tags=["proxies"])
|
router = APIRouter(prefix="/api/proxies", tags=["proxies"])
|
||||||
|
|
||||||
@@ -16,13 +15,9 @@ async def get_stats(
|
|||||||
proxy_service: ProxyService = Depends(get_proxy_service),
|
proxy_service: ProxyService = Depends(get_proxy_service),
|
||||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
stats = await proxy_service.get_stats()
|
stats = await proxy_service.get_stats()
|
||||||
stats["scheduler_running"] = scheduler_service.running
|
stats["scheduler_running"] = scheduler_service.running
|
||||||
return success_response("获取统计信息成功", stats)
|
return success_response("获取统计信息成功", stats)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Get stats failed: {e}")
|
|
||||||
return error_response("获取统计信息失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
@@ -30,7 +25,6 @@ async def list_proxies(
|
|||||||
request: ProxyListRequest,
|
request: ProxyListRequest,
|
||||||
service: ProxyService = Depends(get_proxy_service),
|
service: ProxyService = Depends(get_proxy_service),
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
proxies, total = await service.list_proxies(
|
proxies, total = await service.list_proxies(
|
||||||
page=request.page,
|
page=request.page,
|
||||||
page_size=request.page_size,
|
page_size=request.page_size,
|
||||||
@@ -49,21 +43,14 @@ async def list_proxies(
|
|||||||
"page_size": request.page_size,
|
"page_size": request.page_size,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"List proxies failed: {e}")
|
|
||||||
return error_response("获取代理列表失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/random")
|
@router.get("/random")
|
||||||
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
|
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
|
||||||
try:
|
|
||||||
proxy = await service.get_random_proxy()
|
proxy = await service.get_random_proxy()
|
||||||
if not proxy:
|
if not proxy:
|
||||||
return error_response("没有找到可用的代理", 404)
|
return error_response("没有找到可用的代理", 404)
|
||||||
return success_response("获取随机代理成功", format_proxy(proxy))
|
return success_response("获取随机代理成功", format_proxy(proxy))
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Get random proxy failed: {e}")
|
|
||||||
return error_response("获取随机代理失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export/{fmt}")
|
@router.get("/export/{fmt}")
|
||||||
@@ -93,12 +80,8 @@ async def export_proxies(
|
|||||||
|
|
||||||
@router.delete("/{ip}/{port}")
|
@router.delete("/{ip}/{port}")
|
||||||
async def delete_proxy(ip: str, port: int, service: ProxyService = Depends(get_proxy_service)):
|
async def delete_proxy(ip: str, port: int, service: ProxyService = Depends(get_proxy_service)):
|
||||||
try:
|
|
||||||
await service.delete_proxy(ip, port)
|
await service.delete_proxy(ip, port)
|
||||||
return success_response("删除代理成功")
|
return success_response("删除代理成功")
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Delete proxy failed: {e}")
|
|
||||||
return error_response("删除代理失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/batch-delete")
|
@router.post("/batch-delete")
|
||||||
@@ -106,20 +89,12 @@ async def batch_delete(
|
|||||||
request: BatchDeleteRequest,
|
request: BatchDeleteRequest,
|
||||||
service: ProxyService = Depends(get_proxy_service),
|
service: ProxyService = Depends(get_proxy_service),
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
proxies = [(item.ip, item.port) for item in request.proxies]
|
proxies = [(item.ip, item.port) for item in request.proxies]
|
||||||
deleted = await service.batch_delete(proxies)
|
deleted = await service.batch_delete(proxies)
|
||||||
return success_response(f"批量删除 {deleted} 个代理成功", {"deleted_count": deleted})
|
return success_response(f"批量删除 {deleted} 个代理成功", {"deleted_count": deleted})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Batch delete failed: {e}")
|
|
||||||
return error_response("批量删除失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/clean-invalid")
|
@router.delete("/clean-invalid")
|
||||||
async def clean_invalid(service: ProxyService = Depends(get_proxy_service)):
|
async def clean_invalid(service: ProxyService = Depends(get_proxy_service)):
|
||||||
try:
|
|
||||||
count = await service.clean_invalid()
|
count = await service.clean_invalid()
|
||||||
return success_response(f"清理了 {count} 个无效代理", {"deleted_count": count})
|
return success_response(f"清理了 {count} 个无效代理", {"deleted_count": count})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Clean invalid failed: {e}")
|
|
||||||
return error_response("清理无效代理失败", 500)
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from app.repositories.settings_repo import SettingsRepository
|
|||||||
from app.core.db import get_db
|
from app.core.db import get_db
|
||||||
from app.api.deps import get_scheduler_service
|
from app.api.deps import get_scheduler_service
|
||||||
from app.api.common import success_response, error_response
|
from app.api.common import success_response, error_response
|
||||||
from app.core.log import logger
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
|
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
|
||||||
settings_repo = SettingsRepository()
|
settings_repo = SettingsRepository()
|
||||||
@@ -22,38 +21,26 @@ async def _save_auto_validate_setting(enabled: bool):
|
|||||||
|
|
||||||
@router.post("/start")
|
@router.post("/start")
|
||||||
async def start_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
async def start_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||||
try:
|
|
||||||
if scheduler.running:
|
if scheduler.running:
|
||||||
return success_response("验证调度器已在运行", {"running": True})
|
return success_response("验证调度器已在运行", {"running": True})
|
||||||
await scheduler.start()
|
await scheduler.start()
|
||||||
await _save_auto_validate_setting(True)
|
await _save_auto_validate_setting(True)
|
||||||
return success_response("验证调度器已启动", {"running": True})
|
return success_response("验证调度器已启动", {"running": True})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Start scheduler failed: {e}")
|
|
||||||
return error_response(f"启动调度器失败: {str(e)}", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stop")
|
@router.post("/stop")
|
||||||
async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||||
try:
|
|
||||||
if not scheduler.running:
|
if not scheduler.running:
|
||||||
return success_response("验证调度器未运行", {"running": False})
|
return success_response("验证调度器未运行", {"running": False})
|
||||||
await scheduler.stop()
|
await scheduler.stop()
|
||||||
await _save_auto_validate_setting(False)
|
await _save_auto_validate_setting(False)
|
||||||
return success_response("验证调度器已停止", {"running": False})
|
return success_response("验证调度器已停止", {"running": False})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Stop scheduler failed: {e}")
|
|
||||||
return error_response(f"停止调度器失败: {str(e)}", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/validate-now")
|
@router.post("/validate-now")
|
||||||
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||||
try:
|
|
||||||
await scheduler.validate_all_now()
|
await scheduler.validate_all_now()
|
||||||
return success_response("已开始全量验证", {"started": True})
|
return success_response("已开始全量验证", {"started": True})
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Validate now failed: {e}")
|
|
||||||
return error_response(f"启动验证失败: {str(e)}", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status")
|
@router.get("/status")
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from app.core.db import get_db
|
|||||||
from app.repositories.settings_repo import SettingsRepository
|
from app.repositories.settings_repo import SettingsRepository
|
||||||
from app.models.schemas import SettingsSchema
|
from app.models.schemas import SettingsSchema
|
||||||
from app.api.common import success_response, error_response
|
from app.api.common import success_response, error_response
|
||||||
from app.core.log import logger
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||||
settings_repo = SettingsRepository()
|
settings_repo = SettingsRepository()
|
||||||
@@ -12,23 +11,15 @@ settings_repo = SettingsRepository()
|
|||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def get_settings():
|
async def get_settings():
|
||||||
try:
|
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
settings = await settings_repo.get_all(db)
|
settings = await settings_repo.get_all(db)
|
||||||
return success_response("获取设置成功", settings)
|
return success_response("获取设置成功", settings)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Get settings failed: {e}")
|
|
||||||
return error_response("获取设置失败", 500)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def save_settings(request: SettingsSchema):
|
async def save_settings(request: SettingsSchema):
|
||||||
try:
|
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
success = await settings_repo.save(db, request.model_dump())
|
success = await settings_repo.save(db, request.model_dump())
|
||||||
if not success:
|
if not success:
|
||||||
return error_response("保存设置失败", 500)
|
return error_response("保存设置失败", 500)
|
||||||
return success_response("保存设置成功", request.model_dump())
|
return success_response("保存设置成功", request.model_dump())
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Save settings failed: {e}")
|
|
||||||
return error_response(f"保存设置失败: {str(e)}", 500)
|
|
||||||
|
|||||||
14
app/api/routes/tasks.py
Normal file
14
app/api/routes/tasks.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""异步任务状态路由"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from app.services.task_service import task_service
|
||||||
|
from app.api.common import success_response, error_response
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}")
|
||||||
|
async def get_task_status(task_id: str):
|
||||||
|
task = task_service.get_task(task_id)
|
||||||
|
if not task:
|
||||||
|
return error_response("任务不存在", 404)
|
||||||
|
return success_response("获取任务状态成功", task)
|
||||||
@@ -43,12 +43,13 @@ class Settings(BaseSettings):
|
|||||||
# 插件配置
|
# 插件配置
|
||||||
plugins_dir: str = "plugins"
|
plugins_dir: str = "plugins"
|
||||||
|
|
||||||
# CORS 配置
|
# CORS 配置 - Pydantic v2 会自动将逗号分隔的字符串解析为 List[str]
|
||||||
cors_origins: str = "http://localhost:8080,http://localhost:5173,http://127.0.0.1:18081,http://localhost:18081"
|
cors_origins: List[str] = [
|
||||||
|
"http://localhost:8080",
|
||||||
@property
|
"http://localhost:5173",
|
||||||
def cors_origins_list(self) -> List[str]:
|
"http://127.0.0.1:18081",
|
||||||
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
"http://localhost:18081",
|
||||||
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_dir(self) -> str:
|
def base_dir(self) -> str:
|
||||||
|
|||||||
@@ -1,47 +1,37 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class LogHandler(logging.Logger):
|
|
||||||
def __init__(self, name='ProxyPool', level=logging.INFO):
|
|
||||||
super().__init__(name, level)
|
|
||||||
|
|
||||||
# 获取项目根目录并创建 logs 目录
|
# 获取项目根目录并创建 logs 目录
|
||||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
log_dir = os.path.join(base_dir, 'logs')
|
log_dir = os.path.join(base_dir, 'logs')
|
||||||
if not os.path.exists(log_dir):
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir)
|
|
||||||
|
|
||||||
# 使用日期作为文件名
|
log_file = os.path.join(log_dir, 'app.log')
|
||||||
log_filename = f"{datetime.now().strftime('%Y-%m-%d')}.log"
|
|
||||||
log_file = os.path.join(log_dir, log_filename)
|
|
||||||
|
|
||||||
# 设置格式
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'[%(asctime)s] %(name)s [%(levelname)s] %(filename)s[line:%(lineno)d]: %(message)s'
|
'[%(asctime)s] %(name)s [%(levelname)s] %(filename)s[line:%(lineno)d]: %(message)s'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 文件处理器(使用RotatingFileHandler支持日志轮转)
|
# 按天切分的文件处理器,保留最近 7 天
|
||||||
# 每个日志文件最大10MB,保留5个备份
|
file_handler = TimedRotatingFileHandler(
|
||||||
file_handler = RotatingFileHandler(
|
|
||||||
log_file,
|
log_file,
|
||||||
maxBytes=10*1024*1024,
|
when='midnight',
|
||||||
backupCount=5,
|
interval=1,
|
||||||
|
backupCount=7,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
self.addHandler(file_handler)
|
|
||||||
|
|
||||||
# 控制台处理器
|
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
self.addHandler(console_handler)
|
|
||||||
|
|
||||||
|
# 获取标准 logger
|
||||||
|
logger = logging.getLogger('ProxyPool')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# 实例化一个默认 logger 供外部直接使用
|
# 防止重复添加 handler(如模块重导入)
|
||||||
logger = LogHandler()
|
if not logger.handlers:
|
||||||
|
logger.addHandler(file_handler)
|
||||||
if __name__ == '__main__':
|
logger.addHandler(console_handler)
|
||||||
logger.info('这是一条按日期存储的日志测试')
|
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ class ValidationQueue:
|
|||||||
self._signal: asyncio.Queue[None] = asyncio.Queue()
|
self._signal: asyncio.Queue[None] = asyncio.Queue()
|
||||||
self._workers: list[asyncio.Task] = []
|
self._workers: list[asyncio.Task] = []
|
||||||
self._running = False
|
self._running = False
|
||||||
self._db_lock = asyncio.Lock()
|
|
||||||
self._pending_count = 0
|
self._pending_count = 0
|
||||||
self._condition = asyncio.Condition()
|
self._condition = asyncio.Condition()
|
||||||
|
|
||||||
@@ -72,9 +71,9 @@ class ValidationQueue:
|
|||||||
for i in range(self.worker_count):
|
for i in range(self.worker_count):
|
||||||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||||||
|
|
||||||
# 唤醒 Worker 处理恢复的 pending 任务
|
# 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理)
|
||||||
if pending:
|
if pending:
|
||||||
for _ in range(min(pending, self.worker_count)):
|
for _ in range(self.worker_count):
|
||||||
self._signal.put_nowait(None)
|
self._signal.put_nowait(None)
|
||||||
|
|
||||||
logger.info(f"ValidationQueue started with {self.worker_count} workers")
|
logger.info(f"ValidationQueue started with {self.worker_count} workers")
|
||||||
@@ -92,7 +91,6 @@ class ValidationQueue:
|
|||||||
|
|
||||||
async def submit(self, proxies: list[ProxyRaw]):
|
async def submit(self, proxies: list[ProxyRaw]):
|
||||||
"""提交代理到验证队列(持久化 + 唤醒 Worker)"""
|
"""提交代理到验证队列(持久化 + 唤醒 Worker)"""
|
||||||
async with self._db_lock:
|
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
inserted = await self.task_repo.insert_batch(db, proxies)
|
inserted = await self.task_repo.insert_batch(db, proxies)
|
||||||
if inserted:
|
if inserted:
|
||||||
@@ -117,15 +115,18 @@ class ValidationQueue:
|
|||||||
self._signal.task_done()
|
self._signal.task_done()
|
||||||
if not self._running:
|
if not self._running:
|
||||||
break
|
break
|
||||||
await self._process_one_task(worker_id)
|
# 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死
|
||||||
|
while self._running:
|
||||||
|
processed = await self._process_one_task(worker_id)
|
||||||
|
if not processed:
|
||||||
|
break
|
||||||
|
|
||||||
async def _process_one_task(self, worker_id: int):
|
async def _process_one_task(self, worker_id: int) -> bool:
|
||||||
"""从数据库取一个任务并验证"""
|
"""从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。"""
|
||||||
async with self._db_lock:
|
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
task = await self.task_repo.acquire_pending(db)
|
task = await self.task_repo.acquire_pending(db)
|
||||||
if not task:
|
if not task:
|
||||||
return
|
return False
|
||||||
|
|
||||||
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
|
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
|
||||||
try:
|
try:
|
||||||
@@ -133,11 +134,9 @@ class ValidationQueue:
|
|||||||
proxy.ip, proxy.port, proxy.protocol
|
proxy.ip, proxy.port, proxy.protocol
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Worker {worker_id} validation error: {e}")
|
logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True)
|
||||||
is_valid, latency = False, 0.0
|
is_valid, latency = False, 0.0
|
||||||
|
|
||||||
async with self._db_lock:
|
|
||||||
async with get_db() as db:
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
await self.proxy_repo.insert_or_update(
|
await self.proxy_repo.insert_or_update(
|
||||||
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
|
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
|
||||||
@@ -150,6 +149,11 @@ class ValidationQueue:
|
|||||||
self.valid_count += 1
|
self.valid_count += 1
|
||||||
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
|
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
|
||||||
else:
|
else:
|
||||||
|
# 对已有代理扣分,分数<=0时自动删除
|
||||||
|
await self.proxy_repo.update_score(
|
||||||
|
db, proxy.ip, proxy.port, self.score_invalid,
|
||||||
|
self.score_min, self.score_max
|
||||||
|
)
|
||||||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||||||
self.invalid_count += 1
|
self.invalid_count += 1
|
||||||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||||||
@@ -157,6 +161,7 @@ class ValidationQueue:
|
|||||||
self._pending_count = max(0, self._pending_count - 1)
|
self._pending_count = max(0, self._pending_count - 1)
|
||||||
if self._pending_count == 0:
|
if self._pending_count == 0:
|
||||||
self._condition.notify_all()
|
self._condition.notify_all()
|
||||||
|
return True
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
self.valid_count = 0
|
self.valid_count = 0
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ class ProxyRaw:
|
|||||||
self.protocol = self.protocol.lower().strip()
|
self.protocol = self.protocol.lower().strip()
|
||||||
if self.protocol not in ("http", "https", "socks4", "socks5"):
|
if self.protocol not in ("http", "https", "socks4", "socks5"):
|
||||||
self.protocol = "http"
|
self.protocol = "http"
|
||||||
|
if not isinstance(self.port, int) or not (1 <= self.port <= 65535):
|
||||||
|
raise ValueError(f"port must be between 1 and 65535, got {self.port}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from app.core.plugin_system import BaseCrawlerPlugin
|
from app.core.plugin_system import BaseCrawlerPlugin
|
||||||
|
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
|||||||
]
|
]
|
||||||
self.urls: List[str] = []
|
self.urls: List[str] = []
|
||||||
self.current_url: str = ""
|
self.current_url: str = ""
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
def get_headers(self) -> dict:
|
def get_headers(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@@ -28,15 +29,23 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
|||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""获取或创建复用的 AsyncClient"""
|
||||||
|
if self._client is None or self._client.is_closed:
|
||||||
|
transport = httpx.AsyncHTTPTransport(retries=0)
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
transport=transport,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str:
|
async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str:
|
||||||
"""异步抓取指定 URL 的 HTML 内容"""
|
"""异步抓取指定 URL 的 HTML 内容"""
|
||||||
from app.core.log import logger
|
from app.core.log import logger
|
||||||
headers = self.get_headers()
|
client = self._get_client()
|
||||||
transport = httpx.AsyncHTTPTransport(retries=0)
|
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
async with httpx.AsyncClient(headers=headers, transport=transport, follow_redirects=True) as client:
|
|
||||||
try:
|
try:
|
||||||
response = await client.get(url, timeout=timeout)
|
response = await client.get(url, headers=self.get_headers(), timeout=timeout)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
content = response.content
|
content = response.content
|
||||||
encoding = response.encoding
|
encoding = response.encoding
|
||||||
@@ -64,3 +73,9 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
|
|||||||
|
|
||||||
tasks = [_fetch_limited(url) for url in urls]
|
tasks = [_fetch_limited(url) for url in urls]
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭复用的 HTTP 客户端"""
|
||||||
|
if self._client and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class ProxyRepository:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"insert_or_update proxy failed: {e}")
|
logger.error(f"insert_or_update proxy failed: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -89,11 +89,12 @@ class ProxyRepository:
|
|||||||
(new_score, ip, port),
|
(new_score, ip, port),
|
||||||
)
|
)
|
||||||
if new_score <= 0:
|
if new_score <= 0:
|
||||||
await db.execute("DELETE FROM proxies WHERE score <= 0")
|
# 只删除当前代理,避免误删其他无效代理
|
||||||
|
await db.execute("DELETE FROM proxies WHERE ip = ? AND port = ?", (ip, port))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_score failed: {e}")
|
logger.error(f"update_score failed: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -111,7 +112,7 @@ class ProxyRepository:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_response_time failed: {e}")
|
logger.error(f"update_response_time failed: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -255,12 +256,12 @@ class ProxyRepository:
|
|||||||
async def get_today_new_count(db: aiosqlite.Connection) -> int:
|
async def get_today_new_count(db: aiosqlite.Connection) -> int:
|
||||||
try:
|
try:
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
"SELECT COUNT(*) FROM proxies WHERE DATE(last_check) = DATE('now', 'localtime')"
|
"SELECT COUNT(*) FROM proxies WHERE DATE(created_at) = DATE('now', 'localtime')"
|
||||||
) as cursor:
|
) as cursor:
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
return row[0] if row else 0
|
return row[0] if row else 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_today_new_count failed: {e}")
|
logger.error(f"get_today_new_count failed: {e}", exc_info=True)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -279,5 +280,5 @@ class ProxyRepository:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
return db.total_changes
|
return db.total_changes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"clean_expired failed: {e}")
|
logger.error(f"clean_expired failed: {e}", exc_info=True)
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -3,9 +3,11 @@ import asyncio
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from app.core.db import get_db
|
from app.core.db import get_db
|
||||||
from app.repositories.proxy_repo import ProxyRepository
|
from app.repositories.proxy_repo import ProxyRepository
|
||||||
|
from app.repositories.task_repo import ValidationTaskRepository
|
||||||
from app.core.tasks.queue import ValidationQueue
|
from app.core.tasks.queue import ValidationQueue
|
||||||
from app.core.config import settings as app_settings
|
from app.core.config import settings as app_settings
|
||||||
from app.core.log import logger
|
from app.core.log import logger
|
||||||
|
from app.models.domain import ProxyRaw
|
||||||
|
|
||||||
|
|
||||||
class SchedulerService:
|
class SchedulerService:
|
||||||
@@ -47,6 +49,11 @@ class SchedulerService:
|
|||||||
await self.validation_queue.stop()
|
await self.validation_queue.stop()
|
||||||
logger.info("Scheduler stopped")
|
logger.info("Scheduler stopped")
|
||||||
|
|
||||||
|
def cancel_validate_task(self):
|
||||||
|
"""取消正在执行的全量验证后台任务"""
|
||||||
|
if self._validate_task and not self._validate_task.done():
|
||||||
|
self._validate_task.cancel()
|
||||||
|
|
||||||
async def validate_all_now(self):
|
async def validate_all_now(self):
|
||||||
"""立即执行一次全量验证(后台运行,不阻塞)"""
|
"""立即执行一次全量验证(后台运行,不阻塞)"""
|
||||||
if self._validate_task and not self._validate_task.done():
|
if self._validate_task and not self._validate_task.done():
|
||||||
@@ -57,9 +64,12 @@ class SchedulerService:
|
|||||||
"""定时循环"""
|
"""定时循环"""
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
|
# 清理过期任务,防止内存无限增长
|
||||||
|
from app.services.task_service import task_service
|
||||||
|
task_service.cleanup_old_tasks()
|
||||||
await self._do_validate_all()
|
await self._do_validate_all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduler loop error: {e}")
|
logger.error(f"Scheduler loop error: {e}", exc_info=True)
|
||||||
# 等待下一次
|
# 等待下一次
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
||||||
@@ -71,16 +81,19 @@ class SchedulerService:
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting scheduled validation for all proxies")
|
logger.info("Starting scheduled validation for all proxies")
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
|
# 清理 7 天前的验证任务记录,防止表无限增长
|
||||||
|
cleaned = await ValidationTaskRepository.cleanup_old(db, days=7)
|
||||||
|
if cleaned:
|
||||||
|
logger.info(f"Cleaned up {cleaned} old validation tasks")
|
||||||
proxies = await self.proxy_repo.list_all(db)
|
proxies = await self.proxy_repo.list_all(db)
|
||||||
if not proxies:
|
if not proxies:
|
||||||
logger.info("No proxies to validate")
|
logger.info("No proxies to validate")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Validating {len(proxies)} proxies from database")
|
logger.info(f"Validating {len(proxies)} proxies from database")
|
||||||
from app.models.domain import ProxyRaw
|
# 批量提交到验证队列,不再阻塞等待 drain
|
||||||
|
|
||||||
# 批量提交到验证队列
|
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
|
total_batches = (len(proxies) - 1) // batch_size + 1
|
||||||
for i in range(0, len(proxies), batch_size):
|
for i in range(0, len(proxies), batch_size):
|
||||||
if not self.running:
|
if not self.running:
|
||||||
break
|
break
|
||||||
@@ -88,10 +101,8 @@ class SchedulerService:
|
|||||||
await self.validation_queue.submit([
|
await self.validation_queue.submit([
|
||||||
ProxyRaw(p.ip, p.port, p.protocol) for p in batch
|
ProxyRaw(p.ip, p.port, p.protocol) for p in batch
|
||||||
])
|
])
|
||||||
# 等待当前批次处理完
|
logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}")
|
||||||
await self.validation_queue.drain()
|
|
||||||
logger.info(f"Validated batch {i//batch_size + 1}/{(len(proxies)-1)//batch_size + 1}")
|
|
||||||
|
|
||||||
logger.info("Scheduled validation completed")
|
logger.info("Scheduled validation batches submitted")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduled validation error: {e}")
|
logger.error(f"Scheduled validation error: {e}", exc_info=True)
|
||||||
|
|||||||
62
app/services/task_service.py
Normal file
62
app/services/task_service.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from app.core.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TaskService:
|
||||||
|
"""异步任务状态管理器"""
|
||||||
|
|
||||||
|
def __init__(self, max_age_seconds: int = 3600):
|
||||||
|
self._tasks: Dict[str, dict] = {}
|
||||||
|
self._max_age_seconds = max_age_seconds
|
||||||
|
|
||||||
|
def create_task(self, task_type: str) -> str:
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now()
|
||||||
|
self._tasks[task_id] = {
|
||||||
|
"id": task_id,
|
||||||
|
"type": task_type,
|
||||||
|
"status": "pending",
|
||||||
|
"message": "",
|
||||||
|
"data": {},
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def update_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
) -> bool:
|
||||||
|
task = self._tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return False
|
||||||
|
if status is not None:
|
||||||
|
task["status"] = status
|
||||||
|
if message is not None:
|
||||||
|
task["message"] = message
|
||||||
|
if data is not None:
|
||||||
|
task["data"].update(data)
|
||||||
|
task["updated_at"] = datetime.now()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> Optional[dict]:
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def cleanup_old_tasks(self) -> int:
|
||||||
|
cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds)
|
||||||
|
to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff]
|
||||||
|
for tid in to_remove:
|
||||||
|
del self._tasks[tid]
|
||||||
|
if to_remove:
|
||||||
|
logger.info(f"TaskService cleaned up {len(to_remove)} old tasks")
|
||||||
|
return len(to_remove)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局任务服务实例
|
||||||
|
task_service = TaskService()
|
||||||
@@ -104,8 +104,10 @@ class ValidatorService:
|
|||||||
return True, latency
|
return True, latency
|
||||||
return False, 0.0
|
return False, 0.0
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
# ClientSession 的 async with 退出时会自动关闭 connector,无需手动重复关闭
|
||||||
|
pass
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭共享的 HTTP ClientSession"""
|
"""关闭共享的 HTTP ClientSession"""
|
||||||
|
if self._http_session and not self._http_session.closed:
|
||||||
await self._http_session.close()
|
await self._http_session.close()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest 配置文件和 fixtures"""
|
"""pytest 配置文件和 fixtures"""
|
||||||
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
@@ -7,16 +8,30 @@ from httpx import AsyncClient, ASGITransport
|
|||||||
from app.api import create_app
|
from app.api import create_app
|
||||||
from app.core.db import init_db, get_db
|
from app.core.db import init_db, get_db
|
||||||
from app.repositories.proxy_repo import ProxyRepository
|
from app.repositories.proxy_repo import ProxyRepository
|
||||||
|
from app.models.domain import ProxyRaw
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
async def app():
|
async def app():
|
||||||
"""创建应用实例"""
|
"""创建应用实例"""
|
||||||
# 初始化测试数据库
|
# 初始化测试数据库并清空历史数据,避免任务残留或设置状态导致 drain() 卡住
|
||||||
await init_db()
|
await init_db()
|
||||||
app = create_app()
|
async with get_db() as db:
|
||||||
async with app.router.lifespan_context(app):
|
await db.execute("DELETE FROM validation_tasks")
|
||||||
yield app
|
await db.execute("DELETE FROM proxies")
|
||||||
|
await db.execute("DELETE FROM settings")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 清理全局内存状态,防止跨测试污染
|
||||||
|
from app.services.task_service import task_service
|
||||||
|
task_service._tasks.clear()
|
||||||
|
|
||||||
|
test_app = create_app()
|
||||||
|
async with test_app.router.lifespan_context(test_app):
|
||||||
|
yield test_app
|
||||||
|
|
||||||
|
# 给 aiosqlite / aiohttp 后台线程留出收尾时间,降低 Event loop closed 警告概率
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@@ -47,3 +62,27 @@ async def sample_proxy(db, proxy_repo):
|
|||||||
yield {"ip": "192.168.1.1", "port": 8080, "protocol": "http", "score": 50}
|
yield {"ip": "192.168.1.1", "port": 8080, "protocol": "http", "score": 50}
|
||||||
# 清理
|
# 清理
|
||||||
await proxy_repo.delete(db, "192.168.1.1", 8080)
|
await proxy_repo.delete(db, "192.168.1.1", 8080)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def mock_external_requests(monkeypatch):
|
||||||
|
"""
|
||||||
|
自动在所有测试中 mock 外部网络请求:
|
||||||
|
1. 插件爬取返回固定测试代理,避免真实 HTTP 请求
|
||||||
|
2. 代理验证瞬间成功,避免连接超时等待
|
||||||
|
"""
|
||||||
|
from app.services.plugin_service import PluginService
|
||||||
|
from app.services.validator_service import ValidatorService
|
||||||
|
|
||||||
|
async def _mock_run_plugin(self, plugin_id: str):
|
||||||
|
return [ProxyRaw("192.168.100.10", 8080, "http")]
|
||||||
|
|
||||||
|
async def _mock_run_all_plugins(self):
|
||||||
|
return [ProxyRaw("192.168.100.10", 8080, "http")]
|
||||||
|
|
||||||
|
async def _mock_validate(self, ip: str, port: int, protocol: str = "http"):
|
||||||
|
return True, 1.23
|
||||||
|
|
||||||
|
monkeypatch.setattr(PluginService, "run_plugin", _mock_run_plugin)
|
||||||
|
monkeypatch.setattr(PluginService, "run_all_plugins", _mock_run_all_plugins)
|
||||||
|
monkeypatch.setattr(ValidatorService, "validate", _mock_validate)
|
||||||
|
|||||||
@@ -117,19 +117,33 @@ class TestPluginsAPI:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_crawl_plugin(self, client):
|
async def test_crawl_plugin(self, client):
|
||||||
"""测试 POST /api/plugins/{id}/crawl"""
|
"""测试 POST /api/plugins/{id}/crawl - 异步任务模式"""
|
||||||
|
import asyncio
|
||||||
response = await client.get("/api/plugins")
|
response = await client.get("/api/plugins")
|
||||||
plugins = response.json()["data"]["plugins"]
|
plugins = response.json()["data"]["plugins"]
|
||||||
if not plugins:
|
if not plugins:
|
||||||
pytest.skip("没有可用的插件")
|
pytest.skip("没有可用的插件")
|
||||||
|
|
||||||
plugin_id = plugins[0]["id"]
|
plugin_id = plugins[0]["id"]
|
||||||
# 这个测试可能需要较长时间,设置较短的超时
|
|
||||||
response = await client.post(f"/api/plugins/{plugin_id}/crawl")
|
response = await client.post(f"/api/plugins/{plugin_id}/crawl")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
assert "proxy_count" in data["data"]
|
assert "task_id" in data["data"]
|
||||||
|
|
||||||
|
task_id = data["data"]["task_id"]
|
||||||
|
# 轮询任务状态
|
||||||
|
task_data = None
|
||||||
|
for _ in range(10):
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
res = await client.get(f"/api/tasks/{task_id}")
|
||||||
|
assert res.status_code == 200
|
||||||
|
task_data = res.json()["data"]
|
||||||
|
if task_data["status"] in ("completed", "failed"):
|
||||||
|
break
|
||||||
|
|
||||||
|
assert task_data is not None
|
||||||
|
assert task_data["status"] == "completed"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_crawl_nonexistent_plugin(self, client):
|
async def test_crawl_nonexistent_plugin(self, client):
|
||||||
@@ -139,9 +153,24 @@ class TestPluginsAPI:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_crawl_all_plugins(self, client):
|
async def test_crawl_all_plugins(self, client):
|
||||||
"""测试 POST /api/plugins/crawl-all"""
|
"""测试 POST /api/plugins/crawl-all - 异步任务模式"""
|
||||||
|
import asyncio
|
||||||
response = await client.post("/api/plugins/crawl-all")
|
response = await client.post("/api/plugins/crawl-all")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
assert "total_crawled" in data["data"]
|
assert "task_id" in data["data"]
|
||||||
|
|
||||||
|
task_id = data["data"]["task_id"]
|
||||||
|
# 轮询任务状态
|
||||||
|
task_data = None
|
||||||
|
for _ in range(10):
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
res = await client.get(f"/api/tasks/{task_id}")
|
||||||
|
assert res.status_code == 200
|
||||||
|
task_data = res.json()["data"]
|
||||||
|
if task_data["status"] in ("completed", "failed"):
|
||||||
|
break
|
||||||
|
|
||||||
|
assert task_data is not None
|
||||||
|
assert task_data["status"] == "completed"
|
||||||
|
|||||||
Reference in New Issue
Block a user