137 lines
5.7 KiB
Python
137 lines
5.7 KiB
Python
"""插件注册中心 - 显式注册,类型安全,测试友好"""
|
||
import importlib
|
||
import importlib.util
|
||
import inspect
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, List, Type, Optional
|
||
from app.core.plugin_system.base import BaseCrawlerPlugin
|
||
from app.core.log import logger
|
||
|
||
|
||
class PluginRegistry:
|
||
"""插件注册中心"""
|
||
|
||
def __init__(self):
|
||
self._plugins: Dict[str, Type[BaseCrawlerPlugin]] = {}
|
||
self._instances: Dict[str, BaseCrawlerPlugin] = {}
|
||
|
||
def register(self, plugin_cls: Type[BaseCrawlerPlugin]) -> Type[BaseCrawlerPlugin]:
|
||
"""注册一个插件类。支持装饰器语法。"""
|
||
if not inspect.isclass(plugin_cls) or not issubclass(plugin_cls, BaseCrawlerPlugin):
|
||
raise ValueError("Plugin must be a subclass of BaseCrawlerPlugin")
|
||
if not plugin_cls.name:
|
||
raise ValueError(f"Plugin {plugin_cls.__name__} must have a 'name' attribute")
|
||
|
||
self._plugins[plugin_cls.name] = plugin_cls
|
||
logger.info(f"Plugin registered: {plugin_cls.name} ({plugin_cls.__name__})")
|
||
return plugin_cls
|
||
|
||
def get(self, name: str) -> Optional[BaseCrawlerPlugin]:
|
||
"""获取插件实例(懒加载)"""
|
||
if name not in self._instances:
|
||
cls = self._plugins.get(name)
|
||
if cls:
|
||
self._instances[name] = cls()
|
||
return self._instances.get(name)
|
||
|
||
def list_plugins(self) -> List[BaseCrawlerPlugin]:
|
||
"""获取所有已注册插件的实例列表"""
|
||
result = []
|
||
for name in self._plugins:
|
||
instance = self.get(name)
|
||
if instance:
|
||
result.append(instance)
|
||
return result
|
||
|
||
def get_plugin_names(self) -> List[str]:
|
||
return list(self._plugins.keys())
|
||
|
||
def clear(self) -> None:
|
||
"""清空所有已注册插件(主要用于测试)"""
|
||
self._plugins.clear()
|
||
self._instances.clear()
|
||
|
||
def auto_discover(self, package_name: str):
|
||
"""自动扫描指定包下的所有模块并注册其中的插件类。
|
||
注意:为了类型安全和可控性,推荐显式注册。auto_discover 仅作为兼容。"""
|
||
try:
|
||
package = importlib.import_module(package_name)
|
||
package_dir = os.path.dirname(package.__file__)
|
||
except Exception as e:
|
||
logger.error(f"Auto discover failed for package {package_name}: {e}")
|
||
return
|
||
|
||
for filename in os.listdir(package_dir):
|
||
if filename.endswith(".py") and not filename.startswith("__"):
|
||
module_name = f"{package_name}.{filename[:-3]}"
|
||
try:
|
||
module = importlib.import_module(module_name)
|
||
for attr_name in dir(module):
|
||
obj = getattr(module, attr_name)
|
||
if (
|
||
inspect.isclass(obj)
|
||
and issubclass(obj, BaseCrawlerPlugin)
|
||
and obj is not BaseCrawlerPlugin
|
||
and obj not in self._plugins.values()
|
||
):
|
||
self.register(obj)
|
||
except Exception as e:
|
||
logger.error(f"Failed to load module {module_name}: {e}")
|
||
|
||
def load_external_plugins_directory(self, directory: Path) -> int:
|
||
"""从项目下任意目录加载 ``BaseCrawlerPlugin`` 子类(每个 ``.py`` 一个模块)。
|
||
|
||
与内置 ``app.plugins`` 并存;若 ``name`` 与已注册插件冲突则跳过并打日志。
|
||
"""
|
||
directory = Path(directory).resolve()
|
||
if not directory.is_dir():
|
||
logger.info("外部插件目录不存在,已跳过: %s", directory)
|
||
return 0
|
||
loaded = 0
|
||
for path in sorted(directory.glob("*.py")):
|
||
if path.name.startswith("_"):
|
||
continue
|
||
mod_name = f"proxypool_ext_{path.stem}_{abs(hash(str(path))) % 10_000_000_000}"
|
||
try:
|
||
spec = importlib.util.spec_from_file_location(mod_name, path)
|
||
if spec is None or spec.loader is None:
|
||
continue
|
||
module = importlib.util.module_from_spec(spec)
|
||
sys.modules[mod_name] = module
|
||
spec.loader.exec_module(module)
|
||
for attr_name in dir(module):
|
||
obj = getattr(module, attr_name)
|
||
if (
|
||
inspect.isclass(obj)
|
||
and issubclass(obj, BaseCrawlerPlugin)
|
||
and obj is not BaseCrawlerPlugin
|
||
and obj not in self._plugins.values()
|
||
):
|
||
if not getattr(obj, "name", None):
|
||
logger.warning(
|
||
"跳过外部插件类(缺少 name): %s in %s",
|
||
obj.__name__,
|
||
path,
|
||
)
|
||
continue
|
||
if obj.name in self._plugins:
|
||
logger.warning(
|
||
"外部插件 %s 与已注册插件重名,已跳过: %s",
|
||
obj.name,
|
||
path,
|
||
)
|
||
continue
|
||
self.register(obj)
|
||
loaded += 1
|
||
except Exception as e:
|
||
logger.error("加载外部插件失败 %s: %s", path, e, exc_info=True)
|
||
if loaded:
|
||
logger.info("从 %s 额外加载 %s 个插件", directory, loaded)
|
||
return loaded
|
||
|
||
|
||
# 全局注册中心实例
|
||
registry = PluginRegistry()
|