from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from sqlalchemy import func, distinct from typing import Optional, List from datetime import date, timedelta from app.database import get_db from app.models.habit import Habit, HabitGroup, HabitCheckin from app.schemas.habit import ( HabitGroupCreate, HabitGroupUpdate, HabitGroupResponse, HabitCreate, HabitUpdate, HabitResponse, CheckinCreate, CheckinResponse, HabitStatsResponse, ) from app.schemas.common import DeleteResponse from app.utils.crud import get_or_404 from app.utils.datetime import utcnow, today from app.utils.logger import logger router = APIRouter(tags=["习惯"]) # ============ 习惯分组 CRUD ============ habit_group_router = APIRouter(prefix="/api/habit-groups", tags=["习惯分组"]) @habit_group_router.get("", response_model=List[HabitGroupResponse]) def get_habit_groups(db: Session = Depends(get_db)): """获取所有习惯分组""" try: groups = db.query(HabitGroup).order_by(HabitGroup.sort_order, HabitGroup.id).all() return groups except Exception as e: logger.error(f"获取习惯分组失败: {str(e)}") raise HTTPException(status_code=500, detail="获取习惯分组失败") @habit_group_router.post("", response_model=HabitGroupResponse, status_code=201) def create_habit_group(data: HabitGroupCreate, db: Session = Depends(get_db)): """创建习惯分组""" try: group = HabitGroup(**data.model_dump()) db.add(group) db.commit() db.refresh(group) logger.info(f"创建习惯分组成功: id={group.id}, name={group.name}") return group except Exception as e: db.rollback() logger.error(f"创建习惯分组失败: {str(e)}") raise HTTPException(status_code=500, detail="创建习惯分组失败") @habit_group_router.put("/{group_id}", response_model=HabitGroupResponse) def update_habit_group(group_id: int, data: HabitGroupUpdate, db: Session = Depends(get_db)): """更新习惯分组""" try: group = get_or_404(db, HabitGroup, group_id, "习惯分组") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(group, field, value) db.commit() db.refresh(group) logger.info(f"更新习惯分组成功: id={group_id}") return group except HTTPException: raise except Exception as e: db.rollback() logger.error(f"更新习惯分组失败: {str(e)}") raise HTTPException(status_code=500, detail="更新习惯分组失败") @habit_group_router.delete("/{group_id}") def delete_habit_group(group_id: int, db: Session = Depends(get_db)): """删除习惯分组(习惯的 group_id 会被置空)""" try: group = get_or_404(db, HabitGroup, group_id, "习惯分组") # 将该分组下所有习惯的 group_id 置空 db.query(Habit).filter(Habit.group_id == group_id).update({Habit.group_id: None}) db.delete(group) db.commit() logger.info(f"删除习惯分组成功: id={group_id}") return DeleteResponse(message="习惯分组删除成功") except HTTPException: raise except Exception as e: db.rollback() logger.error(f"删除习惯分组失败: {str(e)}") raise HTTPException(status_code=500, detail="删除习惯分组失败") # ============ 习惯 CRUD ============ habit_router = APIRouter(prefix="/api/habits", tags=["习惯"]) @habit_router.get("", response_model=List[HabitResponse]) def get_habits( include_archived: bool = Query(False, description="是否包含已归档的习惯"), db: Session = Depends(get_db), ): """获取所有习惯""" try: query = db.query(Habit) if not include_archived: query = query.filter(Habit.is_archived == False) habits = query.order_by(Habit.created_at.desc()).all() return habits except Exception as e: logger.error(f"获取习惯列表失败: {str(e)}") raise HTTPException(status_code=500, detail="获取习惯列表失败") @habit_router.post("", response_model=HabitResponse, status_code=201) def create_habit(data: HabitCreate, db: Session = Depends(get_db)): """创建习惯""" try: habit = Habit(**data.model_dump()) db.add(habit) db.commit() db.refresh(habit) logger.info(f"创建习惯成功: id={habit.id}, name={habit.name}") return habit except Exception as e: db.rollback() logger.error(f"创建习惯失败: {str(e)}") raise HTTPException(status_code=500, detail="创建习惯失败") @habit_router.put("/{habit_id}", response_model=HabitResponse) def update_habit(habit_id: int, data: HabitUpdate, db: Session = Depends(get_db)): """更新习惯""" try: habit = get_or_404(db, Habit, habit_id, "习惯") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): if value is not None or field in data.clearable_fields: setattr(habit, field, value) habit.updated_at = utcnow() db.commit() db.refresh(habit) logger.info(f"更新习惯成功: id={habit_id}") return habit except HTTPException: raise except Exception as e: db.rollback() logger.error(f"更新习惯失败: {str(e)}") raise HTTPException(status_code=500, detail="更新习惯失败") @habit_router.delete("/{habit_id}") def delete_habit(habit_id: int, db: Session = Depends(get_db)): """删除习惯(级联删除打卡记录)""" try: habit = get_or_404(db, Habit, habit_id, "习惯") db.delete(habit) db.commit() logger.info(f"删除习惯成功: id={habit_id}") return DeleteResponse(message="习惯删除成功") except HTTPException: raise except Exception as e: db.rollback() logger.error(f"删除习惯失败: {str(e)}") raise HTTPException(status_code=500, detail="删除习惯失败") @habit_router.patch("/{habit_id}/archive", response_model=HabitResponse) def toggle_archive_habit(habit_id: int, db: Session = Depends(get_db)): """切换习惯归档状态""" try: habit = get_or_404(db, Habit, habit_id, "习惯") habit.is_archived = not habit.is_archived habit.updated_at = utcnow() db.commit() db.refresh(habit) logger.info(f"切换习惯归档状态成功: id={habit_id}, is_archived={habit.is_archived}") return habit except HTTPException: raise except Exception as e: db.rollback() logger.error(f"切换习惯归档状态失败: {str(e)}") raise HTTPException(status_code=500, detail="切换习惯归档状态失败") # ============ 打卡 ============ checkin_router = APIRouter(prefix="/api/habits/{habit_id}/checkins", tags=["习惯打卡"]) @checkin_router.get("", response_model=List[CheckinResponse]) def get_checkins( habit_id: int, from_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"), to_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"), db: Session = Depends(get_db), ): """获取习惯打卡记录""" try: get_or_404(db, Habit, habit_id, "习惯") query = db.query(HabitCheckin).filter(HabitCheckin.habit_id == habit_id) if from_date: query = query.filter(HabitCheckin.checkin_date >= date.fromisoformat(from_date)) if to_date: query = query.filter(HabitCheckin.checkin_date <= date.fromisoformat(to_date)) checkins = query.order_by(HabitCheckin.checkin_date.desc()).all() return checkins except HTTPException: raise except Exception as e: logger.error(f"获取打卡记录失败: {str(e)}") raise HTTPException(status_code=500, detail="获取打卡记录失败") @checkin_router.post("", response_model=CheckinResponse) def create_checkin( habit_id: int, data: Optional[CheckinCreate] = None, db: Session = Depends(get_db), ): """打卡(当天 count 累加)""" try: habit = get_or_404(db, Habit, habit_id, "习惯") today_date = today() add_count = data.count if data else 1 # 查找今日已有记录 checkin = db.query(HabitCheckin).filter( HabitCheckin.habit_id == habit_id, HabitCheckin.checkin_date == today_date, ).first() if checkin: checkin.count += add_count else: checkin = HabitCheckin( habit_id=habit_id, checkin_date=today_date, count=add_count, ) db.add(checkin) db.commit() db.refresh(checkin) logger.info(f"打卡成功: habit_id={habit_id}, date={today_date}, count={checkin.count}") return checkin except HTTPException: raise except Exception as e: db.rollback() logger.error(f"打卡失败: {str(e)}") raise HTTPException(status_code=500, detail="打卡失败") @checkin_router.delete("") def cancel_checkin( habit_id: int, count: int = Query(1, ge=1, description="取消的打卡次数"), db: Session = Depends(get_db), ): """取消今日打卡(count-1,为0时删除记录)""" try: habit = get_or_404(db, Habit, habit_id, "习惯") today_date = today() checkin = db.query(HabitCheckin).filter( HabitCheckin.habit_id == habit_id, HabitCheckin.checkin_date == today_date, ).first() if not checkin: return DeleteResponse(message="今日无打卡记录") checkin.count = max(0, checkin.count - count) if checkin.count <= 0: db.delete(checkin) db.commit() logger.info(f"取消打卡: habit_id={habit_id}, date={today_date}") return DeleteResponse(message="取消打卡成功") except HTTPException: raise except Exception as e: db.rollback() logger.error(f"取消打卡失败: {str(e)}") raise HTTPException(status_code=500, detail="取消打卡失败") @checkin_router.get("/stats", response_model=HabitStatsResponse) def get_habit_stats(habit_id: int, db: Session = Depends(get_db)): """获取习惯统计数据""" try: habit = get_or_404(db, Habit, habit_id, "习惯") today_date = today() # 今日打卡 today_checkin = db.query(HabitCheckin).filter( HabitCheckin.habit_id == habit_id, HabitCheckin.checkin_date == today_date, ).first() today_count = today_checkin.count if today_checkin else 0 today_completed = today_count >= habit.target_count # 所有完成打卡的日期(count >= target_count) completed_dates = [ row[0] for row in db.query(HabitCheckin.checkin_date).filter( HabitCheckin.habit_id == habit_id, HabitCheckin.count >= habit.target_count, ).order_by(HabitCheckin.checkin_date).all() ] total_days = len(completed_dates) # 计算连续天数(从今天往回推算) current_streak = 0 check_date = today_date # 如果今天还没完成,从昨天开始算 if not today_completed: check_date = check_date - timedelta(days=1) while True: if check_date in completed_dates: current_streak += 1 check_date -= timedelta(days=1) else: break # 计算最长连续天数 longest_streak = 0 streak = 0 prev_date = None for d in completed_dates: if prev_date is None or d == prev_date + timedelta(days=1): streak += 1 else: streak = 1 longest_streak = max(longest_streak, streak) prev_date = d return HabitStatsResponse( total_days=total_days, current_streak=current_streak, longest_streak=longest_streak, today_count=today_count, today_completed=today_completed, ) except HTTPException: raise except Exception as e: logger.error(f"获取习惯统计失败: {str(e)}") raise HTTPException(status_code=500, detail="获取习惯统计失败") # 将子路由组合到主路由 router.include_router(habit_group_router) router.include_router(habit_router) router.include_router(checkin_router)