refactor(backend): optimize database safety, validator performance, and scheduler concurrency

- Fix SQL injection risks in proxy_repo and task_repo
- Atomic acquire_pending with UPDATE ... RETURNING
- Reuse aiohttp ClientSession in ValidatorService
- Replace polling with asyncio.Event in SchedulerService
- Optimize ValidationQueue.drain with asyncio.Condition
- Concurrent plugin crawling with asyncio.gather
- Unify ProxyRaw model import path
- Fix test baseline and remove tracked __pycache__ files
This commit is contained in:
祀梦
2026-04-04 14:43:31 +08:00
parent abb8b32ed3
commit 635c524a7e
27 changed files with 103 additions and 89 deletions

View File

@@ -38,4 +38,5 @@ async def lifespan(app: FastAPI):
# 关闭调度器 # 关闭调度器
await scheduler_service.stop() await scheduler_service.stop()
await scheduler_service.validation_queue.validator.close()
logger.info("API server shutdown") logger.info("API server shutdown")

View File

@@ -49,7 +49,7 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
@router.post("/validate-now") @router.post("/validate-now")
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)): async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
try: try:
scheduler.validate_all_now() await scheduler.validate_all_now()
return success_response("已开始全量验证", {"started": True}) return success_response("已开始全量验证", {"started": True})
except Exception as e: except Exception as e:
logger.error(f"Validate now failed: {e}") logger.error(f"Validate now failed: {e}")

View File

