Files
wiki-tcg/backend/trade_manager.py
2026-03-26 00:51:25 +01:00

284 lines
7.7 KiB
Python

import asyncio
import uuid
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
from sqlalchemy.orm import Session
from models import Card as CardModel, DeckCard as DeckCardModel
logger = logging.getLogger("app")
## Storage
@dataclass
class TradeOffer:
username: str
cards: list[dict] = field(default_factory=list)
accepted: bool = False
@dataclass
class TradeSession:
trade_id: str
offers: dict[str, TradeOffer] # user_id -> TradeOffer
connections: dict[str, WebSocket] = field(default_factory=dict)
active_trades: dict[str, TradeSession] = {}
@dataclass
class TradeQueueEntry:
user_id: str
username: str
websocket: WebSocket
trade_queue: list[TradeQueueEntry] = []
trade_queue_lock = asyncio.Lock()
## Serialization
def serialize_card_model(card: CardModel) -> dict:
return {
"id": str(card.id),
"name": card.name,
"card_rarity": card.card_rarity,
"card_type": card.card_type,
"image_link": card.image_link,
"attack": card.attack,
"defense": card.defense,
"cost": card.cost,
"text": card.text,
"created_at": card.created_at.isoformat() if card.created_at else None,
}
def serialize_trade(session: TradeSession, perspective_user_id: str) -> dict:
partner_id = next(uid for uid in session.offers if uid != perspective_user_id)
my_offer = session.offers[perspective_user_id]
their_offer = session.offers[partner_id]
return {
"trade_id": session.trade_id,
"partner_username": their_offer.username,
"my_offer": {
"cards": my_offer.cards,
"accepted": my_offer.accepted,
},
"their_offer": {
"cards": their_offer.cards,
"accepted": their_offer.accepted,
},
}
## Broadcasting
async def broadcast_trade(session: TradeSession) -> None:
for user_id, ws in list(session.connections.items()):
try:
await ws.send_json({
"type": "state",
"state": serialize_trade(session, user_id),
})
except Exception:
pass
## Matchmaking
async def try_trade_match() -> None:
async with trade_queue_lock:
if len(trade_queue) < 2:
return
# Guard: same user queued twice
if trade_queue[0].user_id == trade_queue[1].user_id:
return
p1 = trade_queue.pop(0)
p2 = trade_queue.pop(0)
trade_id = str(uuid.uuid4())
session = TradeSession(
trade_id=trade_id,
offers={
p1.user_id: TradeOffer(username=p1.username),
p2.user_id: TradeOffer(username=p2.username),
},
connections={
p1.user_id: p1.websocket,
p2.user_id: p2.websocket,
},
)
active_trades[trade_id] = session
for entry in [p1, p2]:
try:
await entry.websocket.send_json({"type": "trade_start", "trade_id": trade_id})
except Exception:
pass
## Action handling
async def handle_trade_action(
trade_id: str,
user_id: str,
message: dict,
db: Session,
) -> None:
session = active_trades.get(trade_id)
if not session:
return
action = message.get("type")
ws = session.connections.get(user_id)
if action == "update_offer":
card_ids = message.get("card_ids", [])
if card_ids:
try:
parsed_ids = [uuid.UUID(cid) for cid in card_ids]
except ValueError:
if ws:
await ws.send_json({"type": "error", "message": "Invalid card IDs"})
return
db_cards = db.query(CardModel).filter(
CardModel.id.in_(parsed_ids),
CardModel.user_id == uuid.UUID(user_id),
).all()
if len(db_cards) != len(card_ids):
if ws:
await ws.send_json({"type": "error", "message": "Some cards are not in your collection"})
return
# Preserve the order of card_ids
card_map = {str(c.id): c for c in db_cards}
ordered = [card_map[cid] for cid in card_ids if cid in card_map]
session.offers[user_id].cards = [serialize_card_model(c) for c in ordered]
else:
session.offers[user_id].cards = []
# Any offer change unaccepts both sides
for offer in session.offers.values():
offer.accepted = False
await broadcast_trade(session)
elif action == "accept":
either_has_cards = any(len(o.cards) > 0 for o in session.offers.values())
if not either_has_cards:
return
# Validate ownership of offered cards one more time
my_offer = session.offers[user_id]
if my_offer.cards:
owned_count = db.query(CardModel).filter(
CardModel.id.in_([uuid.UUID(c["id"]) for c in my_offer.cards]),
CardModel.user_id == uuid.UUID(user_id),
).count()
if owned_count != len(my_offer.cards):
if ws:
await ws.send_json({"type": "error", "message": "Some offered cards are no longer in your collection"})
return
my_offer.accepted = True
if all(o.accepted for o in session.offers.values()):
await _complete_trade(trade_id, db)
else:
await broadcast_trade(session)
elif action == "unaccept":
session.offers[user_id].accepted = False
await broadcast_trade(session)
## Trade completion
async def _complete_trade(trade_id: str, db: Session) -> None:
session = active_trades.get(trade_id)
if not session:
return
# Re-check that both sides are still accepted and have a non-empty offer.
# A last-second unaccept or offer change (race or client bug) should abort.
if not all(o.accepted for o in session.offers.values()):
await broadcast_trade(session)
return
if not any(len(o.cards) > 0 for o in session.offers.values()):
for offer in session.offers.values():
offer.accepted = False
await broadcast_trade(session)
return
user_ids = list(session.offers.keys())
u1, u2 = user_ids[0], user_ids[1]
cards_u1 = session.offers[u1].cards # u1 gives these to u2
cards_u2 = session.offers[u2].cards # u2 gives these to u1
# Final ownership double-check before writing
def verify(from_id: str, card_dicts: list[dict]) -> bool:
if not card_dicts:
return True
card_uuids = [uuid.UUID(c["id"]) for c in card_dicts]
count = db.query(CardModel).filter(
CardModel.id.in_(card_uuids),
CardModel.user_id == uuid.UUID(from_id),
).count()
return count == len(card_uuids)
if not verify(u1, cards_u1) or not verify(u2, cards_u2):
db.rollback()
for ws in list(session.connections.values()):
try:
await ws.send_json({
"type": "error",
"message": "Trade failed: ownership check failed. Offers have been reset.",
})
except Exception:
pass
for offer in session.offers.values():
offer.accepted = False
await broadcast_trade(session)
return
# Transfer ownership and clear deck relationships
for cid_str in [c["id"] for c in cards_u1]:
cid = uuid.UUID(cid_str)
card = db.query(CardModel).filter(CardModel.id == cid).first()
if card:
card.user_id = uuid.UUID(u2)
db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete()
for cid_str in [c["id"] for c in cards_u2]:
cid = uuid.UUID(cid_str)
card = db.query(CardModel).filter(CardModel.id == cid).first()
if card:
card.user_id = uuid.UUID(u1)
db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete()
db.commit()
active_trades.pop(trade_id, None)
for ws in list(session.connections.values()):
try:
await ws.send_json({"type": "trade_complete"})
except Exception:
pass
## Disconnect handling
async def handle_trade_disconnect(trade_id: str, user_id: str) -> None:
session = active_trades.get(trade_id)
if not session:
return
active_trades.pop(trade_id, None)
for uid, ws in list(session.connections.items()):
if uid == user_id:
continue
try:
await ws.send_json({
"type": "error",
"message": "Your trade partner disconnected. Trade cancelled.",
})
except Exception:
pass