refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool - 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob) - 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计) - 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob - 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源 - 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理 - 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效) - 前端 Store 强制写后重新拉取,消除乐观更新数据不同步 - 删除 queue.py / task_repo.py / task_service.py - 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
@@ -1,21 +1,21 @@
|
||||
"""插件业务服务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.db import get_db
|
||||
from app.core.plugin_system.registry import registry
|
||||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||||
from app.core.exceptions import PluginNotFoundException
|
||||
from app.repositories.settings_repo import PluginSettingsRepository
|
||||
from app.models.domain import PluginInfo, ProxyRaw
|
||||
from app.models.domain import PluginInfo, ProxyRaw, CrawlResult
|
||||
from app.core.log import logger
|
||||
|
||||
|
||||
class PluginService:
|
||||
"""插件业务服务:管理插件生命周期、执行爬取、配置管理"""
|
||||
"""插件业务服务:管理插件生命周期、配置管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_settings_repo = PluginSettingsRepository()
|
||||
self._stats: dict[str, dict] = {}
|
||||
|
||||
async def list_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有插件信息(合并持久化状态和配置)"""
|
||||
@@ -24,42 +24,40 @@ class PluginService:
|
||||
|
||||
result = []
|
||||
for plugin in registry.list_plugins():
|
||||
# 合并持久化状态(不修改全局实例,避免并发竞争)
|
||||
state = db_states.get(plugin.name, {})
|
||||
enabled = state.get("enabled", plugin.enabled)
|
||||
if "config" in state and isinstance(state["config"], dict):
|
||||
plugin.update_config(state["config"])
|
||||
|
||||
# 合并数据库统计与内存统计(内存优先)
|
||||
db_stat = state.get("stats", {})
|
||||
stat = {
|
||||
"success_count": db_stat.get("success_count", 0),
|
||||
"failure_count": db_stat.get("failure_count", 0),
|
||||
"last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None,
|
||||
}
|
||||
mem_stat = self._stats.get(plugin.name, {})
|
||||
if mem_stat:
|
||||
stat["success_count"] = mem_stat.get("success_count", stat["success_count"])
|
||||
stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"])
|
||||
if mem_stat.get("last_run"):
|
||||
stat["last_run"] = mem_stat["last_run"]
|
||||
|
||||
result.append(PluginInfo(
|
||||
id=plugin.name,
|
||||
name=plugin.name,
|
||||
display_name=plugin.display_name or plugin.name,
|
||||
description=plugin.description or f"从 {plugin.name} 爬取代理",
|
||||
enabled=enabled,
|
||||
last_run=stat.get("last_run"),
|
||||
success_count=stat.get("success_count", 0),
|
||||
failure_count=stat.get("failure_count", 0),
|
||||
last_run=self._parse_datetime(db_stat.get("last_run")),
|
||||
success_count=db_stat.get("success_count", 0),
|
||||
failure_count=db_stat.get("failure_count", 0),
|
||||
))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value) -> Optional:
|
||||
from datetime import datetime
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool:
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled)
|
||||
if success:
|
||||
@@ -68,10 +66,9 @@ class PluginService:
|
||||
return success
|
||||
|
||||
async def get_plugin_config(self, plugin_id: str) -> Optional[dict]:
|
||||
"""获取插件当前配置(合并默认值和持久化值)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return None
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
async with get_db() as db:
|
||||
saved = await self.plugin_settings_repo.get_config(db, plugin_id)
|
||||
config = dict(plugin.default_config)
|
||||
@@ -80,11 +77,9 @@ class PluginService:
|
||||
return config
|
||||
|
||||
async def update_plugin_config(self, plugin_id: str, config: dict) -> bool:
|
||||
"""更新插件配置(只保存已存在于 default_config 中的键)"""
|
||||
plugin = registry.get(plugin_id)
|
||||
if not plugin:
|
||||
return False
|
||||
# 过滤非法键
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
safe_config = {k: v for k, v in config.items() if k in plugin.default_config}
|
||||
if not safe_config:
|
||||
return False
|
||||
@@ -95,44 +90,45 @@ class PluginService:
|
||||
def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]:
|
||||
return registry.get(plugin_id)
|
||||
|
||||
async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]:
|
||||
"""执行单个插件爬取"""
|
||||
def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin:
|
||||
plugin = self.get_plugin(plugin_id)
|
||||
if not plugin:
|
||||
raise ValueError(f"Plugin {plugin_id} not found")
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
return plugin
|
||||
|
||||
async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult:
|
||||
"""执行单个插件爬取(通过 PluginRunner)"""
|
||||
plugin = self.get_plugin_or_raise(plugin_id)
|
||||
if not plugin.enabled:
|
||||
logger.warning(f"Plugin {plugin_id} is disabled, skip crawl")
|
||||
return []
|
||||
return CrawlResult(plugin_name=plugin_id)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
try:
|
||||
results = await plugin.crawl()
|
||||
self._record_stat(plugin_id, success=len(results))
|
||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||
return results
|
||||
except Exception as e:
|
||||
self._record_stat(plugin_id, failure=1)
|
||||
logger.error(f"Plugin {plugin_id} crawl failed: {e}")
|
||||
return []
|
||||
finally:
|
||||
await self._save_stats(plugin_id)
|
||||
|
||||
async def run_all_plugins(self) -> List[ProxyRaw]:
|
||||
async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]:
|
||||
"""执行所有启用插件的爬取,限制并发数以避免触发目标站反爬"""
|
||||
all_results: List[ProxyRaw] = []
|
||||
semaphore = asyncio.Semaphore(5)
|
||||
|
||||
async def _run_with_limit(plugin_name: str):
|
||||
plugin = self.get_plugin_or_raise(plugin_name)
|
||||
if not plugin.enabled:
|
||||
return CrawlResult(plugin_name=plugin_name)
|
||||
async with semaphore:
|
||||
return await self.run_plugin(plugin_name)
|
||||
return await plugin_runner.run(plugin)
|
||||
|
||||
tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
|
||||
tasks = [
|
||||
_run_with_limit(plugin.name)
|
||||
for plugin in registry.list_plugins() if plugin.enabled
|
||||
]
|
||||
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for results in results_list:
|
||||
if isinstance(results, Exception):
|
||||
logger.error(f"Run all plugins error: {results}")
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Run all plugins error: {result}")
|
||||
continue
|
||||
all_results.extend(results)
|
||||
# 去重(与数据库 UNIQUE(ip, port) 约束保持一致)
|
||||
if isinstance(result, CrawlResult):
|
||||
all_results.extend(result.proxies)
|
||||
|
||||
# 去重
|
||||
seen = set()
|
||||
unique = []
|
||||
for p in all_results:
|
||||
@@ -141,26 +137,3 @@ class PluginService:
|
||||
seen.add(key)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0):
|
||||
if plugin_id not in self._stats:
|
||||
self._stats[plugin_id] = {
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"last_run": None,
|
||||
}
|
||||
self._stats[plugin_id]["success_count"] += success
|
||||
self._stats[plugin_id]["failure_count"] += failure
|
||||
if success or failure:
|
||||
self._stats[plugin_id]["last_run"] = datetime.now()
|
||||
|
||||
async def _save_stats(self, plugin_id: str):
|
||||
"""将内存中的统计持久化到数据库"""
|
||||
stats = self._stats.get(plugin_id, {})
|
||||
payload = {
|
||||
"success_count": stats.get("success_count", 0),
|
||||
"failure_count": stats.get("failure_count", 0),
|
||||
"last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None,
|
||||
}
|
||||
async with get_db() as db:
|
||||
await self.plugin_settings_repo.set_stats(db, plugin_id, payload)
|
||||
|
||||
Reference in New Issue
Block a user