Files
ProxyPool/app/repositories/task_repo.py
祀梦 635c524a7e refactor(backend): optimize database safety, validator performance, and scheduler concurrency
- Fix SQL injection risks in proxy_repo and task_repo
- Atomic acquire_pending with UPDATE ... RETURNING
- Reuse aiohttp ClientSession in ValidatorService
- Replace polling with asyncio.Event in SchedulerService
- Optimize ValidationQueue.drain with asyncio.Condition
- Concurrent plugin crawling with asyncio.gather
- Unify ProxyRaw model import path
- Fix test baseline and remove tracked __pycache__ files
2026-04-04 14:43:31 +08:00

131 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""验证任务队列持久化层"""
import aiosqlite
from typing import List, Optional
from app.models.domain import ProxyRaw
from app.core.log import logger
class ValidationTaskRepository:
"""验证任务 Repository —— 支持队列持久化"""
@staticmethod
async def insert_batch(db: aiosqlite.Connection, proxies: List[ProxyRaw]) -> int:
if not proxies:
return 0
try:
rows = [(p.ip, p.port, p.protocol) for p in proxies]
await db.executemany(
"""
INSERT INTO validation_tasks (ip, port, protocol, status, created_at)
VALUES (?, ?, ?, 'pending', CURRENT_TIMESTAMP)
""",
rows,
)
await db.commit()
return len(rows)
except Exception as e:
logger.error(f"insert_batch validation tasks failed: {e}")
return 0
@staticmethod
async def acquire_pending(db: aiosqlite.Connection) -> Optional[dict]:
"""原子性地获取一个 pending 任务并将其标记为 processing"""
try:
async with db.execute(
"""
UPDATE validation_tasks
SET status = 'processing', updated_at = CURRENT_TIMESTAMP
WHERE id = (SELECT id FROM validation_tasks WHERE status = 'pending' ORDER BY id ASC LIMIT 1)
RETURNING id, ip, port, protocol
"""
) as cursor:
row = await cursor.fetchone()
if not row:
return None
return {"id": row[0], "ip": row[1], "port": row[2], "protocol": row[3]}
except Exception as e:
logger.error(f"acquire_pending failed: {e}")
return None
@staticmethod
async def complete_task(
db: aiosqlite.Connection,
task_id: int,
is_valid: bool,
response_time_ms: Optional[float] = None,
) -> bool:
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'completed',
result = ?,
response_time_ms = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
("valid" if is_valid else "invalid", response_time_ms, task_id),
)
await db.commit()
return True
except Exception as e:
logger.error(f"complete_task failed: {e}")
return False
@staticmethod
async def fail_task(db: aiosqlite.Connection, task_id: int) -> bool:
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'failed',
result = 'invalid',
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(task_id,),
)
await db.commit()
return True
except Exception as e:
logger.error(f"fail_task failed: {e}")
return False
@staticmethod
async def get_pending_count(db: aiosqlite.Connection) -> int:
async with db.execute(
"SELECT COUNT(*) FROM validation_tasks WHERE status = 'pending'"
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
@staticmethod
async def reset_processing(db: aiosqlite.Connection) -> int:
"""将异常中断的 processing 任务重置为 pending用于启动恢复"""
try:
await db.execute(
"""
UPDATE validation_tasks
SET status = 'pending', updated_at = CURRENT_TIMESTAMP
WHERE status = 'processing'
"""
)
await db.commit()
return db.total_changes
except Exception as e:
logger.error(f"reset_processing failed: {e}")
return 0
@staticmethod
async def cleanup_old(db: aiosqlite.Connection, days: int = 7) -> int:
try:
await db.execute(
"DELETE FROM validation_tasks WHERE updated_at < datetime('now', '-' || ? || ' days')",
(days,),
)
await db.commit()
return db.total_changes
except Exception as e:
logger.error(f"cleanup_old tasks failed: {e}")
return 0