Files
ProxyPool/api_server.py
2026-01-27 21:17:36 +08:00

554 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Header, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Optional, List
import asyncio
import io
import csv
import json
from datetime import datetime
import re
from contextlib import asynccontextmanager
from core.sqlite import SQLiteManager
from core.validator import ProxyValidator
from core.plugin_manager import PluginManager
from tasks_manager import TasksManager, ScheduledTasks
from core.log import logger
from config import Config
from core.auth import verify_api_key, require_admin, PermissionLevel
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
db = SQLiteManager()
await db.init_db()
logger.info("API服务器启动啦~")
yield
logger.info("API服务器关闭啦~")
app = FastAPI(title="代理池API", version="1.1.0", lifespan=lifespan)
def format_datetime(datetime_str: str) -> str:
"""将数据库时间格式统一转换为ISO 8601格式"""
if not datetime_str:
return None
if isinstance(datetime_str, str):
if 'T' in datetime_str:
return datetime_str
if re.match(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', datetime_str):
return datetime_str.replace(' ', 'T') + '.000Z'
return datetime_str
@app.exception_handler(ValidationError)
async def validation_exception_handler(request: Request, exc: ValidationError):
logger.error(f"参数验证失败: {exc}")
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"code": 422, "message": "参数验证失败呢~", "data": exc.errors()}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(f"HTTP异常: {exc.status_code} - {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": exc.detail, "data": None}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
logger.error(f"未处理的异常: {exc}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"code": 500, "message": "服务器内部错误呢~", "data": None}
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
tasks_manager = TasksManager()
scheduled_tasks = ScheduledTasks(tasks_manager)
plugin_manager = PluginManager()
active_websockets = set()
websockets_lock = asyncio.Lock()
def optional_auth():
if Config.REQUIRE_AUTH:
return Depends(verify_api_key)
return None
async def broadcast_message(message: dict):
"""向所有WebSocket客户端广播消息"""
async with websockets_lock:
websockets_to_remove = []
tasks = []
for ws in active_websockets:
try:
tasks.append(ws.send_json(message))
except Exception as e:
logger.error(f"发送WebSocket消息失败: {e}")
websockets_to_remove.append(ws)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
for ws in websockets_to_remove:
active_websockets.discard(ws)
class ProxyRequest(BaseModel):
page: int = Field(default=1, ge=1, description="页码必须大于等于1")
page_size: int = Field(default=20, ge=1, le=100, description="每页数量必须在1-100之间")
protocol: Optional[str] = None
min_score: int = Field(default=0, ge=0, description="最低分数")
max_score: Optional[int] = Field(default=None, ge=0, description="最高分数")
sort_by: str = 'last_check'
sort_order: str = 'DESC'
@field_validator('protocol')
@classmethod
def validate_protocol(cls, v):
if v is not None and v.lower() not in ['http', 'https', 'socks4', 'socks5']:
raise ValueError('协议类型必须是 http, https, socks4 或 socks5')
return v.lower() if v else v
@field_validator('sort_by')
@classmethod
def validate_sort_by(cls, v):
if v not in ['ip', 'port', 'protocol', 'score', 'last_check']:
raise ValueError('排序字段必须是 ip, port, protocol, score 或 last_check')
return v
@field_validator('sort_order')
@classmethod
def validate_sort_order(cls, v):
if v.upper() not in ['ASC', 'DESC']:
raise ValueError('排序方式必须是 ASC 或 DESC')
return v.upper()
class ProxyDeleteItem(BaseModel):
ip: str
port: int
@field_validator('port')
@classmethod
def validate_port(cls, v):
if not 1 <= v <= 65535:
raise ValueError('端口号必须在1-65535范围内')
return v
class DeleteProxiesRequest(BaseModel):
proxies: List[ProxyDeleteItem]
@field_validator('proxies')
@classmethod
def validate_proxies_count(cls, v):
if len(v) > 1000:
raise ValueError('单次最多删除1000个代理')
return v
class CrawlerRequest(BaseModel):
num_validators: int = 50
class ScheduleRequest(BaseModel):
enabled: bool
interval_minutes: int = 60
@app.get("/")
async def root():
return {"message": "欢迎使用代理池API~", "status": "running", "data": None}
@app.get("/health")
async def health_check():
try:
db = SQLiteManager()
await db.count_proxies()
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"database": "connected",
"version": "1.0.0"
}
except Exception as e:
logger.error(f"健康检查失败: {e}")
return {
"status": "unhealthy",
"timestamp": datetime.now().isoformat(),
"database": "disconnected",
"error": str(e)
}
@app.get("/api/stats")
async def get_stats(_permission: str = optional_auth()):
try:
db = SQLiteManager()
stats = await db.get_stats()
today_new = await db.get_today_new_count()
stats['today_new'] = today_new
return {"code": 200, "message": "获取统计信息成功啦~", "data": stats}
except Exception as e:
logger.error(f"获取统计信息失败: {e}")
return {"code": 500, "message": "获取统计信息失败呢~", "data": None}
@app.post("/api/proxies")
async def get_proxies(request: ProxyRequest, _permission: str = optional_auth()):
try:
db = SQLiteManager()
proxies = await db.get_proxies_paginated(
page=request.page,
page_size=request.page_size,
protocol=request.protocol,
min_score=request.min_score,
max_score=request.max_score,
sort_by=request.sort_by,
sort_order=request.sort_order
)
total = await db.get_proxies_total(
protocol=request.protocol,
min_score=request.min_score,
max_score=request.max_score
)
proxy_list = []
for proxy in proxies:
proxy_list.append({
"ip": proxy[0],
"port": proxy[1],
"protocol": proxy[2],
"score": proxy[3],
"last_check": format_datetime(proxy[4])
})
return {
"code": 200,
"message": "获取代理列表成功啦~",
"data": {
"list": proxy_list,
"total": total,
"page": request.page,
"page_size": request.page_size
}
}
except Exception as e:
logger.error(f"获取代理列表失败: {e}")
return {"code": 500, "message": "获取代理列表失败呢~", "data": None}
@app.get("/api/proxies/random")
async def get_random_proxy(_permission: str = optional_auth()):
db = SQLiteManager()
proxy = await db.get_random_proxy()
if proxy:
return {
"code": 200,
"message": "获取随机代理成功啦~",
"data": {
"ip": proxy[0],
"port": proxy[1],
"protocol": proxy[2],
"score": proxy[3],
"last_check": format_datetime(proxy[4])
}
}
return {"code": 404, "message": "没有找到可用的代理呢~", "data": None}
@app.get("/api/proxies/{ip}/{port}")
async def get_proxy_detail(ip: str, port: int, _permission: str = optional_auth()):
db = SQLiteManager()
proxy = await db.get_proxy_detail(ip, port)
if proxy:
return {
"code": 200,
"message": "获取代理详情成功啦~",
"data": {
"ip": proxy[0],
"port": proxy[1],
"protocol": proxy[2],
"score": proxy[3],
"last_check": format_datetime(proxy[4])
}
}
return {"code": 404, "message": "代理不存在呢~", "data": None}
@app.delete("/api/proxies/{ip}/{port}")
async def delete_proxy(ip: str, port: int, _permission: str = Depends(require_admin)):
db = SQLiteManager()
await db.delete_proxy(ip, port)
return {"code": 200, "message": "删除代理成功啦~", "data": None}
@app.post("/api/proxies/batch-delete")
async def batch_delete_proxies(request: DeleteProxiesRequest, _permission: str = Depends(require_admin)):
db = SQLiteManager()
proxy_tuples = [(item.ip, item.port) for item in request.proxies]
deleted_count = await db.batch_delete_proxies(proxy_tuples)
return {"code": 200, "message": f"批量删除 {deleted_count} 个代理成功啦~", "data": {"deleted_count": deleted_count}}
@app.delete("/api/proxies/clean-invalid")
async def clean_invalid_proxies(_permission: str = Depends(require_admin)):
db = SQLiteManager()
deleted_count = await db.clean_invalid_proxies()
return {"code": 200, "message": f"清理了 {deleted_count} 个无效代理啦~", "data": {"deleted_count": deleted_count}}
@app.get("/api/proxies/export/{format}")
async def export_proxies(format: str, protocol: Optional[str] = None, _permission: str = optional_auth(), limit: int = 10000):
try:
db = SQLiteManager()
if format not in ['csv', 'txt', 'json']:
raise HTTPException(status_code=400, detail="不支持的导出格式呢~")
if limit > 100000:
raise HTTPException(status_code=400, detail="导出数量不能超过100000条呢~")
async def generate_csv():
proxies = await db.get_all_proxies()
if protocol:
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
proxies = proxies[:limit]
output = []
output.append('IP,Port,Protocol,Score,Last Check')
for proxy in proxies:
output.append(f"{proxy[0]},{proxy[1]},{proxy[2]},{proxy[3]},{format_datetime(proxy[4])}")
for line in output:
yield line + '\n'
async def generate_txt():
proxies = await db.get_all_proxies()
if protocol:
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
proxies = proxies[:limit]
for proxy in proxies:
yield f"{proxy[0]}:{proxy[1]}\n"
async def generate_json():
proxies = await db.get_all_proxies()
if protocol:
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
proxies = proxies[:limit]
proxy_list = []
for proxy in proxies:
proxy_list.append({'ip': proxy[0], 'port': proxy[1], 'protocol': proxy[2], 'score': proxy[3], 'last_check': format_datetime(proxy[4])})
yield '[\n'
for i, item in enumerate(proxy_list):
if i > 0:
yield ',\n'
yield json.dumps(item, ensure_ascii=False, indent=2)
yield '\n]'
if format == 'csv':
return StreamingResponse(
generate_csv(),
media_type='text/csv',
headers={'Content-Disposition': 'attachment; filename=proxies.csv'}
)
elif format == 'txt':
return StreamingResponse(
generate_txt(),
media_type='text/plain',
headers={'Content-Disposition': 'attachment; filename=proxies.txt'}
)
elif format == 'json':
return StreamingResponse(
generate_json(),
media_type='application/json',
headers={'Content-Disposition': 'attachment; filename=proxies.json'}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"导出代理失败: {e}")
raise HTTPException(status_code=500, detail="导出代理失败呢~")
@app.post("/api/crawler/start")
async def start_crawler(request: CrawlerRequest, _permission: str = Depends(require_admin)):
try:
if tasks_manager.is_task_running():
return {"code": 400, "message": "任务正在运行中呢~"}
async def progress_callback(data):
await broadcast_message({"type": "progress", "data": data})
async def status_callback(data):
await broadcast_message({"type": "status", "data": data})
tasks_manager.set_callbacks(progress_callback, status_callback)
db = SQLiteManager()
async with ProxyValidator(max_concurrency=200) as validator:
asyncio.create_task(tasks_manager.start_task(db, validator, request.num_validators))
return {"code": 200, "message": "爬虫任务开始啦~", "data": None}
except Exception as e:
logger.error(f"启动爬虫失败: {e}")
return {"code": 500, "message": "启动爬虫失败呢~", "data": None}
@app.post("/api/crawler/stop")
async def stop_crawler(_permission: str = Depends(require_admin)):
if not tasks_manager.is_task_running():
return {"code": 400, "message": "没有运行中的任务呢~", "data": None}
await tasks_manager.stop_task()
return {"code": 200, "message": "爬虫任务停止啦~", "data": None}
@app.get("/api/crawler/status")
async def get_crawler_status(_permission: str = optional_auth()):
return {
"code": 200,
"message": "获取爬虫状态成功啦~",
"data": {
"running": tasks_manager.is_task_running(),
"stats": tasks_manager.get_stats()
}
}
@app.post("/api/scheduler")
async def set_scheduler(request: ScheduleRequest, _permission: str = Depends(require_admin)):
if request.enabled:
scheduled_tasks.start_scheduled(request.interval_minutes)
return {"code": 200, "message": f"定时任务已启动,间隔 {request.interval_minutes} 分钟~", "data": None}
else:
scheduled_tasks.stop_scheduled()
return {"code": 200, "message": "定时任务已停止~", "data": None}
@app.get("/api/scheduler")
async def get_scheduler_status(_permission: str = optional_auth()):
return {
"code": 200,
"message": "获取定时任务状态成功啦~",
"data": {
"enabled": scheduled_tasks.is_scheduled,
"interval_minutes": scheduled_tasks.interval_minutes
}
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None):
if Config.REQUIRE_AUTH:
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少认证token")
logger.warning("WebSocket连接被拒绝缺少token")
return
if token != Config.API_KEY and token != Config.ADMIN_API_KEY:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的token")
logger.warning(f"WebSocket连接被拒绝无效的token {token[:8]}...")
return
permission_level = PermissionLevel.ADMIN if token == Config.ADMIN_API_KEY else PermissionLevel.READ_ONLY
logger.info(f"WebSocket连接成功权限级别: {permission_level}")
await websocket.accept()
async with websockets_lock:
active_websockets.add(websocket)
try:
await websocket.send_json({
"type": "status",
"data": {
"status": "connected",
"message": "WebSocket连接成功啦~",
"timestamp": datetime.now().isoformat()
}
})
while True:
await websocket.receive_text()
except WebSocketDisconnect:
async with websockets_lock:
active_websockets.discard(websocket)
logger.info("WebSocket断开连接")
except Exception as e:
logger.error(f"WebSocket错误: {e}")
async with websockets_lock:
active_websockets.discard(websocket)
@app.get("/api/plugins")
async def get_plugins(_permission: str = optional_auth()):
try:
plugins_info = plugin_manager.get_all_plugin_info()
return {
"code": 200,
"message": "获取插件列表成功啦~",
"data": {
"plugins": plugins_info
}
}
except Exception as e:
logger.error(f"获取插件列表失败: {e}")
return {"code": 500, "message": "获取插件列表失败呢~", "data": None}
class PluginToggleRequest(BaseModel):
enabled: bool
@app.put("/api/plugins/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, request: PluginToggleRequest, _permission: str = Depends(require_admin)):
try:
success = plugin_manager.toggle_plugin(plugin_id, request.enabled)
if success:
return {
"code": 200,
"message": f"插件 {plugin_id}{'启用' if request.enabled else '禁用'}啦~",
"data": {
"plugin_id": plugin_id,
"enabled": request.enabled
}
}
else:
return {"code": 404, "message": "插件不存在呢~", "data": None}
except Exception as e:
logger.error(f"切换插件状态失败: {e}")
return {"code": 500, "message": "切换插件状态失败呢~", "data": None}
@app.post("/api/plugins/{plugin_id}/crawl")
async def crawl_plugin(plugin_id: str, _permission: str = Depends(require_admin)):
try:
async def progress_callback(data):
await broadcast_message({"type": "progress", "data": data})
async def status_callback(data):
await broadcast_message({"type": "status", "data": data})
tasks_manager.set_callbacks(progress_callback, status_callback)
db = SQLiteManager()
results = await plugin_manager.run_plugin(plugin_id)
for ip, port, protocol in results:
await db.insert_proxy(ip, port, protocol)
return {
"code": 200,
"message": f"插件 {plugin_id} 开始爬取啦~",
"data": {
"plugin_id": plugin_id,
"proxy_count": len(results)
}
}
except Exception as e:
logger.error(f"插件爬取失败: {e}")
return {"code": 500, "message": "插件爬取失败呢~", "data": None}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8923)