Files
ProxyPool/app/api/lifespan.py

211 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""应用生命周期管理"""
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI
from app.core.db import init_db, get_db
from app.core.config import settings as app_settings
from app.core.log import logger
from app.core.execution import AsyncWorkerPool, JobExecutor
from app.core.plugin_system.registry import registry
from app.repositories.proxy_repo import ProxyRepository
from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS
from app.services.validator_service import ValidatorService
from app.services.proxy_scoring import compute_proxy_quality_score
from app.services.plugin_runner import PluginRunner
from app.services.scheduler_service import SchedulerService
from app.services.proxy_service import ProxyService
from app.api.ws_manager import ConnectionManager
from app.api.realtime import stats_broadcaster_loop
settings_repo = SettingsRepository()
proxy_repo = ProxyRepository()
async def _load_settings() -> dict:
db_settings = DEFAULT_SETTINGS.copy()
try:
async with get_db() as db:
db_settings = await settings_repo.get_all(db)
except Exception as e:
logger.error(f"Failed to load settings on startup: {e}")
return db_settings
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用启动和关闭时的生命周期管理"""
await init_db()
db_settings = await _load_settings()
async with AsyncExitStack() as stack:
# 验证器
validator = ValidatorService(
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
connect_timeout=app_settings.validator_connect_timeout,
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
)
if db_settings.get("validation_targets") is not None:
validator.update_test_urls(db_settings["validation_targets"])
# 高并发验证时若多 worker 同时写 SQLite 会 database is locked网络阶段不占连接写路径串行化
validation_write_lock = asyncio.Lock()
async def validation_handler(proxy):
async with get_db() as db:
existing = await proxy_repo.get_by_ip_port(db, proxy.ip, proxy.port)
is_valid, latency = await validator.validate(
proxy.ip, proxy.port, proxy.protocol
)
async with validation_write_lock:
async with get_db() as db:
if not existing:
return
if existing.validated == 0:
if is_valid:
lat_ms = (
float(latency)
if latency is not None and float(latency) > 0
else None
)
uc = int(getattr(existing, "use_count", 0) or 0)
q_score = compute_proxy_quality_score(
lat_ms, uc, app_settings
)
await proxy_repo.insert_or_update(
db,
proxy.ip,
proxy.port,
proxy.protocol,
score=q_score,
)
rt_ms = (
float(latency)
if latency is not None and float(latency) > 0
else float(app_settings.score_default_latency_ms)
)
await proxy_repo.update_response_time(
db, proxy.ip, proxy.port, rt_ms
)
else:
await proxy_repo.delete(db, proxy.ip, proxy.port)
else:
if is_valid:
lat_ms = (
float(latency)
if latency is not None and float(latency) > 0
else None
)
uc = int(getattr(existing, "use_count", 0) or 0)
q_score = compute_proxy_quality_score(
lat_ms, uc, app_settings
)
await proxy_repo.insert_or_update(
db,
proxy.ip,
proxy.port,
proxy.protocol,
score=q_score,
)
rt_ms = (
float(latency)
if latency is not None and float(latency) > 0
else float(app_settings.score_default_latency_ms)
)
await proxy_repo.update_response_time(
db, proxy.ip, proxy.port, rt_ms
)
else:
await proxy_repo.update_score(
db,
proxy.ip,
proxy.port,
app_settings.score_invalid,
app_settings.score_min,
app_settings.score_max,
)
worker_pool = AsyncWorkerPool(
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
handler=validation_handler,
name="ValidationPool",
)
await stack.enter_async_context(worker_pool)
# Job 执行器并发槽位crawler_max_queue_size 与插件数共同约束,避免 crawl-all 死锁)
_n_plugins = len(registry.list_plugins())
_floor = max(24, _n_plugins + 8)
_max_jobs = max(_floor, app_settings.crawler_max_queue_size)
executor = JobExecutor(worker_pool=worker_pool, max_concurrent_jobs=_max_jobs)
await stack.enter_async_context(executor)
# 插件运行器
plugin_runner = PluginRunner()
proxy_service = ProxyService()
# 调度器
scheduler = SchedulerService(
executor=executor,
worker_pool=worker_pool,
interval_minutes=db_settings.get("validate_interval_minutes", 30),
proxy_service=proxy_service,
settings_repo=settings_repo,
)
# 挂载到 app.state
app.state.validator = validator
app.state.worker_pool = worker_pool
app.state.executor = executor
app.state.plugin_runner = plugin_runner
app.state.scheduler = scheduler
app.state.ws_manager = ConnectionManager()
app.state.stats_broadcaster_task = asyncio.create_task(stats_broadcaster_loop(app))
# 启动调度器
if db_settings.get("auto_validate", True):
try:
await scheduler.start()
except Exception as e:
logger.error(f"Failed to start scheduler on startup: {e}")
logger.info("API server started")
yield
app.state.stats_broadcaster_task.cancel()
try:
await app.state.stats_broadcaster_task
except asyncio.CancelledError:
pass
await app.state.ws_manager.disconnect_all()
# 停止调度器
await scheduler.stop()
# 取消所有运行中的 Job
await executor.cancel_all()
# AsyncExitStack 会自动关闭 executor 和 worker_pool
# 关闭验证器 session
try:
await validator.close()
except Exception:
pass
# 关闭所有插件的 HTTP 客户端
for plugin in registry.list_plugins():
if hasattr(plugin, "close"):
try:
await plugin.close()
except Exception:
pass
# 给 aiosqlite / aiohttp 后台线程留出收尾时间
await asyncio.sleep(0.1)
logger.info("API server shutdown")