Files
ToDoList/api/app/main.py

171 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
import os
import time
import json
from app.config import CORS_ORIGINS, WEBUI_PATH, HOST, PORT
from app.database import init_db
from app.routers import api_router
from app.utils.logger import logger
from app.utils.auth import decode_access_token
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时
logger.info("应用启动中...")
init_db()
logger.info("数据库初始化完成")
yield
# 关闭时
logger.info("应用关闭中...")
# 创建 FastAPI 应用
app = FastAPI(
title="爱莉希雅待办事项 API",
description="Elysia ToDo - 个人信息管理应用",
version="1.0.0",
lifespan=lifespan,
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求日志中间件
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
# 记录请求信息
request_method = request.method
request_path = request.url.path
query_params = dict(request.query_params) if request.query_params else None
# 构建日志信息
log_parts = [f"请求开始 -> {request_method} {request_path}"]
if query_params:
log_parts.append(f"Query参数: {json.dumps(query_params, ensure_ascii=False)}")
# 尝试读取请求体(仅对有 body 的方法)
body_info = None
if request.method in ["POST", "PUT", "PATCH"]:
try:
body_bytes = await request.body()
if body_bytes:
try:
body_json = json.loads(body_bytes)
body_info = json.dumps(body_json, ensure_ascii=False)
except json.JSONDecodeError:
body_info = body_bytes.decode('utf-8', errors='replace')[:200]
log_parts.append(f"Body: {body_info}")
except Exception:
pass
logger.info(" | ".join(log_parts))
# 执行请求
response = await call_next(request)
# 计算耗时
process_time = (time.time() - start_time) * 1000
# 记录响应信息
logger.info(
f"请求完成 <- {request_method} {request_path} | "
f"状态码: {response.status_code} | 耗时: {process_time:.2f}ms"
)
return response
# 认证中间件(保护所有 /api/* 路由,除了 /api/auth/* 和 /health
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
path = request.url.path
# 不拦截健康检查、静态文件、auth 路由
if path == "/health" or not path.startswith("/api/") or path.startswith("/api/auth/"):
return await call_next(request)
auth_header = request.headers.get("Authorization", "")
token = auth_header.replace("Bearer ", "")
if not token:
return JSONResponse(status_code=401, content={"detail": "未登录"})
try:
decode_access_token(token)
except Exception:
return JSONResponse(status_code=401, content={"detail": "登录已过期,请重新登录"})
return await call_next(request)
# 全局异常处理器
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
# 使用 exc_info=True 记录完整堆栈信息
logger.error(
f"全局异常: {request.method} {request.url.path} | 错误: {str(exc)}",
exc_info=True
)
return JSONResponse(
status_code=500,
content={
"success": False,
"message": "服务器内部错误",
"error_code": "INTERNAL_ERROR"
}
)
# 注册路由
app.include_router(api_router)
# 健康检查(必须在 static mount 之前注册,否则会被静态文件拦截)
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "ok", "message": "服务运行正常"}
# SPA 静态文件回退路由(支持前端 History 模式路由)
if os.path.exists(WEBUI_PATH):
@app.get("/")
async def serve_index():
return FileResponse(os.path.join(WEBUI_PATH, "index.html"))
@app.get("/{full_path:path}")
async def spa_fallback(request: Request, full_path: str):
"""SPA 回退:先尝试提供真实文件,找不到则返回 index.html"""
# 规范化路径并防止路径穿越攻击
safe_path = os.path.normpath(os.path.join(WEBUI_PATH, full_path))
if not safe_path.startswith(os.path.normpath(WEBUI_PATH)):
return FileResponse(os.path.join(WEBUI_PATH, "index.html"))
if os.path.isfile(safe_path):
return FileResponse(safe_path)
return FileResponse(os.path.join(WEBUI_PATH, "index.html"))
logger.info(f"SPA 静态文件服务已配置: {WEBUI_PATH}")
else:
logger.warning(f"WebUI 目录不存在: {WEBUI_PATH}")
if __name__ == "__main__":
import uvicorn
logger.info(f"启动服务: {HOST}:{PORT}")
uvicorn.run(app, host=HOST, port=PORT)