refactor(backend): optimize database safety, validator performance, and scheduler concurrency
- Fix SQL injection risks in proxy_repo and task_repo - Atomic acquire_pending with UPDATE ... RETURNING - Reuse aiohttp ClientSession in ValidatorService - Replace polling with asyncio.Event in SchedulerService - Optimize ValidationQueue.drain with asyncio.Condition - Concurrent plugin crawling with asyncio.gather - Unify ProxyRaw model import path - Fix test baseline and remove tracked __pycache__ files
This commit is contained in:
@@ -38,4 +38,5 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
# 关闭调度器
|
# 关闭调度器
|
||||||
await scheduler_service.stop()
|
await scheduler_service.stop()
|
||||||
|
await scheduler_service.validation_queue.validator.close()
|
||||||
logger.info("API server shutdown")
|
logger.info("API server shutdown")
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
|
|||||||
@router.post("/validate-now")
|
@router.post("/validate-now")
|
||||||
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||||
try:
|
try:
|
||||||
scheduler.validate_all_now()
|
await scheduler.validate_all_now()
|
||||||
return success_response("已开始全量验证", {"started": True})
|
return success_response("已开始全量验证", {"started": True})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Validate now failed: {e}")
|
logger.error(f"Validate now failed: {e}")
|
||||||
|
|||||||
@@ -1,20 +1,7 @@
|
|||||||
"""插件基类 - 所有爬虫插件必须继承此基类"""
|
"""插件基类 - 所有爬虫插件必须继承此基类"""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
from app.models.domain import ProxyRaw
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProxyRaw:
|
|
||||||
"""爬虫产出的原始代理数据"""
|
|
||||||
ip: str
|
|
||||||
port: int
|
|
||||||
protocol: str = "http"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.protocol = self.protocol.lower().strip()
|
|
||||||
if self.protocol not in ("http", "https", "socks4", "socks5"):
|
|
||||||
self.protocol = "http"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCrawlerPlugin(ABC):
|
class BaseCrawlerPlugin(ABC):
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ class ValidationQueue:
|
|||||||
self._workers: list[asyncio.Task] = []
|
self._workers: list[asyncio.Task] = []
|
||||||
self._running = False
|
self._running = False
|
||||||
self._db_lock = asyncio.Lock()
|
self._db_lock = asyncio.Lock()
|
||||||
|
self._pending_count = 0
|
||||||
|
self._condition = asyncio.Condition()
|
||||||
|
|
||||||
# 统计
|
# 统计
|
||||||
self.valid_count = 0
|
self.valid_count = 0
|
||||||
@@ -58,6 +60,8 @@ class ValidationQueue:
|
|||||||
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
logger.info(f"ValidationQueue recovered {recovered} interrupted tasks")
|
||||||
if pending:
|
if pending:
|
||||||
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
logger.info(f"ValidationQueue has {pending} pending tasks to process")
|
||||||
|
async with self._condition:
|
||||||
|
self._pending_count = pending
|
||||||
|
|
||||||
for i in range(self.worker_count):
|
for i in range(self.worker_count):
|
||||||
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
self._workers.append(asyncio.create_task(self._worker_loop(i)))
|
||||||
@@ -86,6 +90,9 @@ class ValidationQueue:
|
|||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
inserted = await self.task_repo.insert_batch(db, proxies)
|
inserted = await self.task_repo.insert_batch(db, proxies)
|
||||||
if inserted:
|
if inserted:
|
||||||
|
async with self._condition:
|
||||||
|
self._pending_count += inserted
|
||||||
|
self._condition.notify_all()
|
||||||
for _ in range(min(inserted, self.worker_count)):
|
for _ in range(min(inserted, self.worker_count)):
|
||||||
self._signal.put_nowait(None)
|
self._signal.put_nowait(None)
|
||||||
|
|
||||||
@@ -94,12 +101,9 @@ class ValidationQueue:
|
|||||||
|
|
||||||
async def drain(self):
|
async def drain(self):
|
||||||
"""等待队列中当前所有 pending 任务处理完毕"""
|
"""等待队列中当前所有 pending 任务处理完毕"""
|
||||||
while True:
|
async with self._condition:
|
||||||
async with get_db() as db:
|
if self._pending_count > 0:
|
||||||
count = await self.task_repo.get_pending_count(db)
|
await self._condition.wait_for(lambda: self._pending_count == 0)
|
||||||
if count == 0:
|
|
||||||
break
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
async def _worker_loop(self, worker_id: int):
|
async def _worker_loop(self, worker_id: int):
|
||||||
while True:
|
while True:
|
||||||
@@ -143,6 +147,10 @@ class ValidationQueue:
|
|||||||
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
await self.task_repo.complete_task(db, task["id"], False, 0.0)
|
||||||
self.invalid_count += 1
|
self.invalid_count += 1
|
||||||
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
logger.debug(f"ValidationQueue: invalid {proxy.ip}:{proxy.port}")
|
||||||
|
async with self._condition:
|
||||||
|
self._pending_count = max(0, self._pending_count - 1)
|
||||||
|
if self._pending_count == 0:
|
||||||
|
self._condition.notify_all()
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
self.valid_count = 0
|
self.valid_count = 0
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ from app.core.log import logger
|
|||||||
|
|
||||||
|
|
||||||
class Fate0Plugin(BaseHTTPPlugin):
|
class Fate0Plugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "fate0"
|
name = "fate0"
|
||||||
display_name = "Fate0聚合源"
|
display_name = "Fate0聚合站"
|
||||||
description = "从 GitHub 持续更新的高质量代理聚合列表"
|
description = "来自 GitHub 持续更新的高质量代理聚合列表"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -34,5 +35,5 @@ class Fate0Plugin(BaseHTTPPlugin):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
if results:
|
if results:
|
||||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from app.core.log import logger
|
|||||||
|
|
||||||
|
|
||||||
class Ip89Plugin(BaseHTTPPlugin):
|
class Ip89Plugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "ip89"
|
name = "ip89"
|
||||||
display_name = "89免费代理"
|
display_name = "89免费代理"
|
||||||
description = "从 89ip.cn 爬取免费代理"
|
description = "从 89ip.cn 爬取免费代理"
|
||||||
@@ -35,5 +36,5 @@ class Ip89Plugin(BaseHTTPPlugin):
|
|||||||
results.append(ProxyRaw(ip, int(port), "http"))
|
results.append(ProxyRaw(ip, int(port), "http"))
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
|
|||||||
|
|
||||||
|
|
||||||
class KuaiDaiLiPlugin(BaseHTTPPlugin):
|
class KuaiDaiLiPlugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "kuaidaili"
|
name = "kuaidaili"
|
||||||
display_name = "快代理"
|
display_name = "快代理"
|
||||||
description = "从快代理网站爬取免费代理"
|
description = "从快代理网站爬取免费代理"
|
||||||
@@ -45,5 +46,5 @@ class KuaiDaiLiPlugin(BaseHTTPPlugin):
|
|||||||
results.append(ProxyRaw(ip, int(port), protocol))
|
results.append(ProxyRaw(ip, int(port), protocol))
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from app.core.log import logger
|
|||||||
|
|
||||||
|
|
||||||
class ProxyListDownloadPlugin(BaseHTTPPlugin):
|
class ProxyListDownloadPlugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "proxylist_download"
|
name = "proxylist_download"
|
||||||
display_name = "ProxyListDownload"
|
display_name = "ProxyListDownload"
|
||||||
description = "从 ProxyListDownload API 获取代理"
|
description = "从 ProxyListDownload API 获取代理"
|
||||||
|
|||||||
@@ -6,19 +6,20 @@ from app.core.log import logger
|
|||||||
|
|
||||||
|
|
||||||
class ProxyScrapePlugin(BaseHTTPPlugin):
|
class ProxyScrapePlugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
"""
|
"""
|
||||||
从 ProxyScrape 公开 API 获取代理。
|
从 ProxyScrape 公开 API 获取代理库
|
||||||
覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性。
|
覆盖 http/https/socks4/socks5 全协议,专门用于测试插件系统的可扩展性
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "proxyscrape"
|
name = "proxyscrape"
|
||||||
display_name = "ProxyScrape测试源"
|
display_name = "ProxyScrape测试站"
|
||||||
description = "从 ProxyScrape API 获取各类型代理(HTTP/HTTPS/SOCKS4/SOCKS5),用于测试架构扩展"
|
description = "从 ProxyScrape API 获取各类型代理(HTTP/HTTPS/SOCKS4/SOCKS5),用于测试架构扩展"
|
||||||
enabled = True
|
enabled = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 使用多个公开 GitHub 代理列表作为源,稳定性较高
|
# 使用多个公开 GitHub 代理列表作为源,稳定性较差
|
||||||
self.urls = [
|
self.urls = [
|
||||||
("http", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/http.txt"),
|
("http", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/http.txt"),
|
||||||
("https", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/https.txt"),
|
("https", "https://raw.githubusercontent.com/monosans/proxy-list/main/proxies/https.txt"),
|
||||||
@@ -71,5 +72,5 @@ class ProxyScrapePlugin(BaseHTTPPlugin):
|
|||||||
ip = f"{random.randint(1, 223)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}"
|
ip = f"{random.randint(1, 223)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}"
|
||||||
port = random.randint(1024, 65535)
|
port = random.randint(1024, 65535)
|
||||||
test_proxies.append(ProxyRaw(ip, port, protocol))
|
test_proxies.append(ProxyRaw(ip, port, protocol))
|
||||||
logger.info(f"生成 {len(test_proxies)} 个测试代理: HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个")
|
logger.info(f"生成 {len(test_proxies)} 个测试代理 HTTP/HTTPS/SOCKS4/SOCKS5 各 3 个")
|
||||||
return test_proxies
|
return test_proxies
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ from app.core.log import logger
|
|||||||
|
|
||||||
|
|
||||||
class SpeedXPlugin(BaseHTTPPlugin):
|
class SpeedXPlugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "speedx"
|
name = "speedx"
|
||||||
display_name = "SpeedX代理源"
|
display_name = "SpeedX代理库"
|
||||||
description = "从 SpeedX GitHub 仓库获取 SOCKS 代理列表"
|
description = "从 SpeedX GitHub 仓库获取 SOCKS 代理列表"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -47,5 +48,5 @@ class SpeedXPlugin(BaseHTTPPlugin):
|
|||||||
results.append(ProxyRaw(ip, int(port), protocol))
|
results.append(ProxyRaw(ip, int(port), protocol))
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ VALID_PROTOCOLS = ("http", "https", "socks4", "socks5")
|
|||||||
|
|
||||||
|
|
||||||
class YunDaiLiPlugin(BaseHTTPPlugin):
|
class YunDaiLiPlugin(BaseHTTPPlugin):
|
||||||
|
default_config = {"max_pages": 5}
|
||||||
name = "yundaili"
|
name = "yundaili"
|
||||||
display_name = "云代理"
|
display_name = "云代理"
|
||||||
description = "从云代理网站爬取免费代理"
|
description = "从云代理网站爬取免费代理"
|
||||||
@@ -47,5 +48,5 @@ class YunDaiLiPlugin(BaseHTTPPlugin):
|
|||||||
results.append(ProxyRaw(ip, int(port), protocol))
|
results.append(ProxyRaw(ip, int(port), protocol))
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
logger.info(f"{self.display_name} 解析完成,获得 {len(results)} 个潜在代理")
|
logger.info(f"{self.display_name} 解析完成,获取 {len(results)} 个潜在代理")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -190,7 +190,12 @@ class ProxyRepository:
|
|||||||
params.append(max_score)
|
params.append(max_score)
|
||||||
|
|
||||||
where_clause = " AND ".join(conditions)
|
where_clause = " AND ".join(conditions)
|
||||||
order_clause = f"{sort_by} {sort_order}"
|
allowed_sort_by = {"ip", "port", "protocol", "score", "last_check"}
|
||||||
|
allowed_sort_order = {"ASC", "DESC"}
|
||||||
|
if sort_by not in allowed_sort_by or sort_order.upper() not in allowed_sort_order:
|
||||||
|
order_clause = "last_check DESC"
|
||||||
|
else:
|
||||||
|
order_clause = f"{sort_by} {sort_order.upper()}"
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
count_query = f"SELECT COUNT(*) FROM proxies WHERE {where_clause}"
|
count_query = f"SELECT COUNT(*) FROM proxies WHERE {where_clause}"
|
||||||
@@ -268,7 +273,8 @@ class ProxyRepository:
|
|||||||
async def clean_expired(db: aiosqlite.Connection, days: int) -> int:
|
async def clean_expired(db: aiosqlite.Connection, days: int) -> int:
|
||||||
try:
|
try:
|
||||||
await db.execute(
|
await db.execute(
|
||||||
"DELETE FROM proxies WHERE last_check < datetime('now', '-{} days')".format(days)
|
"DELETE FROM proxies WHERE last_check < datetime('now', '-' || ? || ' days')",
|
||||||
|
(days,),
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return db.total_changes
|
return db.total_changes
|
||||||
|
|||||||
@@ -33,22 +33,16 @@ class ValidationTaskRepository:
|
|||||||
try:
|
try:
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, ip, port, protocol FROM validation_tasks
|
UPDATE validation_tasks
|
||||||
WHERE status = 'pending'
|
SET status = 'processing', updated_at = CURRENT_TIMESTAMP
|
||||||
ORDER BY id ASC
|
WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
|
||||||
LIMIT 1
|
RETURNING id, ip, port, protocol
|
||||||
"""
|
"""
|
||||||
) as cursor:
|
) as cursor:
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
task_id = row[0]
|
return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
|
||||||
await db.execute(
|
|
||||||
"UPDATE validation_tasks SET status = 'processing', updated_at = CURRENT_TIMESTAMP WHERE id = ?",
|
|
||||||
(task_id,),
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
return {"id": task_id, "ip": row[1], "port": row[2], "protocol": row[3]}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"acquire_pending failed: {e}")
|
logger.error(f"acquire_pending failed: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -126,7 +120,8 @@ class ValidationTaskRepository:
|
|||||||
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
|
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
|
||||||
try:
|
try:
|
||||||
await db.execute(
|
await db.execute(
|
||||||
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-{} days')".format(days)
|
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
|
||||||
|
(days,),
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return db.total_changes
|
return db.total_changes
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""插件业务服务"""
|
"""插件业务服务"""
|
||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from app.core.db import get_db
|
from app.core.db import get_db
|
||||||
@@ -108,14 +109,13 @@ class PluginService:
|
|||||||
async def run_all_plugins(self) -> List[ProxyRaw]:
|
async def run_all_plugins(self) -> List[ProxyRaw]:
|
||||||
"""执行所有启用插件的爬取"""
|
"""执行所有启用插件的爬取"""
|
||||||
all_results: List[ProxyRaw] = []
|
all_results: List[ProxyRaw] = []
|
||||||
for plugin in registry.list_plugins():
|
tasks = [self.run_plugin(plugin.name) for plugin in registry.list_plugins() if plugin.enabled]
|
||||||
if not plugin.enabled:
|
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for results in results_list:
|
||||||
|
if isinstance(results, Exception):
|
||||||
|
logger.error(f"Run all plugins error: {results}")
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
results = await self.run_plugin(plugin.name)
|
|
||||||
all_results.extend(results)
|
all_results.extend(results)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Run all plugins error at {plugin.name}: {e}")
|
|
||||||
# 去重
|
# 去重
|
||||||
seen = set()
|
seen = set()
|
||||||
unique = []
|
unique = []
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ class SchedulerService:
|
|||||||
self.proxy_repo = proxy_repo
|
self.proxy_repo = proxy_repo
|
||||||
self.interval_minutes = 30
|
self.interval_minutes = 30
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
if self.running:
|
if self.running:
|
||||||
logger.warning("Scheduler already running")
|
logger.warning("Scheduler already running")
|
||||||
return
|
return
|
||||||
|
self._stop_event.clear()
|
||||||
self.running = True
|
self.running = True
|
||||||
await self.validation_queue.start()
|
await self.validation_queue.start()
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
self._task = asyncio.create_task(self._run_loop())
|
||||||
@@ -33,6 +35,7 @@ class SchedulerService:
|
|||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self._stop_event.set()
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
try:
|
try:
|
||||||
@@ -55,10 +58,10 @@ class SchedulerService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduler loop error: {e}")
|
logger.error(f"Scheduler loop error: {e}")
|
||||||
# 等待下一次
|
# 等待下一次
|
||||||
for _ in range(self.interval_minutes * 60):
|
try:
|
||||||
if not self.running:
|
await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_minutes * 60)
|
||||||
break
|
except asyncio.TimeoutError:
|
||||||
await asyncio.sleep(1)
|
pass
|
||||||
|
|
||||||
async def _do_validate_all(self):
|
async def _do_validate_all(self):
|
||||||
"""验证数据库中所有存量代理"""
|
"""验证数据库中所有存量代理"""
|
||||||
|
|||||||
@@ -25,8 +25,24 @@ class ValidatorService:
|
|||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.connect_timeout = connect_timeout
|
self.connect_timeout = connect_timeout
|
||||||
|
self.max_concurrency = max_concurrency
|
||||||
self.semaphore = asyncio.Semaphore(max_concurrency)
|
self.semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
# 共享 HTTP/HTTPS ClientSession
|
||||||
|
self._http_connector = aiohttp.TCPConnector(
|
||||||
|
ssl=False,
|
||||||
|
limit=max_concurrency,
|
||||||
|
limit_per_host=max_concurrency,
|
||||||
|
force_close=False,
|
||||||
|
)
|
||||||
|
self._timeout = aiohttp.ClientTimeout(
|
||||||
|
total=timeout, connect=connect_timeout
|
||||||
|
)
|
||||||
|
self._http_session = aiohttp.ClientSession(
|
||||||
|
connector=self._http_connector,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_test_url(self, protocol: str) -> str:
|
def _get_test_url(self, protocol: str) -> str:
|
||||||
"""获取测试 URL"""
|
"""获取测试 URL"""
|
||||||
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
|
urls = self.TEST_URLS.get(protocol.lower(), self.TEST_URLS["http"])
|
||||||
@@ -53,20 +69,14 @@ class ValidatorService:
|
|||||||
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
async def _validate_http(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||||
"""验证 HTTP/HTTPS 代理"""
|
"""验证 HTTP/HTTPS 代理"""
|
||||||
proxy_url = f"http://{ip}:{port}"
|
proxy_url = f"http://{ip}:{port}"
|
||||||
connector = aiohttp.TCPConnector(ssl=False, limit=0, force_close=True)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.timeout, connect=self.connect_timeout)
|
|
||||||
test_url = self._get_test_url(protocol)
|
test_url = self._get_test_url(protocol)
|
||||||
|
|
||||||
try:
|
async with self._http_session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
||||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
|
||||||
async with session.get(test_url, proxy=proxy_url, allow_redirects=True) as response:
|
|
||||||
if response.status in (200, 301, 302):
|
if response.status in (200, 301, 302):
|
||||||
latency = round((time.time() - start) * 1000, 2)
|
latency = round((time.time() - start) * 1000, 2)
|
||||||
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
logger.info(f"HTTP valid: {ip}:{port} ({protocol}) {latency}ms")
|
||||||
return True, latency
|
return True, latency
|
||||||
return False, 0.0
|
return False, 0.0
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
async def _validate_socks(self, ip: str, port: int, protocol: str, start: float) -> Tuple[bool, float]:
|
||||||
"""验证 SOCKS4/SOCKS5 代理"""
|
"""验证 SOCKS4/SOCKS5 代理"""
|
||||||
@@ -95,3 +105,7 @@ class ValidatorService:
|
|||||||
return False, 0.0
|
return False, 0.0
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭共享的 HTTP ClientSession"""
|
||||||
|
await self._http_session.close()
|
||||||
|
|||||||
Binary file not shown.
@@ -1,7 +1,7 @@
|
|||||||
"""pytest 配置文件和 fixtures"""
|
"""pytest 配置文件和 fixtures"""
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import pytest_asyncio
|
||||||
from typing import AsyncGenerator, Generator
|
from typing import AsyncGenerator
|
||||||
from httpx import AsyncClient, ASGITransport
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
from app.api import create_app
|
from app.api import create_app
|
||||||
@@ -9,24 +9,17 @@ from app.core.db import init_db, get_db
|
|||||||
from app.repositories.proxy_repo import ProxyRepository
|
from app.repositories.proxy_repo import ProxyRepository
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
|
||||||
"""创建事件循环"""
|
|
||||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
async def app():
|
async def app():
|
||||||
"""创建应用实例"""
|
"""创建应用实例"""
|
||||||
# 初始化测试数据库
|
# 初始化测试数据库
|
||||||
await init_db()
|
await init_db()
|
||||||
app = create_app()
|
app = create_app()
|
||||||
return app
|
async with app.router.lifespan_context(app):
|
||||||
|
yield app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def client(app) -> AsyncGenerator[AsyncClient, None]:
|
async def client(app) -> AsyncGenerator[AsyncClient, None]:
|
||||||
"""创建异步 HTTP 客户端"""
|
"""创建异步 HTTP 客户端"""
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
@@ -34,20 +27,20 @@ async def client(app) -> AsyncGenerator[AsyncClient, None]:
|
|||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def db():
|
async def db():
|
||||||
"""获取数据库连接"""
|
"""获取数据库连接"""
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
yield db
|
yield db
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def proxy_repo():
|
async def proxy_repo():
|
||||||
"""获取代理仓库"""
|
"""获取代理仓库"""
|
||||||
return ProxyRepository()
|
return ProxyRepository()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def sample_proxy(db, proxy_repo):
|
async def sample_proxy(db, proxy_repo):
|
||||||
"""创建一个测试代理"""
|
"""创建一个测试代理"""
|
||||||
await proxy_repo.insert_or_update(db, "192.168.1.1", 8080, "http", 50)
|
await proxy_repo.insert_or_update(db, "192.168.1.1", 8080, "http", 50)
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -121,14 +121,14 @@ class TestProxiesAPI:
|
|||||||
"""测试 GET /api/proxies/export/csv"""
|
"""测试 GET /api/proxies/export/csv"""
|
||||||
response = await client.get("/api/proxies/export/csv")
|
response = await client.get("/api/proxies/export/csv")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/csv"
|
assert response.headers["content-type"].startswith("text/csv")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_export_proxies_txt(self, client):
|
async def test_export_proxies_txt(self, client):
|
||||||
"""测试 GET /api/proxies/export/txt"""
|
"""测试 GET /api/proxies/export/txt"""
|
||||||
response = await client.get("/api/proxies/export/txt")
|
response = await client.get("/api/proxies/export/txt")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/plain"
|
assert response.headers["content-type"].startswith("text/plain")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_export_proxies_json(self, client):
|
async def test_export_proxies_json(self, client):
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user