feat: add data backup/import, goal step ordering, and PostgreSQL migration

- Add GET /api/backup/export and POST /api/backup/import endpoints for full data backup
- Add drag-and-drop reorder for goal steps with PUT /api/goals/{id}/steps/reorder
- Auto-assign sort_order on step creation (preserves creation order)
- Fix duplicate milestone rendering in goal detail page
- Add category management button in goal dialog
- Migrate database default from SQLite to PostgreSQL
- Fix router guard redirect loop for logged-in users on setup/login pages
- Fix ALTER TABLE ADD COLUMN crash on callable defaults (uuid.uuid4)
- Add auth status rate limiter and token version caching
- Update CLAUDE.md to reflect current architecture

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
祀梦
2026-05-18 00:02:18 +08:00
parent 0ab719500b
commit 5048de4fa1
21 changed files with 543 additions and 225 deletions

View File

@@ -9,8 +9,10 @@ _logger = logging.getLogger("app.config")
_BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 数据库配置
DATABASE_PATH = os.path.join(_BASE_DIR, "data", "todo.db")
DATABASE_URL = f"sqlite:///{DATABASE_PATH}"
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://ToDoList:53N2PTSjMBPDy6zY@192.168.1.86:5432/ToDoList",
)
# WebUI 配置
WEBUI_PATH = os.path.join(_BASE_DIR, "webui")
@@ -49,3 +51,4 @@ def _load_jwt_secret() -> str:
JWT_SECRET = _load_jwt_secret()
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24小时
ACCESS_TOKEN_EXPIRE_SECONDS = ACCESS_TOKEN_EXPIRE_MINUTES * 60

View File

