"""插件业务服务""" import asyncio from typing import List, Optional from app.core.db import get_db from app.core.plugin_system.registry import registry from app.core.plugin_system.base import BaseCrawlerPlugin from app.core.exceptions import PluginNotFoundException, ValidationException from app.repositories.settings_repo import PluginSettingsRepository from app.models.domain import PluginInfo, ProxyRaw, CrawlResult from app.core.log import logger class PluginService: """插件业务服务:管理插件生命周期、配置管理""" def __init__(self): self.plugin_settings_repo = PluginSettingsRepository() async def list_plugins(self) -> List[PluginInfo]: """获取所有插件信息(合并持久化状态和配置)""" async with get_db() as db: db_states = await self.plugin_settings_repo.list_all(db) result = [] for plugin in registry.list_plugins(): state = db_states.get(plugin.name, {}) enabled = state.get("enabled", plugin.enabled) if "config" in state and isinstance(state["config"], dict): plugin.update_config(state["config"]) db_stat = state.get("stats", {}) result.append(PluginInfo( id=plugin.name, name=plugin.name, display_name=plugin.display_name or plugin.name, description=plugin.description or f"从 {plugin.name} 爬取代理", enabled=enabled, last_run=self._parse_datetime(db_stat.get("last_run")), success_count=db_stat.get("success_count", 0), failure_count=db_stat.get("failure_count", 0), )) return result @staticmethod def _parse_datetime(value) -> Optional: from datetime import datetime if not value: return None if isinstance(value, datetime): return value try: return datetime.fromisoformat(value) except Exception: return None async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool: plugin = registry.get(plugin_id) if not plugin: raise PluginNotFoundException(plugin_id) async with get_db() as db: success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled) if success: plugin.enabled = enabled logger.info(f"Plugin {plugin_id} toggled to {enabled}") return success async def get_plugin_config(self, plugin_id: str) -> Optional[dict]: plugin = registry.get(plugin_id) if not plugin: raise PluginNotFoundException(plugin_id) async with get_db() as db: saved = await self.plugin_settings_repo.get_config(db, plugin_id) config = dict(plugin.default_config) if saved: config.update(saved) return config async def update_plugin_config(self, plugin_id: str, config: dict) -> bool: plugin = registry.get(plugin_id) if not plugin: raise PluginNotFoundException(plugin_id) safe_config = {k: v for k, v in config.items() if k in plugin.default_config} if not safe_config: raise ValidationException("配置项无效或为空") plugin.update_config(safe_config) async with get_db() as db: success = await self.plugin_settings_repo.set_config(db, plugin_id, plugin.config) if not success: raise ProxyPoolException("保存插件配置失败", 500) return True def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]: return registry.get(plugin_id) def get_plugin_or_raise(self, plugin_id: str) -> BaseCrawlerPlugin: plugin = self.get_plugin(plugin_id) if not plugin: raise PluginNotFoundException(plugin_id) return plugin async def run_plugin(self, plugin_id: str, plugin_runner) -> CrawlResult: """执行单个插件爬取(通过 PluginRunner)""" plugin = self.get_plugin_or_raise(plugin_id) if not plugin.enabled: logger.warning(f"Plugin {plugin_id} is disabled, skip crawl") return CrawlResult(plugin_name=plugin_id) return await plugin_runner.run(plugin) async def run_all_plugins(self, plugin_runner) -> List[ProxyRaw]: """执行所有启用插件的爬取,限制并发数以避免触发目标站反爬""" all_results: List[ProxyRaw] = [] semaphore = asyncio.Semaphore(5) async def _run_with_limit(plugin_name: str): plugin = self.get_plugin_or_raise(plugin_name) if not plugin.enabled: return CrawlResult(plugin_name=plugin_name) async with semaphore: return await plugin_runner.run(plugin) tasks = [ _run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled ] results_list = await asyncio.gather(*tasks, return_exceptions=True) for result in results_list: if isinstance(result, Exception): logger.error(f"Run all plugins error: {result}") continue if isinstance(result, CrawlResult): all_results.extend(result.proxies) # 去重 seen = set() unique = [] for p in all_results: key = (p.ip, p.port, p.protocol) if key not in seen: seen.add(key) unique.append(p) return unique