import uuid from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect from sqlalchemy.orm import Session from services import notification_manager from core.auth import decode_access_token from core.database import get_db from core.dependencies import get_current_user from core.models import Notification as NotificationModel from core.models import User as UserModel router = APIRouter() def _serialize_notification(n: NotificationModel) -> dict: return { "id": str(n.id), "type": n.type, "payload": n.payload, "read": n.read, "created_at": n.created_at.isoformat(), "expires_at": n.expires_at.isoformat() if n.expires_at else None, } @router.websocket("/ws/notifications") async def notifications_endpoint(websocket: WebSocket, db: Session = Depends(get_db)): await websocket.accept() token = await websocket.receive_text() user_id = decode_access_token(token) if not user_id: await websocket.close(code=1008) return user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first() if not user: await websocket.close(code=1008) return notification_manager.register(user_id, websocket) # Flush all unread (non-expired) notifications on connect now = datetime.now() pending = ( db.query(NotificationModel) .filter( NotificationModel.user_id == uuid.UUID(user_id), NotificationModel.read == False, (NotificationModel.expires_at == None) | (NotificationModel.expires_at > now), ) .order_by(NotificationModel.created_at.asc()) .all() ) await websocket.send_json({ "type": "flush", "notifications": [_serialize_notification(n) for n in pending], }) try: while True: await websocket.receive_text() # keep connection alive; server only pushes except WebSocketDisconnect: notification_manager.unregister(user_id) @router.get("/notifications") def get_notifications(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)): now = datetime.now() notifications = ( db.query(NotificationModel) .filter( NotificationModel.user_id == user.id, (NotificationModel.expires_at == None) | (NotificationModel.expires_at > now), ) .order_by(NotificationModel.created_at.desc()) .all() ) return [_serialize_notification(n) for n in notifications] @router.post("/notifications/{notification_id}/read") def mark_notification_read( notification_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db), ): n = db.query(NotificationModel).filter( NotificationModel.id == uuid.UUID(notification_id), NotificationModel.user_id == user.id, ).first() if not n: raise HTTPException(status_code=404, detail="Notification not found") n.read = True db.commit() return {"ok": True} @router.delete("/notifications/{notification_id}") def delete_notification( notification_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db), ): n = db.query(NotificationModel).filter( NotificationModel.id == uuid.UUID(notification_id), NotificationModel.user_id == user.id, ).first() if not n: raise HTTPException(status_code=404, detail="Notification not found") db.delete(n) db.commit() return {"ok": True}