Files
ToDoList/api/app/database.py
祀梦 5048de4fa1 feat: add data backup/import, goal step ordering, and PostgreSQL migration
- Add GET /api/backup/export and POST /api/backup/import endpoints for full data backup
- Add drag-and-drop reorder for goal steps with PUT /api/goals/{id}/steps/reorder
- Auto-assign sort_order on step creation (preserves creation order)
- Fix duplicate milestone rendering in goal detail page
- Add category management button in goal dialog
- Migrate database default from SQLite to PostgreSQL
- Fix router guard redirect loop for logged-in users on setup/login pages
- Fix ALTER TABLE ADD COLUMN crash on callable defaults (uuid.uuid4)
- Add auth status rate limiter and token version caching
- Update CLAUDE.md to reflect current architecture

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 00:02:18 +08:00

148 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import create_engine, inspect, text, String, Integer, Text, Boolean, Float, DateTime, Date, event
from sqlalchemy.orm import declarative_base, sessionmaker
from app.config import DATABASE_URL
engine = create_engine(
DATABASE_URL,
pool_size=10,
max_overflow=20,
pool_recycle=3600,
pool_pre_ping=True,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()
_TYPE_MAP = {
String: "VARCHAR",
Integer: "INTEGER",
Text: "TEXT",
Boolean: "BOOLEAN",
Float: "DOUBLE PRECISION",
DateTime: "TIMESTAMP",
Date: "DATE",
}
def _col_type_str(col_type) -> str:
if col_type.__class__ in _TYPE_MAP:
base = _TYPE_MAP[col_type.__class__]
else:
base = str(col_type).split("(")[0].strip()
length = getattr(col_type, "length", None)
if length:
return f"{base}({length})"
return base
def init_db():
"""初始化数据库表,自动补充新增的列,并为缺少 uuid 的记录回填"""
from app.utils.logger import logger # 避免循环导入
from app.models import ( # noqa: F401
task, category, tag, user_settings, habit, anniversary, goal, sync_settings,
)
Base.metadata.create_all(bind=engine)
inspector = inspect(engine)
table_names = set(inspector.get_table_names())
with engine.begin() as conn:
for table_cls in Base.metadata.sorted_tables:
table_name = table_cls.name
if table_name not in table_names:
continue
existing_cols = {c["name"] for c in inspector.get_columns(table_name)}
for col in table_cls.columns:
if col.name in existing_cols:
continue
if col.nullable is False and col.server_default is None and col.default is None:
continue
col_type_str = _col_type_str(col.type)
col_name = col.name
if col.server_default is not None:
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
ddl += f" DEFAULT {col.server_default.arg}"
if not col.nullable:
ddl += " NOT NULL"
elif col.default is not None:
default_val = col.default.arg
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
if callable(default_val):
# callable 类型的默认值(如 uuid.uuid4无法写入 SQL DEFAULT
# 后续的 UUID 回填逻辑会处理已有记录
pass
elif isinstance(default_val, bool):
ddl += f" DEFAULT {'TRUE' if default_val else 'FALSE'}"
elif isinstance(default_val, str):
ddl += f" DEFAULT '{default_val}'"
else:
ddl += f" DEFAULT {default_val}"
if not col.nullable:
ddl += " NOT NULL"
else:
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
conn.execute(text(ddl))
# 为缺少 uuid 的已有记录回填 UUID4
import uuid
db_session = SessionLocal()
try:
from app.models import Task, Category, Tag, HabitGroup, Habit, HabitCheckin
from app.models import AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview, SyncSettings
for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]:
if hasattr(model_cls, 'uuid'):
null_uuid_records = db_session.query(model_cls).filter(
(model_cls.uuid == None) | (model_cls.uuid == '') # noqa: E711
).all()
for record in null_uuid_records:
record.uuid = str(uuid.uuid4())
if null_uuid_records:
logger.info(f"{len(null_uuid_records)}{model_cls.__name__} 记录回填了 uuid")
db_session.commit()
except Exception as e:
logger.warning(f"UUID 回填时出现异常(可忽略): {e}")
db_session.rollback()
finally:
db_session.close()
# 注册 sync_version 自增事件监听
_register_sync_version_listeners()
def _bump_sync_version(mapper, connection, target):
"""before_update 事件:自动递增 sync_version同步模式中跳过"""
from app.utils.sync_lock import is_sync_mode
if not is_sync_mode() and hasattr(target, 'sync_version'):
target.sync_version = (target.sync_version or 0) + 1
def _register_sync_version_listeners():
"""为所有可同步模型注册 before_update 事件监听"""
from app.models import (
Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview,
)
for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]:
if hasattr(model_cls, 'sync_version'):
event.listen(model_cls, 'before_update', _bump_sync_version)