102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
from sqlalchemy import create_engine, inspect, text, String, Integer, Text, Boolean, Float, DateTime, Date
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker
|
||
import os
|
||
|
||
from app.config import DATABASE_PATH, DATABASE_URL
|
||
|
||
# 确保 data 目录存在
|
||
os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True)
|
||
|
||
# 创建引擎
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
connect_args={"check_same_thread": False}
|
||
)
|
||
|
||
# 创建会话工厂
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
# 创建基类
|
||
Base = declarative_base()
|
||
|
||
|
||
def get_db():
|
||
"""获取数据库会话"""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
# SQLAlchemy 类型到 SQLite 类型名的映射
|
||
_TYPE_MAP = {
|
||
String: "VARCHAR",
|
||
Integer: "INTEGER",
|
||
Text: "TEXT",
|
||
Boolean: "BOOLEAN",
|
||
Float: "REAL",
|
||
DateTime: "DATETIME",
|
||
Date: "DATE",
|
||
}
|
||
|
||
|
||
def _col_type_str(col_type) -> str:
|
||
"""将 SQLAlchemy 列类型转为 SQLite 类型字符串"""
|
||
if col_type.__class__ in _TYPE_MAP:
|
||
base = _TYPE_MAP[col_type.__class__]
|
||
else:
|
||
base = str(col_type).split("(")[0].strip()
|
||
|
||
length = getattr(col_type, "length", None)
|
||
if length:
|
||
return f"{base}({length})"
|
||
return base
|
||
|
||
|
||
def init_db():
|
||
"""初始化数据库表,自动补充新增的列"""
|
||
# 导入所有模型,确保 Base.metadata 包含全部表定义
|
||
from app.models import ( # noqa: F401
|
||
task, category, tag, user_settings, habit, anniversary, account,
|
||
)
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
# 通用自动迁移:对比 ORM 模型与实际表结构,补充缺失的列(SQLite 兼容)
|
||
inspector = inspect(engine)
|
||
table_names = set(inspector.get_table_names())
|
||
|
||
with engine.begin() as conn:
|
||
for table_cls in Base.metadata.sorted_tables:
|
||
table_name = table_cls.name
|
||
if table_name not in table_names:
|
||
continue
|
||
|
||
existing_cols = {c["name"] for c in inspector.get_columns(table_name)}
|
||
|
||
for col in table_cls.columns:
|
||
if col.name in existing_cols:
|
||
continue
|
||
# 跳过无服务端默认值且不可为空的列(容易出错)
|
||
if col.nullable is False and col.server_default is None and col.default is None:
|
||
continue
|
||
|
||
sqlite_type = _col_type_str(col.type)
|
||
|
||
ddl = f"ALTER TABLE {table_name} ADD COLUMN {col.name} {sqlite_type}"
|
||
|
||
# 为可空列或已有默认值的列附加 DEFAULT
|
||
if col.server_default is not None:
|
||
ddl += f" DEFAULT {col.server_default.arg}"
|
||
elif col.default is not None and col.nullable:
|
||
default_val = col.default.arg
|
||
if isinstance(default_val, str):
|
||
ddl += f" DEFAULT '{default_val}'"
|
||
elif isinstance(default_val, bool):
|
||
ddl += f" DEFAULT {1 if default_val else 0}"
|
||
else:
|
||
ddl += f" DEFAULT {default_val}"
|
||
|
||
conn.execute(text(ddl))
|