from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Header, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Optional, List import asyncio import io import csv import json from datetime import datetime import re from contextlib import asynccontextmanager from core.sqlite import SQLiteManager from core.validator import ProxyValidator from core.plugin_manager import PluginManager from tasks_manager import TasksManager, ScheduledTasks from core.log import logger from config import Config from core.auth import verify_api_key, require_admin, PermissionLevel @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" db = SQLiteManager() await db.init_db() logger.info("API服务器启动啦~") yield logger.info("API服务器关闭啦~") app = FastAPI(title="代理池API", version="1.1.0", lifespan=lifespan) def format_datetime(datetime_str: str) -> str: """将数据库时间格式统一转换为ISO 8601格式""" if not datetime_str: return None if isinstance(datetime_str, str): if 'T' in datetime_str: return datetime_str if re.match(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', datetime_str): return datetime_str.replace(' ', 'T') + '.000Z' return datetime_str @app.exception_handler(ValidationError) async def validation_exception_handler(request: Request, exc: ValidationError): logger.error(f"参数验证失败: {exc}") return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"code": 422, "message": "参数验证失败呢~", "data": exc.errors()} ) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): logger.error(f"HTTP异常: {exc.status_code} - {exc.detail}") return JSONResponse( status_code=exc.status_code, content={"code": exc.status_code, "message": exc.detail, "data": None} ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): logger.error(f"未处理的异常: {exc}", exc_info=True) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"code": 500, "message": "服务器内部错误呢~", "data": None} ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) tasks_manager = TasksManager() scheduled_tasks = ScheduledTasks(tasks_manager) plugin_manager = PluginManager() active_websockets = set() websockets_lock = asyncio.Lock() def optional_auth(): if Config.REQUIRE_AUTH: return Depends(verify_api_key) return None async def broadcast_message(message: dict): """向所有WebSocket客户端广播消息""" async with websockets_lock: websockets_to_remove = [] tasks = [] for ws in active_websockets: try: tasks.append(ws.send_json(message)) except Exception as e: logger.error(f"发送WebSocket消息失败: {e}") websockets_to_remove.append(ws) if tasks: await asyncio.gather(*tasks, return_exceptions=True) for ws in websockets_to_remove: active_websockets.discard(ws) class ProxyRequest(BaseModel): page: int = Field(default=1, ge=1, description="页码,必须大于等于1") page_size: int = Field(default=20, ge=1, le=100, description="每页数量,必须在1-100之间") protocol: Optional[str] = None min_score: int = Field(default=0, ge=0, description="最低分数") max_score: Optional[int] = Field(default=None, ge=0, description="最高分数") sort_by: str = 'last_check' sort_order: str = 'DESC' @field_validator('protocol') @classmethod def validate_protocol(cls, v): if v is not None and v.lower() not in ['http', 'https', 'socks4', 'socks5']: raise ValueError('协议类型必须是 http, https, socks4 或 socks5') return v.lower() if v else v @field_validator('sort_by') @classmethod def validate_sort_by(cls, v): if v not in ['ip', 'port', 'protocol', 'score', 'last_check']: raise ValueError('排序字段必须是 ip, port, protocol, score 或 last_check') return v @field_validator('sort_order') @classmethod def validate_sort_order(cls, v): if v.upper() not in ['ASC', 'DESC']: raise ValueError('排序方式必须是 ASC 或 DESC') return v.upper() class ProxyDeleteItem(BaseModel): ip: str port: int @field_validator('port') @classmethod def validate_port(cls, v): if not 1 <= v <= 65535: raise ValueError('端口号必须在1-65535范围内') return v class DeleteProxiesRequest(BaseModel): proxies: List[ProxyDeleteItem] @field_validator('proxies') @classmethod def validate_proxies_count(cls, v): if len(v) > 1000: raise ValueError('单次最多删除1000个代理') return v class CrawlerRequest(BaseModel): num_validators: int = 50 class ScheduleRequest(BaseModel): enabled: bool interval_minutes: int = 60 @app.get("/") async def root(): return {"message": "欢迎使用代理池API~", "status": "running", "data": None} @app.get("/health") async def health_check(): try: db = SQLiteManager() await db.count_proxies() return { "status": "healthy", "timestamp": datetime.now().isoformat(), "database": "connected", "version": "1.0.0" } except Exception as e: logger.error(f"健康检查失败: {e}") return { "status": "unhealthy", "timestamp": datetime.now().isoformat(), "database": "disconnected", "error": str(e) } @app.get("/api/stats") async def get_stats(_permission: str = optional_auth()): try: db = SQLiteManager() stats = await db.get_stats() today_new = await db.get_today_new_count() stats['today_new'] = today_new return {"code": 200, "message": "获取统计信息成功啦~", "data": stats} except Exception as e: logger.error(f"获取统计信息失败: {e}") return {"code": 500, "message": "获取统计信息失败呢~", "data": None} @app.post("/api/proxies") async def get_proxies(request: ProxyRequest, _permission: str = optional_auth()): try: db = SQLiteManager() proxies = await db.get_proxies_paginated( page=request.page, page_size=request.page_size, protocol=request.protocol, min_score=request.min_score, max_score=request.max_score, sort_by=request.sort_by, sort_order=request.sort_order ) total = await db.get_proxies_total( protocol=request.protocol, min_score=request.min_score, max_score=request.max_score ) proxy_list = [] for proxy in proxies: proxy_list.append({ "ip": proxy[0], "port": proxy[1], "protocol": proxy[2], "score": proxy[3], "last_check": format_datetime(proxy[4]) }) return { "code": 200, "message": "获取代理列表成功啦~", "data": { "list": proxy_list, "total": total, "page": request.page, "page_size": request.page_size } } except Exception as e: logger.error(f"获取代理列表失败: {e}") return {"code": 500, "message": "获取代理列表失败呢~", "data": None} @app.get("/api/proxies/random") async def get_random_proxy(_permission: str = optional_auth()): db = SQLiteManager() proxy = await db.get_random_proxy() if proxy: return { "code": 200, "message": "获取随机代理成功啦~", "data": { "ip": proxy[0], "port": proxy[1], "protocol": proxy[2], "score": proxy[3], "last_check": format_datetime(proxy[4]) } } return {"code": 404, "message": "没有找到可用的代理呢~", "data": None} @app.get("/api/proxies/{ip}/{port}") async def get_proxy_detail(ip: str, port: int, _permission: str = optional_auth()): db = SQLiteManager() proxy = await db.get_proxy_detail(ip, port) if proxy: return { "code": 200, "message": "获取代理详情成功啦~", "data": { "ip": proxy[0], "port": proxy[1], "protocol": proxy[2], "score": proxy[3], "last_check": format_datetime(proxy[4]) } } return {"code": 404, "message": "代理不存在呢~", "data": None} @app.delete("/api/proxies/{ip}/{port}") async def delete_proxy(ip: str, port: int, _permission: str = Depends(require_admin)): db = SQLiteManager() await db.delete_proxy(ip, port) return {"code": 200, "message": "删除代理成功啦~", "data": None} @app.post("/api/proxies/batch-delete") async def batch_delete_proxies(request: DeleteProxiesRequest, _permission: str = Depends(require_admin)): db = SQLiteManager() proxy_tuples = [(item.ip, item.port) for item in request.proxies] deleted_count = await db.batch_delete_proxies(proxy_tuples) return {"code": 200, "message": f"批量删除 {deleted_count} 个代理成功啦~", "data": {"deleted_count": deleted_count}} @app.delete("/api/proxies/clean-invalid") async def clean_invalid_proxies(_permission: str = Depends(require_admin)): db = SQLiteManager() deleted_count = await db.clean_invalid_proxies() return {"code": 200, "message": f"清理了 {deleted_count} 个无效代理啦~", "data": {"deleted_count": deleted_count}} @app.get("/api/proxies/export/{format}") async def export_proxies(format: str, protocol: Optional[str] = None, _permission: str = optional_auth(), limit: int = 10000): try: db = SQLiteManager() if format not in ['csv', 'txt', 'json']: raise HTTPException(status_code=400, detail="不支持的导出格式呢~") if limit > 100000: raise HTTPException(status_code=400, detail="导出数量不能超过100000条呢~") async def generate_csv(): proxies = await db.get_all_proxies() if protocol: proxies = [p for p in proxies if p[2].lower() == protocol.lower()] proxies = proxies[:limit] output = [] output.append('IP,Port,Protocol,Score,Last Check') for proxy in proxies: output.append(f"{proxy[0]},{proxy[1]},{proxy[2]},{proxy[3]},{format_datetime(proxy[4])}") for line in output: yield line + '\n' async def generate_txt(): proxies = await db.get_all_proxies() if protocol: proxies = [p for p in proxies if p[2].lower() == protocol.lower()] proxies = proxies[:limit] for proxy in proxies: yield f"{proxy[0]}:{proxy[1]}\n" async def generate_json(): proxies = await db.get_all_proxies() if protocol: proxies = [p for p in proxies if p[2].lower() == protocol.lower()] proxies = proxies[:limit] proxy_list = [] for proxy in proxies: proxy_list.append({'ip': proxy[0], 'port': proxy[1], 'protocol': proxy[2], 'score': proxy[3], 'last_check': format_datetime(proxy[4])}) yield '[\n' for i, item in enumerate(proxy_list): if i > 0: yield ',\n' yield json.dumps(item, ensure_ascii=False, indent=2) yield '\n]' if format == 'csv': return StreamingResponse( generate_csv(), media_type='text/csv', headers={'Content-Disposition': 'attachment; filename=proxies.csv'} ) elif format == 'txt': return StreamingResponse( generate_txt(), media_type='text/plain', headers={'Content-Disposition': 'attachment; filename=proxies.txt'} ) elif format == 'json': return StreamingResponse( generate_json(), media_type='application/json', headers={'Content-Disposition': 'attachment; filename=proxies.json'} ) except HTTPException: raise except Exception as e: logger.error(f"导出代理失败: {e}") raise HTTPException(status_code=500, detail="导出代理失败呢~") @app.post("/api/crawler/start") async def start_crawler(request: CrawlerRequest, _permission: str = Depends(require_admin)): try: if tasks_manager.is_task_running(): return {"code": 400, "message": "任务正在运行中呢~"} async def progress_callback(data): await broadcast_message({"type": "progress", "data": data}) async def status_callback(data): await broadcast_message({"type": "status", "data": data}) tasks_manager.set_callbacks(progress_callback, status_callback) db = SQLiteManager() asyncio.create_task(tasks_manager.start_task(db, request.num_validators)) return {"code": 200, "message": "爬虫任务开始啦~", "data": None} except Exception as e: logger.error(f"启动爬虫失败: {e}") return {"code": 500, "message": "启动爬虫失败呢~", "data": None} @app.post("/api/crawler/stop") async def stop_crawler(_permission: str = Depends(require_admin)): if not tasks_manager.is_task_running(): return {"code": 400, "message": "没有运行中的任务呢~", "data": None} await tasks_manager.stop_task() return {"code": 200, "message": "爬虫任务停止啦~", "data": None} @app.get("/api/crawler/status") async def get_crawler_status(_permission: str = optional_auth()): return { "code": 200, "message": "获取爬虫状态成功啦~", "data": { "running": tasks_manager.is_task_running(), "stats": tasks_manager.get_stats() } } @app.post("/api/scheduler") async def set_scheduler(request: ScheduleRequest, _permission: str = Depends(require_admin)): if request.enabled: scheduled_tasks.start_scheduled(request.interval_minutes) return {"code": 200, "message": f"定时任务已启动,间隔 {request.interval_minutes} 分钟~", "data": None} else: scheduled_tasks.stop_scheduled() return {"code": 200, "message": "定时任务已停止~", "data": None} @app.get("/api/scheduler") async def get_scheduler_status(_permission: str = optional_auth()): return { "code": 200, "message": "获取定时任务状态成功啦~", "data": { "enabled": scheduled_tasks.is_scheduled, "interval_minutes": scheduled_tasks.interval_minutes } } @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None): if Config.REQUIRE_AUTH: if not token: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少认证token") logger.warning("WebSocket连接被拒绝:缺少token") return if token != Config.API_KEY and token != Config.ADMIN_API_KEY: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的token") logger.warning(f"WebSocket连接被拒绝:无效的token {token[:8]}...") return permission_level = PermissionLevel.ADMIN if token == Config.ADMIN_API_KEY else PermissionLevel.READ_ONLY logger.info(f"WebSocket连接成功,权限级别: {permission_level}") await websocket.accept() async with websockets_lock: active_websockets.add(websocket) try: await websocket.send_json({ "type": "status", "data": { "status": "connected", "message": "WebSocket连接成功啦~", "timestamp": datetime.now().isoformat() } }) while True: await websocket.receive_text() except WebSocketDisconnect: async with websockets_lock: active_websockets.discard(websocket) logger.info("WebSocket断开连接") except Exception as e: logger.error(f"WebSocket错误: {e}") async with websockets_lock: active_websockets.discard(websocket) @app.get("/api/plugins") async def get_plugins(_permission: str = optional_auth()): try: plugins_info = plugin_manager.get_all_plugin_info() return { "code": 200, "message": "获取插件列表成功啦~", "data": { "plugins": plugins_info } } except Exception as e: logger.error(f"获取插件列表失败: {e}") return {"code": 500, "message": "获取插件列表失败呢~", "data": None} class PluginToggleRequest(BaseModel): enabled: bool @app.put("/api/plugins/{plugin_id}/toggle") async def toggle_plugin(plugin_id: str, request: PluginToggleRequest, _permission: str = Depends(require_admin)): try: success = plugin_manager.toggle_plugin(plugin_id, request.enabled) if success: return { "code": 200, "message": f"插件 {plugin_id} 已{'启用' if request.enabled else '禁用'}啦~", "data": { "plugin_id": plugin_id, "enabled": request.enabled } } else: return {"code": 404, "message": "插件不存在呢~", "data": None} except Exception as e: logger.error(f"切换插件状态失败: {e}") return {"code": 500, "message": "切换插件状态失败呢~", "data": None} @app.post("/api/plugins/{plugin_id}/crawl") async def crawl_plugin(plugin_id: str, _permission: str = Depends(require_admin)): try: async def progress_callback(data): await broadcast_message({"type": "progress", "data": data}) async def status_callback(data): await broadcast_message({"type": "status", "data": data}) tasks_manager.set_callbacks(progress_callback, status_callback) db = SQLiteManager() results = await plugin_manager.run_plugin(plugin_id) for ip, port, protocol in results: await db.insert_proxy(ip, port, protocol) return { "code": 200, "message": f"插件 {plugin_id} 开始爬取啦~", "data": { "plugin_id": plugin_id, "proxy_count": len(results) } } except Exception as e: logger.error(f"插件爬取失败: {e}") return {"code": 500, "message": "插件爬取失败呢~", "data": None} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8923)