import asyncio import uuid import logging from dataclasses import dataclass, field from fastapi import WebSocket from sqlalchemy.orm import Session from models import Card as CardModel, DeckCard as DeckCardModel logger = logging.getLogger("app") ## Storage @dataclass class TradeOffer: username: str cards: list[dict] = field(default_factory=list) accepted: bool = False @dataclass class TradeSession: trade_id: str offers: dict[str, TradeOffer] # user_id -> TradeOffer connections: dict[str, WebSocket] = field(default_factory=dict) active_trades: dict[str, TradeSession] = {} @dataclass class TradeQueueEntry: user_id: str username: str websocket: WebSocket trade_queue: list[TradeQueueEntry] = [] trade_queue_lock = asyncio.Lock() ## Serialization def serialize_card_model(card: CardModel) -> dict: return { "id": str(card.id), "name": card.name, "card_rarity": card.card_rarity, "card_type": card.card_type, "image_link": card.image_link, "attack": card.attack, "defense": card.defense, "cost": card.cost, "text": card.text, "created_at": card.created_at.isoformat() if card.created_at else None, } def serialize_trade(session: TradeSession, perspective_user_id: str) -> dict: partner_id = next(uid for uid in session.offers if uid != perspective_user_id) my_offer = session.offers[perspective_user_id] their_offer = session.offers[partner_id] return { "trade_id": session.trade_id, "partner_username": their_offer.username, "my_offer": { "cards": my_offer.cards, "accepted": my_offer.accepted, }, "their_offer": { "cards": their_offer.cards, "accepted": their_offer.accepted, }, } ## Broadcasting async def broadcast_trade(session: TradeSession) -> None: for user_id, ws in list(session.connections.items()): try: await ws.send_json({ "type": "state", "state": serialize_trade(session, user_id), }) except Exception: pass ## Matchmaking async def try_trade_match() -> None: async with trade_queue_lock: if len(trade_queue) < 2: return # Guard: same user queued twice if trade_queue[0].user_id == trade_queue[1].user_id: return p1 = trade_queue.pop(0) p2 = trade_queue.pop(0) trade_id = str(uuid.uuid4()) session = TradeSession( trade_id=trade_id, offers={ p1.user_id: TradeOffer(username=p1.username), p2.user_id: TradeOffer(username=p2.username), }, connections={ p1.user_id: p1.websocket, p2.user_id: p2.websocket, }, ) active_trades[trade_id] = session for entry in [p1, p2]: try: await entry.websocket.send_json({"type": "trade_start", "trade_id": trade_id}) except Exception: pass ## Action handling async def handle_trade_action( trade_id: str, user_id: str, message: dict, db: Session, ) -> None: session = active_trades.get(trade_id) if not session: return action = message.get("type") ws = session.connections.get(user_id) if action == "update_offer": card_ids = message.get("card_ids", []) if card_ids: try: parsed_ids = [uuid.UUID(cid) for cid in card_ids] except ValueError: if ws: await ws.send_json({"type": "error", "message": "Invalid card IDs"}) return db_cards = db.query(CardModel).filter( CardModel.id.in_(parsed_ids), CardModel.user_id == uuid.UUID(user_id), ).all() if len(db_cards) != len(card_ids): if ws: await ws.send_json({"type": "error", "message": "Some cards are not in your collection"}) return # Preserve the order of card_ids card_map = {str(c.id): c for c in db_cards} ordered = [card_map[cid] for cid in card_ids if cid in card_map] session.offers[user_id].cards = [serialize_card_model(c) for c in ordered] else: session.offers[user_id].cards = [] # Any offer change unaccepts both sides for offer in session.offers.values(): offer.accepted = False await broadcast_trade(session) elif action == "accept": either_has_cards = any(len(o.cards) > 0 for o in session.offers.values()) if not either_has_cards: return # Validate ownership of offered cards one more time my_offer = session.offers[user_id] if my_offer.cards: owned_count = db.query(CardModel).filter( CardModel.id.in_([uuid.UUID(c["id"]) for c in my_offer.cards]), CardModel.user_id == uuid.UUID(user_id), ).count() if owned_count != len(my_offer.cards): if ws: await ws.send_json({"type": "error", "message": "Some offered cards are no longer in your collection"}) return my_offer.accepted = True if all(o.accepted for o in session.offers.values()): await _complete_trade(trade_id, db) else: await broadcast_trade(session) elif action == "unaccept": session.offers[user_id].accepted = False await broadcast_trade(session) ## Trade completion async def _complete_trade(trade_id: str, db: Session) -> None: session = active_trades.get(trade_id) if not session: return # Re-check that both sides are still accepted and have a non-empty offer. # A last-second unaccept or offer change (race or client bug) should abort. if not all(o.accepted for o in session.offers.values()): await broadcast_trade(session) return if not any(len(o.cards) > 0 for o in session.offers.values()): for offer in session.offers.values(): offer.accepted = False await broadcast_trade(session) return user_ids = list(session.offers.keys()) u1, u2 = user_ids[0], user_ids[1] cards_u1 = session.offers[u1].cards # u1 gives these to u2 cards_u2 = session.offers[u2].cards # u2 gives these to u1 # Final ownership double-check before writing def verify(from_id: str, card_dicts: list[dict]) -> bool: if not card_dicts: return True card_uuids = [uuid.UUID(c["id"]) for c in card_dicts] count = db.query(CardModel).filter( CardModel.id.in_(card_uuids), CardModel.user_id == uuid.UUID(from_id), ).count() return count == len(card_uuids) if not verify(u1, cards_u1) or not verify(u2, cards_u2): db.rollback() for ws in list(session.connections.values()): try: await ws.send_json({ "type": "error", "message": "Trade failed: ownership check failed. Offers have been reset.", }) except Exception: pass for offer in session.offers.values(): offer.accepted = False await broadcast_trade(session) return # Transfer ownership and clear deck relationships for cid_str in [c["id"] for c in cards_u1]: cid = uuid.UUID(cid_str) card = db.query(CardModel).filter(CardModel.id == cid).first() if card: card.user_id = uuid.UUID(u2) db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete() for cid_str in [c["id"] for c in cards_u2]: cid = uuid.UUID(cid_str) card = db.query(CardModel).filter(CardModel.id == cid).first() if card: card.user_id = uuid.UUID(u1) db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete() db.commit() active_trades.pop(trade_id, None) for ws in list(session.connections.values()): try: await ws.send_json({"type": "trade_complete"}) except Exception: pass ## Disconnect handling async def handle_trade_disconnect(trade_id: str, user_id: str) -> None: session = active_trades.get(trade_id) if not session: return active_trades.pop(trade_id, None) for uid, ws in list(session.connections.items()): if uid == user_id: continue try: await ws.send_json({ "type": "error", "message": "Your trade partner disconnected. Trade cancelled.", }) except Exception: pass