@@ -84,7 +84,11 @@ def init_db():
elif col.default is not None:
default_val = col.default.arg
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
if isinstance(default_val, bool):
if callable(default_val):
# callable 类型的默认值(如 uuid.uuid4无法写入 SQL DEFAULT
# 后续的 UUID 回填逻辑会处理已有记录
pass
elif isinstance(default_val, bool):
ddl += f" DEFAULT {'TRUE' if default_val else 'FALSE'}"
elif isinstance(default_val, str):
ddl += f" DEFAULT '{default_val}'"

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.routers import tasks, categories, tags, user_settings, habits, anniversaries, auth, goals, sync
from app.routers import tasks, categories, tags, user_settings, habits, anniversaries, auth, goals, sync, backup
api_router = APIRouter()
@@ -12,3 +12,4 @@ api_router.include_router(habits.router)
api_router.include_router(anniversaries.router)
api_router.include_router(goals.router)
api_router.include_router(sync.router)
api_router.include_router(backup.router)

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
import time
from app.database import get_db
from app.models.user_settings import UserSettings
from app.schemas.auth import LoginRequest, LoginResponse, ChangePasswordRequest, AuthStatusResponse, SetupPasswordRequest, AuthSetupStatusResponse
from app.utils.auth import (
hash_password, verify_password, create_access_token,
get_current_user,
get_current_user, set_cached_token_version,
)
from app.utils.datetime import utcnow
from app.utils.rate_limiter import login_limiter
from app.config import ACCESS_TOKEN_EXPIRE_SECONDS
router = APIRouter(prefix="/api/auth", tags=["认证"])
@@ -32,9 +34,31 @@ def _get_or_create_settings(db: Session) -> UserSettings:
return settings
class _StatusLimiter:
MAX_REQUESTS = 30
WINDOW_SECONDS = 60
def __init__(self):
self._requests: dict[str, list[float]] = {}
def check(self, ip: str) -> bool:
now = time.time()
times = [t for t in self._requests.get(ip, []) if now - t < self.WINDOW_SECONDS]
self._requests[ip] = times
if len(times) >= self.MAX_REQUESTS:
return False
times.append(now)
return True
_status_limiter = _StatusLimiter()
@router.get("/status", response_model=AuthSetupStatusResponse)
def auth_status(db: Session = Depends(get_db)):
def auth_status(request: Request, db: Session = Depends(get_db)):
"""检查系统密码是否已设置"""
ip = _get_client_ip(request)
if not _status_limiter.check(ip):
raise HTTPException(status_code=429, detail="请求过于频繁")
settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
has_password = bool(settings and settings.password_hash)
return AuthSetupStatusResponse(has_password=has_password)
@@ -68,7 +92,7 @@ def setup_password(data: SetupPasswordRequest, db: Session = Depends(get_db)):
value=token,
httponly=True,
samesite="strict",
max_age=86400,
max_age=ACCESS_TOKEN_EXPIRE_SECONDS,
path="/",
)
return response
@@ -111,7 +135,7 @@ def login(data: LoginRequest, request: Request, db: Session = Depends(get_db)):
value=token,
httponly=True,
samesite="strict",
max_age=86400,
max_age=ACCESS_TOKEN_EXPIRE_SECONDS,
path="/",
)
return response
@@ -147,6 +171,7 @@ def change_password(
settings.password_hash = hash_password(data.new_password)
settings.token_version = (settings.token_version or 0) + 1
set_cached_token_version(str(settings.id), settings.token_version)
settings.updated_at = utcnow()
db.commit()

161
api/app/routers/backup.py Normal file
View File

@@ -0,0 +1,161 @@
"""数据备份导入导出路由"""
import json
from datetime import datetime, date
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import text
from io import BytesIO
from app.database import get_db, Base, engine
from app.utils.logger import logger
from app.utils.datetime import utcnow
router = APIRouter(prefix="/api/backup", tags=["备份"])
# 导出顺序:按依赖关系(无 FK 的先导出)
EXPORT_TABLES = [
"categories", "tags", "user_settings", "sync_settings",
"habit_groups", "anniversary_categories",
"goals", "tasks", "habits", "anniversaries",
"goal_steps", "goal_reviews", "habit_checkins",
"task_tags", "goal_tasks",
]
# 导入时的清表顺序:子表先删(避免 FK 约束报错)
TRUNCATE_ORDER = [
"task_tags", "goal_tasks",
"habit_checkins",
"goal_reviews", "goal_steps",
"tasks", "habits", "anniversaries",
"goals", "categories", "tags",
"habit_groups", "anniversary_categories",
"user_settings", "sync_settings",
]
# 导入时的插入顺序:父表先插
INSERT_ORDER = [
"categories", "tags", "user_settings", "sync_settings",
"habit_groups", "anniversary_categories",
"goals", "tasks", "habits", "anniversaries",
"goal_steps", "goal_reviews", "habit_checkins",
"task_tags", "goal_tasks",
]
def _serialize_value(val):
"""将 Python 对象转为 JSON 可序列化的值"""
if val is None:
return None
if isinstance(val, (datetime, date)):
return val.isoformat()
if isinstance(val, bytes):
return val.decode("utf-8", errors="replace")
return val
@router.get("/export")
def export_data(db: Session = Depends(get_db)):
"""导出所有数据为 JSON 备份文件"""
try:
data: dict[str, list[dict]] = {}
for table_name in EXPORT_TABLES:
rows = []
try:
result = db.execute(text(f"SELECT * FROM {table_name}"))
columns = list(result.keys())
for row in result:
rows.append({
col: _serialize_value(getattr(row, col))
for col in columns
})
except Exception:
# 表可能不存在
rows = []
data[table_name] = rows
backup = {
"metadata": {
"version": 1,
"exported_at": utcnow().isoformat(),
},
"data": data,
}
json_bytes = json.dumps(backup, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"elysia-backup-{utcnow().strftime('%Y%m%d-%H%M%S')}.json"
logger.info(f"数据导出成功,共 {sum(len(v) for v in data.values())} 条记录")
return StreamingResponse(
BytesIO(json_bytes),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
except Exception as e:
logger.error(f"导出数据失败: {str(e)}")
raise HTTPException(status_code=500, detail="导出数据失败")
@router.post("/import")
async def import_data(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""导入备份数据(覆盖当前所有数据)"""
if not file.filename or not file.filename.endswith(".json"):
raise HTTPException(status_code=400, detail="请上传 JSON 格式的备份文件")
try:
content = await file.read()
backup = json.loads(content.decode("utf-8"))
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="备份文件格式不正确")
except Exception as e:
logger.error(f"读取备份文件失败: {str(e)}")
raise HTTPException(status_code=400, detail="读取备份文件失败")
payload = backup.get("data")
if not payload:
raise HTTPException(status_code=400, detail="备份文件内容为空")
# 验证必要表存在
for table_name in INSERT_ORDER:
if table_name not in payload:
raise HTTPException(status_code=400, detail=f"备份文件缺少表: {table_name}")
imported_count = 0
try:
# 1. 按序清空所有表
for table_name in TRUNCATE_ORDER:
try:
db.execute(text(f"DELETE FROM {table_name}"))
except Exception:
pass # 表可能不存在
db.flush()
# 2. 按序插入数据
for table_name in INSERT_ORDER:
rows = payload.get(table_name, [])
if not rows:
continue
columns = list(rows[0].keys())
col_str = ", ".join(columns)
placeholders = ", ".join([f":{c}" for c in columns])
for row_data in rows:
db.execute(
text(f"INSERT INTO {table_name} ({col_str}) VALUES ({placeholders})"),
{c: row_data[c] for c in columns},
)
imported_count += 1
db.commit()
logger.info(f"数据导入成功,共 {imported_count} 条记录")
return {"message": "数据导入成功", "count": imported_count}
except Exception as e:
db.rollback()
logger.error(f"导入数据失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"导入数据失败: {str(e)}")

View File

@@ -9,6 +9,7 @@ from app.schemas.goal import (
GoalCreate, GoalUpdate, GoalListResponse, GoalDetailResponse, GoalStatusUpdate,
GoalStepCreate, GoalStepUpdate, GoalStepResponse,
GoalReviewCreate, GoalReviewResponse,
ReorderRequest,
)
from app.schemas.common import DeleteResponse
from app.utils.crud import get_or_404
@@ -35,10 +36,10 @@ def recalc_progress(db: Session, goal_id: int):
def build_step_tree(steps: list[GoalStep]) -> list[dict]:
"""将扁平的 step 列表转为树形结构phase 包含子 milestone"""
"""将扁平的 step 列表转为树形结构phase 包含子 milestone,按 sort_order 排序"""
step_map = {}
roots = []
for s in steps:
for s in sorted(steps, key=lambda x: (x.sort_order or 0)):
step_map[s.id] = {
"id": s.id,
"goal_id": s.goal_id,
@@ -52,7 +53,7 @@ def build_step_tree(steps: list[GoalStep]) -> list[dict]:
"created_at": s.created_at,
"children": [],
}
for s in steps:
for s in sorted(steps, key=lambda x: (x.sort_order or 0)):
node = step_map[s.id]
if s.parent_id and s.parent_id in step_map:
step_map[s.parent_id]["children"].append(node)
@@ -206,7 +207,13 @@ def create_step(goal_id: int, data: GoalStepCreate, db: Session = Depends(get_db
"""添加阶段/里程碑"""
try:
get_or_404(db, Goal, goal_id, "目标")
step = GoalStep(goal_id=goal_id, **data.model_dump())
# 自动分配 sort_order同类步骤中取最大值 + 1
max_sort = db.query(GoalStep).filter(
GoalStep.goal_id == goal_id,
GoalStep.step_type == data.step_type,
).order_by(GoalStep.sort_order.desc()).first()
next_sort = (max_sort.sort_order + 1) if max_sort and max_sort.sort_order is not None else 0
step = GoalStep(goal_id=goal_id, sort_order=next_sort, **data.model_dump())
db.add(step)
db.commit()
db.refresh(step)
@@ -228,6 +235,31 @@ def create_step(goal_id: int, data: GoalStepCreate, db: Session = Depends(get_db
raise HTTPException(status_code=500, detail="添加步骤失败")
# ============ Reorder ============
@router.put("/{goal_id}/steps/reorder")
def reorder_steps(goal_id: int, data: ReorderRequest, db: Session = Depends(get_db)):
"""批量更新步骤排序"""
try:
get_or_404(db, Goal, goal_id, "目标")
for item in data.items:
step = db.query(GoalStep).filter(
GoalStep.id == item.id,
GoalStep.goal_id == goal_id,
).first()
if step:
step.sort_order = item.sort_order
db.commit()
logger.info(f"步骤排序更新成功: goal_id={goal_id}, count={len(data.items)}")
return {"message": "排序更新成功"}
except HTTPException:
raise
except Exception as e:
db.rollback()
logger.error(f"更新步骤排序失败: {str(e)}")
raise HTTPException(status_code=500, detail="更新排序失败")
@router.put("/{goal_id}/steps/{step_id}", response_model=GoalStepResponse)
def update_step(goal_id: int, step_id: int, data: GoalStepUpdate, db: Session = Depends(get_db)):
"""更新阶段/里程碑"""

14
api/app/schemas/backup.py Normal file
View File

@@ -0,0 +1,14 @@
"""数据备份导入导出 Schema"""
from pydantic import BaseModel
from datetime import datetime, date
from typing import Optional, Any
class BackupMetadata(BaseModel):
version: int = 1
exported_at: datetime
class BackupPayload(BaseModel):
metadata: BackupMetadata
data: dict[str, list[dict[str, Any]]]

View File

@@ -120,3 +120,14 @@ class GoalDetailResponse(GoalListResponse):
class GoalStatusUpdate(BaseModel):
status: str = Field(..., pattern="^(active|paused|completed|abandoned)$")
# ============ Reorder Schema ============
class ReorderItem(BaseModel):
id: int
sort_order: int
class ReorderRequest(BaseModel):
items: list[ReorderItem]

View File

@@ -9,6 +9,16 @@ from app.config import JWT_SECRET, ACCESS_TOKEN_EXPIRE_MINUTES
ALGORITHM = "HS256"
_token_version_cache: dict[str, int] = {}
def get_cached_token_version(user_id: str) -> int | None:
return _token_version_cache.get(user_id)
def set_cached_token_version(user_id: str, version: int):
_token_version_cache[user_id] = version
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
@@ -30,6 +40,8 @@ def decode_access_token(token: str) -> dict:
def get_current_user(request: Request) -> dict:
if hasattr(request.state, "user") and request.state.user:
return request.state.user
token = request.cookies.get("access_token", "")
if not token:
raise HTTPException(status_code=401, detail="未登录")