first commit
This commit is contained in:
553
api_server.py
Normal file
553
api_server.py
Normal file
@@ -0,0 +1,553 @@
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, Header, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationError
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
import io
|
||||
import csv
|
||||
import json
|
||||
from datetime import datetime
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from core.sqlite import SQLiteManager
|
||||
from core.validator import ProxyValidator
|
||||
from core.plugin_manager import PluginManager
|
||||
from tasks_manager import TasksManager, ScheduledTasks
|
||||
from core.log import logger
|
||||
from config import Config
|
||||
from core.auth import verify_api_key, require_admin, PermissionLevel
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
db = SQLiteManager()
|
||||
await db.init_db()
|
||||
logger.info("API服务器启动啦~")
|
||||
yield
|
||||
logger.info("API服务器关闭啦~")
|
||||
|
||||
app = FastAPI(title="代理池API", version="1.1.0", lifespan=lifespan)
|
||||
|
||||
def format_datetime(datetime_str: str) -> str:
|
||||
"""将数据库时间格式统一转换为ISO 8601格式"""
|
||||
if not datetime_str:
|
||||
return None
|
||||
|
||||
if isinstance(datetime_str, str):
|
||||
if 'T' in datetime_str:
|
||||
return datetime_str
|
||||
|
||||
if re.match(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', datetime_str):
|
||||
return datetime_str.replace(' ', 'T') + '.000Z'
|
||||
|
||||
return datetime_str
|
||||
|
||||
@app.exception_handler(ValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: ValidationError):
|
||||
logger.error(f"参数验证失败: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={"code": 422, "message": "参数验证失败呢~", "data": exc.errors()}
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
logger.error(f"HTTP异常: {exc.status_code} - {exc.detail}")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"code": exc.status_code, "message": exc.detail, "data": None}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"未处理的异常: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"code": 500, "message": "服务器内部错误呢~", "data": None}
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
tasks_manager = TasksManager()
|
||||
scheduled_tasks = ScheduledTasks(tasks_manager)
|
||||
plugin_manager = PluginManager()
|
||||
active_websockets = set()
|
||||
websockets_lock = asyncio.Lock()
|
||||
|
||||
def optional_auth():
|
||||
if Config.REQUIRE_AUTH:
|
||||
return Depends(verify_api_key)
|
||||
return None
|
||||
|
||||
async def broadcast_message(message: dict):
|
||||
"""向所有WebSocket客户端广播消息"""
|
||||
async with websockets_lock:
|
||||
websockets_to_remove = []
|
||||
tasks = []
|
||||
|
||||
for ws in active_websockets:
|
||||
try:
|
||||
tasks.append(ws.send_json(message))
|
||||
except Exception as e:
|
||||
logger.error(f"发送WebSocket消息失败: {e}")
|
||||
websockets_to_remove.append(ws)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for ws in websockets_to_remove:
|
||||
active_websockets.discard(ws)
|
||||
|
||||
class ProxyRequest(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="页码,必须大于等于1")
|
||||
page_size: int = Field(default=20, ge=1, le=100, description="每页数量,必须在1-100之间")
|
||||
protocol: Optional[str] = None
|
||||
min_score: int = Field(default=0, ge=0, description="最低分数")
|
||||
max_score: Optional[int] = Field(default=None, ge=0, description="最高分数")
|
||||
sort_by: str = 'last_check'
|
||||
sort_order: str = 'DESC'
|
||||
|
||||
@field_validator('protocol')
|
||||
@classmethod
|
||||
def validate_protocol(cls, v):
|
||||
if v is not None and v.lower() not in ['http', 'https', 'socks4', 'socks5']:
|
||||
raise ValueError('协议类型必须是 http, https, socks4 或 socks5')
|
||||
return v.lower() if v else v
|
||||
|
||||
@field_validator('sort_by')
|
||||
@classmethod
|
||||
def validate_sort_by(cls, v):
|
||||
if v not in ['ip', 'port', 'protocol', 'score', 'last_check']:
|
||||
raise ValueError('排序字段必须是 ip, port, protocol, score 或 last_check')
|
||||
return v
|
||||
|
||||
@field_validator('sort_order')
|
||||
@classmethod
|
||||
def validate_sort_order(cls, v):
|
||||
if v.upper() not in ['ASC', 'DESC']:
|
||||
raise ValueError('排序方式必须是 ASC 或 DESC')
|
||||
return v.upper()
|
||||
|
||||
class ProxyDeleteItem(BaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@field_validator('port')
|
||||
@classmethod
|
||||
def validate_port(cls, v):
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError('端口号必须在1-65535范围内')
|
||||
return v
|
||||
|
||||
class DeleteProxiesRequest(BaseModel):
|
||||
proxies: List[ProxyDeleteItem]
|
||||
|
||||
@field_validator('proxies')
|
||||
@classmethod
|
||||
def validate_proxies_count(cls, v):
|
||||
if len(v) > 1000:
|
||||
raise ValueError('单次最多删除1000个代理')
|
||||
return v
|
||||
|
||||
class CrawlerRequest(BaseModel):
|
||||
num_validators: int = 50
|
||||
|
||||
class ScheduleRequest(BaseModel):
|
||||
enabled: bool
|
||||
interval_minutes: int = 60
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "欢迎使用代理池API~", "status": "running", "data": None}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
try:
|
||||
db = SQLiteManager()
|
||||
await db.count_proxies()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"database": "connected",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查失败: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"database": "disconnected",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@app.get("/api/stats")
|
||||
async def get_stats(_permission: str = optional_auth()):
|
||||
try:
|
||||
db = SQLiteManager()
|
||||
stats = await db.get_stats()
|
||||
today_new = await db.get_today_new_count()
|
||||
stats['today_new'] = today_new
|
||||
return {"code": 200, "message": "获取统计信息成功啦~", "data": stats}
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}")
|
||||
return {"code": 500, "message": "获取统计信息失败呢~", "data": None}
|
||||
|
||||
@app.post("/api/proxies")
|
||||
async def get_proxies(request: ProxyRequest, _permission: str = optional_auth()):
|
||||
try:
|
||||
db = SQLiteManager()
|
||||
proxies = await db.get_proxies_paginated(
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
protocol=request.protocol,
|
||||
min_score=request.min_score,
|
||||
max_score=request.max_score,
|
||||
sort_by=request.sort_by,
|
||||
sort_order=request.sort_order
|
||||
)
|
||||
total = await db.get_proxies_total(
|
||||
protocol=request.protocol,
|
||||
min_score=request.min_score,
|
||||
max_score=request.max_score
|
||||
)
|
||||
|
||||
proxy_list = []
|
||||
for proxy in proxies:
|
||||
proxy_list.append({
|
||||
"ip": proxy[0],
|
||||
"port": proxy[1],
|
||||
"protocol": proxy[2],
|
||||
"score": proxy[3],
|
||||
"last_check": format_datetime(proxy[4])
|
||||
})
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取代理列表成功啦~",
|
||||
"data": {
|
||||
"list": proxy_list,
|
||||
"total": total,
|
||||
"page": request.page,
|
||||
"page_size": request.page_size
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取代理列表失败: {e}")
|
||||
return {"code": 500, "message": "获取代理列表失败呢~", "data": None}
|
||||
|
||||
@app.get("/api/proxies/random")
|
||||
async def get_random_proxy(_permission: str = optional_auth()):
|
||||
db = SQLiteManager()
|
||||
proxy = await db.get_random_proxy()
|
||||
if proxy:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取随机代理成功啦~",
|
||||
"data": {
|
||||
"ip": proxy[0],
|
||||
"port": proxy[1],
|
||||
"protocol": proxy[2],
|
||||
"score": proxy[3],
|
||||
"last_check": format_datetime(proxy[4])
|
||||
}
|
||||
}
|
||||
return {"code": 404, "message": "没有找到可用的代理呢~", "data": None}
|
||||
|
||||
@app.get("/api/proxies/{ip}/{port}")
|
||||
async def get_proxy_detail(ip: str, port: int, _permission: str = optional_auth()):
|
||||
db = SQLiteManager()
|
||||
proxy = await db.get_proxy_detail(ip, port)
|
||||
if proxy:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取代理详情成功啦~",
|
||||
"data": {
|
||||
"ip": proxy[0],
|
||||
"port": proxy[1],
|
||||
"protocol": proxy[2],
|
||||
"score": proxy[3],
|
||||
"last_check": format_datetime(proxy[4])
|
||||
}
|
||||
}
|
||||
return {"code": 404, "message": "代理不存在呢~", "data": None}
|
||||
|
||||
@app.delete("/api/proxies/{ip}/{port}")
|
||||
async def delete_proxy(ip: str, port: int, _permission: str = Depends(require_admin)):
|
||||
db = SQLiteManager()
|
||||
await db.delete_proxy(ip, port)
|
||||
return {"code": 200, "message": "删除代理成功啦~", "data": None}
|
||||
|
||||
@app.post("/api/proxies/batch-delete")
|
||||
async def batch_delete_proxies(request: DeleteProxiesRequest, _permission: str = Depends(require_admin)):
|
||||
db = SQLiteManager()
|
||||
proxy_tuples = [(item.ip, item.port) for item in request.proxies]
|
||||
deleted_count = await db.batch_delete_proxies(proxy_tuples)
|
||||
return {"code": 200, "message": f"批量删除 {deleted_count} 个代理成功啦~", "data": {"deleted_count": deleted_count}}
|
||||
|
||||
@app.delete("/api/proxies/clean-invalid")
|
||||
async def clean_invalid_proxies(_permission: str = Depends(require_admin)):
|
||||
db = SQLiteManager()
|
||||
deleted_count = await db.clean_invalid_proxies()
|
||||
return {"code": 200, "message": f"清理了 {deleted_count} 个无效代理啦~", "data": {"deleted_count": deleted_count}}
|
||||
|
||||
@app.get("/api/proxies/export/{format}")
|
||||
async def export_proxies(format: str, protocol: Optional[str] = None, _permission: str = optional_auth(), limit: int = 10000):
|
||||
try:
|
||||
db = SQLiteManager()
|
||||
|
||||
if format not in ['csv', 'txt', 'json']:
|
||||
raise HTTPException(status_code=400, detail="不支持的导出格式呢~")
|
||||
|
||||
if limit > 100000:
|
||||
raise HTTPException(status_code=400, detail="导出数量不能超过100000条呢~")
|
||||
|
||||
async def generate_csv():
|
||||
proxies = await db.get_all_proxies()
|
||||
if protocol:
|
||||
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
|
||||
|
||||
proxies = proxies[:limit]
|
||||
|
||||
output = []
|
||||
output.append('IP,Port,Protocol,Score,Last Check')
|
||||
for proxy in proxies:
|
||||
output.append(f"{proxy[0]},{proxy[1]},{proxy[2]},{proxy[3]},{format_datetime(proxy[4])}")
|
||||
|
||||
for line in output:
|
||||
yield line + '\n'
|
||||
|
||||
async def generate_txt():
|
||||
proxies = await db.get_all_proxies()
|
||||
if protocol:
|
||||
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
|
||||
|
||||
proxies = proxies[:limit]
|
||||
|
||||
for proxy in proxies:
|
||||
yield f"{proxy[0]}:{proxy[1]}\n"
|
||||
|
||||
async def generate_json():
|
||||
proxies = await db.get_all_proxies()
|
||||
if protocol:
|
||||
proxies = [p for p in proxies if p[2].lower() == protocol.lower()]
|
||||
|
||||
proxies = proxies[:limit]
|
||||
|
||||
proxy_list = []
|
||||
for proxy in proxies:
|
||||
proxy_list.append({'ip': proxy[0], 'port': proxy[1], 'protocol': proxy[2], 'score': proxy[3], 'last_check': format_datetime(proxy[4])})
|
||||
|
||||
yield '[\n'
|
||||
for i, item in enumerate(proxy_list):
|
||||
if i > 0:
|
||||
yield ',\n'
|
||||
yield json.dumps(item, ensure_ascii=False, indent=2)
|
||||
yield '\n]'
|
||||
|
||||
if format == 'csv':
|
||||
return StreamingResponse(
|
||||
generate_csv(),
|
||||
media_type='text/csv',
|
||||
headers={'Content-Disposition': 'attachment; filename=proxies.csv'}
|
||||
)
|
||||
|
||||
elif format == 'txt':
|
||||
return StreamingResponse(
|
||||
generate_txt(),
|
||||
media_type='text/plain',
|
||||
headers={'Content-Disposition': 'attachment; filename=proxies.txt'}
|
||||
)
|
||||
|
||||
elif format == 'json':
|
||||
return StreamingResponse(
|
||||
generate_json(),
|
||||
media_type='application/json',
|
||||
headers={'Content-Disposition': 'attachment; filename=proxies.json'}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导出代理失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="导出代理失败呢~")
|
||||
|
||||
@app.post("/api/crawler/start")
|
||||
async def start_crawler(request: CrawlerRequest, _permission: str = Depends(require_admin)):
|
||||
try:
|
||||
if tasks_manager.is_task_running():
|
||||
return {"code": 400, "message": "任务正在运行中呢~"}
|
||||
|
||||
async def progress_callback(data):
|
||||
await broadcast_message({"type": "progress", "data": data})
|
||||
|
||||
async def status_callback(data):
|
||||
await broadcast_message({"type": "status", "data": data})
|
||||
|
||||
tasks_manager.set_callbacks(progress_callback, status_callback)
|
||||
|
||||
db = SQLiteManager()
|
||||
async with ProxyValidator(max_concurrency=200) as validator:
|
||||
asyncio.create_task(tasks_manager.start_task(db, validator, request.num_validators))
|
||||
|
||||
return {"code": 200, "message": "爬虫任务开始啦~", "data": None}
|
||||
except Exception as e:
|
||||
logger.error(f"启动爬虫失败: {e}")
|
||||
return {"code": 500, "message": "启动爬虫失败呢~", "data": None}
|
||||
|
||||
@app.post("/api/crawler/stop")
|
||||
async def stop_crawler(_permission: str = Depends(require_admin)):
|
||||
if not tasks_manager.is_task_running():
|
||||
return {"code": 400, "message": "没有运行中的任务呢~", "data": None}
|
||||
|
||||
await tasks_manager.stop_task()
|
||||
return {"code": 200, "message": "爬虫任务停止啦~", "data": None}
|
||||
|
||||
@app.get("/api/crawler/status")
|
||||
async def get_crawler_status(_permission: str = optional_auth()):
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取爬虫状态成功啦~",
|
||||
"data": {
|
||||
"running": tasks_manager.is_task_running(),
|
||||
"stats": tasks_manager.get_stats()
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/api/scheduler")
|
||||
async def set_scheduler(request: ScheduleRequest, _permission: str = Depends(require_admin)):
|
||||
if request.enabled:
|
||||
scheduled_tasks.start_scheduled(request.interval_minutes)
|
||||
return {"code": 200, "message": f"定时任务已启动,间隔 {request.interval_minutes} 分钟~", "data": None}
|
||||
else:
|
||||
scheduled_tasks.stop_scheduled()
|
||||
return {"code": 200, "message": "定时任务已停止~", "data": None}
|
||||
|
||||
@app.get("/api/scheduler")
|
||||
async def get_scheduler_status(_permission: str = optional_auth()):
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取定时任务状态成功啦~",
|
||||
"data": {
|
||||
"enabled": scheduled_tasks.is_scheduled,
|
||||
"interval_minutes": scheduled_tasks.interval_minutes
|
||||
}
|
||||
}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None):
|
||||
if Config.REQUIRE_AUTH:
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少认证token")
|
||||
logger.warning("WebSocket连接被拒绝:缺少token")
|
||||
return
|
||||
|
||||
if token != Config.API_KEY and token != Config.ADMIN_API_KEY:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的token")
|
||||
logger.warning(f"WebSocket连接被拒绝:无效的token {token[:8]}...")
|
||||
return
|
||||
|
||||
permission_level = PermissionLevel.ADMIN if token == Config.ADMIN_API_KEY else PermissionLevel.READ_ONLY
|
||||
logger.info(f"WebSocket连接成功,权限级别: {permission_level}")
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
async with websockets_lock:
|
||||
active_websockets.add(websocket)
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"data": {
|
||||
"status": "connected",
|
||||
"message": "WebSocket连接成功啦~",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
})
|
||||
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
async with websockets_lock:
|
||||
active_websockets.discard(websocket)
|
||||
logger.info("WebSocket断开连接")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
async with websockets_lock:
|
||||
active_websockets.discard(websocket)
|
||||
|
||||
@app.get("/api/plugins")
|
||||
async def get_plugins(_permission: str = optional_auth()):
|
||||
try:
|
||||
plugins_info = plugin_manager.get_all_plugin_info()
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "获取插件列表成功啦~",
|
||||
"data": {
|
||||
"plugins": plugins_info
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件列表失败: {e}")
|
||||
return {"code": 500, "message": "获取插件列表失败呢~", "data": None}
|
||||
|
||||
class PluginToggleRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
@app.put("/api/plugins/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, request: PluginToggleRequest, _permission: str = Depends(require_admin)):
|
||||
try:
|
||||
success = plugin_manager.toggle_plugin(plugin_id, request.enabled)
|
||||
if success:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": f"插件 {plugin_id} 已{'启用' if request.enabled else '禁用'}啦~",
|
||||
"data": {
|
||||
"plugin_id": plugin_id,
|
||||
"enabled": request.enabled
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {"code": 404, "message": "插件不存在呢~", "data": None}
|
||||
except Exception as e:
|
||||
logger.error(f"切换插件状态失败: {e}")
|
||||
return {"code": 500, "message": "切换插件状态失败呢~", "data": None}
|
||||
|
||||
@app.post("/api/plugins/{plugin_id}/crawl")
|
||||
async def crawl_plugin(plugin_id: str, _permission: str = Depends(require_admin)):
|
||||
try:
|
||||
async def progress_callback(data):
|
||||
await broadcast_message({"type": "progress", "data": data})
|
||||
|
||||
async def status_callback(data):
|
||||
await broadcast_message({"type": "status", "data": data})
|
||||
|
||||
tasks_manager.set_callbacks(progress_callback, status_callback)
|
||||
|
||||
db = SQLiteManager()
|
||||
results = await plugin_manager.run_plugin(plugin_id)
|
||||
|
||||
for ip, port, protocol in results:
|
||||
await db.insert_proxy(ip, port, protocol)
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"message": f"插件 {plugin_id} 开始爬取啦~",
|
||||
"data": {
|
||||
"plugin_id": plugin_id,
|
||||
"proxy_count": len(results)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"插件爬取失败: {e}")
|
||||
return {"code": 500, "message": "插件爬取失败呢~", "data": None}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8923)
|
||||
Reference in New Issue
Block a user