"""数据备份导入导出路由""" import json from datetime import datetime, date from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from sqlalchemy import text from io import BytesIO from app.database import get_db, Base, engine from app.utils.logger import logger from app.utils.datetime import utcnow router = APIRouter(prefix="/api/backup", tags=["备份"]) # 导出顺序:按依赖关系(无 FK 的先导出) EXPORT_TABLES = [ "categories", "tags", "user_settings", "sync_settings", "habit_groups", "anniversary_categories", "certificate_categories", "goals", "tasks", "habits", "anniversaries", "certificates", "goal_steps", "goal_reviews", "goal_checkins", "habit_checkins", "task_tags", "goal_tasks", ] # 导入时的清表顺序:子表先删(避免 FK 约束报错) TRUNCATE_ORDER = [ "task_tags", "goal_tasks", "habit_checkins", "goal_checkins", "goal_reviews", "goal_steps", "tasks", "habits", "anniversaries", "certificates", "goals", "categories", "tags", "habit_groups", "anniversary_categories", "certificate_categories", "user_settings", "sync_settings", ] # 导入时的插入顺序:父表先插 INSERT_ORDER = [ "categories", "tags", "user_settings", "sync_settings", "habit_groups", "anniversary_categories", "certificate_categories", "goals", "tasks", "habits", "anniversaries", "certificates", "goal_steps", "goal_reviews", "goal_checkins", "habit_checkins", "task_tags", "goal_tasks", ] def _serialize_value(val): """将 Python 对象转为 JSON 可序列化的值""" if val is None: return None if isinstance(val, (datetime, date)): return val.isoformat() if isinstance(val, bytes): return val.decode("utf-8", errors="replace") return val @router.get("/export") def export_data(db: Session = Depends(get_db)): """导出所有数据为 JSON 备份文件""" try: data: dict[str, list[dict]] = {} for table_name in EXPORT_TABLES: rows = [] try: result = db.execute(text(f"SELECT * FROM {table_name}")) columns = list(result.keys()) for row in result: rows.append({ col: _serialize_value(getattr(row, col)) for col in columns }) except Exception: # 表可能不存在 rows = [] data[table_name] = rows backup = { "metadata": { "version": 1, "exported_at": utcnow().isoformat(), }, "data": data, } json_bytes = json.dumps(backup, ensure_ascii=False, indent=2).encode("utf-8") filename = f"elysia-backup-{utcnow().strftime('%Y%m%d-%H%M%S')}.json" logger.info(f"数据导出成功,共 {sum(len(v) for v in data.values())} 条记录") return StreamingResponse( BytesIO(json_bytes), media_type="application/json", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) except Exception as e: logger.error(f"导出数据失败: {str(e)}") raise HTTPException(status_code=500, detail="导出数据失败") @router.post("/import") async def import_data( file: UploadFile = File(...), db: Session = Depends(get_db), ): """导入备份数据(覆盖当前所有数据)""" if not file.filename or not file.filename.endswith(".json"): raise HTTPException(status_code=400, detail="请上传 JSON 格式的备份文件") try: content = await file.read() backup = json.loads(content.decode("utf-8")) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="备份文件格式不正确") except Exception as e: logger.error(f"读取备份文件失败: {str(e)}") raise HTTPException(status_code=400, detail="读取备份文件失败") payload = backup.get("data") if not payload: raise HTTPException(status_code=400, detail="备份文件内容为空") # 验证必要表存在 for table_name in INSERT_ORDER: if table_name not in payload: raise HTTPException(status_code=400, detail=f"备份文件缺少表: {table_name}") imported_count = 0 try: # 1. 按序清空所有表 for table_name in TRUNCATE_ORDER: try: db.execute(text(f"DELETE FROM {table_name}")) except Exception: pass # 表可能不存在 db.flush() # 2. 按序插入数据 for table_name in INSERT_ORDER: rows = payload.get(table_name, []) if not rows: continue columns = list(rows[0].keys()) col_str = ", ".join(columns) placeholders = ", ".join([f":{c}" for c in columns]) for row_data in rows: db.execute( text(f"INSERT INTO {table_name} ({col_str}) VALUES ({placeholders})"), {c: row_data[c] for c in columns}, ) imported_count += 1 db.commit() logger.info(f"数据导入成功,共 {imported_count} 条记录") return {"message": "数据导入成功", "count": imported_count} except Exception as e: db.rollback() logger.error(f"导入数据失败: {str(e)}") raise HTTPException(status_code=500, detail=f"导入数据失败: {str(e)}")