diff --git a/app/api/lifespan.py b/app/api/lifespan.py index 1c05226..c84ac3e 100644 --- a/app/api/lifespan.py +++ b/app/api/lifespan.py @@ -3,7 +3,7 @@ import asyncio from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI -from app.core.db import init_db, get_db, get_db_connection +from app.core.db import init_db, get_db from app.core.config import settings as app_settings from app.core.log import logger from app.core.execution import AsyncWorkerPool, JobExecutor @@ -46,52 +46,58 @@ async def lifespan(app: FastAPI): if db_settings.get("validation_targets") is not None: validator.update_test_urls(db_settings["validation_targets"]) - # 验证 WorkerPool + # 高并发验证时若多 worker 同时写 SQLite 会 database is locked;网络阶段不占连接,写路径串行化 + validation_write_lock = asyncio.Lock() + async def validation_handler(proxy): - async with get_db_connection() as db: + async with get_db() as db: existing = await proxy_repo.get_by_ip_port(db, proxy.ip, proxy.port) - is_valid, latency = await validator.validate( - proxy.ip, proxy.port, proxy.protocol - ) - if not existing: - return - if existing.validated == 0: - if is_valid: - await proxy_repo.insert_or_update( - db, - proxy.ip, - proxy.port, - proxy.protocol, - score=app_settings.score_valid, - ) - if latency: - await proxy_repo.update_response_time( - db, proxy.ip, proxy.port, latency + + is_valid, latency = await validator.validate( + proxy.ip, proxy.port, proxy.protocol + ) + + async with validation_write_lock: + async with get_db() as db: + if not existing: + return + if existing.validated == 0: + if is_valid: + await proxy_repo.insert_or_update( + db, + proxy.ip, + proxy.port, + proxy.protocol, + score=app_settings.score_valid, ) + if latency: + await proxy_repo.update_response_time( + db, proxy.ip, proxy.port, latency + ) + else: + await proxy_repo.delete(db, proxy.ip, proxy.port) else: - await proxy_repo.delete(db, proxy.ip, proxy.port) - else: - if is_valid: - await proxy_repo.insert_or_update( - db, - proxy.ip, - proxy.port, - proxy.protocol, - score=app_settings.score_valid, - ) - if latency: - await proxy_repo.update_response_time( - db, proxy.ip, proxy.port, latency + if is_valid: + await proxy_repo.insert_or_update( + db, + proxy.ip, + proxy.port, + proxy.protocol, + score=app_settings.score_valid, + ) + if latency: + await proxy_repo.update_response_time( + db, proxy.ip, proxy.port, latency + ) + else: + await proxy_repo.update_score( + db, + proxy.ip, + proxy.port, + app_settings.score_invalid, + app_settings.score_min, + app_settings.score_max, ) - else: - await proxy_repo.update_score( - db, - proxy.ip, - proxy.port, - app_settings.score_invalid, - app_settings.score_min, - app_settings.score_max, - ) worker_pool = AsyncWorkerPool( worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency), diff --git a/app/core/db.py b/app/core/db.py index 7bcd73f..8afd324 100644 --- a/app/core/db.py +++ b/app/core/db.py @@ -16,12 +16,21 @@ def ensure_db_dir(): os.makedirs(db_dir, exist_ok=True) +async def _apply_connection_pragmas(db: aiosqlite.Connection) -> None: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute("PRAGMA synchronous=NORMAL") + await db.execute("PRAGMA busy_timeout=30000") + + +# aiosqlite/sqlite3:等待锁的最长时间(秒),与高并发验证写入配合 +_SQLITE_CONNECT_TIMEOUT = 30.0 + + async def init_db(): """初始化数据库表结构(支持迁移)""" ensure_db_dir() - async with aiosqlite.connect(DB_PATH) as db: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA synchronous=NORMAL") + async with aiosqlite.connect(DB_PATH, timeout=_SQLITE_CONNECT_TIMEOUT) as db: + await _apply_connection_pragmas(db) await db.execute("PRAGMA cache_size=-64000") await db.execute("PRAGMA temp_store=MEMORY") @@ -120,10 +129,9 @@ async def init_db(): async def get_db() -> AsyncIterator[aiosqlite.Connection]: """获取数据库连接的异步上下文管理器""" ensure_db_dir() - db = await aiosqlite.connect(DB_PATH) + db = await aiosqlite.connect(DB_PATH, timeout=_SQLITE_CONNECT_TIMEOUT) try: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA synchronous=NORMAL") + await _apply_connection_pragmas(db) yield db finally: await db.close() @@ -131,12 +139,11 @@ async def get_db() -> AsyncIterator[aiosqlite.Connection]: @asynccontextmanager async def get_db_connection() -> AsyncIterator[aiosqlite.Connection]: - """单连接贯穿「读库 → await 网络 I/O → 写库」,减少验证 worker 每条代理两次 connect。""" + """与 get_db 相同 pragma/超时;保留别名供需「长连接」语义处使用。""" ensure_db_dir() - db = await aiosqlite.connect(DB_PATH) + db = await aiosqlite.connect(DB_PATH, timeout=_SQLITE_CONNECT_TIMEOUT) try: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA synchronous=NORMAL") + await _apply_connection_pragmas(db) yield db finally: await db.close() @@ -152,10 +159,9 @@ async def transaction() -> AsyncIterator[aiosqlite.Connection]: # 如果抛出异常,自动 rollback """ ensure_db_dir() - db = await aiosqlite.connect(DB_PATH) + db = await aiosqlite.connect(DB_PATH, timeout=_SQLITE_CONNECT_TIMEOUT) try: - await db.execute("PRAGMA journal_mode=WAL") - await db.execute("PRAGMA synchronous=NORMAL") + await _apply_connection_pragmas(db) await db.execute("BEGIN") yield db await db.commit()