Files
ProxyPool/app/services/plugin_service.py
祀梦 b972b64616 refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool
- 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob)
- 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计)
- 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob
- 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源
- 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理
- 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效)
- 前端 Store 强制写后重新拉取,消除乐观更新数据不同步
- 删除 queue.py / task_repo.py / task_service.py
- 新增 execution 单元测试,全部 85 个测试通过
2026-04-04 22:36:57 +08:00

140 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""插件业务服务"""
import asyncio
from typing import List, Optional
from app.core.db import get_db
from app.core.plugin_system.registry import registry
from app.core.plugin_system.base import BaseCrawlerPlugin
from app.core.exceptions import PluginNotFoundException
from app.repositories.settings_repo import PluginSettingsRepository
from app.models.domain import PluginInfo, ProxyRaw, CrawlResult
from app.core.log import logger
class PluginService:
"""插件业务服务:管理插件生命周期、配置管理"""
def __init__(self):
self.plugin_settings_repo = PluginSettingsRepository()
async def list_plugins(self) -> List[PluginInfo]:
"""获取所有插件信息(合并持久化状态和配置)"""
async with get_db() as db:
db_states = await self.plugin_settings_repo.list_all(db)
result = []
for plugin in registry.list_plugins():
state = db_states.get(plugin.name, {})
enabled = state.get("enabled", plugin.enabled)
if "config" in state and isinstance(state["config"], dict):
plugin.update_config(state["config"])
db_stat = state.get("stats", {})
result.append(PluginInfo(
id=plugin.name,
name=plugin.name,
display_name=plugin.display_name or plugin.name,
description=plugin.description or f"{plugin.name} 爬取代理",
enabled=enabled,
last_run=self._parse_datetime(db_stat.get("last_run")),
success_count=db_stat.get("success_count", 0),
failure_count=db_stat.get("failure_count", 0),
))
return result
@staticmethod
def _parse_datetime(value) -> Optional:
from datetime import datetime
if not value:
return None
if isinstance(value, datetime):
return value
try:
return datetime.fromisoformat(value)
except Exception:
return None
async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool:
plugin = registry.get(plugin_id)
if not plugin:
raise PluginNotFoundException(plugin_id)
async with get_db() as db:
success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled)
if success:
plugin.enabled = enabled
logger.info(f"Plugin {plugin_id} toggled to {enabled}")
return success
async def get_plugin_config(self, plugin_id: str) -> Optional[dict]:
plugin = registry.get(plugin_id)
if not plugin:
raise PluginNotFoundException(plugin_id)
async with get_db() as db:
saved = await self.plugin_settings_repo.get_config(db, plugin_id)
config = dict(plugin.default_config)
if saved:
config.update(saved)
return config
async def update_plugin_config(self, plugin_id: str, config: dict) -> bool:
plugin = registry.get(plugin_id)
if not plugin:
raise PluginNotFoundException(plugin_id)
safe_config = {k: v for k, v in config.items() if k in plugin.default_config}
if not safe_config:
return False
plugin.update_config(safe_config)
async with get_db() as db:
return await self.plugin_settings_repo.set_config(db, plugin_id, plugin.config)
def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]:
return registry.get(plugin_id)
def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin:
plugin = self.get_plugin(plugin_id)
if not plugin:
raise PluginNotFoundException(plugin_id)
return plugin
async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult:
"""执行单个插件爬取(通过 PluginRunner"""
plugin = self.get_plugin_or_raise(plugin_id)
if not plugin.enabled:
logger.warning(f"Plugin {plugin_id} is disabled, skip crawl")
return CrawlResult(plugin_name=plugin_id)
return await plugin_runner.run(plugin)
async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]:
"""执行所有启用插件的爬取,限制并发数以避免触发目标站反爬"""
all_results: List[ProxyRaw] = []
semaphore = asyncio.Semaphore(5)
async def _run_with_limit(plugin_name: str):
plugin = self.get_plugin_or_raise(plugin_name)
if not plugin.enabled:
return CrawlResult(plugin_name=plugin_name)
async with semaphore:
return await plugin_runner.run(plugin)
tasks = [
_run_with_limit(plugin.name)
for plugin in registry.list_plugins() if plugin.enabled
]
results_list = await asyncio.gather(*tasks, return_exceptions=True)
for result in results_list:
if isinstance(result, Exception):
logger.error(f"Run all plugins error: {result}")
continue
if isinstance(result, CrawlResult):
all_results.extend(result.proxies)
# 去重
seen = set()
unique = []
for p in all_results:
key = (p.ip, p.port)
if key not in seen:
seen.add(key)
unique.append(p)
return unique