@@ -1,20 +1,7 @@
"""插件基类 - 所有爬虫插件必须继承此基类""" """插件基类 - 所有爬虫插件必须继承此基类"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Dict, Any from typing import List, Dict, Any
from app.models.domain import ProxyRaw
@dataclass
class ProxyRaw:
"""爬虫产出的原始代理数据"""
ip: str
port: int
protocol: str = "http"
def __post_init__(self):
self.protocol = self.protocol.lower().strip()
if self.protocol not in ("http", "https", "socks4", "socks5"):
self.protocol = "http"
class BaseCrawlerPlugin(ABC): class BaseCrawlerPlugin(ABC):

View File

@@ -40,6 +40,8 @@ class ValidationQueue:
self._workers: list[asyncio.Task] = [] self._workers: list[asyncio.Task] = []
self._running = False self._running = False
self._db_lock = asyncio.Lock() self._db_lock = asyncio.Lock()
self._pending_count = 0
self._condition = asyncio.Condition()
# 统计 # 统计
self.valid_count = 0 self.valid_count = 0
@@ -58,6 +60,8 @@ class ValidationQueue:
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks") logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
if pending: if pending:
logger.info(f"ValidationQueue has {pending} pending tasks to process") logger.info(f"ValidationQueue has {pending} pending tasks to process")
async with self._condition:
self._pending_count = pending
for i in range(self.worker_count): for i in range(self.worker_count):
self._workers.append(asyncio.create_task(self._worker_loop(i))) self._workers.append(asyncio.create_task(self._worker_loop(i)))
@@ -86,6 +90,9 @@ class ValidationQueue:
async with get_db() as db: async with get_db() as db:
inserted = await self.task_repo.insert_batch(db, proxies) inserted = await self.task_repo.insert_batch(db, proxies)
if inserted: if inserted:
async with self._condition:
self._pending_count += inserted
self._condition.notify_all()
for _ in range(min(inserted, self.worker_count)): for _ in range(min(inserted, self.worker_count)):
self._signal.put_nowait(None) self._signal.put_nowait(None)
@@ -94,12 +101,9 @@ class ValidationQueue:
async def drain(self): async def drain(self):
"""等待队列中当前所有 pending 任务处理完毕""" """等待队列中当前所有 pending 任务处理完毕"""
while True: async with self._condition:
async with get_db() as db: if self._pending_count > 0:
count = await self.task_repo.get_pending_count(db) await self._condition.wait_for(lambda: self._pending_count == 0)
if count == 0:
break
await asyncio.sleep(0.5)
async def _worker_loop(self, worker_id: int): async def _worker_loop(self, worker_id: int):
while True: while True:
@@ -143,6 +147,10 @@ class ValidationQueue:
await self.task_repo.complete_task(db, task["id"], False, 0.0) await self.task_repo.complete_task(db, task["id"], False, 0.0)
self.invalid_count += 1 self.invalid_count += 1
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}") logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
async with self._condition:
self._pending_count = max(0, self._pending_count - 1)
if self._pending_count == 0:
self._condition.notify_all()
def reset_stats(self): def reset_stats(self):
self.valid_count = 0 self.valid_count = 0

View File

@@ -6,9 +6,10 @@ from app.core.log import logger
class Fate0Plugin(BaseHTTPPlugin): class Fate0Plugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "fate0" name = "fate0"
display_name = "Fate0聚合" display_name = "Fate0聚合"
description = " GitHub 持续更新的高质量代理聚合列表" description = "来自 GitHub 持续更新的高质量代理聚合列表"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -34,5 +35,5 @@ class Fate0Plugin(BaseHTTPPlugin):
except Exception: except Exception:
continue continue
if results: if results:
logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理") logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理")
return results return results

View File

@@ -7,6 +7,7 @@ from app.core.log import logger
class Ip89Plugin(BaseHTTPPlugin): class Ip89Plugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "ip89" name = "ip89"
display_name = "89免费代理" display_name = "89免费代理"
description = "从 89ip.cn 爬取免费代理" description = "从 89ip.cn 爬取免费代理"
@@ -35,5 +36,5 @@ class Ip89Plugin(BaseHTTPPlugin):
results.append(ProxyRaw(ip, int(port), "http")) results.append(ProxyRaw(ip, int(port), "http"))
if results: if results:
logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理") logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理")
return results return results

View File

@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
class KuaiDaiLiPlugin(BaseHTTPPlugin): class KuaiDaiLiPlugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "kuaidaili" name = "kuaidaili"
display_name = "快代理" display_name = "快代理"
description = "从快代理网站爬取免费代理" description = "从快代理网站爬取免费代理"
@@ -45,5 +46,5 @@ class KuaiDaiLiPlugin(BaseHTTPPlugin):
results.append(ProxyRaw(ip, int(port), protocol)) results.append(ProxyRaw(ip, int(port), protocol))
if results: if results:
logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理") logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理")
return results return results

View File

@@ -5,6 +5,7 @@ from app.core.log import logger
class ProxyListDownloadPlugin(BaseHTTPPlugin): class ProxyListDownloadPlugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "proxylist_download" name = "proxylist_download"
display_name = "ProxyListDownload" display_name = "ProxyListDownload"
description = "从 ProxyListDownload API 获取代理" description = "从 ProxyListDownload API 获取代理"

View File

@@ -6,19 +6,20 @@ from app.core.log import logger
class ProxyScrapePlugin(BaseHTTPPlugin): class ProxyScrapePlugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
""" """
从 ProxyScrape 公开 API 获取代理 从 ProxyScrape 公开 API 获取代理
覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性 覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性
""" """
name = "proxyscrape" name = "proxyscrape"
display_name = "ProxyScrape测试" display_name = "ProxyScrape测试"
description = "从 ProxyScrape API 获取各类型代理HTTP/HTTPS/SOCKS4/SOCKS5用于测试架构扩展" description = "从 ProxyScrape API 获取各类型代理HTTP/HTTPS/SOCKS4/SOCKS5用于测试架构扩展"
enabled = True enabled = True
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# 使用多个公开 GitHub 代理列表作为源,稳定性较 # 使用多个公开 GitHub 代理列表作为源,稳定性较
self.urls = [ self.urls = [
("http", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/http.txt"), ("http", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/http.txt"),
("https", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/https.txt"), ("https", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/https.txt"),
@@ -71,5 +72,5 @@ class ProxyScrapePlugin(BaseHTTPPlugin):
ip = f"{random.randint(1, 223)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}" ip = f"{random.randint(1, 223)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}"
port = random.randint(1024, 65535) port = random.randint(1024, 65535)
test_proxies.append(ProxyRaw(ip, port, protocol)) test_proxies.append(ProxyRaw(ip, port, protocol))
logger.info(f"生成 {len(test_proxies)} 个测试代理: HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个") logger.info(f"生成 {len(test_proxies)} 个测试代理 HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个")
return test_proxies return test_proxies

View File

@@ -6,8 +6,9 @@ from app.core.log import logger
class SpeedXPlugin(BaseHTTPPlugin): class SpeedXPlugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "speedx" name = "speedx"
display_name = "SpeedX代理" display_name = "SpeedX代理"
description = "从 SpeedX GitHub 仓库获取 SOCKS 代理列表" description = "从 SpeedX GitHub 仓库获取 SOCKS 代理列表"
def __init__(self): def __init__(self):
@@ -47,5 +48,5 @@ class SpeedXPlugin(BaseHTTPPlugin):
results.append(ProxyRaw(ip, int(port), protocol)) results.append(ProxyRaw(ip, int(port), protocol))
if results: if results:
logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理") logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理")
return results return results

View File

@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
class YunDaiLiPlugin(BaseHTTPPlugin): class YunDaiLiPlugin(BaseHTTPPlugin):
default_config = {"max_pages": 5}
name = "yundaili" name = "yundaili"
display_name = "云代理" display_name = "云代理"
description = "从云代理网站爬取免费代理" description = "从云代理网站爬取免费代理"
@@ -47,5 +48,5 @@ class YunDaiLiPlugin(BaseHTTPPlugin):
results.append(ProxyRaw(ip, int(port), protocol)) results.append(ProxyRaw(ip, int(port), protocol))
if results: if results:
logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理") logger.info(f"{self.display_name} 解析完成,获 {len(results)} 个潜在代理")
return results return results

View File

@@ -190,7 +190,12 @@ class ProxyRepository:
params.append(max_score) params.append(max_score)
where_clause = " AND ".join(conditions) where_clause = " AND ".join(conditions)
order_clause = f"{sort_by} {sort_order}" allowed_sort_by = {"ip", "port", "protocol", "score", "last_check"}
allowed_sort_order = {"ASC", "DESC"}
if sort_by not in allowed_sort_by or sort_order.upper() not in allowed_sort_order:
order_clause = "last_check DESC"
else:
order_clause = f"{sort_by} {sort_order.upper()}"
offset = (page - 1) * page_size offset = (page - 1) * page_size
count_query = f"SELECT COUNT(*) FROM proxies WHERE {where_clause}" count_query = f"SELECT COUNT(*) FROM proxies WHERE {where_clause}"
@@ -268,7 +273,8 @@ class ProxyRepository:
async def clean_expired(db: aiosqlite.Connection, days: int) -> int: async def clean_expired(db: aiosqlite.Connection, days: int) -> int:
try: try:
await db.execute( await db.execute(
"DELETE FROM proxies WHERE last_check < datetime('now', '-{} days')".format(days) "DELETE FROM proxies WHERE last_check < datetime('now', '-' || ? || ' days')",
(days,),
) )
await db.commit() await db.commit()
return db.total_changes return db.total_changes

View File

@@ -33,22 +33,16 @@ class ValidationTaskRepository:
try: try:
async with db.execute( async with db.execute(
""" """
SELECT id, ip, port, protocol FROM validation_tasks UPDATE validation_tasks
WHERE status = 'pending' SET status = 'processing', updated_at = CURRENT_TIMESTAMP
ORDER BY id ASC WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
LIMIT 1 RETURNING id, ip, port, protocol
""" """
) as cursor: ) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
if not row: if not row:
return None return None
task_id = row[0] return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
await db.execute(
"UPDATE validation_tasks SET status = 'processing', updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(task_id,),
)
await db.commit()
return {"id": task_id, "ip": row[1], "port": row[2], "protocol": row[3]}
except Exception as e: except Exception as e:
logger.error(f"acquire_pending failed: {e}") logger.error(f"acquire_pending failed: {e}")
return None return None
@@ -126,7 +120,8 @@ class ValidationTaskRepository:
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int: async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
try: try:
await db.execute( await db.execute(
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-{} days')".format(days) "DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
(days,),
) )
await db.commit() await db.commit()
return db.total_changes return db.total_changes

View File

@@ -1,4 +1,5 @@
"""插件业务服务""" """插件业务服务"""
import asyncio
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from app.core.db import get_db from app.core.db import get_db
@@ -108,14 +109,13 @@ class PluginService:
async def run_all_plugins(self) -> List[ProxyRaw]: async def run_all_plugins(self) -> List[ProxyRaw]:
"""执行所有启用插件的爬取""" """执行所有启用插件的爬取"""
all_results: List[ProxyRaw] = [] all_results: List[ProxyRaw] = []
for plugin in registry.list_plugins(): tasks = [self.run_plugin(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
if not plugin.enabled: results_list = await asyncio.gather(*tasks, return_exceptions=True)
for results in results_list:
if isinstance(results, Exception):
logger.error(f"Run all plugins error: {results}")
continue continue
try:
results = await self.run_plugin(plugin.name)
all_results.extend(results) all_results.extend(results)
except Exception as e:
logger.error(f"Run all plugins error at {plugin.name}: {e}")
# 去重 # 去重
seen = set() seen = set()
unique = [] unique = []

View File

@@ -20,12 +20,14 @@ class SchedulerService:
self.proxy_repo = proxy_repo self.proxy_repo = proxy_repo
self.interval_minutes = 30 self.interval_minutes = 30
self.running = False self.running = False
self._stop_event = asyncio.Event()
self._task: asyncio.Task | None = None self._task: asyncio.Task | None = None
async def start(self): async def start(self):
if self.running: if self.running:
logger.warning("Scheduler already running") logger.warning("Scheduler already running")
return return
self._stop_event.clear()
self.running = True self.running = True
await self.validation_queue.start() await self.validation_queue.start()
self._task = asyncio.create_task(self._run_loop()) self._task = asyncio.create_task(self._run_loop())
@@ -33,6 +35,7 @@ class SchedulerService:
async def stop(self): async def stop(self):
self.running = False self.running = False
self._stop_event.set()
if self._task: if self._task:
self._task.cancel() self._task.cancel()
try: try:
@@ -55,10 +58,10 @@ class SchedulerService:
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop error: {e}") logger.error(f"Scheduler loop error: {e}")
# 等待下一次 # 等待下一次
for _ in range(self.interval_minutes * 60): try:
if not self.running: await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
break except asyncio.TimeoutError:
await asyncio.sleep(1) pass
async def _do_validate_all(self): async def _do_validate_all(self):
"""验证数据库中所有存量代理""" """验证数据库中所有存量代理"""

View File

@@ -25,8 +25,24 @@ class ValidatorService:
): ):
self.timeout = timeout self.timeout = timeout
self.connect_timeout = connect_timeout self.connect_timeout = connect_timeout
self.max_concurrency = max_concurrency
self.semaphore = asyncio.Semaphore(max_concurrency) self.semaphore = asyncio.Semaphore(max_concurrency)
# 共享 HTTP/HTTPS ClientSession
self._http_connector = aiohttp.TCPConnector(
ssl=False,
limit=max_concurrency,
limit_per_host=max_concurrency,
force_close=False,
)
self._timeout = aiohttp.ClientTimeout(
total=timeout, connect=connect_timeout
)
self._http_session = aiohttp.ClientSession(
connector=self._http_connector,
timeout=self._timeout,
)
def _get_test_url(self, protocol: str) -> str: def _get_test_url(self, protocol: str) -> str:
"""获取测试 URL""" """获取测试 URL"""
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"]) urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
@@ -53,20 +69,14 @@ class ValidatorService:
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]: async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
"""验证 HTTP/HTTPS 代理""" """验证 HTTP/HTTPS 代理"""
proxy_url = f"http://{ip}:{port}" proxy_url = f"http://{ip}:{port}"
connector = aiohttp.TCPConnector(ssl=False, limit=0, force_close=True)
timeout = aiohttp.ClientTimeout(total=self.timeout, connect=self.connect_timeout)
test_url = self._get_test_url(protocol) test_url = self._get_test_url(protocol)
try: async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
if response.status in (200, 301, 302): if response.status in (200, 301, 302):
latency = round((time.time() - start) * 1000, 2) latency = round((time.time() - start) * 1000, 2)
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms") logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
return True, latency return True, latency
return False, 0.0 return False, 0.0
finally:
await connector.close()
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]: async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
"""验证 SOCKS4/SOCKS5 代理""" """验证 SOCKS4/SOCKS5 代理"""
@@ -95,3 +105,7 @@ class ValidatorService:
return False, 0.0 return False, 0.0
finally: finally:
await connector.close() await connector.close()
async def close(self):
"""关闭共享的 HTTP ClientSession"""
await self._http_session.close()

View File

@@ -1,7 +1,7 @@
"""pytest 配置文件和 fixtures""" """pytest 配置文件和 fixtures"""
import pytest import pytest
import asyncio import pytest_asyncio
from typing import AsyncGenerator, Generator from typing import AsyncGenerator
from httpx import AsyncClient, ASGITransport from httpx import AsyncClient, ASGITransport
from app.api import create_app from app.api import create_app
@@ -9,24 +9,17 @@ from app.core.db import init_db, get_db
from app.repositories.proxy_repo import ProxyRepository from app.repositories.proxy_repo import ProxyRepository
@pytest.fixture(scope="session") @pytest_asyncio.fixture(scope="function")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""创建事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def app(): async def app():
"""创建应用实例""" """创建应用实例"""
# 初始化测试数据库 # 初始化测试数据库
await init_db() await init_db()
app = create_app() app = create_app()
return app async with app.router.lifespan_context(app):
yield app
@pytest.fixture @pytest_asyncio.fixture
async def client(app) -> AsyncGenerator[AsyncClient, None]: async def client(app) -> AsyncGenerator[AsyncClient, None]:
"""创建异步 HTTP 客户端""" """创建异步 HTTP 客户端"""
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
@@ -34,20 +27,20 @@ async def client(app) -> AsyncGenerator[AsyncClient, None]:
yield client yield client
@pytest.fixture @pytest_asyncio.fixture
async def db(): async def db():
"""获取数据库连接""" """获取数据库连接"""
async with get_db() as db: async with get_db() as db:
yield db yield db
@pytest.fixture @pytest_asyncio.fixture
async def proxy_repo(): async def proxy_repo():
"""获取代理仓库""" """获取代理仓库"""
return ProxyRepository() return ProxyRepository()
@pytest.fixture @pytest_asyncio.fixture
async def sample_proxy(db, proxy_repo): async def sample_proxy(db, proxy_repo):
"""创建一个测试代理""" """创建一个测试代理"""
await proxy_repo.insert_or_update(db, "192.168.1.1", 8080, "http", 50) await proxy_repo.insert_or_update(db, "192.168.1.1", 8080, "http", 50)

View File

@@ -121,14 +121,14 @@ class TestProxiesAPI:
"""测试 GET /api/proxies/export/csv""" """测试 GET /api/proxies/export/csv"""
response = await client.get("/api/proxies/export/csv") response = await client.get("/api/proxies/export/csv")
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/csv" assert response.headers["content-type"].startswith("text/csv")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_export_proxies_txt(self, client): async def test_export_proxies_txt(self, client):
"""测试 GET /api/proxies/export/txt""" """测试 GET /api/proxies/export/txt"""
response = await client.get("/api/proxies/export/txt") response = await client.get("/api/proxies/export/txt")
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/plain" assert response.headers["content-type"].startswith("text/plain")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_export_proxies_json(self, client): async def test_export_proxies_json(self, client):