"""插件业务服务""" import asyncio from datetime import datetime from typing import List, Optional from app.core.db import get_db from app.core.plugin_system.registry import registry from app.core.plugin_system.base import BaseCrawlerPlugin from app.repositories.settings_repo import PluginSettingsRepository from app.models.domain import PluginInfo, ProxyRaw from app.core.log import logger class PluginService: """插件业务服务:管理插件生命周期、执行爬取、配置管理""" def __init__(self): self.plugin_settings_repo = PluginSettingsRepository() self._stats: dict[str, dict] = {} async def list_plugins(self) -> List[PluginInfo]: """获取所有插件信息(合并持久化状态和配置)""" async with get_db() as db: db_states = await self.plugin_settings_repo.list_all(db) result = [] for plugin in registry.list_plugins(): # 合并持久化状态 state = db_states.get(plugin.name, {}) if "enabled" in state: plugin.enabled = state["enabled"] if "config" in state and isinstance(state["config"], dict): plugin.update_config(state["config"]) # 合并数据库统计与内存统计(内存优先) db_stat = state.get("stats", {}) stat = { "success_count": db_stat.get("success_count", 0), "failure_count": db_stat.get("failure_count", 0), "last_run": datetime.fromisoformat(db_stat["last_run"]) if db_stat.get("last_run") else None, } mem_stat = self._stats.get(plugin.name, {}) if mem_stat: stat["success_count"] = mem_stat.get("success_count", stat["success_count"]) stat["failure_count"] = mem_stat.get("failure_count", stat["failure_count"]) if mem_stat.get("last_run"): stat["last_run"] = mem_stat["last_run"] result.append(PluginInfo( id=plugin.name, name=plugin.name, display_name=plugin.display_name or plugin.name, description=plugin.description or f"从 {plugin.name} 爬取代理", enabled=plugin.enabled, last_run=stat.get("last_run"), success_count=stat.get("success_count", 0), failure_count=stat.get("failure_count", 0), )) return result async def toggle_plugin(self, plugin_id: str, enabled: bool) -> bool: plugin = registry.get(plugin_id) if not plugin: return False async with get_db() as db: success = await self.plugin_settings_repo.set_enabled(db, plugin_id, enabled) if success: plugin.enabled = enabled logger.info(f"Plugin {plugin_id} toggled to {enabled}") return success async def get_plugin_config(self, plugin_id: str) -> Optional[dict]: """获取插件当前配置(合并默认值和持久化值)""" plugin = registry.get(plugin_id) if not plugin: return None async with get_db() as db: saved = await self.plugin_settings_repo.get_config(db, plugin_id) config = dict(plugin.default_config) if saved: config.update(saved) return config async def update_plugin_config(self, plugin_id: str, config: dict) -> bool: """更新插件配置(只保存已存在于 default_config 中的键)""" plugin = registry.get(plugin_id) if not plugin: return False # 过滤非法键 safe_config = {k: v for k, v in config.items() if k in plugin.default_config} if not safe_config: return False plugin.update_config(safe_config) async with get_db() as db: return await self.plugin_settings_repo.set_config(db, plugin_id, plugin.config) def get_plugin(self, plugin_id: str) -> Optional[BaseCrawlerPlugin]: return registry.get(plugin_id) async def run_plugin(self, plugin_id: str) -> List[ProxyRaw]: """执行单个插件爬取""" plugin = self.get_plugin(plugin_id) if not plugin: raise ValueError(f"Plugin {plugin_id} not found") if not plugin.enabled: logger.warning(f"Plugin {plugin_id} is disabled, skip crawl") return [] try: results = await plugin.crawl() self._record_stat(plugin_id, success=len(results)) logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies") return results except Exception as e: self._record_stat(plugin_id, failure=1) logger.error(f"Plugin {plugin_id} crawl failed: {e}") return [] finally: await self._save_stats(plugin_id) async def run_all_plugins(self) -> List[ProxyRaw]: """执行所有启用插件的爬取,限制并发数以避免触发目标站反爬""" all_results: List[ProxyRaw] = [] semaphore = asyncio.Semaphore(5) async def _run_with_limit(plugin_name: str): async with semaphore: return await self.run_plugin(plugin_name) tasks = [_run_with_limit(plugin.name) for plugin in registry.list_plugins() if plugin.enabled] results_list = await asyncio.gather(*tasks, return_exceptions=True) for results in results_list: if isinstance(results, Exception): logger.error(f"Run all plugins error: {results}") continue all_results.extend(results) # 去重 seen = set() unique = [] for p in all_results: key = (p.ip, p.port, p.protocol) if key not in seen: seen.add(key) unique.append(p) return unique def _record_stat(self, plugin_id: str, success: int = 0, failure: int = 0): if plugin_id not in self._stats: self._stats[plugin_id] = { "success_count": 0, "failure_count": 0, "last_run": None, } self._stats[plugin_id]["success_count"] += success self._stats[plugin_id]["failure_count"] += failure if success or failure: self._stats[plugin_id]["last_run"] = datetime.now() async def _save_stats(self, plugin_id: str): """将内存中的统计持久化到数据库""" stats = self._stats.get(plugin_id, {}) payload = { "success_count": stats.get("success_count", 0), "failure_count": stats.get("failure_count", 0), "last_run": stats.get("last_run").isoformat() if stats.get("last_run") else None, } async with get_db() as db: await self.plugin_settings_repo.set_stats(db, plugin_id, payload)