from sqlalchemy import create_engine, inspect, text, String, Integer, Text, Boolean, Float, DateTime, Date from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker import os from app.config import DATABASE_PATH, DATABASE_URL # 确保 data 目录存在 os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True) # 创建引擎 engine = create_engine( DATABASE_URL, connect_args={"check_same_thread": False} ) # 创建会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # 创建基类 Base = declarative_base() def get_db(): """获取数据库会话""" db = SessionLocal() try: yield db finally: db.close() # SQLAlchemy 类型到 SQLite 类型名的映射 _TYPE_MAP = { String: "VARCHAR", Integer: "INTEGER", Text: "TEXT", Boolean: "BOOLEAN", Float: "REAL", DateTime: "DATETIME", Date: "DATE", } def _col_type_str(col_type) -> str: """将 SQLAlchemy 列类型转为 SQLite 类型字符串""" if col_type.__class__ in _TYPE_MAP: base = _TYPE_MAP[col_type.__class__] else: base = str(col_type).split("(")[0].strip() length = getattr(col_type, "length", None) if length: return f"{base}({length})" return base def init_db(): """初始化数据库表,自动补充新增的列""" # 导入所有模型,确保 Base.metadata 包含全部表定义 from app.models import ( # noqa: F401 task, category, tag, user_settings, habit, anniversary, account, ) Base.metadata.create_all(bind=engine) # 通用自动迁移:对比 ORM 模型与实际表结构,补充缺失的列(SQLite 兼容) inspector = inspect(engine) table_names = set(inspector.get_table_names()) with engine.begin() as conn: for table_cls in Base.metadata.sorted_tables: table_name = table_cls.name if table_name not in table_names: continue existing_cols = {c["name"] for c in inspector.get_columns(table_name)} for col in table_cls.columns: if col.name in existing_cols: continue # 跳过无服务端默认值且不可为空的列(容易出错) if col.nullable is False and col.server_default is None and col.default is None: continue sqlite_type = _col_type_str(col.type) ddl = f"ALTER TABLE {table_name} ADD COLUMN {col.name} {sqlite_type}" # 为可空列或已有默认值的列附加 DEFAULT if col.server_default is not None: ddl += f" DEFAULT {col.server_default.arg}" elif col.default is not None and col.nullable: default_val = col.default.arg if isinstance(default_val, str): ddl += f" DEFAULT '{default_val}'" elif isinstance(default_val, bool): ddl += f" DEFAULT {1 if default_val else 0}" else: ddl += f" DEFAULT {default_val}" conn.execute(text(ddl))