feat: add WebDAV sync support and startup/shutdown scripts
Backend: - Add uuid, sync_version, is_deleted fields to all syncable models - Add SyncSettings model for WebDAV configuration (AES-256-GCM encrypted passwords) - Add crypto.py: AES-256-GCM encryption derived from JWT_SECRET via PBKDF2 - Add sync_lock.py: thread-level sync lock with 503 middleware for write blocking - Add webdav.py: WebDAV client using requests (PUT/GET/MKCOL/DELETE) - Add sync_service.py: push/pull/bidirectional merge with LWW conflict resolution - Add sync router with 8 endpoints: config, test, push, pull, sync, status, remote delete - Add UUID backfill for existing records in init_db() - Add SQLAlchemy before_update event to auto-increment sync_version - Register sync middleware to block writes during sync (503) Frontend: - Add sync API client (WebUI/src/api/sync.ts) - Add useSyncStore with config, test, push/pull/sync operations - Add WebDAV config + sync UI in SettingsView - Add 503 status code handling in axios interceptor - Add uuid field to all TypeScript type definitions Scripts: - Add scripts/start.bat and scripts/stop.bat for project management Design doc: docs/plan/webdav-sync-design.md
This commit is contained in:
53
api/app/utils/crypto.py
Normal file
53
api/app/utils/crypto.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
AES-256-GCM 加解密工具
|
||||
密钥从 JWT_SECRET 派生,用于加密 WebDAV 密码等敏感信息
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
import hashlib
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from app.config import JWT_SECRET
|
||||
|
||||
_SALT = b"elysia-todo-sync-v1"
|
||||
_NONCE_SIZE = 12 # AES-GCM 标准 nonce 长度
|
||||
|
||||
|
||||
def _derive_key() -> bytes:
|
||||
"""从 JWT_SECRET 派生 256-bit AES 密钥"""
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=_SALT,
|
||||
iterations=480000,
|
||||
)
|
||||
return kdf.derive(JWT_SECRET.encode("utf-8"))
|
||||
|
||||
|
||||
def encrypt(plaintext: str) -> str:
|
||||
"""AES-256-GCM 加密,返回 base64(iv + ciphertext + tag)"""
|
||||
if not plaintext:
|
||||
return ""
|
||||
key = _derive_key()
|
||||
nonce = os.urandom(_NONCE_SIZE)
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
|
||||
return base64.b64encode(nonce + ciphertext).decode("ascii")
|
||||
|
||||
|
||||
def decrypt(encrypted: str) -> str | None:
|
||||
"""AES-256-GCM 解密,解密失败返回 None"""
|
||||
if not encrypted:
|
||||
return None
|
||||
try:
|
||||
key = _derive_key()
|
||||
raw = base64.b64decode(encrypted)
|
||||
nonce = raw[:_NONCE_SIZE]
|
||||
ciphertext = raw[_NONCE_SIZE:]
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
return plaintext.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
37
api/app/utils/sync_lock.py
Normal file
37
api/app/utils/sync_lock.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
全局同步锁与同步模式标记
|
||||
同步期间禁止所有写操作,前端显示同步遮罩
|
||||
"""
|
||||
import threading
|
||||
|
||||
_sync_lock = threading.Lock()
|
||||
_sync_in_progress = False
|
||||
_sync_mode = threading.local()
|
||||
|
||||
|
||||
def acquire_sync_lock() -> bool:
|
||||
"""非阻塞获取同步锁,成功返回 True"""
|
||||
acquired = _sync_lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
global _sync_in_progress
|
||||
_sync_in_progress = True
|
||||
_sync_mode.active = True
|
||||
return acquired
|
||||
|
||||
|
||||
def release_sync_lock():
|
||||
"""释放同步锁"""
|
||||
global _sync_in_progress
|
||||
_sync_in_progress = False
|
||||
_sync_mode.active = False
|
||||
_sync_lock.release()
|
||||
|
||||
|
||||
def is_syncing() -> bool:
|
||||
"""检查是否正在同步"""
|
||||
return _sync_in_progress
|
||||
|
||||
|
||||
def is_sync_mode() -> bool:
|
||||
"""检查当前线程是否在同步模式中(跳过 sync_version 自增)"""
|
||||
return getattr(_sync_mode, 'active', False)
|
||||
567
api/app/utils/sync_service.py
Normal file
567
api/app/utils/sync_service.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""
|
||||
同步核心服务
|
||||
处理 push / pull / bidirectional merge 逻辑
|
||||
"""
|
||||
from datetime import datetime, date as date_type
|
||||
import json
|
||||
import os
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models import (
|
||||
Task, Category, Tag, task_tags, UserSettings,
|
||||
HabitGroup, Habit, HabitCheckin,
|
||||
AnniversaryCategory, Anniversary,
|
||||
Goal, GoalStep, GoalReview, goal_tasks,
|
||||
SyncSettings,
|
||||
)
|
||||
from app.utils.crypto import encrypt, decrypt
|
||||
from app.utils.webdav import WebDAVClient
|
||||
from app.utils.sync_lock import acquire_sync_lock, release_sync_lock
|
||||
from app.utils.logger import logger
|
||||
from app.utils.datetime import utcnow
|
||||
|
||||
SYNC_COLLECTIONS = [
|
||||
("categories", Category),
|
||||
("tags", Tag),
|
||||
("tasks", Task),
|
||||
("habit_groups", HabitGroup),
|
||||
("habits", Habit),
|
||||
("habit_checkins", HabitCheckin),
|
||||
("anniversary_categories", AnniversaryCategory),
|
||||
("anniversaries", Anniversary),
|
||||
("goals", Goal),
|
||||
("goal_steps", GoalStep),
|
||||
("goal_reviews", GoalReview),
|
||||
]
|
||||
|
||||
ASSOCIATION_COLLECTIONS = [
|
||||
("task_tags", task_tags, "tasks", "tags"),
|
||||
("goal_tasks", goal_tasks, "goals", "tasks"),
|
||||
]
|
||||
|
||||
USER_SETTINGS_SYNC_FIELDS = [
|
||||
"nickname", "avatar", "signature", "birthday", "email",
|
||||
"site_name", "theme", "language", "default_view",
|
||||
"default_sort_by", "default_sort_order",
|
||||
]
|
||||
|
||||
MODEL_MAP = {
|
||||
"tasks": Task,
|
||||
"categories": Category,
|
||||
"tags": Tag,
|
||||
"habit_groups": HabitGroup,
|
||||
"habits": Habit,
|
||||
"anniversary_categories": AnniversaryCategory,
|
||||
"anniversaries": Anniversary,
|
||||
"goals": Goal,
|
||||
"goal_steps": GoalStep,
|
||||
"goal_reviews": GoalReview,
|
||||
"user_settings": UserSettings,
|
||||
}
|
||||
|
||||
|
||||
def _get_sync_settings(db: Session) -> SyncSettings:
|
||||
settings = db.query(SyncSettings).filter(SyncSettings.id == 1).first()
|
||||
if not settings:
|
||||
settings = SyncSettings(id=1)
|
||||
db.add(settings)
|
||||
db.commit()
|
||||
db.refresh(settings)
|
||||
return settings
|
||||
|
||||
|
||||
def _create_webdav_client(settings: SyncSettings) -> WebDAVClient | None:
|
||||
if not settings.webdav_url:
|
||||
return None
|
||||
password = decrypt(settings.webdav_password) if settings.webdav_password else ""
|
||||
if settings.webdav_password and password is None:
|
||||
logger.error("WebDAV 密码解密失败,可能 JWT_SECRET 已更改")
|
||||
return None
|
||||
return WebDAVClient(
|
||||
url=settings.webdav_url,
|
||||
username=settings.webdav_username or "",
|
||||
password=password,
|
||||
path=settings.webdav_path or "/elysia-todo/",
|
||||
)
|
||||
|
||||
|
||||
def _serialize_model(obj, model_class) -> dict:
|
||||
result = {}
|
||||
for col in model_class.__table__.columns:
|
||||
val = getattr(obj, col.name, None)
|
||||
if isinstance(val, datetime):
|
||||
val = val.isoformat() if val else None
|
||||
elif isinstance(val, date_type):
|
||||
val = val.isoformat() if val else None
|
||||
result[col.name] = val
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_association(row, left_model, right_model, db: Session) -> dict | None:
|
||||
left_id, right_id = row[0], row[1]
|
||||
left_obj = db.query(left_model).filter(left_model.id == left_id).first()
|
||||
right_obj = db.query(right_model).filter(right_model.id == right_id).first()
|
||||
if not left_obj or not right_obj or not left_obj.uuid or not right_obj.uuid:
|
||||
return None
|
||||
return {
|
||||
f"{left_model.__tablename__}_uuid": left_obj.uuid,
|
||||
f"{right_model.__tablename__}_uuid": right_obj.uuid,
|
||||
}
|
||||
|
||||
|
||||
def _convert_value(val, col_type_name: str):
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (datetime, date_type)):
|
||||
return val
|
||||
if not isinstance(val, str):
|
||||
return val
|
||||
if "DateTime" in col_type_name:
|
||||
try:
|
||||
return datetime.fromisoformat(val.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif "Date" in col_type_name and "Time" not in col_type_name:
|
||||
try:
|
||||
return date_type.fromisoformat(val)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return val
|
||||
|
||||
|
||||
def _item_to_model_kwargs(item: dict, model_class) -> dict:
|
||||
"""将远程 JSON 对象转换为可用于创建模型的 kwargs,保留 uuid 和 sync_version"""
|
||||
kwargs = {}
|
||||
for col in model_class.__table__.columns:
|
||||
if col.name not in item:
|
||||
continue
|
||||
val = item[col.name]
|
||||
if col.name == "id":
|
||||
continue
|
||||
col_type_name = type(col.type).__name__
|
||||
val = _convert_value(val, col_type_name)
|
||||
kwargs[col.name] = val
|
||||
return kwargs
|
||||
|
||||
|
||||
def _backup_local(db: Session) -> str:
|
||||
timestamp = utcnow().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
backup_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"data", "backups", timestamp
|
||||
)
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
rows = db.query(model_class).filter(model_class.is_deleted == False).all()
|
||||
items = [_serialize_model(r, model_class) for r in rows]
|
||||
filepath = os.path.join(backup_dir, f"{coll_name}.json")
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump({"collection": coll_name, "items": items}, f, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
logger.info(f"本地数据已备份到: {backup_dir}")
|
||||
return backup_dir
|
||||
|
||||
|
||||
def push_to_remote(db: Session) -> dict:
|
||||
settings = _get_sync_settings(db)
|
||||
client = _create_webdav_client(settings)
|
||||
if not client:
|
||||
return {"success": False, "message": "WebDAV 未配置或密码解密失败"}
|
||||
|
||||
if not acquire_sync_lock():
|
||||
return {"success": False, "message": "同步正在进行中"}
|
||||
|
||||
try:
|
||||
client.ensure_dirs()
|
||||
timestamp = utcnow().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
client.backup_remote(timestamp)
|
||||
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
rows = db.query(model_class).all()
|
||||
items = [_serialize_model(r, model_class) for r in rows]
|
||||
data = {
|
||||
"version": 1,
|
||||
"collection": coll_name,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": items,
|
||||
}
|
||||
if not client.upload_json(f"{coll_name}.json", data):
|
||||
return {"success": False, "message": f"上传 {coll_name} 失败"}
|
||||
|
||||
for assoc_name, assoc_table, left_name, right_name in ASSOCIATION_COLLECTIONS:
|
||||
left_model = MODEL_MAP.get(left_name)
|
||||
right_model = MODEL_MAP.get(right_name)
|
||||
if not left_model or not right_model:
|
||||
continue
|
||||
rows = db.execute(assoc_table.select()).fetchall()
|
||||
items = []
|
||||
for row in rows:
|
||||
item = _serialize_association(row, left_model, right_model, db)
|
||||
if item:
|
||||
items.append(item)
|
||||
client.upload_json(f"{assoc_name}.json", {
|
||||
"version": 1,
|
||||
"collection": assoc_name,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": items,
|
||||
})
|
||||
|
||||
user_settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
|
||||
if user_settings:
|
||||
pref_data = {}
|
||||
for field in USER_SETTINGS_SYNC_FIELDS:
|
||||
val = getattr(user_settings, field, None)
|
||||
if isinstance(val, (datetime, date_type)):
|
||||
val = val.isoformat() if val else None
|
||||
pref_data[field] = val
|
||||
client.upload_json("user_settings.json", {
|
||||
"version": 1,
|
||||
"collection": "user_settings",
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": [pref_data],
|
||||
})
|
||||
|
||||
manifest = {
|
||||
"version": 1,
|
||||
"last_sync_at": utcnow().isoformat(),
|
||||
"collections": {},
|
||||
}
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
count = db.query(model_class).filter(model_class.is_deleted == False).count()
|
||||
manifest["collections"][coll_name] = {
|
||||
"count": count,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
}
|
||||
client.upload_json("manifest.json", manifest)
|
||||
|
||||
settings.last_sync_at = utcnow()
|
||||
settings.last_sync_version = (settings.last_sync_version or 0) + 1
|
||||
settings.sync_enabled = True
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": "推送成功"}
|
||||
except Exception as e:
|
||||
logger.error(f"推送失败: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
return {"success": False, "message": f"推送失败: {str(e)}"}
|
||||
finally:
|
||||
release_sync_lock()
|
||||
|
||||
|
||||
def pull_from_remote(db: Session) -> dict:
|
||||
settings = _get_sync_settings(db)
|
||||
client = _create_webdav_client(settings)
|
||||
if not client:
|
||||
return {"success": False, "message": "WebDAV 未配置或密码解密失败"}
|
||||
|
||||
if not acquire_sync_lock():
|
||||
return {"success": False, "message": "同步正在进行中"}
|
||||
|
||||
try:
|
||||
_backup_local(db)
|
||||
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
remote_data = client.download_json(f"{coll_name}.json")
|
||||
if remote_data is None:
|
||||
continue
|
||||
db.query(model_class).delete()
|
||||
db.commit()
|
||||
|
||||
for item in remote_data.get("items", []):
|
||||
kwargs = _item_to_model_kwargs(item, model_class)
|
||||
is_deleted = kwargs.pop("is_deleted", False)
|
||||
obj = model_class(**kwargs)
|
||||
obj.is_deleted = bool(is_deleted)
|
||||
db.add(obj)
|
||||
|
||||
db.commit()
|
||||
|
||||
for assoc_name, assoc_table, left_name, right_name in ASSOCIATION_COLLECTIONS:
|
||||
remote_data = client.download_json(f"{assoc_name}.json")
|
||||
if remote_data is None:
|
||||
continue
|
||||
db.execute(assoc_table.delete())
|
||||
db.commit()
|
||||
|
||||
left_model = MODEL_MAP.get(left_name)
|
||||
right_model = MODEL_MAP.get(right_name)
|
||||
if not left_model or not right_model:
|
||||
continue
|
||||
|
||||
for item in remote_data.get("items", []):
|
||||
left_uuid = item.get(f"{left_name}_uuid")
|
||||
right_uuid = item.get(f"{right_name}_uuid")
|
||||
if not left_uuid or not right_uuid:
|
||||
continue
|
||||
left_obj = db.query(left_model).filter(left_model.uuid == left_uuid).first()
|
||||
right_obj = db.query(right_model).filter(right_model.uuid == right_uuid).first()
|
||||
if left_obj and right_obj:
|
||||
db.execute(assoc_table.insert().values(
|
||||
left_id=left_obj.id, right_id=right_obj.id
|
||||
))
|
||||
db.commit()
|
||||
|
||||
remote_prefs = client.download_json("user_settings.json")
|
||||
if remote_prefs and remote_prefs.get("items"):
|
||||
pref = remote_prefs["items"][0]
|
||||
user_settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
|
||||
if user_settings:
|
||||
for field in USER_SETTINGS_SYNC_FIELDS:
|
||||
if field in pref and pref[field] is not None:
|
||||
val = pref[field]
|
||||
if isinstance(val, str) and field == "birthday":
|
||||
try:
|
||||
val = date_type.fromisoformat(val)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
setattr(user_settings, field, val)
|
||||
db.commit()
|
||||
|
||||
settings.last_sync_at = utcnow()
|
||||
settings.last_sync_version = (settings.last_sync_version or 0) + 1
|
||||
settings.sync_enabled = True
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": "拉取成功"}
|
||||
except Exception as e:
|
||||
logger.error(f"拉取失败: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
return {"success": False, "message": f"拉取失败: {str(e)}"}
|
||||
finally:
|
||||
release_sync_lock()
|
||||
|
||||
|
||||
def bidirectional_sync(db: Session) -> dict:
|
||||
settings = _get_sync_settings(db)
|
||||
client = _create_webdav_client(settings)
|
||||
if not client:
|
||||
return {"success": False, "message": "WebDAV 未配置或密码解密失败"}
|
||||
|
||||
if not acquire_sync_lock():
|
||||
return {"success": False, "message": "同步正在进行中"}
|
||||
|
||||
try:
|
||||
client.ensure_dirs()
|
||||
stats = {"pushed": 0, "pulled": 0, "merged": 0, "deleted": 0}
|
||||
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
remote_data = client.download_json(f"{coll_name}.json")
|
||||
remote_by_uuid = {}
|
||||
remote_deleted_uuids = set()
|
||||
if remote_data:
|
||||
for item in remote_data.get("items", []):
|
||||
uuid_val = item.get("uuid")
|
||||
if not uuid_val:
|
||||
continue
|
||||
if item.get("is_deleted"):
|
||||
remote_deleted_uuids.add(uuid_val)
|
||||
else:
|
||||
remote_by_uuid[uuid_val] = item
|
||||
|
||||
local_objs = db.query(model_class).all()
|
||||
local_by_uuid = {}
|
||||
local_deleted_uuids = set()
|
||||
for obj in local_objs:
|
||||
if obj.uuid:
|
||||
if obj.is_deleted:
|
||||
local_deleted_uuids.add(obj.uuid)
|
||||
local_by_uuid[obj.uuid] = obj
|
||||
|
||||
all_uuids = set(remote_by_uuid.keys()) | set(local_by_uuid.keys()) | remote_deleted_uuids | local_deleted_uuids
|
||||
|
||||
for uuid_val in all_uuids:
|
||||
remote_item = remote_by_uuid.get(uuid_val)
|
||||
local_obj = local_by_uuid.get(uuid_val)
|
||||
remote_deleted = uuid_val in remote_deleted_uuids and uuid_val not in remote_by_uuid
|
||||
local_deleted = uuid_val in local_deleted_uuids and uuid_val not in local_by_uuid
|
||||
|
||||
# 两边都删除了 → 什么都不做
|
||||
if remote_deleted and local_deleted:
|
||||
continue
|
||||
|
||||
# 远端删除了,本地还在 → 删除本地
|
||||
if remote_deleted and local_obj:
|
||||
local_obj.is_deleted = True
|
||||
local_obj.sync_version = (local_obj.sync_version or 0) + 1
|
||||
stats["deleted"] += 1
|
||||
continue
|
||||
|
||||
# 本地删除了,远端还在 → 删除远端标记(本地占优在这里意味着:远端也应标记删除)
|
||||
# 这里简单处理:如果本地标记了删除但远端还活着,以远端为准拉取回来
|
||||
if local_deleted and not remote_deleted and not local_obj and remote_item:
|
||||
kwargs = _item_to_model_kwargs(remote_item, model_class)
|
||||
kwargs.pop("is_deleted", None)
|
||||
new_obj = model_class(**kwargs)
|
||||
db.add(new_obj)
|
||||
stats["pulled"] += 1
|
||||
continue
|
||||
|
||||
# 仅远端有 → 拉取到本地
|
||||
if remote_item and not local_obj:
|
||||
kwargs = _item_to_model_kwargs(remote_item, model_class)
|
||||
kwargs.pop("is_deleted", None)
|
||||
new_obj = model_class(**kwargs)
|
||||
db.add(new_obj)
|
||||
stats["pulled"] += 1
|
||||
continue
|
||||
|
||||
# 仅本地有 → 会在最后统一上传时推送到远端
|
||||
if not remote_item and local_obj and not local_deleted:
|
||||
stats["pushed"] += 1
|
||||
continue
|
||||
|
||||
# 两边都有 → LWW 合并
|
||||
if remote_item and local_obj:
|
||||
remote_ver = remote_item.get("sync_version", 1) or 1
|
||||
local_ver = local_obj.sync_version or 1
|
||||
|
||||
if remote_ver > local_ver:
|
||||
kwargs = _item_to_model_kwargs(remote_item, model_class)
|
||||
kwargs.pop("is_deleted", None)
|
||||
kwargs.pop("sync_version", None)
|
||||
for key, val in kwargs.items():
|
||||
if val is not None or key in getattr(local_obj, '__clearable_fields__', set()):
|
||||
setattr(local_obj, key, val)
|
||||
local_obj.sync_version = remote_ver
|
||||
stats["merged"] += 1
|
||||
elif local_ver > remote_ver:
|
||||
local_obj.sync_version = local_ver
|
||||
stats["pushed"] += 1
|
||||
else:
|
||||
# 版本相同,以远端为准
|
||||
kwargs = _item_to_model_kwargs(remote_item, model_class)
|
||||
kwargs.pop("is_deleted", None)
|
||||
kwargs.pop("sync_version", None)
|
||||
for key, val in kwargs.items():
|
||||
setattr(local_obj, key, val)
|
||||
local_obj.sync_version = local_ver + 1
|
||||
stats["merged"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
# 合并关联表
|
||||
for assoc_name, assoc_table, left_name, right_name in ASSOCIATION_COLLECTIONS:
|
||||
remote_data = client.download_json(f"{assoc_name}.json")
|
||||
if remote_data is None:
|
||||
continue
|
||||
|
||||
left_model = MODEL_MAP.get(left_name)
|
||||
right_model = MODEL_MAP.get(right_name)
|
||||
if not left_model or not right_model:
|
||||
continue
|
||||
|
||||
remote_pairs = set()
|
||||
for item in remote_data.get("items", []):
|
||||
left_uuid = item.get(f"{left_name}_uuid")
|
||||
right_uuid = item.get(f"{right_name}_uuid")
|
||||
if left_uuid and right_uuid:
|
||||
remote_pairs.add((left_uuid, right_uuid))
|
||||
|
||||
local_pairs = set()
|
||||
rows = db.execute(assoc_table.select()).fetchall()
|
||||
for row in rows:
|
||||
left_id, right_id = row[0], row[1]
|
||||
left_obj = db.query(left_model).filter(left_model.id == left_id).first()
|
||||
right_obj = db.query(right_model).filter(right_model.id == right_id).first()
|
||||
if left_obj and right_obj and left_obj.uuid and right_obj.uuid:
|
||||
local_pairs.add((left_obj.uuid, right_obj.uuid))
|
||||
|
||||
merged_pairs = local_pairs | remote_pairs
|
||||
db.execute(assoc_table.delete())
|
||||
for left_uuid, right_uuid in merged_pairs:
|
||||
left_obj = db.query(left_model).filter(left_model.uuid == left_uuid).first()
|
||||
right_obj = db.query(right_model).filter(right_model.uuid == right_uuid).first()
|
||||
if left_obj and right_obj:
|
||||
db.execute(assoc_table.insert().values(left_id=left_obj.id, right_id=right_obj.id))
|
||||
db.commit()
|
||||
|
||||
# 合并 user_settings
|
||||
remote_prefs = client.download_json("user_settings.json")
|
||||
if remote_prefs and remote_prefs.get("items"):
|
||||
pref = remote_prefs["items"][0]
|
||||
user_settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
|
||||
if user_settings:
|
||||
for field in USER_SETTINGS_SYNC_FIELDS:
|
||||
if field in pref and pref[field] is not None:
|
||||
val = pref[field]
|
||||
if isinstance(val, str) and field == "birthday":
|
||||
try:
|
||||
val = date_type.fromisoformat(val)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
setattr(user_settings, field, val)
|
||||
db.commit()
|
||||
|
||||
# 统一上传合并后的数据到远端
|
||||
_upload_all_to_remote(db, client)
|
||||
|
||||
settings.last_sync_at = utcnow()
|
||||
settings.last_sync_version = (settings.last_sync_version or 0) + 1
|
||||
settings.sync_enabled = True
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": f"同步完成: 推送 {stats['pushed']}, 拉取 {stats['pulled']}, 合并 {stats['merged']}, 删除 {stats['deleted']}"}
|
||||
except Exception as e:
|
||||
logger.error(f"双向同步失败: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||||
finally:
|
||||
release_sync_lock()
|
||||
|
||||
|
||||
def _upload_all_to_remote(db: Session, client: WebDAVClient):
|
||||
"""将本地所有数据上传到远端"""
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
items = [_serialize_model(obj, model_class) for obj in db.query(model_class).all()]
|
||||
client.upload_json(f"{coll_name}.json", {
|
||||
"version": 1,
|
||||
"collection": coll_name,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": items,
|
||||
})
|
||||
|
||||
for assoc_name, assoc_table, left_name, right_name in ASSOCIATION_COLLECTIONS:
|
||||
left_model = MODEL_MAP.get(left_name)
|
||||
right_model = MODEL_MAP.get(right_name)
|
||||
if not left_model or not right_model:
|
||||
continue
|
||||
rows = db.execute(assoc_table.select()).fetchall()
|
||||
items = [_serialize_association(row, left_model, right_model, db) for row in rows]
|
||||
items = [i for i in items if i is not None]
|
||||
client.upload_json(f"{assoc_name}.json", {
|
||||
"version": 1,
|
||||
"collection": assoc_name,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": items,
|
||||
})
|
||||
|
||||
user_settings = db.query(UserSettings).filter(UserSettings.id == 1).first()
|
||||
if user_settings:
|
||||
pref_data = {}
|
||||
for field in USER_SETTINGS_SYNC_FIELDS:
|
||||
val = getattr(user_settings, field, None)
|
||||
if isinstance(val, (datetime, date_type)):
|
||||
val = val.isoformat() if val else None
|
||||
pref_data[field] = val
|
||||
client.upload_json("user_settings.json", {
|
||||
"version": 1,
|
||||
"collection": "user_settings",
|
||||
"updated_at": utcnow().isoformat(),
|
||||
"items": [pref_data],
|
||||
})
|
||||
|
||||
manifest = {
|
||||
"version": 1,
|
||||
"last_sync_at": utcnow().isoformat(),
|
||||
"collections": {},
|
||||
}
|
||||
for coll_name, model_class in SYNC_COLLECTIONS:
|
||||
count = db.query(model_class).filter(model_class.is_deleted == False).count()
|
||||
manifest["collections"][coll_name] = {
|
||||
"count": count,
|
||||
"updated_at": utcnow().isoformat(),
|
||||
}
|
||||
client.upload_json("manifest.json", manifest)
|
||||
150
api/app/utils/webdav.py
Normal file
150
api/app/utils/webdav.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
WebDAV 客户端工具
|
||||
基于 requests 实现,兼容 Alist 等 WebDAV 服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class WebDAVClient:
|
||||
"""WebDAV 客户端,用于与 Alist 等 WebDAV 服务交互"""
|
||||
|
||||
def __init__(self, url: str, username: str, password: str, path: str = "/elysia-todo/"):
|
||||
self.base_url = url.rstrip("/")
|
||||
self.username = username
|
||||
self.password = username # Alist 使用用户名作为密码
|
||||
self.auth = HTTPBasicAuth(username, self.password)
|
||||
self.path = path if path.startswith("/") else f"/{path}"
|
||||
self._session = requests.Session()
|
||||
self._session.auth = self.auth
|
||||
self._session.timeout = 30
|
||||
self._session.headers.update({"Content-Type": "application/json"})
|
||||
|
||||
@property
|
||||
def _data_url(self) -> str:
|
||||
return f"{self.base_url}{self.path}data/"
|
||||
|
||||
@property
|
||||
def _backups_url(self) -> str:
|
||||
return f"{self.base_url}{self.path}backups/"
|
||||
|
||||
def _url(self, filename: str) -> str:
|
||||
return f"{self._data_url}{filename}"
|
||||
|
||||
def _manifest_url(self) -> str:
|
||||
return f"{self.base_url}{self.path}manifest.json"
|
||||
|
||||
def test_connection(self) -> tuple[bool, str]:
|
||||
"""测试 WebDAV 连接,返回 (成功, 消息)"""
|
||||
try:
|
||||
resp = self._session.request("PROPFIND", f"{self.base_url}{self.path}", headers={"Depth": "0"})
|
||||
if resp.status_code in (200, 207, 404):
|
||||
return True, "连接成功"
|
||||
return False, f"连接失败: HTTP {resp.status_code}"
|
||||
except requests.ConnectionError:
|
||||
return False, "连接失败: 无法连接到服务器"
|
||||
except requests.Timeout:
|
||||
return False, "连接超时"
|
||||
except Exception as e:
|
||||
return False, f"连接失败: {str(e)}"
|
||||
|
||||
def ensure_dirs(self) -> bool:
|
||||
"""确保远端目录结构存在"""
|
||||
try:
|
||||
for path in [self.path, f"{self.path}data/", f"{self.path}backups/"]:
|
||||
url = f"{self.base_url}{path}"
|
||||
self._session.request("PROPFIND", url, headers={"Depth": "0"})
|
||||
resp = self._session.request("MKCOL", url)
|
||||
if resp.status_code in (200, 201, 405, 301):
|
||||
pass
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建远端目录失败: {e}")
|
||||
return False
|
||||
|
||||
def upload_json(self, filename: str, data: Any) -> bool:
|
||||
"""上传 JSON 数据到 WebDAV"""
|
||||
try:
|
||||
url = self._url(filename) if filename != "manifest.json" else self._manifest_url()
|
||||
content = json.dumps(data, ensure_ascii=False, indent=2, default=str).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
resp = self._session.put(url, data=content, headers=headers)
|
||||
if resp.status_code in (200, 201, 204):
|
||||
return True
|
||||
logger.error(f"上传 {filename} 失败: HTTP {resp.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"上传 {filename} 异常: {e}")
|
||||
return False
|
||||
|
||||
def download_json(self, filename: str) -> Any | None:
|
||||
"""从 WebDAV 下载 JSON 数据"""
|
||||
try:
|
||||
url = self._url(filename) if filename != "manifest.json" else self._manifest_url()
|
||||
resp = self._session.get(url)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
logger.error(f"下载 {filename} 失败: HTTP {resp.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载 {filename} 异常: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, filename: str) -> bool:
|
||||
"""删除 WebDAV 上的文件"""
|
||||
try:
|
||||
url = self._url(filename) if filename != "manifest.json" else self._manifest_url()
|
||||
resp = self._session.delete(url)
|
||||
return resp.status_code in (200, 204, 404)
|
||||
except Exception as e:
|
||||
logger.error(f"删除 {filename} 异常: {e}")
|
||||
return False
|
||||
|
||||
def backup_remote(self, timestamp: str) -> bool:
|
||||
"""备份远端数据到 backups/{timestamp}/"""
|
||||
try:
|
||||
backup_path = f"{self.path}backups/{timestamp}/data/"
|
||||
backup_url = f"{self.base_url}{backup_path}"
|
||||
self._session.request("MKCOL", f"{self.base_url}{self.path}backups/")
|
||||
self._session.request("MKCOL", f"{self.base_url}{self.path}backups/{timestamp}/")
|
||||
self._session.request("MKCOL", backup_url)
|
||||
|
||||
for filename in [
|
||||
"manifest.json", "user_settings.json", "categories.json",
|
||||
"tasks.json", "tags.json", "task_tags.json",
|
||||
"habit_groups.json", "habits.json", "habit_checkins.json",
|
||||
"anniversary_categories.json", "anniversaries.json",
|
||||
"goals.json", "goal_steps.json", "goal_reviews.json", "goal_tasks.json",
|
||||
]:
|
||||
data = self.download_json(filename)
|
||||
if data is not None:
|
||||
src_url = self._url(filename) if filename != "manifest.json" else self._manifest_url()
|
||||
dst_url = f"{backup_url}{filename}" if filename != "manifest.json" else f"{self.base_url}{backup_path}../manifest.json"
|
||||
content = json.dumps(data, ensure_ascii=False, indent=2, default=str).encode("utf-8")
|
||||
self._session.put(dst_url, data=content, headers={"Content-Type": "application/json"})
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"备份远端数据失败: {e}")
|
||||
return False
|
||||
|
||||
def clear_remote(self) -> bool:
|
||||
"""清空远端数据目录"""
|
||||
filenames = [
|
||||
"manifest.json", "user_settings.json", "categories.json",
|
||||
"tasks.json", "tags.json", "task_tags.json",
|
||||
"habit_groups.json", "habits.json", "habit_checkins.json",
|
||||
"anniversary_categories.json", "anniversaries.json",
|
||||
"goals.json", "goal_steps.json", "goal_reviews.json", "goal_tasks.json",
|
||||
]
|
||||
for f in filenames:
|
||||
self.delete_file(f)
|
||||
return True
|
||||
Reference in New Issue
Block a user