feat: add data backup/import, goal step ordering, and PostgreSQL migration
- Add GET /api/backup/export and POST /api/backup/import endpoints for full data backup
- Add drag-and-drop reorder for goal steps with PUT /api/goals/{id}/steps/reorder
- Auto-assign sort_order on step creation (preserves creation order)
- Fix duplicate milestone rendering in goal detail page
- Add category management button in goal dialog
- Migrate database default from SQLite to PostgreSQL
- Fix router guard redirect loop for logged-in users on setup/login pages
- Fix ALTER TABLE ADD COLUMN crash on callable defaults (uuid.uuid4)
- Add auth status rate limiter and token version caching
- Update CLAUDE.md to reflect current architecture
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.routers import tasks, categories, tags, user_settings, habits, anniversaries, auth, goals, sync
|
||||
from app.routers import tasks, categories, tags, user_settings, habits, anniversaries, auth, goals, sync, backup
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -12,3 +12,4 @@ api_router.include_router(habits.router)
|
||||
api_router.include_router(anniversaries.router)
|
||||
api_router.include_router(goals.router)
|
||||
api_router.include_router(sync.router)
|
||||
api_router.include_router(backup.router)
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
import time
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.user_settings import UserSettings
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, ChangePasswordRequest, AuthStatusResponse, SetupPasswordRequest, AuthSetupStatusResponse
|
||||
from app.utils.auth import (
|
||||
hash_password, verify_password, create_access_token,
|
||||
get_current_user,
|
||||
get_current_user, set_cached_token_version,
|
||||
)
|
||||
from app.utils.datetime import utcnow
|
||||
from app.utils.rate_limiter import login_limiter
|
||||
from app.config import ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
@@ -32,9 +34,31 @@ def _get_or_create_settings(db: Session) -> UserSettings:
|
||||
return settings
|
||||
|
||||
|
||||
class _StatusLimiter:
|
||||
MAX_REQUESTS = 30
|
||||
WINDOW_SECONDS = 60
|
||||
|
||||
def __init__(self):
|
||||
self._requests: dict[str, list[float]] = {}
|
||||
|
||||
def check(self, ip: str) -> bool:
|
||||
now = time.time()
|
||||
times = [t for t in self._requests.get(ip, []) if now - t < self.WINDOW_SECONDS]
|
||||
self._requests[ip] = times
|
||||
if len(times) >= self.MAX_REQUESTS:
|
||||
return False
|
||||
times.append(now)
|
||||
return True
|
||||
|
||||
_status_limiter = _StatusLimiter()
|
||||
|
||||
|
||||
@router.get("/status", response_model=AuthSetupStatusResponse)
|
||||
def auth_status(db: Session = Depends(get_db)):
|
||||
def auth_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""检查系统密码是否已设置"""
|
||||
ip = _get_client_ip(request)
|
||||
if not _status_limiter.check(ip):
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁")
|
||||
settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
|
||||
has_password = bool(settings and settings.password_hash)
|
||||
return AuthSetupStatusResponse(has_password=has_password)
|
||||
@@ -68,7 +92,7 @@ def setup_password(data: SetupPasswordRequest, db: Session = Depends(get_db)):
|
||||
value=token,
|
||||
httponly=True,
|
||||
samesite="strict",
|
||||
max_age=86400,
|
||||
max_age=ACCESS_TOKEN_EXPIRE_SECONDS,
|
||||
path="/",
|
||||
)
|
||||
return response
|
||||
@@ -111,7 +135,7 @@ def login(data: LoginRequest, request: Request, db: Session = Depends(get_db)):
|
||||
value=token,
|
||||
httponly=True,
|
||||
samesite="strict",
|
||||
max_age=86400,
|
||||
max_age=ACCESS_TOKEN_EXPIRE_SECONDS,
|
||||
path="/",
|
||||
)
|
||||
return response
|
||||
@@ -147,6 +171,7 @@ def change_password(
|
||||
|
||||
settings.password_hash = hash_password(data.new_password)
|
||||
settings.token_version = (settings.token_version or 0) + 1
|
||||
set_cached_token_version(str(settings.id), settings.token_version)
|
||||
settings.updated_at = utcnow()
|
||||
db.commit()
|
||||
|
||||
|
||||
161
api/app/routers/backup.py
Normal file
161
api/app/routers/backup.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""数据备份导入导出路由"""
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from io import BytesIO
|
||||
|
||||
from app.database import get_db, Base, engine
|
||||
from app.utils.logger import logger
|
||||
from app.utils.datetime import utcnow
|
||||
|
||||
router = APIRouter(prefix="/api/backup", tags=["备份"])
|
||||
|
||||
# 导出顺序:按依赖关系(无 FK 的先导出)
|
||||
EXPORT_TABLES = [
|
||||
"categories", "tags", "user_settings", "sync_settings",
|
||||
"habit_groups", "anniversary_categories",
|
||||
"goals", "tasks", "habits", "anniversaries",
|
||||
"goal_steps", "goal_reviews", "habit_checkins",
|
||||
"task_tags", "goal_tasks",
|
||||
]
|
||||
|
||||
# 导入时的清表顺序:子表先删(避免 FK 约束报错)
|
||||
TRUNCATE_ORDER = [
|
||||
"task_tags", "goal_tasks",
|
||||
"habit_checkins",
|
||||
"goal_reviews", "goal_steps",
|
||||
"tasks", "habits", "anniversaries",
|
||||
"goals", "categories", "tags",
|
||||
"habit_groups", "anniversary_categories",
|
||||
"user_settings", "sync_settings",
|
||||
]
|
||||
|
||||
# 导入时的插入顺序:父表先插
|
||||
INSERT_ORDER = [
|
||||
"categories", "tags", "user_settings", "sync_settings",
|
||||
"habit_groups", "anniversary_categories",
|
||||
"goals", "tasks", "habits", "anniversaries",
|
||||
"goal_steps", "goal_reviews", "habit_checkins",
|
||||
"task_tags", "goal_tasks",
|
||||
]
|
||||
|
||||
|
||||
def _serialize_value(val):
|
||||
"""将 Python 对象转为 JSON 可序列化的值"""
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (datetime, date)):
|
||||
return val.isoformat()
|
||||
if isinstance(val, bytes):
|
||||
return val.decode("utf-8", errors="replace")
|
||||
return val
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
def export_data(db: Session = Depends(get_db)):
|
||||
"""导出所有数据为 JSON 备份文件"""
|
||||
try:
|
||||
data: dict[str, list[dict]] = {}
|
||||
|
||||
for table_name in EXPORT_TABLES:
|
||||
rows = []
|
||||
try:
|
||||
result = db.execute(text(f"SELECT * FROM {table_name}"))
|
||||
columns = list(result.keys())
|
||||
for row in result:
|
||||
rows.append({
|
||||
col: _serialize_value(getattr(row, col))
|
||||
for col in columns
|
||||
})
|
||||
except Exception:
|
||||
# 表可能不存在
|
||||
rows = []
|
||||
data[table_name] = rows
|
||||
|
||||
backup = {
|
||||
"metadata": {
|
||||
"version": 1,
|
||||
"exported_at": utcnow().isoformat(),
|
||||
},
|
||||
"data": data,
|
||||
}
|
||||
|
||||
json_bytes = json.dumps(backup, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"elysia-backup-{utcnow().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
|
||||
logger.info(f"数据导出成功,共 {sum(len(v) for v in data.values())} 条记录")
|
||||
return StreamingResponse(
|
||||
BytesIO(json_bytes),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"导出数据失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="导出数据失败")
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
async def import_data(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""导入备份数据(覆盖当前所有数据)"""
|
||||
if not file.filename or not file.filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="请上传 JSON 格式的备份文件")
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
backup = json.loads(content.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="备份文件格式不正确")
|
||||
except Exception as e:
|
||||
logger.error(f"读取备份文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail="读取备份文件失败")
|
||||
|
||||
payload = backup.get("data")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=400, detail="备份文件内容为空")
|
||||
|
||||
# 验证必要表存在
|
||||
for table_name in INSERT_ORDER:
|
||||
if table_name not in payload:
|
||||
raise HTTPException(status_code=400, detail=f"备份文件缺少表: {table_name}")
|
||||
|
||||
imported_count = 0
|
||||
try:
|
||||
# 1. 按序清空所有表
|
||||
for table_name in TRUNCATE_ORDER:
|
||||
try:
|
||||
db.execute(text(f"DELETE FROM {table_name}"))
|
||||
except Exception:
|
||||
pass # 表可能不存在
|
||||
db.flush()
|
||||
|
||||
# 2. 按序插入数据
|
||||
for table_name in INSERT_ORDER:
|
||||
rows = payload.get(table_name, [])
|
||||
if not rows:
|
||||
continue
|
||||
columns = list(rows[0].keys())
|
||||
col_str = ", ".join(columns)
|
||||
placeholders = ", ".join([f":{c}" for c in columns])
|
||||
|
||||
for row_data in rows:
|
||||
db.execute(
|
||||
text(f"INSERT INTO {table_name} ({col_str}) VALUES ({placeholders})"),
|
||||
{c: row_data[c] for c in columns},
|
||||
)
|
||||
imported_count += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"数据导入成功,共 {imported_count} 条记录")
|
||||
return {"message": "数据导入成功", "count": imported_count}
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"导入数据失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"导入数据失败: {str(e)}")
|
||||
@@ -9,6 +9,7 @@ from app.schemas.goal import (
|
||||
GoalCreate, GoalUpdate, GoalListResponse, GoalDetailResponse, GoalStatusUpdate,
|
||||
GoalStepCreate, GoalStepUpdate, GoalStepResponse,
|
||||
GoalReviewCreate, GoalReviewResponse,
|
||||
ReorderRequest,
|
||||
)
|
||||
from app.schemas.common import DeleteResponse
|
||||
from app.utils.crud import get_or_404
|
||||
@@ -35,10 +36,10 @@ def recalc_progress(db: Session, goal_id: int):
|
||||
|
||||
|
||||
def build_step_tree(steps: list[GoalStep]) -> list[dict]:
|
||||
"""将扁平的 step 列表转为树形结构(phase 包含子 milestone)"""
|
||||
"""将扁平的 step 列表转为树形结构(phase 包含子 milestone),按 sort_order 排序"""
|
||||
step_map = {}
|
||||
roots = []
|
||||
for s in steps:
|
||||
for s in sorted(steps, key=lambda x: (x.sort_order or 0)):
|
||||
step_map[s.id] = {
|
||||
"id": s.id,
|
||||
"goal_id": s.goal_id,
|
||||
@@ -52,7 +53,7 @@ def build_step_tree(steps: list[GoalStep]) -> list[dict]:
|
||||
"created_at": s.created_at,
|
||||
"children": [],
|
||||
}
|
||||
for s in steps:
|
||||
for s in sorted(steps, key=lambda x: (x.sort_order or 0)):
|
||||
node = step_map[s.id]
|
||||
if s.parent_id and s.parent_id in step_map:
|
||||
step_map[s.parent_id]["children"].append(node)
|
||||
@@ -206,7 +207,13 @@ def create_step(goal_id: int, data: GoalStepCreate, db: Session = Depends(get_db
|
||||
"""添加阶段/里程碑"""
|
||||
try:
|
||||
get_or_404(db, Goal, goal_id, "目标")
|
||||
step = GoalStep(goal_id=goal_id, **data.model_dump())
|
||||
# 自动分配 sort_order:同类步骤中取最大值 + 1
|
||||
max_sort = db.query(GoalStep).filter(
|
||||
GoalStep.goal_id == goal_id,
|
||||
GoalStep.step_type == data.step_type,
|
||||
).order_by(GoalStep.sort_order.desc()).first()
|
||||
next_sort = (max_sort.sort_order + 1) if max_sort and max_sort.sort_order is not None else 0
|
||||
step = GoalStep(goal_id=goal_id, sort_order=next_sort, **data.model_dump())
|
||||
db.add(step)
|
||||
db.commit()
|
||||
db.refresh(step)
|
||||
@@ -228,6 +235,31 @@ def create_step(goal_id: int, data: GoalStepCreate, db: Session = Depends(get_db
|
||||
raise HTTPException(status_code=500, detail="添加步骤失败")
|
||||
|
||||
|
||||
# ============ Reorder ============
|
||||
|
||||
@router.put("/{goal_id}/steps/reorder")
|
||||
def reorder_steps(goal_id: int, data: ReorderRequest, db: Session = Depends(get_db)):
|
||||
"""批量更新步骤排序"""
|
||||
try:
|
||||
get_or_404(db, Goal, goal_id, "目标")
|
||||
for item in data.items:
|
||||
step = db.query(GoalStep).filter(
|
||||
GoalStep.id == item.id,
|
||||
GoalStep.goal_id == goal_id,
|
||||
).first()
|
||||
if step:
|
||||
step.sort_order = item.sort_order
|
||||
db.commit()
|
||||
logger.info(f"步骤排序更新成功: goal_id={goal_id}, count={len(data.items)}")
|
||||
return {"message": "排序更新成功"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"更新步骤排序失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="更新排序失败")
|
||||
|
||||
|
||||
@router.put("/{goal_id}/steps/{step_id}", response_model=GoalStepResponse)
|
||||
def update_step(goal_id: int, step_id: int, data: GoalStepUpdate, db: Session = Depends(get_db)):
|
||||
"""更新阶段/里程碑"""
|
||||
|
||||
Reference in New Issue
Block a user