from sqlalchemy import create_engine, inspect, text, String, Integer, Text, Boolean, Float, DateTime, Date, event from sqlalchemy.orm import declarative_base, sessionmaker from app.config import DATABASE_URL engine = create_engine( DATABASE_URL, pool_size=10, max_overflow=20, pool_recycle=3600, pool_pre_ping=True, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() def get_db(): """获取数据库会话""" db = SessionLocal() try: yield db finally: db.close() _TYPE_MAP = { String: "VARCHAR", Integer: "INTEGER", Text: "TEXT", Boolean: "BOOLEAN", Float: "DOUBLE PRECISION", DateTime: "TIMESTAMP", Date: "DATE", } def _col_type_str(col_type) -> str: if col_type.__class__ in _TYPE_MAP: base = _TYPE_MAP[col_type.__class__] else: base = str(col_type).split("(")[0].strip() length = getattr(col_type, "length", None) if length: return f"{base}({length})" return base def init_db(): """初始化数据库表,自动补充新增的列,并为缺少 uuid 的记录回填""" from app.utils.logger import logger # 避免循环导入 from app.models import ( # noqa: F401 task, category, tag, user_settings, habit, anniversary, goal, sync_settings, ) Base.metadata.create_all(bind=engine) inspector = inspect(engine) table_names = set(inspector.get_table_names()) with engine.begin() as conn: for table_cls in Base.metadata.sorted_tables: table_name = table_cls.name if table_name not in table_names: continue existing_cols = {c["name"] for c in inspector.get_columns(table_name)} for col in table_cls.columns: if col.name in existing_cols: continue if col.nullable is False and col.server_default is None and col.default is None: continue col_type_str = _col_type_str(col.type) col_name = col.name if col.server_default is not None: ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}" ddl += f" DEFAULT {col.server_default.arg}" if not col.nullable: ddl += " NOT NULL" elif col.default is not None: default_val = col.default.arg ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}" if isinstance(default_val, bool): ddl += f" DEFAULT {'TRUE' if default_val else 'FALSE'}" elif isinstance(default_val, str): ddl += f" DEFAULT '{default_val}'" else: ddl += f" DEFAULT {default_val}" if not col.nullable: ddl += " NOT NULL" else: ddl = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type_str}" conn.execute(text(ddl)) # 为缺少 uuid 的已有记录回填 UUID4 import uuid db_session = SessionLocal() try: from app.models import Task, Category, Tag, HabitGroup, Habit, HabitCheckin from app.models import AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview, SyncSettings for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin, AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]: if hasattr(model_cls, 'uuid'): null_uuid_records = db_session.query(model_cls).filter( (model_cls.uuid == None) | (model_cls.uuid == '') # noqa: E711 ).all() for record in null_uuid_records: record.uuid = str(uuid.uuid4()) if null_uuid_records: logger.info(f"为 {len(null_uuid_records)} 条 {model_cls.__name__} 记录回填了 uuid") db_session.commit() except Exception as e: logger.warning(f"UUID 回填时出现异常(可忽略): {e}") db_session.rollback() finally: db_session.close() # 注册 sync_version 自增事件监听 _register_sync_version_listeners() def _bump_sync_version(mapper, connection, target): """before_update 事件:自动递增 sync_version(同步模式中跳过)""" from app.utils.sync_lock import is_sync_mode if not is_sync_mode() and hasattr(target, 'sync_version'): target.sync_version = (target.sync_version or 0) + 1 def _register_sync_version_listeners(): """为所有可同步模型注册 before_update 事件监听""" from app.models import ( Task, Category, Tag, HabitGroup, Habit, HabitCheckin, AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview, ) for model_cls in [Task, Category, Tag, HabitGroup, Habit, HabitCheckin, AnniversaryCategory, Anniversary, Goal, GoalStep, GoalReview]: if hasattr(model_cls, 'sync_version'): event.listen(model_cls, 'before_update', _bump_sync_version)