refactor: 全面重构核心架构,消除反复修改的根因

- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool
- 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob)
- 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计)
- 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob
- 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源
- 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理
- 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效)
- 前端 Store 强制写后重新拉取,消除乐观更新数据不同步
- 删除 queue.py / task_repo.py / task_service.py
- 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
祀梦
2026-04-04 22:36:57 +08:00
parent 4ef7931941
commit b972b64616
33 changed files with 1168 additions and 864 deletions

View File

@@ -0,0 +1,73 @@
import { ref } from 'vue'
import { schedulerService } from '../services/schedulerService'
const schedulerRunning = ref(false)
const schedulerLoading = ref(false)
const validating = ref(false)
export function useScheduler() {
async function fetchStatus() {
try {
const response = await schedulerService.getStatus()
if (response.code === 200) {
schedulerRunning.value = response.data.running
}
} catch (error) {
console.error('获取调度器状态失败:', error)
}
}
async function startScheduler(onSuccess) {
schedulerLoading.value = true
try {
const response = await schedulerService.start()
if (response.code === 200) {
schedulerRunning.value = true
onSuccess?.('自动验证已启动')
}
} catch (error) {
console.error('启动调度器失败:', error)
} finally {
schedulerLoading.value = false
}
}
async function stopScheduler(onSuccess) {
schedulerLoading.value = true
try {
const response = await schedulerService.stop()
if (response.code === 200) {
schedulerRunning.value = false
onSuccess?.('自动验证已停止')
}
} catch (error) {
console.error('停止调度器失败:', error)
} finally {
schedulerLoading.value = false
}
}
async function validateNow(onSuccess) {
validating.value = true
try {
const response = await schedulerService.validateNow()
if (response.code === 200) {
onSuccess?.('全量验证已启动')
}
} catch (error) {
console.error('启动验证失败:', error)
} finally {
validating.value = false
}
}
return {
schedulerRunning,
schedulerLoading,
validating,
fetchStatus,
startScheduler,
stopScheduler,
validateNow,
}
}

View File

@@ -0,0 +1,28 @@
import { tasksAPI } from '../api'
const POLL_INTERVAL = 1000
const MAX_POLL_ATTEMPTS = 30
/**
* 轮询任务状态直到完成或失败
* @param {string} taskId
* @returns {Promise<object>}
*/
export async function pollTaskStatus(taskId) {
for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) {
await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL))
const response = await tasksAPI.getTaskStatus(taskId)
if (response.code !== 200) {
continue
}
const status = response.data.status
if (status === 'completed' || status === 'failed') {
return response
}
}
return {
code: 200,
message: '任务进行中,请稍后刷新查看结果',
data: { task_id: taskId, status: 'running' }
}
}

View File

