- Add GET /api/backup/export and POST /api/backup/import endpoints for full data backup
- Add drag-and-drop reorder for goal steps with PUT /api/goals/{id}/steps/reorder
- Auto-assign sort_order on step creation (preserves creation order)
- Fix duplicate milestone rendering in goal detail page
- Add category management button in goal dialog
- Migrate database default from SQLite to PostgreSQL
- Fix router guard redirect loop for logged-in users on setup/login pages
- Fix ALTER TABLE ADD COLUMN crash on callable defaults (uuid.uuid4)
- Add auth status rate limiter and token version caching
- Update CLAUDE.md to reflect current architecture
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
148 lines
5.4 KiB
Python
148 lines
5.4 KiB
Python
from sqlalchemy import create_engine, inspect, text, String, Integer, Text, Boolean, Float, DateTime, Date, event
|
||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||
|
||
from app.config import DATABASE_URL
|
||
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
pool_size=10,
|
||
max_overflow=20,
|
||
pool_recycle=3600,
|
||
pool_pre_ping=True,
|
||
)
|
||
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
Base = declarative_base()
|
||
|
||
|
||
def get_db():
|
||
"""获取数据库会话"""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
_TYPE_MAP = {
|
||
String: "VARCHAR",
|
||
Integer: "INTEGER",
|
||
Text: "TEXT",
|
||
Boolean: "BOOLEAN",
|
||
Float: "DOUBLE PRECISION",
|
||
DateTime: "TIMESTAMP",
|
||
Date: "DATE",
|
||
}
|
||
|
||
|
||
def _col_type_str(col_type) -> str:
|
||
if col_type.__class__ in _TYPE_MAP:
|
||
base = _TYPE_MAP[col_type.__class__]
|
||
else:
|
||
base = str(col_type).split("(")[0].strip()
|
||
|
||
length = getattr(col_type, "length", None)
|
||
if length:
|
||
return f"{base}({length})"
|
||
return base
|
||
|
||
|
||
def init_db():
|
||
"""初始化数据库表,自动补充新增的列,并为缺少 uuid 的记录回填"""
|
||
from app.utils.logger import logger # 避免循环导入
|
||
from app.models import ( # noqa: F401
|
||
task, category, tag, user_settings, habit, anniversary, goal, sync_settings,
|
||
)
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
inspector = inspect(engine)
|
||
table_names = set(inspector.get_table_names())
|
||
|
||
with engine.begin() as conn:
|
||
for table_cls in Base.metadata.sorted_tables:
|
||
table_name = table_cls.name
|
||
if table_name not in table_names:
|
||
continue
|
||
|
||
existing_cols = {c["name"] for c in inspector.get_columns(table_name)}
|
||
|
||
for col in table_cls.columns:
|
||
if col.name in existing_cols:
|
||
continue
|
||
if col.nullable is False and col.server_default is None and col.default is None:
|
||
continue
|
||
|
||
col_type_str = _col_type_str(col.type)
|
||
col_name = col.name
|
||
|
||
if col.server_default is not None:
|
||
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
|
||
ddl += f" DEFAULT {col.server_default.arg}"
|
||
if not col.nullable:
|
||
ddl += " NOT NULL"
|
||
elif col.default is not None:
|
||
default_val = col.default.arg
|
||
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
|
||
if callable(default_val):
|
||
# callable 类型的默认值(如 uuid.uuid4)无法写入 SQL DEFAULT,
|
||
# 后续的 UUID 回填逻辑会处理已有记录
|
||
pass
|
||
elif isinstance(default_val, bool):
|
||
ddl += f" DEFAULT {'TRUE' if default_val else 'FALSE'}"
|
||
elif isinstance(default_val, str):
|
||
ddl += f" DEFAULT '{default_val}'"
|
||
else:
|
||
ddl += f" DEFAULT {default_val}"
|
||
if not col.nullable:
|
||
ddl += " NOT NULL"
|
||
else:
|
||
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}"
|
||
|
||
conn.execute(text(ddl))
|
||
|
||
# 为缺少 uuid 的已有记录回填 UUID4
|
||
import uuid
|
||
db_session = SessionLocal()
|
||
try:
|
||
from app.models import Task, Category, Tag, HabitGroup, Habit, HabitCheckin
|
||
from app.models import AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview, SyncSettings
|
||
|
||
for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
|
||
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]:
|
||
if hasattr(model_cls, 'uuid'):
|
||
null_uuid_records = db_session.query(model_cls).filter(
|
||
(model_cls.uuid == None) | (model_cls.uuid == '') # noqa: E711
|
||
).all()
|
||
for record in null_uuid_records:
|
||
record.uuid = str(uuid.uuid4())
|
||
if null_uuid_records:
|
||
logger.info(f"为 {len(null_uuid_records)} 条 {model_cls.__name__} 记录回填了 uuid")
|
||
db_session.commit()
|
||
except Exception as e:
|
||
logger.warning(f"UUID 回填时出现异常(可忽略): {e}")
|
||
db_session.rollback()
|
||
finally:
|
||
db_session.close()
|
||
|
||
# 注册 sync_version 自增事件监听
|
||
_register_sync_version_listeners()
|
||
|
||
|
||
def _bump_sync_version(mapper, connection, target):
|
||
"""before_update 事件:自动递增 sync_version(同步模式中跳过)"""
|
||
from app.utils.sync_lock import is_sync_mode
|
||
if not is_sync_mode() and hasattr(target, 'sync_version'):
|
||
target.sync_version = (target.sync_version or 0) + 1
|
||
|
||
|
||
def _register_sync_version_listeners():
|
||
"""为所有可同步模型注册 before_update 事件监听"""
|
||
from app.models import (
|
||
Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
|
||
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview,
|
||
)
|
||
for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin,
|
||
AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]:
|
||
if hasattr(model_cls, 'sync_version'):
|
||
event.listen(model_cls, 'before_update', _bump_sync_version) |