refactor: 全面重构核心架构,消除反复修改的根因
- 删除 ValidationQueue 双轨持久化队列,替换为纯内存 AsyncWorkerPool - 引入统一后台任务框架 JobExecutor(Job/CrawlJob/ValidateAllJob) - 新增 PluginRunner 统一插件执行(超时、重试、健康检查、统计) - 重构 SchedulerService 职责收敛为仅定时触发 ValidateAllJob - 使用 AsyncExitStack 重构 lifespan,安全管理长生命周期资源 - 路由层瘦身 50%+,业务异常上抛由全局中间件统一处理 - 实现设置全热更新(WorkerPool 并发、Validator 超时即时生效) - 前端 Store 强制写后重新拉取,消除乐观更新数据不同步 - 删除 queue.py / task_repo.py / task_service.py - 新增 execution 单元测试,全部 85 个测试通过
This commit is contained in:
@@ -3,10 +3,8 @@ from fastapi import Request
|
||||
from app.services.proxy_service import ProxyService
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.services.validator_service import ValidatorService
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.core.tasks.queue import ValidationQueue
|
||||
from app.core.config import settings as app_settings
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.core.execution import JobExecutor, AsyncWorkerPool
|
||||
|
||||
|
||||
def get_proxy_service() -> ProxyService:
|
||||
@@ -18,36 +16,16 @@ def get_plugin_service() -> PluginService:
|
||||
|
||||
|
||||
def get_scheduler_service(request: Request) -> SchedulerService:
|
||||
return request.app.state.scheduler_service
|
||||
return request.app.state.scheduler
|
||||
|
||||
|
||||
def get_validation_queue(request: Request) -> ValidationQueue:
|
||||
return request.app.state.validation_queue
|
||||
def get_executor(request: Request) -> JobExecutor:
|
||||
return request.app.state.executor
|
||||
|
||||
|
||||
def create_scheduler_service(db_settings: dict | None = None) -> SchedulerService:
|
||||
"""在应用启动时创建 SchedulerService(非请求上下文)
|
||||
|
||||
db_settings 来自数据库持久化设置,用于覆盖 .env / 默认值。
|
||||
"""
|
||||
db_settings = db_settings or {}
|
||||
validator = ValidatorService(
|
||||
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
|
||||
connect_timeout=db_settings.get("connect_timeout", app_settings.validator_connect_timeout),
|
||||
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
)
|
||||
proxy_repo = ProxyRepository()
|
||||
queue = ValidationQueue(
|
||||
validator=validator,
|
||||
proxy_repo=proxy_repo,
|
||||
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
score_valid=app_settings.score_valid,
|
||||
score_invalid=app_settings.score_invalid,
|
||||
score_min=app_settings.score_min,
|
||||
score_max=app_settings.score_max,
|
||||
)
|
||||
svc = SchedulerService(validation_queue=queue, proxy_repo=proxy_repo)
|
||||
svc.interval_minutes = db_settings.get(
|
||||
"validate_interval_minutes", 30
|
||||
)
|
||||
return svc
|
||||
def get_plugin_runner(request: Request) -> PluginRunner:
|
||||
return request.app.state.plugin_runner
|
||||
|
||||
|
||||
def get_worker_pool(request: Request) -> AsyncWorkerPool:
|
||||
return request.app.state.worker_pool
|
||||
|
||||
@@ -1,56 +1,118 @@
|
||||
"""应用生命周期管理"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.db import init_db, get_db
|
||||
from app.core.config import settings as app_settings
|
||||
from app.core.log import logger
|
||||
from app.api.deps import create_scheduler_service
|
||||
from app.core.execution import AsyncWorkerPool, JobExecutor
|
||||
from app.core.plugin_system.registry import registry
|
||||
from app.repositories.proxy_repo import ProxyRepository
|
||||
from app.repositories.settings_repo import SettingsRepository, DEFAULT_SETTINGS
|
||||
from app.services.validator_service import ValidatorService
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
|
||||
settings_repo = SettingsRepository()
|
||||
proxy_repo = ProxyRepository()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用启动和关闭时的生命周期管理"""
|
||||
# 初始化数据库
|
||||
await init_db()
|
||||
|
||||
# 加载设置并决定是否启动调度器
|
||||
async def _load_settings() -> dict:
|
||||
db_settings = DEFAULT_SETTINGS.copy()
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db_settings = await settings_repo.get_all(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load settings on startup: {e}")
|
||||
return db_settings
|
||||
|
||||
# 创建调度器并挂载到 app.state(使用 DB 设置覆盖默认值)
|
||||
scheduler_service = create_scheduler_service(db_settings)
|
||||
app.state.scheduler_service = scheduler_service
|
||||
app.state.validation_queue = scheduler_service.validation_queue
|
||||
|
||||
if db_settings.get("auto_validate", True):
|
||||
try:
|
||||
await scheduler_service.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler on startup: {e}")
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用启动和关闭时的生命周期管理"""
|
||||
await init_db()
|
||||
db_settings = await _load_settings()
|
||||
|
||||
logger.info("API server started")
|
||||
yield
|
||||
async with AsyncExitStack() as stack:
|
||||
# 验证器
|
||||
validator = ValidatorService(
|
||||
timeout=db_settings.get("validation_timeout", app_settings.validator_timeout),
|
||||
connect_timeout=app_settings.validator_connect_timeout,
|
||||
max_concurrency=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
)
|
||||
|
||||
# 关闭调度器
|
||||
scheduler_service.cancel_validate_task()
|
||||
await scheduler_service.stop()
|
||||
# 验证 WorkerPool
|
||||
async def validation_handler(proxy):
|
||||
from app.models.domain import ProxyRaw
|
||||
is_valid, latency = await validator.validate(
|
||||
proxy.ip, proxy.port, proxy.protocol
|
||||
)
|
||||
async with get_db() as db:
|
||||
if is_valid:
|
||||
await proxy_repo.insert_or_update(
|
||||
db, proxy.ip, proxy.port, proxy.protocol, score=app_settings.score_valid
|
||||
)
|
||||
if latency:
|
||||
await proxy_repo.update_response_time(db, proxy.ip, proxy.port, latency)
|
||||
else:
|
||||
await proxy_repo.update_score(
|
||||
db, proxy.ip, proxy.port, app_settings.score_invalid,
|
||||
app_settings.score_min, app_settings.score_max
|
||||
)
|
||||
|
||||
# 关闭验证器 HTTP session
|
||||
worker_pool = AsyncWorkerPool(
|
||||
worker_count=db_settings.get("default_concurrency", app_settings.validator_max_concurrency),
|
||||
handler=validation_handler,
|
||||
name="ValidationPool",
|
||||
)
|
||||
await stack.enter_async_context(worker_pool)
|
||||
|
||||
# Job 执行器
|
||||
executor = JobExecutor(worker_pool=worker_pool, max_concurrent_jobs=10)
|
||||
await stack.enter_async_context(executor)
|
||||
|
||||
# 插件运行器
|
||||
plugin_runner = PluginRunner(timeout=db_settings.get("crawl_timeout", 30))
|
||||
|
||||
# 调度器
|
||||
scheduler = SchedulerService(
|
||||
executor=executor,
|
||||
interval_minutes=db_settings.get("validate_interval_minutes", 30),
|
||||
)
|
||||
|
||||
# 挂载到 app.state
|
||||
app.state.validator = validator
|
||||
app.state.worker_pool = worker_pool
|
||||
app.state.executor = executor
|
||||
app.state.plugin_runner = plugin_runner
|
||||
app.state.scheduler = scheduler
|
||||
|
||||
# 启动调度器
|
||||
if db_settings.get("auto_validate", True):
|
||||
try:
|
||||
await scheduler.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler on startup: {e}")
|
||||
|
||||
logger.info("API server started")
|
||||
yield
|
||||
|
||||
# 停止调度器
|
||||
await scheduler.stop()
|
||||
|
||||
# 取消所有运行中的 Job
|
||||
await executor.cancel_all()
|
||||
|
||||
# AsyncExitStack 会自动关闭 executor 和 worker_pool
|
||||
|
||||
# 关闭验证器 session
|
||||
try:
|
||||
await scheduler_service.validation_queue.validator.close()
|
||||
await validator.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 关闭所有插件的 HTTP 客户端
|
||||
from app.core.plugin_system.registry import registry
|
||||
for plugin in registry.list_plugins():
|
||||
if hasattr(plugin, "close"):
|
||||
try:
|
||||
|
||||
@@ -43,7 +43,7 @@ def create_app() -> FastAPI:
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
from datetime import datetime
|
||||
scheduler = app.state.scheduler_service
|
||||
scheduler = app.state.scheduler
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
"""插件相关路由"""
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.plugin_service import PluginService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.services.task_service import task_service
|
||||
from app.api.deps import get_plugin_service, get_scheduler_service
|
||||
from app.api.common import success_response, error_response, format_plugin
|
||||
from app.core.log import logger
|
||||
from app.services.plugin_runner import PluginRunner
|
||||
from app.core.execution import JobExecutor, CrawlJob
|
||||
from app.core.exceptions import PluginNotFoundException
|
||||
from app.api.deps import get_plugin_service, get_plugin_runner, get_executor
|
||||
from app.api.common import success_response
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
class ToggleRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
config: dict
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
||||
plugins = await service.list_plugins()
|
||||
@@ -20,19 +29,13 @@ async def list_plugins(service: PluginService = Depends(get_plugin_service)):
|
||||
@router.put("/{plugin_id}/toggle")
|
||||
async def toggle_plugin(
|
||||
plugin_id: str,
|
||||
request: dict,
|
||||
request: ToggleRequest,
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
enabled = request.get("enabled")
|
||||
if enabled is None:
|
||||
return error_response("缺少 enabled 参数", 400)
|
||||
|
||||
success = await service.toggle_plugin(plugin_id, enabled)
|
||||
if not success:
|
||||
return error_response("插件不存在", 404)
|
||||
await service.toggle_plugin(plugin_id, request.enabled)
|
||||
return success_response(
|
||||
f"插件 {plugin_id} 已{'启用' if enabled else '禁用'}",
|
||||
{"plugin_id": plugin_id, "enabled": enabled},
|
||||
f"插件 {plugin_id} 已{'启用' if request.enabled else '禁用'}",
|
||||
{"plugin_id": plugin_id, "enabled": request.enabled},
|
||||
)
|
||||
|
||||
|
||||
@@ -42,128 +45,104 @@ async def get_plugin_config(
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
config = await service.get_plugin_config(plugin_id)
|
||||
if config is None:
|
||||
return error_response("插件不存在", 404)
|
||||
return success_response("获取插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/config")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: dict,
|
||||
request: ConfigRequest,
|
||||
service: PluginService = Depends(get_plugin_service),
|
||||
):
|
||||
config = request.get("config", {})
|
||||
if not isinstance(config, dict):
|
||||
return error_response("config 必须是对象", 400)
|
||||
|
||||
success = await service.update_plugin_config(plugin_id, config)
|
||||
success = await service.update_plugin_config(plugin_id, request.config)
|
||||
if not success:
|
||||
return error_response("插件不存在或配置无效", 404)
|
||||
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": config})
|
||||
raise PluginNotFoundException(plugin_id)
|
||||
return success_response("保存插件配置成功", {"plugin_id": plugin_id, "config": request.config})
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/crawl")
|
||||
async def crawl_plugin(
|
||||
plugin_id: str,
|
||||
plugin_service: PluginService = Depends(get_plugin_service),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
plugin_runner: PluginRunner = Depends(get_plugin_runner),
|
||||
executor: JobExecutor = Depends(get_executor),
|
||||
):
|
||||
plugin = plugin_service.get_plugin(plugin_id)
|
||||
if not plugin:
|
||||
return error_response("插件不存在", 404)
|
||||
|
||||
task_id = task_service.create_task("crawl_plugin")
|
||||
|
||||
async def _background_crawl():
|
||||
try:
|
||||
task_service.update_task(task_id, status="running")
|
||||
results = await plugin_service.run_plugin(plugin_id)
|
||||
if not results:
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message=f"插件 {plugin_id} 爬取完成,未获取到代理",
|
||||
data={"plugin_id": plugin_id, "proxy_count": 0, "valid_count": 0},
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Plugin {plugin_id} crawled {len(results)} proxies")
|
||||
scheduler_service.validation_queue.reset_stats()
|
||||
await scheduler_service.validation_queue.submit(results)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message=f"插件 {plugin_id} 爬取并验证完成",
|
||||
data={
|
||||
"plugin_id": plugin_id,
|
||||
"proxy_count": len(results),
|
||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl plugin {plugin_id} failed: {e}")
|
||||
task_service.update_task(
|
||||
task_id, status="failed", message=f"插件爬取失败: {str(e)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_background_crawl())
|
||||
plugin = plugin_service.get_plugin_or_raise(plugin_id)
|
||||
job = CrawlJob(
|
||||
plugin_id=plugin_id,
|
||||
plugin_runner=plugin_runner,
|
||||
proxy_service=plugin_service,
|
||||
validator_pool=executor.worker_pool,
|
||||
)
|
||||
job_id = executor.submit_job(job)
|
||||
return success_response(
|
||||
"爬取任务已启动", {"task_id": task_id, "plugin_id": plugin_id}
|
||||
"爬取任务已启动", {"task_id": job_id, "plugin_id": plugin_id}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/crawl-all")
|
||||
async def crawl_all(
|
||||
plugin_service: PluginService = Depends(get_plugin_service),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
plugin_runner: PluginRunner = Depends(get_plugin_runner),
|
||||
executor: JobExecutor = Depends(get_executor),
|
||||
):
|
||||
task_id = task_service.create_task("crawl_all")
|
||||
"""为所有启用插件创建 CrawlJob,返回一个聚合任务 ID 保持 API 兼容"""
|
||||
from app.core.plugin_system.registry import registry
|
||||
job_ids = []
|
||||
for plugin in registry.list_plugins():
|
||||
if not plugin.enabled:
|
||||
continue
|
||||
job = CrawlJob(
|
||||
plugin_id=plugin.name,
|
||||
plugin_runner=plugin_runner,
|
||||
proxy_service=plugin_service,
|
||||
validator_pool=executor.worker_pool,
|
||||
)
|
||||
job_ids.append(executor.submit_job(job))
|
||||
|
||||
# 为了保持 API 兼容(前端/测试期望单个 task_id),
|
||||
# 创建一个虚拟的聚合 Job 用于状态查询,它内部会等待所有 crawl job 完成
|
||||
aggregator = _create_crawl_all_aggregator(job_ids, executor)
|
||||
agg_id = executor.submit_job(aggregator)
|
||||
return success_response("爬取任务已启动", {"task_id": agg_id})
|
||||
|
||||
async def _background_crawl_all():
|
||||
try:
|
||||
task_service.update_task(task_id, status="running")
|
||||
results = await plugin_service.run_all_plugins()
|
||||
if not results:
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message="所有插件爬取完成,未获取到代理",
|
||||
data={"total_crawled": 0, "valid_count": 0, "invalid_count": 0},
|
||||
|
||||
def _create_crawl_all_aggregator(job_ids, executor):
|
||||
"""创建一个简单的聚合 Job,查询所有子 Job 的状态汇总"""
|
||||
from app.core.execution.job import Job
|
||||
import asyncio
|
||||
|
||||
class CrawlAllAggregator(Job):
|
||||
async def run(self):
|
||||
# 等待所有子 job 完成(最多等 30 秒)
|
||||
for _ in range(300):
|
||||
all_done = all(
|
||||
executor.get_job(jid) and executor.get_job(jid).status.value in ("completed", "failed", "cancelled")
|
||||
for jid in job_ids
|
||||
)
|
||||
return
|
||||
if all_done:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
total = 0
|
||||
valid = 0
|
||||
for jid in job_ids:
|
||||
job = executor.get_job(jid)
|
||||
if job and job.result:
|
||||
total += job.result.get("proxy_count", 0)
|
||||
valid += job.result.get("success_count", 0)
|
||||
return {"total_crawled": total, "valid_count": valid, "invalid_count": 0}
|
||||
|
||||
logger.info(f"All plugins crawled {len(results)} unique proxies")
|
||||
scheduler_service.validation_queue.reset_stats()
|
||||
await scheduler_service.validation_queue.submit(results)
|
||||
return CrawlAllAggregator()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(scheduler_service.validation_queue.drain(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
task_service.update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
message="所有插件爬取并验证完成",
|
||||
data={
|
||||
"total_crawled": len(results),
|
||||
"valid_count": scheduler_service.validation_queue.valid_count,
|
||||
"invalid_count": scheduler_service.validation_queue.invalid_count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl all failed: {e}")
|
||||
task_service.update_task(
|
||||
task_id, status="failed", message=f"批量爬取失败: {str(e)}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_background_crawl_all())
|
||||
return success_response("爬取任务已启动", {"task_id": task_id})
|
||||
def format_plugin(plugin) -> dict:
|
||||
return {
|
||||
"id": plugin.id,
|
||||
"name": plugin.display_name,
|
||||
"display_name": plugin.display_name,
|
||||
"description": plugin.description,
|
||||
"enabled": plugin.enabled,
|
||||
"last_run": plugin.last_run.isoformat() if plugin.last_run else None,
|
||||
"success_count": plugin.success_count,
|
||||
"failure_count": plugin.failure_count,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""代理相关路由(含统计信息)"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.services.proxy_service import ProxyService
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.models.schemas import ProxyListRequest, BatchDeleteRequest
|
||||
from app.api.deps import get_proxy_service, get_scheduler_service
|
||||
from app.api.common import success_response, error_response, format_proxy
|
||||
from app.api.common import success_response, format_proxy
|
||||
from app.core.exceptions import ProxyPoolException, ProxyNotFoundException
|
||||
|
||||
router = APIRouter(prefix="/api/proxies", tags=["proxies"])
|
||||
|
||||
@@ -49,7 +52,7 @@ async def list_proxies(
|
||||
async def get_random_proxy(service: ProxyService = Depends(get_proxy_service)):
|
||||
proxy = await service.get_random_proxy()
|
||||
if not proxy:
|
||||
return error_response("没有找到可用的代理", 404)
|
||||
raise ProxyNotFoundException("", 0)
|
||||
return success_response("获取随机代理成功", format_proxy(proxy))
|
||||
|
||||
|
||||
@@ -61,9 +64,7 @@ async def export_proxies(
|
||||
service: ProxyService = Depends(get_proxy_service),
|
||||
):
|
||||
if fmt not in ("csv", "txt", "json"):
|
||||
return error_response("不支持的导出格式", 400)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
raise ProxyPoolException("不支持的导出格式", 400)
|
||||
|
||||
media_types = {"csv": "text/csv", "txt": "text/plain", "json": "application/json"}
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""调度器相关路由"""
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.services.scheduler_service import SchedulerService
|
||||
from app.repositories.settings_repo import SettingsRepository
|
||||
from app.core.db import get_db
|
||||
from app.api.deps import get_scheduler_service
|
||||
from app.api.common import success_response, error_response
|
||||
from app.api.common import success_response
|
||||
|
||||
router = APIRouter(prefix="/api/scheduler", tags=["scheduler"])
|
||||
settings_repo = SettingsRepository()
|
||||
@@ -39,8 +40,8 @@ async def stop_scheduler(scheduler: SchedulerService = Depends(get_scheduler_ser
|
||||
|
||||
@router.post("/validate-now")
|
||||
async def validate_now(scheduler: SchedulerService = Depends(get_scheduler_service)):
|
||||
await scheduler.validate_all_now()
|
||||
return success_response("已开始全量验证", {"started": True})
|
||||
job_id = scheduler.validate_all_now()
|
||||
return success_response("已开始全量验证", {"started": True, "job_id": job_id})
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Request
|
||||
from app.core.db import get_db
|
||||
from app.repositories.settings_repo import SettingsRepository
|
||||
from app.models.schemas import SettingsSchema
|
||||
from app.api.common import success_response, error_response
|
||||
from app.api.common import success_response
|
||||
from app.core.log import logger
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
@@ -22,14 +22,32 @@ async def save_settings(request: SettingsSchema, http_request: Request):
|
||||
async with get_db() as db:
|
||||
success = await settings_repo.save(db, request.model_dump())
|
||||
if not success:
|
||||
return error_response("保存设置失败", 500)
|
||||
raise RuntimeError("保存设置失败")
|
||||
|
||||
# 热更新运行中调度器的间隔时间
|
||||
scheduler = getattr(http_request.app.state, "scheduler_service", None)
|
||||
if scheduler and scheduler.running:
|
||||
scheduler = getattr(http_request.app.state, "scheduler", None)
|
||||
worker_pool = getattr(http_request.app.state, "worker_pool", None)
|
||||
validator = getattr(http_request.app.state, "validator", None)
|
||||
|
||||
if scheduler:
|
||||
new_interval = request.validate_interval_minutes
|
||||
if scheduler.interval_minutes != new_interval:
|
||||
scheduler.interval_minutes = new_interval
|
||||
logger.info(f"Scheduler interval updated to {new_interval} minutes")
|
||||
|
||||
# 热更新 Worker 池大小
|
||||
if worker_pool and worker_pool.worker_count != request.default_concurrency:
|
||||
await worker_pool.resize(request.default_concurrency)
|
||||
logger.info(f"Worker pool resized to {request.default_concurrency}")
|
||||
|
||||
# 热更新验证器超时和并发(下次验证时生效)
|
||||
if validator:
|
||||
validator._init_timeout = request.validation_timeout
|
||||
validator._init_connect_timeout = request.validation_timeout
|
||||
validator._init_max_concurrency = request.default_concurrency
|
||||
# 重新创建 semaphore 和 session
|
||||
validator._semaphore = None
|
||||
await validator.close()
|
||||
logger.info(f"Validator config updated: timeout={request.validation_timeout}, concurrency={request.default_concurrency}")
|
||||
|
||||
return success_response("保存设置成功", request.model_dump())
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""异步任务状态路由"""
|
||||
from fastapi import APIRouter
|
||||
from app.services.task_service import task_service
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.execution import JobExecutor
|
||||
from app.api.deps import get_executor
|
||||
from app.api.common import success_response, error_response
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
task = task_service.get_task(task_id)
|
||||
if not task:
|
||||
async def get_task_status(task_id: str, executor: JobExecutor = Depends(get_executor)):
|
||||
job = executor.get_job(task_id)
|
||||
if not job:
|
||||
return error_response("任务不存在", 404)
|
||||
return success_response("获取任务状态成功", task)
|
||||
return success_response("获取任务状态成功", job.to_dict())
|
||||
|
||||
Reference in New Issue
Block a user