@@ -1,26 +1,5 @@
import { pluginsAPI, tasksAPI } from '../api'
const POLL_INTERVAL = 1000
const MAX_POLL_ATTEMPTS = 30
async function pollTaskStatus(taskId) {
for (let i = 0; i < MAX_POLL_ATTEMPTS; i++) {
await new Promise(resolve => setTimeout(resolve, POLL_INTERVAL))
const response = await tasksAPI.getTaskStatus(taskId)
if (response.code !== 200) {
continue
}
const status = response.data.status
if (status === 'completed' || status === 'failed') {
return response
}
}
return {
code: 200,
message: '爬取任务进行中,请稍后刷新查看结果',
data: { task_id: taskId, status: 'running' }
}
}
import { pluginsAPI } from '../api'
import { pollTaskStatus } from '../composables/useTaskPolling'
export const pluginService = {
async getPlugins() {
@@ -54,14 +33,18 @@ export const pluginService = {
async crawlAll() {
const startRes = await pluginsAPI.crawlAll()
if (startRes.code !== 200 || !startRes.data?.task_id) {
if (startRes.code !== 200 || !startRes.data?.task_ids?.length) {
return startRes
}
const finalRes = await pollTaskStatus(startRes.data.task_id)
// 批量轮询所有任务,取最后一个完成的结果
const results = await Promise.all(
startRes.data.task_ids.map(tid => pollTaskStatus(tid))
)
const last = results[results.length - 1]
return {
code: finalRes.code,
message: finalRes.data?.message || finalRes.message,
data: finalRes.data?.data || finalRes.data
code: last.code,
message: last.data?.message || last.message,
data: last.data?.data || last.data
}
}
}

View File

@@ -40,7 +40,7 @@ export const usePluginsStore = defineStore('plugins', () => {
/**
* 切换插件启用状态
* @param {string|number} pluginId
* @param {string} pluginId
* @param {boolean} enabled
* @returns {Promise<boolean>}
*/
@@ -48,10 +48,7 @@ export const usePluginsStore = defineStore('plugins', () => {
try {
const response = await pluginService.togglePlugin(pluginId, enabled)
if (response.code === 200) {
const plugin = plugins.value.find(p => p.id === pluginId)
if (plugin) {
plugin.enabled = enabled
}
await fetchPlugins() // 强制重新拉取最新状态
return true
}
} catch (error) {
@@ -62,7 +59,7 @@ export const usePluginsStore = defineStore('plugins', () => {
/**
* 触发插件爬取
* @param {string|number} pluginId
* @param {string} pluginId
* @returns {Promise<boolean>}
*/
async function crawlPlugin(pluginId) {
@@ -77,7 +74,7 @@ export const usePluginsStore = defineStore('plugins', () => {
/**
* 根据 ID 获取插件
* @param {string|number} id
* @param {string} id
* @returns {object|undefined}
*/
function getPluginById(id) {

View File

@@ -80,7 +80,10 @@ export const useProxyStore = defineStore('proxy', () => {
async function deleteProxy(ip, port) {
try {
const response = await proxyService.deleteProxy(ip, port)
return response.code === 200
if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表
return true
}
} catch (error) {
console.error('删除代理失败:', error)
return false
@@ -98,6 +101,7 @@ export const useProxyStore = defineStore('proxy', () => {
try {
const response = await proxyService.batchDelete(proxyList)
if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表
return response.data.deleted_count
}
} catch (error) {
@@ -114,6 +118,7 @@ export const useProxyStore = defineStore('proxy', () => {
try {
const response = await proxyService.cleanInvalid()
if (response.code === 200) {
await fetchProxies({ page: 1, page_size: 20 }) // 刷新列表
return response.data.deleted_count
}
} catch (error) {

View File

@@ -191,18 +191,25 @@ import {
Refresh
} from '@element-plus/icons-vue'
import { settingService } from '../services/settingService'
import { schedulerService } from '../services/schedulerService'
import { useScheduler } from '../composables/useScheduler'
import PageHeader from '../components/PageHeader.vue'
// ==================== Composables ====================
const {
schedulerRunning,
schedulerLoading,
validating,
fetchStatus,
startScheduler,
stopScheduler,
validateNow
} = useScheduler()
// ==================== 状态 ====================
const loading = ref(false)
const saving = ref(false)
const formRef = ref(null)
const schedulerRunning = ref(false)
const schedulerLoading = ref(false)
const validating = ref(false)
const settings = reactive({
crawl_timeout: 30,
validation_timeout: 10,
@@ -250,52 +257,13 @@ async function fetchSettings() {
}
}
async function fetchSchedulerStatus() {
try {
const response = await schedulerService.getStatus()
if (response.code === 200) {
schedulerRunning.value = response.data.running
}
} catch (error) {
console.error('获取调度器状态失败:', error)
}
}
// ==================== 调度器控制 ====================
async function handleStartScheduler() {
schedulerLoading.value = true
try {
const response = await schedulerService.start()
if (response.code === 200) {
schedulerRunning.value = true
ElMessage.success('自动验证已启动')
} else {
ElMessage.error('启动失败')
}
} catch (error) {
console.error('启动调度器失败:', error)
ElMessage.error('启动失败')
} finally {
schedulerLoading.value = false
}
await startScheduler((msg) => ElMessage.success(msg))
}
async function handleStopScheduler() {
schedulerLoading.value = true
try {
const response = await schedulerService.stop()
if (response.code === 200) {
schedulerRunning.value = false
ElMessage.success('自动验证已停止')
} else {
ElMessage.error('停止失败')
}
} catch (error) {
console.error('停止调度器失败:', error)
ElMessage.error('停止失败')
} finally {
schedulerLoading.value = false
}
await stopScheduler((msg) => ElMessage.success(msg))
}
async function handleValidateNow() {
@@ -309,21 +277,12 @@ async function handleValidateNow() {
type: 'info'
}
)
validating.value = true
const response = await schedulerService.validateNow()
if (response.code === 200) {
ElMessage.success('全量验证已启动,请在日志中查看进度')
} else {
ElMessage.error('启动验证失败')
}
await validateNow((msg) => ElMessage.success(msg))
} catch (error) {
if (error !== 'cancel') {
console.error('启动验证失败:', error)
ElMessage.error('启动验证失败')
}
} finally {
validating.value = false
}
}
@@ -338,8 +297,7 @@ async function handleSave() {
if (response.code === 200) {
ElMessage.success('配置保存成功')
// 刷新调度器状态
await fetchSchedulerStatus()
await fetchStatus()
} else {
ElMessage.error('配置保存失败')
}
@@ -354,7 +312,7 @@ async function handleSave() {
// ==================== 生命周期 ====================
onMounted(() => {
fetchSettings()
fetchSchedulerStatus()
fetchStatus()
})
</script>

View File

@@ -3,10 +3,8 @@ from fastapi import Request
from app.services.proxy_service import ProxyService
from app.services.plugin_service import PluginService
from app.services.scheduler_service import SchedulerService
from app.services.validator_service import ValidatorService
from app.repositories.proxy_repo import ProxyRepository
from app.core.tasks.queue import ValidationQueue
from app.core.config import settings as app_settings
from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, AsyncWorkerPool
def get_proxy_service() -> ProxyService:
@@ -18,36 +16,16 @@ def get_plugin_service() -> PluginService:
def get_scheduler_service(request: Request) -> SchedulerService:
return request.app.state.scheduler_service
return request.app.state.scheduler
def get_validation_queue(request: Request) -> ValidationQueue:
return request.app.state.validation_queue
def get_executor(request: Request) -> JobExecutor:
return request.app.state.executor
def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService:
"""在应用启动时创建 SchedulerService非请求上下文
def get_plugin_runner(request: Request) -> PluginRunner:
return request.app.state.plugin_runner
db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。
"""
db_settings = db_settings or {}
validator = ValidatorService(
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout),
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
)
proxy_repo = ProxyRepository()
queue = ValidationQueue(
validator=validator,
proxy_repo=proxy_repo,
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
score_valid=app_settings.score_valid,
score_invalid=app_settings.score_invalid,
score_min=app_settings.score_min,
score_max=app_settings.score_max,
)
svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo)
svc.interval_minutes = db_settings.get(
"validate_interval_minutes", 30
)
return svc
def get_worker_pool(request: Request) -> AsyncWorkerPool:
return request.app.state.worker_pool

View File

@@ -1,56 +1,118 @@
"""应用生命周期管理"""
import asyncio
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI
from app.core.db import init_db, get_db
from app.core.config import settings as app_settings
from app.core.log import logger
from app.api.deps import create_scheduler_service
from app.core.execution import AsyncWorkerPool, JobExecutor
from app.core.plugin_system.registry import registry
from app.repositories.proxy_repo import ProxyRepository
from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS
from app.services.validator_service import ValidatorService
from app.services.plugin_runner import PluginRunner
from app.services.scheduler_service import SchedulerService
settings_repo = SettingsRepository()
proxy_repo = ProxyRepository()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用启动和关闭时的生命周期管理"""
# 初始化数据库
await init_db()
# 加载设置并决定是否启动调度器
async def _load_settings() -> dict:
db_settings = DEFAULT_SETTINGS.copy()
try:
async with get_db() as db:
db_settings = await settings_repo.get_all(db)
except Exception as e:
logger.error(f"Failed to load settings on startup: {e}")
return db_settings
# 创建调度器并挂载到 app.state使用 DB 设置覆盖默认值)
scheduler_service = create_scheduler_service(db_settings)
app.state.scheduler_service = scheduler_service
app.state.validation_queue = scheduler_service.validation_queue
if db_settings.get("auto_validate", True):
try:
await scheduler_service.start()
except Exception as e:
logger.error(f"Failed to start scheduler on startup: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用启动和关闭时的生命周期管理"""
await init_db()
db_settings = await _load_settings()
logger.info("API server started")
yield
async with AsyncExitStack() as stack:
# 验证器
validator = ValidatorService(
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
connect_timeout=app_settings.validator_connect_timeout,
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
)
# 关闭调度器
scheduler_service.cancel_validate_task()
await scheduler_service.stop()
# 验证 WorkerPool
async def validation_handler(proxy):
from app.models.domain import ProxyRaw
is_valid, latency = await validator.validate(
proxy.ip, proxy.port, proxy.protocol
)
async with get_db() as db:
if is_valid:
await proxy_repo.insert_or_update(
db, proxy.ip, proxy.port, proxy.protocol, score=app_settings.score_valid
)
if latency:
await proxy_repo.update_response_time(db, proxy.ip, proxy.port, latency)
else:
await proxy_repo.update_score(
db, proxy.ip, proxy.port, app_settings.score_invalid,
app_settings.score_min, app_settings.score_max
)
# 关闭验证器 HTTP session
worker_pool = AsyncWorkerPool(
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
handler=validation_handler,
name="ValidationPool",
)
await stack.enter_async_context(worker_pool)
# Job 执行器
executor = JobExecutor(worker_pool=worker_pool, max_concurrent_jobs=10)
await stack.enter_async_context(executor)
# 插件运行器
plugin_runner = PluginRunner(timeout=db_settings.get("crawl_timeout", 30))
# 调度器
scheduler = SchedulerService(
executor=executor,
interval_minutes=db_settings.get("validate_interval_minutes", 30),
)
# 挂载到 app.state
app.state.validator = validator
app.state.worker_pool = worker_pool
app.state.executor = executor
app.state.plugin_runner = plugin_runner
app.state.scheduler = scheduler
# 启动调度器
if db_settings.get("auto_validate", True):
try:
await scheduler.start()
except Exception as e:
logger.error(f"Failed to start scheduler on startup: {e}")
logger.info("API server started")
yield
# 停止调度器
await scheduler.stop()
# 取消所有运行中的 Job
await executor.cancel_all()
# AsyncExitStack 会自动关闭 executor 和 worker_pool
# 关闭验证器 session
try:
await scheduler_service.validation_queue.validator.close()
await validator.close()
except Exception:
pass
# 关闭所有插件的 HTTP 客户端
from app.core.plugin_system.registry import registry
for plugin in registry.list_plugins():
if hasattr(plugin, "close"):
try:

View File

@@ -43,7 +43,7 @@ def create_app() -> FastAPI:
@app.get("/health")
async def health_check():
from datetime import datetime
scheduler = app.state.scheduler_service
scheduler = app.state.scheduler
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),

View File

@@ -1,16 +1,25 @@
"""插件相关路由"""
import asyncio
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.services.plugin_service import PluginService
from app.services.scheduler_service import SchedulerService
from app.services.task_service import task_service
from app.api.deps import get_plugin_service, get_scheduler_service
from app.api.common import success_response, error_response, format_plugin
from app.core.log import logger
from app.services.plugin_runner import PluginRunner
from app.core.execution import JobExecutor, CrawlJob
from app.core.exceptions import PluginNotFoundException
from app.api.deps import get_plugin_service, get_plugin_runner, get_executor
from app.api.common import success_response
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
class ToggleRequest(BaseModel):
enabled: bool
class ConfigRequest(BaseModel):
config: dict
@router.get("")
async def list_plugins(service: PluginService = Depends(get_plugin_service)):
plugins = await service.list_plugins()
@@ -20,19 +29,13 @@ async def list_plugins(service: PluginService = Depends(get_plugin_service)):
@router.put("/{plugin_id}/toggle")
async def toggle_plugin(
plugin_id: str,
request: dict,
request: ToggleRequest,
service: PluginService = Depends(get_plugin_service),
):
enabled = request.get("enabled")
if enabled is None:
return error_response("缺少 enabled 参数", 400)
success = await service.toggle_plugin(plugin_id, enabled)
if not success:
return error_response("插件不存在", 404)
await service.toggle_plugin(plugin_id, request.enabled)
return success_response(
f"插件 {plugin_id}{'启用' if enabled else '禁用'}",
{"plugin_id": plugin_id, "enabled": enabled},
f"插件 {plugin_id}{'启用' if request.enabled else '禁用'}",
{"plugin_id": plugin_id, "enabled": request.enabled},
)
@@ -42,128 +45,104 @@ async def get_plugin_config(
service: PluginService = Depends(get_plugin_service),
):
config = await service.get_plugin_config(plugin_id)
if config is None:
return error_response("插件不存在", 404)
return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config})
@router.post("/{plugin_id}/config")
async def update_plugin_config(
plugin_id: str,
request: dict,
request: ConfigRequest,
service: PluginService = Depends(get_plugin_service),
):
config = request.get("config", {})
if not isinstance(config, dict):
return error_response("config 必须是对象", 400)
success = await service.update_plugin_config(plugin_id, config)
success = await service.update_plugin_config(plugin_id, request.config)
if not success:
return error_response("插件不存在或配置无效", 404)
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config})
raise PluginNotFoundException(plugin_id)
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": request.config})
@router.post("/{plugin_id}/crawl")
async def crawl_plugin(
plugin_id: str,
plugin_service: PluginService = Depends(get_plugin_service),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
plugin_runner: PluginRunner = Depends(get_plugin_runner),
executor: JobExecutor = Depends(get_executor),
):
plugin = plugin_service.get_plugin(plugin_id)
if not plugin:
return error_response("插件不存在", 404)
task_id = task_service.create_task("crawl_plugin")
async def _background_crawl():
try:
task_service.update_task(task_id, status="running")
results = await plugin_service.run_plugin(plugin_id)
if not results:
task_service.update_task(
task_id,
status="completed",
message=f"插件 {plugin_id} 爬取完成,未获取到代理",
data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0},
)
return
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
scheduler_service.validation_queue.reset_stats()
await scheduler_service.validation_queue.submit(results)
try:
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
except asyncio.TimeoutError:
pass
task_service.update_task(
task_id,
status="completed",
message=f"插件 {plugin_id} 爬取并验证完成",
data={
"plugin_id": plugin_id,
"proxy_count": len(results),
"valid_count": scheduler_service.validation_queue.valid_count,
"invalid_count": scheduler_service.validation_queue.invalid_count,
},
)
except Exception as e:
logger.error(f"Crawl plugin {plugin_id} failed: {e}")
task_service.update_task(
task_id, status="failed", message=f"插件爬取失败: {str(e)}"
)
asyncio.create_task(_background_crawl())
plugin = plugin_service.get_plugin_or_raise(plugin_id)
job = CrawlJob(
plugin_id=plugin_id,
plugin_runner=plugin_runner,
proxy_service=plugin_service,
validator_pool=executor.worker_pool,
)
job_id = executor.submit_job(job)
return success_response(
"爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id}
"爬取任务已启动", {"task_id": job_id, "plugin_id": plugin_id}
)
@router.post("/crawl-all")
async def crawl_all(
plugin_service: PluginService = Depends(get_plugin_service),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
plugin_runner: PluginRunner = Depends(get_plugin_runner),
executor: JobExecutor = Depends(get_executor),
):
task_id = task_service.create_task("crawl_all")
"""为所有启用插件创建 CrawlJob返回一个聚合任务 ID 保持 API 兼容"""
from app.core.plugin_system.registry import registry
job_ids = []
for plugin in registry.list_plugins():
if not plugin.enabled:
continue
job = CrawlJob(
plugin_id=plugin.name,
plugin_runner=plugin_runner,
proxy_service=plugin_service,
validator_pool=executor.worker_pool,
)
job_ids.append(executor.submit_job(job))
async def _background_crawl_all():
try:
task_service.update_task(task_id, status="running")
results = await plugin_service.run_all_plugins()
if not results:
task_service.update_task(
task_id,
status="completed",
message="所有插件爬取完成,未获取到代理",
data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0},
# 为了保持 API 兼容(前端/测试期望单个 task_id
# 创建一个虚拟的聚合 Job 用于状态查询,它内部会等待所有 crawl job 完成
aggregator = _create_crawl_all_aggregator(job_ids, executor)
agg_id = executor.submit_job(aggregator)
return success_response("爬取任务已启动", {"task_id": agg_id})
def _create_crawl_all_aggregator(job_ids, executor):
"""创建一个简单的聚合 Job查询所有子 Job 的状态汇总"""
from app.core.execution.job import Job
import asyncio
class CrawlAllAggregator(Job):
async def run(self):
# 等待所有子 job 完成(最多等 30 秒)
for _ in range(300):
all_done = all(
executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled")
for jid in job_ids
)
return
if all_done:
break
await asyncio.sleep(0.1)
total = 0
valid = 0
for jid in job_ids:
job = executor.get_job(jid)
if job and job.result:
total += job.result.get("proxy_count", 0)
valid += job.result.get("success_count", 0)
return {"total_crawled": total, "valid_count": valid, "invalid_count": 0}
logger.info(f"All plugins crawled {len(results)} unique proxies")
scheduler_service.validation_queue.reset_stats()
await scheduler_service.validation_queue.submit(results)
return CrawlAllAggregator()
try:
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
except asyncio.TimeoutError:
pass
task_service.update_task(
task_id,
status="completed",
message="所有插件爬取并验证完成",
data={
"total_crawled": len(results),
"valid_count": scheduler_service.validation_queue.valid_count,
"invalid_count": scheduler_service.validation_queue.invalid_count,
},
)
except Exception as e:
logger.error(f"Crawl all failed: {e}")
task_service.update_task(
task_id, status="failed", message=f"批量爬取失败: {str(e)}"
)
asyncio.create_task(_background_crawl_all())
return success_response("爬取任务已启动", {"task_id": task_id})
def format_plugin(plugin) -> dict:
return {
"id": plugin.id,
"name": plugin.display_name,
"display_name": plugin.display_name,
"description": plugin.description,
"enabled": plugin.enabled,
"last_run": plugin.last_run.isoformat() if plugin.last_run else None,
"success_count": plugin.success_count,
"failure_count": plugin.failure_count,
}

View File

@@ -1,11 +1,14 @@
"""代理相关路由(含统计信息)"""
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from app.services.proxy_service import ProxyService
from app.services.scheduler_service import SchedulerService
from app.models.schemas import ProxyListRequest, BatchDeleteRequest
from app.api.deps import get_proxy_service, get_scheduler_service
from app.api.common import success_response, error_response, format_proxy
from app.api.common import success_response, format_proxy
from app.core.exceptions import ProxyPoolException, ProxyNotFoundException
router = APIRouter(prefix="/api/proxies", tags=["proxies"])
@@ -49,7 +52,7 @@ async def list_proxies(
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
proxy = await service.get_random_proxy()
if not proxy:
return error_response("没有找到可用的代理", 404)
raise ProxyNotFoundException("", 0)
return success_response("获取随机代理成功", format_proxy(proxy))
@@ -61,9 +64,7 @@ async def export_proxies(
service: ProxyService = Depends(get_proxy_service),
):
if fmt not in ("csv", "txt", "json"):
return error_response("不支持的导出格式", 400)
from fastapi.responses import StreamingResponse
raise ProxyPoolException("不支持的导出格式", 400)
media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"}

View File

@@ -1,10 +1,11 @@
"""调度器相关路由"""
from fastapi import APIRouter, Depends
from app.services.scheduler_service import SchedulerService
from app.repositories.settings_repo import SettingsRepository
from app.core.db import get_db
from app.api.deps import get_scheduler_service
from app.api.common import success_response, error_response
from app.api.common import success_response
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
settings_repo = SettingsRepository()
@@ -39,8 +40,8 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
@router.post("/validate-now")
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
await scheduler.validate_all_now()
return success_response("已开始全量验证", {"started": True})
job_id = scheduler.validate_all_now()
return success_response("已开始全量验证", {"started": True, "job_id": job_id})
@router.get("/status")

View File

@@ -3,7 +3,7 @@ from fastapi import APIRouter, Request
from app.core.db import get_db
from app.repositories.settings_repo import SettingsRepository
from app.models.schemas import SettingsSchema
from app.api.common import success_response, error_response
from app.api.common import success_response
from app.core.log import logger
router = APIRouter(prefix="/api/settings", tags=["settings"])
@@ -22,14 +22,32 @@ async def save_settings(request: SettingsSchema, http_request: Request):
async with get_db() as db:
success = await settings_repo.save(db, request.model_dump())
if not success:
return error_response("保存设置失败", 500)
raise RuntimeError("保存设置失败")
# 热更新运行中调度器的间隔时间
scheduler = getattr(http_request.app.state, "scheduler_service", None)
if scheduler and scheduler.running:
scheduler = getattr(http_request.app.state, "scheduler", None)
worker_pool = getattr(http_request.app.state, "worker_pool", None)
validator = getattr(http_request.app.state, "validator", None)
if scheduler:
new_interval = request.validate_interval_minutes
if scheduler.interval_minutes != new_interval:
scheduler.interval_minutes = new_interval
logger.info(f"Scheduler interval updated to {new_interval} minutes")
# 热更新 Worker 池大小
if worker_pool and worker_pool.worker_count != request.default_concurrency:
await worker_pool.resize(request.default_concurrency)
logger.info(f"Worker pool resized to {request.default_concurrency}")
# 热更新验证器超时和并发(下次验证时生效)
if validator:
validator._init_timeout = request.validation_timeout
validator._init_connect_timeout = request.validation_timeout
validator._init_max_concurrency = request.default_concurrency
# 重新创建 semaphore 和 session
validator._semaphore = None
await validator.close()
logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}")
return success_response("保存设置成功", request.model_dump())

View File

@@ -1,14 +1,16 @@
"""异步任务状态路由"""
from fastapi import APIRouter
from app.services.task_service import task_service
from fastapi import APIRouter, Depends
from app.core.execution import JobExecutor
from app.api.deps import get_executor
from app.api.common import success_response, error_response
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
@router.get("/{task_id}")
async def get_task_status(task_id: str):
task = task_service.get_task(task_id)
if not task:
async def get_task_status(task_id: str, executor: JobExecutor = Depends(get_executor)):
job = executor.get_job(task_id)
if not job:
return error_response("任务不存在", 404)
return success_response("获取任务状态成功", task)
return success_response("获取任务状态成功", job.to_dict())

View File

@@ -84,22 +84,7 @@ async def init_db():
await db.execute("ALTER TABLE plugin_settings ADD COLUMN stats_json TEXT DEFAULT '{}'")
logger.info("Migrated: added stats_json column to plugin_settings")
# 验证任务队列表
await db.execute("""
CREATE TABLE IF NOT EXISTS validation_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
port INTEGER NOT NULL,
protocol TEXT DEFAULT 'http',
status TEXT DEFAULT 'pending',
result TEXT,
response_time_ms REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_status ON validation_tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_validation_created ON validation_tasks(created_at)")
# 验证任务队列表已废弃,不再创建;旧表保留在数据库中不影响功能
# 系统设置表
await db.execute("""
@@ -125,3 +110,27 @@ async def get_db() -> AsyncIterator[aiosqlite.Connection]:
yield db
finally:
await db.close()
@asynccontextmanager
async def transaction() -> AsyncIterator[aiosqlite.Connection]:
"""获取带有显式事务控制的数据库连接
用法:
async with transaction() as db:
await repo.update(db, ...)
# 如果抛出异常,自动 rollback
"""
ensure_db_dir()
db = await aiosqlite.connect(DB_PATH)
try:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA synchronous=NORMAL")
await db.execute("BEGIN")
yield db
await db.commit()
except Exception:
await db.rollback()
raise
finally:
await db.close()

View File

@@ -0,0 +1,13 @@
"""统一后台任务执行框架"""
from .job import JobStatus, Job, CrawlJob, ValidateAllJob
from .worker_pool import AsyncWorkerPool
from .executor import JobExecutor
__all__ = [
"JobStatus",
"Job",
"CrawlJob",
"ValidateAllJob",
"AsyncWorkerPool",
"JobExecutor",
]

View File

@@ -0,0 +1,138 @@
"""Job 执行器 - 统一管理所有后台 Job 的生命周期"""
import asyncio
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from app.core.execution.job import Job, JobStatus
from app.core.execution.worker_pool import AsyncWorkerPool
from app.core.log import logger
class JobExecutor:
"""后台任务执行器
- 维护内存中的 Job 状态表
- 限制最大并发 Job 数
- 自动清理过期 Job
"""
def __init__(
self,
worker_pool: Optional[AsyncWorkerPool] = None,
max_concurrent_jobs: int = 10,
cleanup_interval_seconds: int = 300,
):
self.worker_pool = worker_pool
self.max_concurrent_jobs = max_concurrent_jobs
self._jobs: Dict[str, Job] = {}
self._running = False
self._semaphore = asyncio.Semaphore(max_concurrent_jobs)
self._cleanup_interval = cleanup_interval_seconds
self._cleanup_task: Optional[asyncio.Task] = None
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
async def start(self) -> None:
if self._running:
return
self._running = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("JobExecutor started")
async def stop(self) -> None:
if not self._running:
return
self._running = False
await self.cancel_all()
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("JobExecutor stopped")
def submit_job(self, job: Job) -> str:
"""提交一个 Job 到后台执行"""
self._jobs[job.id] = job
asyncio.create_task(self._run_job(job))
return job.id
async def _run_job(self, job: Job) -> None:
async with self._semaphore:
try:
if job.is_cancelled:
logger.info(f"Job {job.id} was cancelled before running")
return
result = await job.run()
# 如果子类没有显式设置完成状态,自动设为 completed
if job.status not in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
job._set_completed(result)
logger.info(f"Job {job.id} completed: {result}")
except asyncio.CancelledError:
job.status = JobStatus.CANCELLED
job._touch()
logger.info(f"Job {job.id} cancelled during execution")
except Exception as e:
job._set_failed(str(e))
logger.error(f"Job {job.id} failed: {e}", exc_info=True)
def get_job(self, job_id: str) -> Optional[Job]:
return self._jobs.get(job_id)
def list_jobs(
self,
status: Optional[JobStatus] = None,
limit: int = 100,
) -> List[Job]:
jobs = list(self._jobs.values())
if status:
jobs = [j for j in jobs if j.status == status]
jobs.sort(key=lambda j: j.created_at, reverse=True)
return jobs[:limit]
async def cancel_job(self, job_id: str) -> bool:
job = self._jobs.get(job_id)
if not job:
return False
job.cancel()
return True
async def cancel_all(self) -> int:
cancelled = 0
for job in list(self._jobs.values()):
if job.status in (JobStatus.PENDING, JobStatus.RUNNING):
job.cancel()
cancelled += 1
return cancelled
async def _cleanup_loop(self) -> None:
"""定期清理已完成且过期的 Job"""
while self._running:
try:
await asyncio.sleep(self._cleanup_interval)
self._cleanup_old_jobs()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"JobExecutor cleanup error: {e}")
def _cleanup_old_jobs(self) -> int:
cutoff = datetime.now() - timedelta(seconds=self._cleanup_interval)
to_remove = [
job_id
for job_id, job in self._jobs.items()
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
and job.updated_at < cutoff
]
for job_id in to_remove:
del self._jobs[job_id]
if to_remove:
logger.info(f"JobExecutor cleaned up {len(to_remove)} old jobs")
return len(to_remove)

159
app/core/execution/job.py Normal file
View File

@@ -0,0 +1,159 @@
"""Job 定义 - 所有后台异步任务的统一抽象"""
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from app.models.domain import ProxyRaw
class JobStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class Job(ABC):
"""后台任务基类"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
status: JobStatus = JobStatus.PENDING
progress: float = 0.0 # 0-100
result: Any = None
error: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
_cancelled: bool = field(default=False, repr=False)
def cancel(self) -> None:
"""请求取消任务"""
self._cancelled = True
if self.status in (JobStatus.PENDING, JobStatus.RUNNING):
self.status = JobStatus.CANCELLED
self._touch()
def _touch(self) -> None:
self.updated_at = datetime.now()
def _set_running(self) -> None:
self.status = JobStatus.RUNNING
self._touch()
def _set_completed(self, result: Any = None) -> None:
self.status = JobStatus.COMPLETED
self.result = result
self.progress = 100.0
self._touch()
def _set_failed(self, error: str) -> None:
self.status = JobStatus.FAILED
self.error = error
self._touch()
@property
def is_cancelled(self) -> bool:
return self._cancelled
@abstractmethod
async def run(self) -> Any:
"""执行任务的核心逻辑,子类必须实现"""
raise NotImplementedError
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"status": self.status.value,
"progress": round(self.progress, 2),
"result": self.result,
"error": self.error,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
@dataclass
class CrawlJob(Job):
"""插件爬取任务"""
plugin_id: str = ""
plugin_runner: Any = field(repr=False, default=None)
proxy_service: Any = field(repr=False, default=None)
validator_pool: Any = field(repr=False, default=None)
async def run(self) -> Dict[str, Any]:
from app.services.plugin_service import PluginService
from app.core.log import logger
self._set_running()
if not self.plugin_runner:
raise RuntimeError("plugin_runner is not set")
plugin_service = self.proxy_service or PluginService()
plugin = plugin_service.get_plugin(self.plugin_id)
if not plugin:
raise ValueError(f"Plugin '{self.plugin_id}' not found")
result = await self.plugin_runner.run(plugin)
proxies: List[ProxyRaw] = result.proxies if result else []
if proxies and self.validator_pool:
await self.validator_pool.submit(proxies)
logger.info(f"CrawlJob {self.id}: submitted {len(proxies)} proxies for validation")
payload = {
"plugin_id": self.plugin_id,
"proxy_count": len(proxies),
}
if result:
payload["success_count"] = result.success_count
payload["failure_count"] = result.failure_count
self._set_completed(payload)
return payload
@dataclass
class ValidateAllJob(Job):
"""全量验证任务 - 验证数据库中所有存量代理"""
proxy_repo: Any = field(repr=False, default=None)
validator_pool: Any = field(repr=False, default=None)
batch_size: int = 100
async def run(self) -> Dict[str, Any]:
from app.repositories.proxy_repo import ProxyRepository
from app.core.db import get_db
from app.core.log import logger
self._set_running()
repo = self.proxy_repo or ProxyRepository()
async with get_db() as db:
proxies = await repo.list_all(db)
if not proxies:
self._set_completed({"total": 0, "submitted": 0})
return self.result
total = len(proxies)
submitted = 0
for i in range(0, total, self.batch_size):
if self.is_cancelled:
logger.info(f"ValidateAllJob {self.id}: cancelled")
break
batch = proxies[i : i + self.batch_size]
raws = [ProxyRaw(p.ip, p.port, p.protocol) for p in batch]
if self.validator_pool:
await self.validator_pool.submit(raws)
submitted += len(raws)
self.progress = min(100.0, (submitted / total) * 100)
self._touch()
payload = {"total": total, "submitted": submitted}
self._set_completed(payload)
logger.info(f"ValidateAllJob {self.id}: submitted {submitted}/{total} proxies")
return payload

View File

@@ -0,0 +1,97 @@
"""轻量级纯内存异步 Worker Pool"""
import asyncio
from typing import Callable, Coroutine, List, TypeVar
from app.core.log import logger
T = TypeVar("T")
Handler = Callable[[T], Coroutine[None, None, None]]
class AsyncWorkerPool:
"""纯内存异步工作池
职责单一:接收任务列表 -> 分发给 N 个 Worker 协程 -> 调用 handler 处理
不包含任何持久化逻辑,也不维护复杂的状态机。
"""
def __init__(
self,
worker_count: int,
handler: Handler,
queue_maxsize: int = 10000,
name: str = "WorkerPool",
):
self.worker_count = worker_count
self.handler = handler
self.name = name
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
self._workers: List[asyncio.Task] = []
self._running = False
self._shutdown_event = asyncio.Event()
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
async def start(self) -> None:
if self._running:
return
self._running = True
self._shutdown_event.clear()
self._workers = [
asyncio.create_task(self._worker_loop(i), name=f"{self.name}-worker-{i}")
for i in range(self.worker_count)
]
logger.info(f"{self.name} started with {self.worker_count} workers")
async def stop(self) -> None:
if not self._running:
return
self._running = False
self._shutdown_event.set()
# 发送足够数量的 sentinel确保所有 Worker 都能收到
for _ in range(len(self._workers) + self._queue.qsize()):
try:
self._queue.put_nowait(None)
except asyncio.QueueFull:
break
if self._workers:
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info(f"{self.name} stopped")
async def submit(self, items: List[T]) -> None:
"""提交一批任务到队列(阻塞直到有空位,天然背压)"""
for item in items:
await self._queue.put(item)
async def drain(self) -> None:
"""等待队列中所有任务被消费完毕"""
await self._queue.join()
async def resize(self, new_worker_count: int) -> None:
"""动态调整 Worker 数量:先全部停止,再按新数量启动"""
if new_worker_count == self.worker_count:
return
logger.info(f"{self.name} resizing from {self.worker_count} to {new_worker_count}")
# 安全做法:先 stop 再 start避免新旧 Worker 竞争 sentinel 导致死锁
await self.stop()
self.worker_count = new_worker_count
await self.start()
async def _worker_loop(self, worker_id: int) -> None:
while True:
item = await self._queue.get()
try:
if item is None or not self._running:
self._queue.task_done()
break
await self.handler(item)
except Exception as e:
logger.error(f"{self.name} worker {worker_id} handler error: {e}", exc_info=True)
finally:
self._queue.task_done()

View File

@@ -45,6 +45,11 @@ class PluginRegistry:
def get_plugin_names(self) -> List[str]:
return list(self._plugins.keys())
def clear(self) -> None:
"""清空所有已注册插件(主要用于测试)"""
self._plugins.clear()
self._instances.clear()
def auto_discover(self, package_name: str):
"""自动扫描指定包下的所有模块并注册其中的插件类。
注意为了类型安全和可控性推荐显式注册。auto_discover 仅作为兼容。"""

View File

@@ -1,168 +0,0 @@
"""验证任务队列 - 解耦爬取与验证,支持背压控制和持久化"""
import asyncio
from typing import Optional
from app.models.domain import ProxyRaw
from app.repositories.task_repo import ValidationTaskRepository
from app.core.db import get_db
from app.core.log import logger
class ValidationQueue:
"""代理验证队列(支持持久化到 SQLite
工作流程:
1. 爬虫将原始代理 submit() 到队列(写入数据库 + 内存信号)
2. Worker 池从数据库消费并验证
3. 验证通过的代理写入数据库
4. 服务重启时自动恢复未完成的 pending 任务
"""
def __init__(
self,
validator,
proxy_repo,
worker_count: int = 50,
score_valid: int = 10,
score_invalid: int = -5,
score_min: int = 0,
score_max: int = 100,
):
self.validator = validator
self.proxy_repo = proxy_repo
self.task_repo = ValidationTaskRepository()
self.worker_count = worker_count
self.score_valid = score_valid
self.score_invalid = score_invalid
self.score_min = score_min
self.score_max = score_max
self._signal: asyncio.Queue[None] = asyncio.Queue()
self._workers: list[asyncio.Task] = []
self._running = False
self._pending_count = 0
self._condition = asyncio.Condition()
# 统计
self.valid_count = 0
self.invalid_count = 0
async def start(self):
if self._running:
return
self._running = True
# 恢复之前中断的 processing 任务
async with get_db() as db:
recovered = await self.task_repo.reset_processing(db)
pending = await self.task_repo.get_pending_count(db)
if pending > 1000:
logger.warning(f"ValidationQueue has {pending} pending tasks, cleaning up all pending tasks...")
await db.execute("DELETE FROM validation_tasks WHERE status = 'pending'")
await db.commit()
pending = await self.task_repo.get_pending_count(db)
logger.info(f"ValidationQueue cleaned up pending tasks, remaining: {pending}")
if recovered:
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
if pending:
logger.info(f"ValidationQueue has {pending} pending tasks to process")
async with self._condition:
self._pending_count = pending
for i in range(self.worker_count):
self._workers.append(asyncio.create_task(self._worker_loop(i)))
# 唤醒 Worker 处理恢复的 pending 任务(每个 Worker 一次唤醒即可,内部会循环处理)
if pending:
for _ in range(self.worker_count):
self._signal.put_nowait(None)
logger.info(f"ValidationQueue started with {self.worker_count} workers")
async def stop(self):
if not self._running:
return
self._running = False
for _ in self._workers:
self._signal.put_nowait(None) # sentinel
if self._workers:
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("ValidationQueue stopped")
async def submit(self, proxies: list[ProxyRaw]):
"""提交代理到验证队列(持久化 + 唤醒 Worker"""
async with get_db() as db:
inserted = await self.task_repo.insert_batch(db, proxies)
if inserted:
async with self._condition:
self._pending_count += inserted
self._condition.notify_all()
for _ in range(min(inserted, self.worker_count)):
self._signal.put_nowait(None)
async def submit_one(self, proxy: ProxyRaw):
await self.submit([proxy])
async def drain(self):
"""等待队列中当前所有 pending 任务处理完毕"""
async with self._condition:
if self._pending_count > 0:
await self._condition.wait_for(lambda: self._pending_count == 0)
async def _worker_loop(self, worker_id: int):
while True:
await self._signal.get()
self._signal.task_done()
if not self._running:
break
# 持续处理任务直到没有 pending 为止,避免信号数不足导致任务饿死
while self._running:
processed = await self._process_one_task(worker_id)
if not processed:
break
async def _process_one_task(self, worker_id: int) -> bool:
"""从数据库取一个任务并验证。返回 True 表示确实处理了一个任务。"""
async with get_db() as db:
task = await self.task_repo.acquire_pending(db)
if not task:
return False
proxy = ProxyRaw(task["ip"], task["port"], task["protocol"])
try:
is_valid, latency = await self.validator.validate(
proxy.ip, proxy.port, proxy.protocol
)
except Exception as e:
logger.error(f"Worker {worker_id} validation error: {e}", exc_info=True)
is_valid, latency = False, 0.0
if is_valid:
await self.proxy_repo.insert_or_update(
db, proxy.ip, proxy.port, proxy.protocol, score=self.score_valid
)
if latency:
await self.proxy_repo.update_response_time(
db, proxy.ip, proxy.port, latency
)
await self.task_repo.complete_task(db, task["id"], True, latency)
self.valid_count += 1
logger.debug(f"ValidationQueue: valid {proxy.ip}:{proxy.port}")
else:
# 对已有代理扣分,分数<=0时自动删除
await self.proxy_repo.update_score(
db, proxy.ip, proxy.port, self.score_invalid,
self.score_min, self.score_max
)
await self.task_repo.complete_task(db, task["id"], False, 0.0)
self.invalid_count += 1
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
async with self._condition:
self._pending_count = max(0, self._pending_count - 1)
if self._pending_count == 0:
self._condition.notify_all()
return True
def reset_stats(self):
self.valid_count = 0
self.invalid_count = 0

View File

@@ -1,7 +1,7 @@
"""领域模型 - 纯数据结构,不依赖任何框架"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
from typing import List, Optional
@dataclass
@@ -42,3 +42,13 @@ class PluginInfo:
last_run: Optional[datetime] = None
success_count: int = 0
failure_count: int = 0
@dataclass
class CrawlResult:
"""插件爬取结果"""
plugin_name: str
proxies: List[ProxyRaw] = field(default_factory=list)
success_count: int = 0
failure_count: int = 0
error: Optional[str] = None

View File

@@ -1,9 +1,15 @@
"""通用 HTTP 爬虫基类 - 为基于 HTTP 请求的插件提供封装"""
import re
import random
import asyncio
import httpx
from typing import List, Optional
from bs4 import BeautifulSoup
from app.core.plugin_system import BaseCrawlerPlugin
from app.models.domain import ProxyRaw
VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
class BaseHTTPPlugin(BaseCrawlerPlugin):
@@ -20,6 +26,7 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
self.urls: List[str] = []
self.current_url: str = ""
self._client: Optional[httpx.AsyncClient] = None
self.max_concurrency: int = 3
def get_headers(self) -> dict:
return {
@@ -39,13 +46,21 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
)
return self._client
async def fetch(self, url: str, timeout: float = 15.0, retries: int = 2) -> str:
async def fetch(
self,
url: str,
timeout: float = 15.0,
retries: int = 2,
raise_for_status: bool = False,
) -> str:
"""异步抓取指定 URL 的 HTML 内容"""
from app.core.log import logger
client = self._get_client()
for attempt in range(retries):
try:
response = await client.get(url, headers=self.get_headers(), timeout=timeout)
if raise_for_status:
response.raise_for_status()
if response.status_code == 200:
content = response.content
encoding = response.encoding
@@ -64,8 +79,8 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
return ""
async def fetch_all(self, urls: List[str], timeout: float = 15.0) -> List[str]:
"""并发抓取多个 URL限制单个插件内部并发为 3"""
semaphore = asyncio.Semaphore(3)
"""并发抓取多个 URL限制单个插件内部并发"""
semaphore = asyncio.Semaphore(self.max_concurrency)
async def _fetch_limited(url: str):
async with semaphore:
@@ -74,6 +89,70 @@ class BaseHTTPPlugin(BaseCrawlerPlugin):
tasks = [_fetch_limited(url) for url in urls]
return await asyncio.gather(*tasks)
def parse_text_proxies(self, text: str, protocol: str = "http") -> List[ProxyRaw]:
"""解析 ip:port 格式的文本代理列表
统一处理 \r\n\n 两种换行以及可能存在的空行。
"""
results = []
text = text.replace("\r\n", "\n").replace("\r", "\n")
for line in text.split("\n"):
line = line.strip()
if not line or ":" not in line:
continue
ip, _, port = line.rpartition(":")
ip = ip.strip()
port = port.strip()
if ip and port.isdigit() and 1 <= int(port) <= 65535:
try:
results.append(ProxyRaw(ip, int(port), protocol))
except ValueError:
continue
return results
def parse_html_table(
self,
html: str,
column_map: dict,
protocol: str = "http",
) -> List[ProxyRaw]:
"""通用 HTML 表格解析器
Args:
html: HTML 文本
column_map: 列名到索引的映射,如 {"ip": 0, "port": 1, "protocol": 4}
protocol: 默认协议,如果表格中没有协议列则使用此值
"""
results = []
soup = BeautifulSoup(html, "lxml")
table = soup.find("table")
if not table:
return results
ip_idx = column_map.get("ip", 0)
port_idx = column_map.get("port", 1)
protocol_idx = column_map.get("protocol", -1)
for row in table.find_all("tr"):
tds = row.find_all("td")
if len(tds) <= max(ip_idx, port_idx):
continue
ip = tds[ip_idx].get_text(strip=True)
port = tds[port_idx].get_text(strip=True)
if protocol_idx >= 0 and len(tds) > protocol_idx:
proto = tds[protocol_idx].get_text(strip=True).lower()
if proto not in VALID_PROTOCOLS:
proto = protocol
else:
proto = protocol
if re.match(r"^\d+\.\d+\.\d+\.\d+$", ip) and port.isdigit() and 1 <= int(port) <= 65535:
try:
results.append(ProxyRaw(ip, int(port), proto))
except ValueError:
continue
return results
async def close(self):
"""关闭复用的 HTTP 客户端"""
if self._client and not self._client.is_closed:

View File

@@ -1,11 +1,9 @@
"""数据访问层包"""
from .proxy_repo import ProxyRepository
from .settings_repo import SettingsRepository, PluginSettingsRepository
from .task_repo import ValidationTaskRepository
__all__ = [
"ProxyRepository",
"SettingsRepository",
"PluginSettingsRepository",
"ValidationTaskRepository",
]

View File

@@ -1,130 +0,0 @@
"""验证任务队列持久化层"""
import aiosqlite
from typing import List, Optional
from app.models.domain import ProxyRaw
from app.core.log import logger
class ValidationTaskRepository:
"""验证任务 Repository —— 支持队列持久化"""
@staticmethod
async def insert_batch(db: aiosqlite.Connection, proxies: List[ProxyRaw]) -> int:
if not proxies:
return 0
try:
rows = [(p.ip, p.port, p.protocol) for p in proxies]
await db.executemany(
"""
INSERT INTO validation_tasks (ip, port, protocol, status, created_at)
VALUES (?, ?, ?, 'pending', CURRENT_TIMESTAMP)
""",
rows,
)
await db.commit()
return len(rows)
except Exception as e:
logger.error(f"insert_batch validation tasks failed: {e}")
return 0
@staticmethod
async def acquire_pending(db: aiosqlite.Connection) -> Optional[dict]:
"""原子性地获取一个 pending 任务并将其标记为 processing"""
try:
async with db.execute(
"""
UPDATE validation_tasks
SET status = 'processing', updated_at = CURRENT_TIMESTAMP
WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
RETURNING id, ip, port, protocol
"""
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
except Exception as e:
logger.error(f"acquire_pending failed: {e}")
return None
@staticmethod
async def complete_task(
db: aiosqlite.Connection,
task_id: int,
is_valid: bool,
response_time_ms: Optional[float] = None,
) -> bool:
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'completed',
result = ?,
response_time_ms = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
("valid" if is_valid else "invalid", response_time_ms, task_id),
)
await db.commit()
return True
except Exception as e:
logger.error(f"complete_task failed: {e}")
return False
@staticmethod
async def fail_task(db: aiosqlite.Connection, task_id: int) -> bool:
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'failed',
result = 'invalid',
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(task_id,),
)
await db.commit()
return True
except Exception as e:
logger.error(f"fail_task failed: {e}")
return False
@staticmethod
async def get_pending_count(db: aiosqlite.Connection) -> int:
async with db.execute(
"SELECT COUNT(*) FROM validation_tasks WHERE status = 'pending'"
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
@staticmethod
async def reset_processing(db: aiosqlite.Connection) -> int:
"""将异常中断的 processing 任务重置为 pending用于启动恢复"""
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'pending', updated_at = CURRENT_TIMESTAMP
WHERE status = 'processing'
"""
)
await db.commit()
return db.total_changes
except Exception as e:
logger.error(f"reset_processing failed: {e}")
return 0
@staticmethod
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
try:
await db.execute(
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
(days,),
)
await db.commit()
return db.total_changes
except Exception as e:
logger.error(f"cleanup_old tasks failed: {e}")
return 0

View File

@@ -0,0 +1,90 @@
"""插件统一执行器 - 封装超时、重试、健康检查、错误捕获"""
import asyncio
from datetime import datetime
from typing import Optional
from app.core.plugin_system.base import BaseCrawlerPlugin
from app.core.config import settings as app_settings
from app.core.log import logger
from app.models.domain import CrawlResult, ProxyRaw
class PluginRunner:
"""统一插件执行器
- 超时控制(从 settings 读取 crawl_timeout
- 异常捕获和统计更新
- 可选的健康检查前置
- 结果去重
"""
def __init__(self, timeout: Optional[float] = None):
self.timeout = timeout or getattr(app_settings, "crawler_timeout", 30)
async def run(self, plugin: BaseCrawlerPlugin) -> CrawlResult:
"""执行单个插件爬取"""
result = CrawlResult(plugin_name=plugin.name)
# 健康检查(可选)
try:
healthy = await asyncio.wait_for(
plugin.health_check(), timeout=5.0
)
if not healthy:
result.error = "health check failed"
result.failure_count = 1
await self._save_stats(plugin, result)
return result
except Exception as e:
logger.warning(f"Plugin {plugin.name} health check error: {e}")
# 执行爬取
try:
proxies = await asyncio.wait_for(
plugin.crawl(),
timeout=self.timeout,
)
result.proxies = self._dedup(proxies)
result.success_count = 1 if result.proxies else 0
logger.info(
f"Plugin {plugin.name} crawled {len(result.proxies)} unique proxies"
)
except asyncio.TimeoutError:
result.error = f"crawl timeout after {self.timeout}s"
result.failure_count = 1
logger.error(f"Plugin {plugin.name} crawl timeout")
except Exception as e:
result.error = str(e)
result.failure_count = 1
logger.error(f"Plugin {plugin.name} crawl failed: {e}", exc_info=True)
await self._save_stats(plugin, result)
return result
@staticmethod
def _dedup(proxies: list[ProxyRaw]) -> list[ProxyRaw]:
seen = set()
unique = []
for p in proxies:
key = (p.ip, p.port)
if key not in seen:
seen.add(key)
unique.append(p)
return unique
async def _save_stats(self, plugin: BaseCrawlerPlugin, result: CrawlResult) -> None:
"""将爬取统计持久化到数据库"""
from app.core.db import get_db
from app.repositories.settings_repo import PluginSettingsRepository
repo = PluginSettingsRepository()
payload = {
"success_count": result.success_count,
"failure_count": result.failure_count,
"last_run": datetime.now().isoformat(),
}
try:
async with get_db() as db:
await repo.set_stats(db, plugin.name, payload)
except Exception as e:
logger.error(f"Failed to save stats for {plugin.name}: {e}")

View File

@@ -1,21 +1,21 @@
"""插件业务服务"""
import asyncio
from datetime import datetime
from typing import List, Optional
from app.core.db import get_db
from app.core.plugin_system.registry import registry
from app.core.plugin_system.base import BaseCrawlerPlugin
from app.core.exceptions import PluginNotFoundException
from app.repositories.settings_repo import PluginSettingsRepository
from app.models.domain import PluginInfo, ProxyRaw
from app.models.domain import PluginInfo, ProxyRaw, CrawlResult
from app.core.log import logger
class PluginService:
"""插件业务服务:管理插件生命周期、执行爬取、配置管理"""
"""插件业务服务:管理插件生命周期、配置管理"""
def __init__(self):
self.plugin_settings_repo = PluginSettingsRepository()
self._stats: dict[str, dict] = {}
async def list_plugins(self) -> List[PluginInfo]:
"""获取所有插件信息(合并持久化状态和配置)"""
@@ -24,42 +24,40 @@ class PluginService:
result = []
for plugin in registry.list_plugins():
# 合并持久化状态(不修改全局实例,避免并发竞争)
state = db_states.get(plugin.name, {})
enabled = state.get("enabled", plugin.enabled)
if "config" in state and isinstance(state["config"], dict):
plugin.update_config(state["config"])
# 合并数据库统计与内存统计(内存优先)
db_stat = state.get("stats", {})
stat = {
"success_count": db_stat.get("success_count", 0),
"failure_count": db_stat.get("failure_count", 0),
"last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None,
}
mem_stat = self._stats.get(plugin.name, {})
if mem_stat:
stat["success_count"] = mem_stat.get("success_count", stat["success_count"])
stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"])
if mem_stat.get("last_run"):
stat["last_run"] = mem_stat["last_run"]
result.append(PluginInfo(
id=plugin.name,
name=plugin.name,
display_name=plugin.display_name or plugin.name,
description=plugin.description or f"{plugin.name} 爬取代理",
enabled=enabled,
last_run=stat.get("last_run"),
success_count=stat.get("success_count", 0),
failure_count=stat.get("failure_count", 0),
last_run=self._parse_datetime(db_stat.get("last_run")),
success_count=db_stat.get("success_count", 0),
failure_count=db_stat.get("failure_count", 0),
))
return result
@staticmethod
def _parse_datetime(value) -> Optional:
from datetime import datetime
if not value:
return None
if isinstance(value, datetime):
return value
try:
return datetime.fromisoformat(value)
except Exception:
return None
async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool:
plugin = registry.get(plugin_id)
if not plugin:
return False
raise PluginNotFoundException(plugin_id)
async with get_db() as db:
success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled)
if success:
@@ -68,10 +66,9 @@ class PluginService:
return success
async def get_plugin_config(self, plugin_id: str) -> Optional[dict]:
"""获取插件当前配置(合并默认值和持久化值)"""
plugin = registry.get(plugin_id)
if not plugin:
return None
raise PluginNotFoundException(plugin_id)
async with get_db() as db:
saved = await self.plugin_settings_repo.get_config(db, plugin_id)
config = dict(plugin.default_config)
@@ -80,11 +77,9 @@ class PluginService:
return config
async def update_plugin_config(self, plugin_id: str, config: dict) -> bool:
"""更新插件配置(只保存已存在于 default_config 中的键)"""
plugin = registry.get(plugin_id)
if not plugin:
return False
# 过滤非法键
raise PluginNotFoundException(plugin_id)
safe_config = {k: v for k, v in config.items() if k in plugin.default_config}
if not safe_config:
return False
@@ -95,44 +90,45 @@ class PluginService:
def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]:
return registry.get(plugin_id)
async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]:
"""执行单个插件爬取"""
def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin:
plugin = self.get_plugin(plugin_id)
if not plugin:
raise ValueError(f"Plugin {plugin_id} not found")
raise PluginNotFoundException(plugin_id)
return plugin
async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult:
"""执行单个插件爬取(通过 PluginRunner"""
plugin = self.get_plugin_or_raise(plugin_id)
if not plugin.enabled:
logger.warning(f"Plugin {plugin_id} is disabled, skip crawl")
return []
return CrawlResult(plugin_name=plugin_id)
return await plugin_runner.run(plugin)
try:
results = await plugin.crawl()
self._record_stat(plugin_id, success=len(results))
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
return results
except Exception as e:
self._record_stat(plugin_id, failure=1)
logger.error(f"Plugin {plugin_id} crawl failed: {e}")
return []
finally:
await self._save_stats(plugin_id)
async def run_all_plugins(self) -> List[ProxyRaw]:
async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]:
"""执行所有启用插件的爬取,限制并发数以避免触发目标站反爬"""
all_results: List[ProxyRaw] = []
semaphore = asyncio.Semaphore(5)
async def _run_with_limit(plugin_name: str):
plugin = self.get_plugin_or_raise(plugin_name)
if not plugin.enabled:
return CrawlResult(plugin_name=plugin_name)
async with semaphore:
return await self.run_plugin(plugin_name)
return await plugin_runner.run(plugin)
tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
tasks = [
_run_with_limit(plugin.name)
for plugin in registry.list_plugins() if plugin.enabled
]
results_list = await asyncio.gather(*tasks, return_exceptions=True)
for results in results_list:
if isinstance(results, Exception):
logger.error(f"Run all plugins error: {results}")
for result in results_list:
if isinstance(result, Exception):
logger.error(f"Run all plugins error: {result}")
continue
all_results.extend(results)
# 去重(与数据库 UNIQUE(ip, port) 约束保持一致)
if isinstance(result, CrawlResult):
all_results.extend(result.proxies)
# 去重
seen = set()
unique = []
for p in all_results:
@@ -141,26 +137,3 @@ class PluginService:
seen.add(key)
unique.append(p)
return unique
def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0):
if plugin_id not in self._stats:
self._stats[plugin_id] = {
"success_count": 0,
"failure_count": 0,
"last_run": None,
}
self._stats[plugin_id]["success_count"] += success
self._stats[plugin_id]["failure_count"] += failure
if success or failure:
self._stats[plugin_id]["last_run"] = datetime.now()
async def _save_stats(self, plugin_id: str):
"""将内存中的统计持久化到数据库"""
stats = self._stats.get(plugin_id, {})
payload = {
"success_count": stats.get("success_count", 0),
"failure_count": stats.get("failure_count", 0),
"last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None,
}
async with get_db() as db:
await self.plugin_settings_repo.set_stats(db, plugin_id, payload)

View File

@@ -4,7 +4,8 @@ import json
import io
from datetime import datetime
from typing import List, Optional, Tuple, AsyncIterator
from app.core.db import get_db
from app.core.db import get_db, transaction
from app.repositories.proxy_repo import ProxyRepository
from app.models.domain import Proxy
from app.core.log import logger

View File

@@ -1,43 +1,42 @@
"""调度器服务 - 定时验证存量代理"""
"""调度器服务 - 定时触发全量验证"""
import asyncio
from datetime import datetime
from app.core.db import get_db
from app.repositories.proxy_repo import ProxyRepository
from app.repositories.task_repo import ValidationTaskRepository
from app.core.tasks.queue import ValidationQueue
from app.core.config import settings as app_settings
from typing import Optional
from app.core.execution.executor import JobExecutor
from app.core.execution.job import ValidateAllJob
from app.core.log import logger
from app.models.domain import ProxyRaw
from app.services.task_service import task_service
class SchedulerService:
"""代理验证调度器"""
"""代理验证调度器
职责单一:定时循环,触发 ValidateAllJob。
不再直接持有验证队列或 ValidatorService。
"""
def __init__(
self,
validation_queue: ValidationQueue,
proxy_repo: ProxyRepository = ProxyRepository(),
executor: JobExecutor,
interval_minutes: int = 30,
):
self.validation_queue = validation_queue
self.proxy_repo = proxy_repo
self.interval_minutes = 30
self.executor = executor
self.interval_minutes = interval_minutes
self.running = False
self._stop_event = asyncio.Event()
self._task: asyncio.Task | None = None
self._validate_task: asyncio.Task | None = None
self._task: Optional[asyncio.Task] = None
async def start(self):
async def start(self) -> None:
if self.running:
logger.warning("Scheduler already running")
return
self._stop_event.clear()
self.running = True
await self.validation_queue.start()
self._task = asyncio.create_task(self._run_loop())
logger.info("Scheduler started")
async def stop(self):
async def stop(self) -> None:
if not self.running:
return
self.running = False
self._stop_event.set()
if self._task:
@@ -47,72 +46,26 @@ class SchedulerService:
except asyncio.CancelledError:
pass
self._task = None
await self.validation_queue.stop()
logger.info("Scheduler stopped")
def cancel_validate_task(self):
"""取消正在执行的全量验证后台任务"""
if self._validate_task and not self._validate_task.done():
self._validate_task.cancel()
def validate_all_now(self) -> str:
"""立即执行一次全量验证,返回 Job ID"""
job_id = self.executor.submit_job(ValidateAllJob())
logger.info(f"ValidateAllJob submitted: {job_id}")
return job_id
async def validate_all_now(self):
"""立即执行一次全量验证(后台运行,不阻塞)"""
if self._validate_task and not self._validate_task.done():
return
self._validate_task = asyncio.create_task(self._do_validate_all(from_loop=False))
async def _run_loop(self):
async def _run_loop(self) -> None:
"""定时循环"""
while self.running:
try:
# 清理过期任务,防止内存无限增长
task_service.cleanup_old_tasks()
await self._do_validate_all(from_loop=True)
self.executor.submit_job(ValidateAllJob())
except Exception as e:
logger.error(f"Scheduler loop error: {e}", exc_info=True)
# 等待下一次
try:
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self.interval_minutes * 60,
)
except asyncio.TimeoutError:
pass
async def _do_validate_all(self, from_loop: bool = True):
"""验证数据库中所有存量代理"""
queue_started_here = False
try:
logger.info("Starting scheduled validation for all proxies")
# 如果队列未运行,临时启动它(适用于 validate_all_now 在调度器停止时调用)
if not self.validation_queue._running:
await self.validation_queue.start()
queue_started_here = True
async with get_db() as db:
# 清理 7 天前的验证任务记录,防止表无限增长
cleaned = await ValidationTaskRepository.cleanup_old(db, days=7)
if cleaned:
logger.info(f"Cleaned up {cleaned} old validation tasks")
proxies = await self.proxy_repo.list_all(db)
if not proxies:
logger.info("No proxies to validate")
return
logger.info(f"Validating {len(proxies)} proxies from database")
# 批量提交到验证队列,不再阻塞等待 drain
batch_size = 100
total_batches = (len(proxies) - 1) // batch_size + 1
for i in range(0, len(proxies), batch_size):
if from_loop and not self.running:
break
batch = proxies[i : i + batch_size]
await self.validation_queue.submit([
ProxyRaw(p.ip, p.port, p.protocol) for p in batch
])
logger.info(f"Submitted batch {i // batch_size + 1}/{total_batches}")
logger.info("Scheduled validation batches submitted")
except Exception as e:
logger.error(f"Scheduled validation error: {e}", exc_info=True)
finally:
if queue_started_here:
await self.validation_queue.stop()

View File

@@ -1,62 +0,0 @@
"""轻量级异步任务服务 - 内存存储,用于跟踪后台爬取任务"""
import uuid
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from app.core.log import logger
class TaskService:
"""异步任务状态管理器"""
def __init__(self, max_age_seconds: int = 3600):
self._tasks: Dict[str, dict] = {}
self._max_age_seconds = max_age_seconds
def create_task(self, task_type: str) -> str:
task_id = str(uuid.uuid4())
now = datetime.now()
self._tasks[task_id] = {
"id": task_id,
"type": task_type,
"status": "pending",
"message": "",
"data": {},
"created_at": now,
"updated_at": now,
}
return task_id
def update_task(
self,
task_id: str,
status: Optional[str] = None,
message: Optional[str] = None,
data: Optional[dict] = None,
) -> bool:
task = self._tasks.get(task_id)
if not task:
return False
if status is not None:
task["status"] = status
if message is not None:
task["message"] = message
if data is not None:
task["data"].update(data)
task["updated_at"] = datetime.now()
return True
def get_task(self, task_id: str) -> Optional[dict]:
return self._tasks.get(task_id)
def cleanup_old_tasks(self) -> int:
cutoff = datetime.now() - timedelta(seconds=self._max_age_seconds)
to_remove = [tid for tid, task in self._tasks.items() if task["created_at"] < cutoff]
for tid in to_remove:
del self._tasks[tid]
if to_remove:
logger.info(f"TaskService cleaned up {len(to_remove)} old tasks")
return len(to_remove)
# 全局任务服务实例
task_service = TaskService()

View File

@@ -4,12 +4,17 @@ import random
import time
import aiohttp
import aiohttp_socks
from typing import Tuple
from typing import Tuple, Optional
from app.core.config import settings as app_settings
from app.core.log import logger
class ValidatorService:
"""代理验证器"""
"""代理验证器
支持动态读取配置,实现设置热更新。
"""
# 测试 URL
TEST_URLS = {
@@ -19,40 +24,66 @@ class ValidatorService:
def __init__(
self,
timeout: float = 5.0,
connect_timeout: float = 3.0,
max_concurrency: int = 50,
timeout: Optional[float] = None,
connect_timeout: Optional[float] = None,
max_concurrency: Optional[int] = None,
):
self.timeout = timeout
self.connect_timeout = connect_timeout
self.max_concurrency = max_concurrency
self.semaphore = asyncio.Semaphore(max_concurrency)
# 初始化时使用传入值或默认值,但运行期会动态读取 settings
self._init_timeout = timeout or app_settings.validator_timeout
self._init_connect_timeout = connect_timeout or app_settings.validator_connect_timeout
self._init_max_concurrency = max_concurrency or app_settings.validator_max_concurrency
# 共享 HTTP/HTTPS ClientSession
self._http_connector = aiohttp.TCPConnector(
ssl=False,
limit=max_concurrency,
limit_per_host=max_concurrency,
force_close=False,
)
self._timeout = aiohttp.ClientTimeout(
total=timeout, connect=connect_timeout
)
self._http_session = aiohttp.ClientSession(
connector=self._http_connector,
timeout=self._timeout,
)
self._http_connector: Optional[aiohttp.TCPConnector] = None
self._http_session: Optional[aiohttp.ClientSession] = None
self._semaphore: Optional[asyncio.Semaphore] = None
self._lock = asyncio.Lock()
@property
def timeout(self) -> float:
return float(self._init_timeout)
@property
def connect_timeout(self) -> float:
return float(self._init_connect_timeout)
@property
def max_concurrency(self) -> int:
return int(self._init_max_concurrency)
def _ensure_session(self) -> aiohttp.ClientSession:
"""懒加载共享 HTTP session"""
if self._http_session is None or self._http_session.closed:
connector = aiohttp.TCPConnector(
ssl=False,
limit=self.max_concurrency,
limit_per_host=self.max_concurrency,
force_close=False,
)
timeout = aiohttp.ClientTimeout(
total=self.timeout, connect=self.connect_timeout
)
self._http_connector = connector
self._http_session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
return self._http_session
def _ensure_semaphore(self) -> asyncio.Semaphore:
if self._semaphore is None:
self._semaphore = asyncio.Semaphore(self.max_concurrency)
return self._semaphore
def _get_test_url(self, protocol: str) -> str:
"""获取测试 URL"""
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
return random.choice(urls)
async def validate(self, ip: str, port: int, protocol: str = "http") -> Tuple[bool, float]:
"""验证单个代理,返回 (是否有效, 延迟毫秒)"""
protocol = protocol.lower()
semaphore = self._ensure_semaphore()
async with self.semaphore:
async with semaphore:
start = time.time()
try:
if protocol in ("socks4", "socks5"):
@@ -67,11 +98,11 @@ class ValidatorService:
return False, 0.0
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
"""验证 HTTP/HTTPS 代理"""
proxy_url = f"http://{ip}:{port}"
test_url = self._get_test_url(protocol)
session = self._ensure_session()
async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
if response.status in (200, 301, 302):
latency = round((time.time() - start) * 1000, 2)
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
@@ -79,7 +110,6 @@ class ValidatorService:
return False, 0.0
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
"""验证 SOCKS4/SOCKS5 代理"""
proxy_type = (
aiohttp_socks.ProxyType.SOCKS4
if protocol == "socks4"
@@ -103,7 +133,9 @@ class ValidatorService:
return True, latency
return False, 0.0
async def close(self):
async def close(self) -> None:
"""关闭共享的 HTTP ClientSession"""
if self._http_session and not self._http_session.closed:
await self._http_session.close()
self._http_session = None
self._http_connector = None

View File

@@ -7,6 +7,17 @@ from httpx import AsyncClient, ASGITransport
from app.api import create_app
from app.core.db import init_db, get_db
from app.core.plugin_system.registry import registry
from app.plugins import (
Fate0Plugin,
ProxyListDownloadPlugin,
Ip3366Plugin,
Ip89Plugin,
KuaiDaiLiPlugin,
SpeedXPlugin,
YunDaiLiPlugin,
ProxyScrapePlugin,
)
from app.repositories.proxy_repo import ProxyRepository
from app.models.domain import ProxyRaw
@@ -14,23 +25,32 @@ from app.models.domain import ProxyRaw
@pytest_asyncio.fixture(scope="function")
async def app():
"""创建应用实例"""
# 初始化测试数据库并清空历史数据,避免任务残留或设置状态导致 drain() 卡住
# 初始化测试数据库并清空历史数据
await init_db()
async with get_db() as db:
await db.execute("DELETE FROM validation_tasks")
await db.execute("DELETE FROM proxies")
await db.execute("DELETE FROM settings")
await db.commit()
# 清理全局内存状态,防止跨测试污染
from app.services.task_service import task_service
task_service._tasks.clear()
# 清理并重新注册插件,防止跨测试污染
registry.clear()
for plugin_cls in [
Fate0Plugin,
ProxyListDownloadPlugin,
Ip3366Plugin,
Ip89Plugin,
KuaiDaiLiPlugin,
SpeedXPlugin,
YunDaiLiPlugin,
ProxyScrapePlugin,
]:
registry.register(plugin_cls)
test_app = create_app()
async with test_app.router.lifespan_context(test_app):
yield test_app
# 给 aiosqlite / aiohttp 后台线程留出收尾时间,降低 Event loop closed 警告概率
# 给 aiosqlite / aiohttp 后台线程留出收尾时间
await asyncio.sleep(0.1)
@@ -71,18 +91,19 @@ async def mock_external_requests(monkeypatch):
1. 插件爬取返回固定测试代理,避免真实 HTTP 请求
2. 代理验证瞬间成功,避免连接超时等待
"""
from app.services.plugin_service import PluginService
from app.services.plugin_runner import PluginRunner
from app.services.validator_service import ValidatorService
async def _mock_run_plugin(self, plugin_id: str):
return [ProxyRaw("192.168.100.10", 8080, "http")]
async def _mock_run_all_plugins(self):
return [ProxyRaw("192.168.100.10", 8080, "http")]
async def _mock_run(self, plugin):
from app.models.domain import CrawlResult
return CrawlResult(
plugin_name=plugin.name,
proxies=[ProxyRaw("192.168.100.10", 8080, "http")],
success_count=1,
)
async def _mock_validate(self, ip: str, port: int, protocol: str = "http"):
return True, 1.23
monkeypatch.setattr(PluginService, "run_plugin", _mock_run_plugin)
monkeypatch.setattr(PluginService, "run_all_plugins", _mock_run_all_plugins)
monkeypatch.setattr(PluginRunner, "run", _mock_run)
monkeypatch.setattr(ValidatorService, "validate", _mock_validate)

View File

@@ -69,7 +69,8 @@ class TestPluginsAPI:
plugin_id = plugins[0]["id"]
response = await client.put(f"/api/plugins/{plugin_id}/toggle", json={})
assert response.status_code == 400
# Pydantic 验证缺失必填字段返回 422
assert response.status_code == 422
@pytest.mark.asyncio
async def test_toggle_nonexistent_plugin(self, client):