refactor(backend): optimize database safety, validator performance, and scheduler concurrency
- Fix SQL injection risks in proxy_repo and task_repo - Atomic acquire_pending with UPDATE ... RETURNING - Reuse aiohttp ClientSession in ValidatorService - Replace polling with asyncio.Event in SchedulerService - Optimize ValidationQueue.drain with asyncio.Condition - Concurrent plugin crawling with asyncio.gather - Unify ProxyRaw model import path - Fix test baseline and remove tracked __pycache__ files
This commit is contained in:
@@ -38,4 +38,5 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# 关闭调度器
|
||||
await scheduler_service.stop()
|
||||
await scheduler_service.validation_queue.validator.close()
|
||||
logger.info("API server shutdown")
|
||||
|
||||
@@ -49,7 +49,7 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
|
||||
@router.post("/validate-now")
|
||||
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||
try:
|
||||
scheduler.validate_all_now()
|
||||
await scheduler.validate_all_now()
|
||||
return success_response("已开始全量验证", {"started": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Validate now failed: {e}")
|
||||
|
||||
@@ -1,20 +1,7 @@
|
||||
"""插件基类 - 所有爬虫插件必须继承此基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyRaw:
|
||||
"""爬虫产出的原始代理数据"""
|
||||
ip: str
|
||||
port: int
|
||||
protocol: str = "http"
|
||||
|
||||
def __post_init__(self):
|
||||
self.protocol = self.protocol.lower().strip()
|
||||
if self.protocol not in ("http", "https", "socks4", "socks5"):
|
||||
self.protocol = "http"
|
||||
from app.models.domain import ProxyRaw
|
||||
|
||||
|
||||
class BaseCrawlerPlugin(ABC):
|
||||
|
||||
@@ -40,6 +40,8 @@ class ValidationQueue:
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._db_lock = asyncio.Lock()
|
||||
self._pending_count = 0
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
# 统计
|
||||
self.valid_count = 0
|
||||
@@ -58,6 +60,8 @@ class ValidationQueue:
|
||||
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
||||
if pending:
|
||||
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
||||
async with self._condition:
|
||||
self._pending_count = pending
|
||||
|
||||
for i in range(self.worker_count):
|
||||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||||
@@ -86,6 +90,9 @@ class ValidationQueue:
|
||||
async with get_db() as db:
|
||||
inserted = await self.task_repo.insert_batch(db, proxies)
|
||||
if inserted:
|
||||
async with self._condition:
|
||||
self._pending_count += inserted
|
||||
self._condition.notify_all()
|
||||
for _ in range(min(inserted, self.worker_count)):
|
||||
self._signal.put_nowait(None)
|
||||
|
||||
@@ -94,12 +101,9 @@ class ValidationQueue:
|
||||
|
||||
async def drain(self):
|
||||
"""等待队列中当前所有 pending 任务处理完毕"""
|
||||
while True:
|
||||
async with get_db() as db:
|
||||
count = await self.task_repo.get_pending_count(db)
|
||||
if count == 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
async with self._condition:
|
||||
if self._pending_count > 0:
|
||||
await self._condition.wait_for(lambda: self._pending_count == 0)
|
||||
|
||||
async def _worker_loop(self, worker_id: int):
|
||||
while True:
|
||||
@@ -143,6 +147,10 @@ class ValidationQueue:
|
||||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||||
self.invalid_count += 1
|
||||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||||
async with self._condition:
|
||||
self._pending_count = max(0, self._pending_count - 1)
|
||||
if self._pending_count == 0:
|
||||
self._condition.notify_all()
|
||||
|
||||
def reset_stats(self):
|
||||
self.valid_count = 0
|
||||
|
||||
@@ -6,9 +6,10 @@ from app.core.log import logger
|
||||
|
||||
|
||||
class Fate0Plugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "fate0"
|
||||
display_name = "Fate0聚合源"
|
||||
description = "从 GitHub 持续更新的高质量代理聚合列表"
|
||||
display_name = "Fate0聚合站"
|
||||
description = "来自 GitHub 持续更新的高质量代理聚合列表"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -34,5 +35,5 @@ class Fate0Plugin(BaseHTTPPlugin):
|
||||
except Exception:
|
||||
continue
|
||||
if results:
|
||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
||||
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||
return results
|
||||
|
||||
@@ -7,6 +7,7 @@ from app.core.log import logger
|
||||
|
||||
|
||||
class Ip89Plugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "ip89"
|
||||
display_name = "89免费代理"
|
||||
description = "从 89ip.cn 爬取免费代理"
|
||||
@@ -35,5 +36,5 @@ class Ip89Plugin(BaseHTTPPlugin):
|
||||
results.append(ProxyRaw(ip, int(port), "http"))
|
||||
|
||||
if results:
|
||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
||||
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||
return results
|
||||
|
||||
@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
|
||||
|
||||
|
||||
class KuaiDaiLiPlugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "kuaidaili"
|
||||
display_name = "快代理"
|
||||
description = "从快代理网站爬取免费代理"
|
||||
@@ -45,5 +46,5 @@ class KuaiDaiLiPlugin(BaseHTTPPlugin):
|
||||
results.append(ProxyRaw(ip, int(port), protocol))
|
||||
|
||||
if results:
|
||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
||||
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||
return results
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.core.log import logger
|
||||
|
||||
|
||||
class ProxyListDownloadPlugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "proxylist_download"
|
||||
display_name = "ProxyListDownload"
|
||||
description = "从 ProxyListDownload API 获取代理"
|
||||
|
||||
@@ -6,19 +6,20 @@ from app.core.log import logger
|
||||
|
||||
|
||||
class ProxyScrapePlugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
"""
|
||||
从 ProxyScrape 公开 API 获取代理。
|
||||
覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性。
|
||||
从 ProxyScrape 公开 API 获取代理库
|
||||
覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性
|
||||
"""
|
||||
|
||||
name = "proxyscrape"
|
||||
display_name = "ProxyScrape测试源"
|
||||
display_name = "ProxyScrape测试站"
|
||||
description = "从 ProxyScrape API 获取各类型代理(HTTP/HTTPS/SOCKS4/SOCKS5),用于测试架构扩展"
|
||||
enabled = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 使用多个公开 GitHub 代理列表作为源,稳定性较高
|
||||
# 使用多个公开 GitHub 代理列表作为源,稳定性较差
|
||||
self.urls = [
|
||||
("http", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/http.txt"),
|
||||
("https", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/https.txt"),
|
||||
@@ -71,5 +72,5 @@ class ProxyScrapePlugin(BaseHTTPPlugin):
|
||||
ip = f"{random.randint(1, 223)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}"
|
||||
port = random.randint(1024, 65535)
|
||||
test_proxies.append(ProxyRaw(ip, port, protocol))
|
||||
logger.info(f"生成 {len(test_proxies)} 个测试代理: HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个")
|
||||
logger.info(f"生成 {len(test_proxies)} 个测试代理 HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个")
|
||||
return test_proxies
|
||||
|
||||
@@ -6,8 +6,9 @@ from app.core.log import logger
|
||||
|
||||
|
||||
class SpeedXPlugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "speedx"
|
||||
display_name = "SpeedX代理源"
|
||||
display_name = "SpeedX代理库"
|
||||
description = "从 SpeedX GitHub 仓库获取 SOCKS 代理列表"
|
||||
|
||||
def __init__(self):
|
||||
@@ -47,5 +48,5 @@ class SpeedXPlugin(BaseHTTPPlugin):
|
||||
results.append(ProxyRaw(ip, int(port), protocol))
|
||||
|
||||
if results:
|
||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
||||
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||
return results
|
||||
|
||||
@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
|
||||
|
||||
|
||||
class YunDaiLiPlugin(BaseHTTPPlugin):
|
||||
default_config = {"max_pages": 5}
|
||||
name = "yundaili"
|
||||
display_name = "云代理"
|
||||
description = "从云代理网站爬取免费代理"
|
||||
@@ -47,5 +48,5 @@ class YunDaiLiPlugin(BaseHTTPPlugin):
|
||||
results.append(ProxyRaw(ip, int(port), protocol))
|
||||
|
||||
if results:
|
||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
||||
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||
return results
|
||||
|
||||
@@ -190,7 +190,12 @@ class ProxyRepository:
|
||||
params.append(max_score)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
order_clause = f"{sort_by} {sort_order}"
|
||||
allowed_sort_by = {"ip", "port", "protocol", "score", "last_check"}
|
||||
allowed_sort_order = {"ASC", "DESC"}
|
||||
if sort_by not in allowed_sort_by or sort_order.upper() not in allowed_sort_order:
|
||||
order_clause = "last_check DESC"
|
||||
else:
|
||||
order_clause = f"{sort_by} {sort_order.upper()}"
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
count_query = f"SELECT COUNT(*) FROM proxies WHERE {where_clause}"
|
||||
@@ -268,7 +273,8 @@ class ProxyRepository:
|
||||
async def clean_expired(db: aiosqlite.Connection, days: int) -> int:
|
||||
try:
|
||||
await db.execute(
|
||||
"DELETE FROM proxies WHERE last_check < datetime('now', '-{} days')".format(days)
|
||||
"DELETE FROM proxies WHERE last_check < datetime('now', '-' || ? || ' days')",
|
||||
(days,),
|
||||
)
|
||||
await db.commit()
|
||||
return db.total_changes
|
||||
|
||||
@@ -33,22 +33,16 @@ class ValidationTaskRepository:
|
||||
try:
|
||||
async with db.execute(
|
||||
"""
|
||||
SELECT id, ip, port, protocol FROM validation_tasks
|
||||
WHERE status = 'pending'
|
||||
ORDER BY id ASC
|
||||
LIMIT 1
|
||||
UPDATE validation_tasks
|
||||
SET status = 'processing', updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
|
||||
RETURNING id, ip, port, protocol
|
||||
"""
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
task_id = row[0]
|
||||
await db.execute(
|
||||
"UPDATE validation_tasks SET status = 'processing', updated_at = CURRENT_TIMESTAMP WHERE id = ?",
|
||||
(task_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return {"id": task_id, "ip": row[1], "port": row[2], "protocol": row[3]}
|
||||
return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
|
||||
except Exception as e:
|
||||
logger.error(f"acquire_pending failed: {e}")
|
||||
return None
|
||||
@@ -126,7 +120,8 @@ class ValidationTaskRepository:
|
||||
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
|
||||
try:
|
||||
await db.execute(
|
||||
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-{} days')".format(days)
|
||||
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
|
||||
(days,),
|
||||
)
|
||||
await db.commit()
|
||||
return db.total_changes
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""插件业务服务"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from app.core.db import get_db
|
||||
@@ -108,14 +109,13 @@ class PluginService:
|
||||
async def run_all_plugins(self) -> List[ProxyRaw]:
|
||||
"""执行所有启用插件的爬取"""
|
||||
all_results: List[ProxyRaw] = []
|
||||
for plugin in registry.list_plugins():
|
||||
if not plugin.enabled:
|
||||
tasks = [self.run_plugin(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
|
||||
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for results in results_list:
|
||||
if isinstance(results, Exception):
|
||||
logger.error(f"Run all plugins error: {results}")
|
||||
continue
|
||||
try:
|
||||
results = await self.run_plugin(plugin.name)
|
||||
all_results.extend(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Run all plugins error at {plugin.name}: {e}")
|
||||
all_results.extend(results)
|
||||
# 去重
|
||||
seen = set()
|
||||
unique = []
|
||||
|
||||
@@ -20,12 +20,14 @@ class SchedulerService:
|
||||
self.proxy_repo = proxy_repo
|
||||
self.interval_minutes = 30
|
||||
self.running = False
|
||||
self._stop_event = asyncio.Event()
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self):
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
self._stop_event.clear()
|
||||
self.running = True
|
||||
await self.validation_queue.start()
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
@@ -33,6 +35,7 @@ class SchedulerService:
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
@@ -55,10 +58,10 @@ class SchedulerService:
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler loop error: {e}")
|
||||
# 等待下一次
|
||||
for _ in range(self.interval_minutes * 60):
|
||||
if not self.running:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _do_validate_all(self):
|
||||
"""验证数据库中所有存量代理"""
|
||||
|
||||
@@ -25,8 +25,24 @@ class ValidatorService:
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.max_concurrency = max_concurrency
|
||||
self.semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
# 共享 HTTP/HTTPS ClientSession
|
||||
self._http_connector = aiohttp.TCPConnector(
|
||||
ssl=False,
|
||||
limit=max_concurrency,
|
||||
limit_per_host=max_concurrency,
|
||||
force_close=False,
|
||||
)
|
||||
self._timeout = aiohttp.ClientTimeout(
|
||||
total=timeout, connect=connect_timeout
|
||||
)
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
connector=self._http_connector,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
def _get_test_url(self, protocol: str) -> str:
|
||||
"""获取测试 URL"""
|
||||
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
|
||||
@@ -53,20 +69,14 @@ class ValidatorService:
|
||||
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 HTTP/HTTPS 代理"""
|
||||
proxy_url = f"http://{ip}:{port}"
|
||||
connector = aiohttp.TCPConnector(ssl=False, limit=0, force_close=True)
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout, connect=self.connect_timeout)
|
||||
test_url = self._get_test_url(protocol)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
if response.status in (200, 301, 302):
|
||||
latency = round((time.time() - start) * 1000, 2)
|
||||
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
||||
return True, latency
|
||||
return False, 0.0
|
||||
finally:
|
||||
await connector.close()
|
||||
async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||
if response.status in (200, 301, 302):
|
||||
latency = round((time.time() - start) * 1000, 2)
|
||||
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
||||
return True, latency
|
||||
return False, 0.0
|
||||
|
||||
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||
"""验证 SOCKS4/SOCKS5 代理"""
|
||||
@@ -95,3 +105,7 @@ class ValidatorService:
|
||||
return False, 0.0
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def close(self):
|
||||
"""关闭共享的 HTTP ClientSession"""
|
||||
await self._http_session.close()
|
||||
|
||||
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
"""pytest 配置文件和 fixtures"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator
|
||||
import pytest_asyncio
|
||||
from typing import AsyncGenerator
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from app.api import create_app
|
||||
@@ -9,24 +9,17 @@ from app.core.db import init_db, get_db
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""创建事件循环"""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def app():
|
||||
"""创建应用实例"""
|
||||
# 初始化测试数据库
|
||||
await init_db()
|
||||
app = create_app()
|
||||
return app
|
||||
async with app.router.lifespan_context(app):
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def client(app) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""创建异步 HTTP 客户端"""
|
||||
transport = ASGITransport(app=app)
|
||||
@@ -34,20 +27,20 @@ async def client(app) -> AsyncGenerator[AsyncClient, None]:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def db():
|
||||
"""获取数据库连接"""
|
||||
async with get_db() as db:
|
||||
yield db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def proxy_repo():
|
||||
"""获取代理仓库"""
|
||||
return ProxyRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_proxy(db, proxy_repo):
|
||||
"""创建一个测试代理"""
|
||||
await proxy_repo.insert_or_update(db, "192.168.1.1", 8080, "http", 50)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -121,14 +121,14 @@ class TestProxiesAPI:
|
||||
"""测试 GET /api/proxies/export/csv"""
|
||||
response = await client.get("/api/proxies/export/csv")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv"
|
||||
assert response.headers["content-type"].startswith("text/csv")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_proxies_txt(self, client):
|
||||
"""测试 GET /api/proxies/export/txt"""
|
||||
response = await client.get("/api/proxies/export/txt")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/plain"
|
||||
assert response.headers["content-type"].startswith("text/plain")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_proxies_json(self, client):
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user