🐐
This commit is contained in:
0
backend/ai/__init__.py
Normal file
0
backend/ai/__init__.py
Normal file
176
backend/ai/card_pick_nn.py
Normal file
176
backend/ai/card_pick_nn.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai.nn import NeuralNet, _softmax
|
||||
|
||||
# Separate weights file so this NN trains independently from the plan NN.
|
||||
CARD_PICK_WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), "card_pick_weights.json")
|
||||
|
||||
N_CARD_FEATURES = 15
|
||||
|
||||
# Normalization constants — chosen to cover the realistic stat range for generated cards.
|
||||
_MAX_ATK = 50.0
|
||||
_MAX_DEF = 100.0
|
||||
|
||||
|
||||
def _precompute_static_features(allowed: list) -> np.ndarray:
|
||||
"""
|
||||
Vectorized precomputation of the 7 per-card static features for the whole pool.
|
||||
Returns (n, 7) float32. Called once per choose_cards() invocation.
|
||||
"""
|
||||
n = len(allowed)
|
||||
atk = np.array([c.attack for c in allowed], dtype=np.float32)
|
||||
defn = np.array([c.defense for c in allowed], dtype=np.float32)
|
||||
cost = np.array([c.cost for c in allowed], dtype=np.float32)
|
||||
rar = np.array([c.card_rarity.value for c in allowed], dtype=np.float32)
|
||||
typ = np.array([c.card_type.value for c in allowed], dtype=np.float32)
|
||||
|
||||
exact_cost = np.minimum(10.0, np.maximum(1.0, ((atk**2 + defn**2)**0.18) / 1.5))
|
||||
total = atk + defn
|
||||
atk_ratio = np.where(total > 0, atk / total, 0.5)
|
||||
pcv_norm = np.clip(exact_cost - cost, 0.0, 1.0)
|
||||
|
||||
out = np.empty((n, 7), dtype=np.float32)
|
||||
out[:, 0] = atk / _MAX_ATK
|
||||
out[:, 1] = defn / _MAX_DEF
|
||||
out[:, 2] = cost / 10.0
|
||||
out[:, 3] = rar / 5.0
|
||||
out[:, 4] = atk_ratio
|
||||
out[:, 5] = pcv_norm
|
||||
out[:, 6] = typ / 9.0
|
||||
return out
|
||||
|
||||
|
||||
class CardPickPlayer:
|
||||
"""
|
||||
Uses a NeuralNet to sequentially select cards from a pool until the cost
|
||||
budget is exhausted. API mirrors NeuralPlayer so training code stays uniform.
|
||||
|
||||
In training mode: samples stochastically (softmax) and records the
|
||||
trajectory for a REINFORCE update after the game ends.
|
||||
In inference mode: picks the highest-scoring affordable card at each step.
|
||||
|
||||
Performance design:
|
||||
- Static per-card features (7) are computed once via vectorized numpy.
|
||||
- Context features (8) use running totals updated by O(1) increments.
|
||||
- Picked cards are tracked with a boolean mask; no list.remove() calls.
|
||||
- Each pick step does one small forward pass over the affordable subset only.
|
||||
"""
|
||||
|
||||
def __init__(self, net: NeuralNet, training: bool = False, temperature: float = 1.0):
|
||||
self.net = net
|
||||
self.training = training
|
||||
self.temperature = temperature
|
||||
self.trajectory: list[tuple[np.ndarray, int]] = [] # (features_matrix, chosen_idx)
|
||||
|
||||
def choose_cards(self, allowed: list, difficulty: int) -> list:
|
||||
"""
|
||||
allowed: pre-filtered list of Card objects (cost ≤ max_card_cost already applied).
|
||||
Returns the selected deck as a list of Cards.
|
||||
"""
|
||||
BUDGET = 50
|
||||
n = len(allowed)
|
||||
|
||||
static = _precompute_static_features(allowed) # (n, 7) — computed once
|
||||
costs = np.array([c.cost for c in allowed], dtype=np.float32)
|
||||
picked = np.zeros(n, dtype=bool)
|
||||
|
||||
budget_remaining = BUDGET
|
||||
selected: list = []
|
||||
|
||||
# Running totals for context features — incremented O(1) per pick.
|
||||
n_picked = 0
|
||||
sum_atk = 0.0
|
||||
sum_def = 0.0
|
||||
sum_cost = 0.0
|
||||
n_cheap = 0 # cost ≤ 3
|
||||
n_high = 0 # cost ≥ 6
|
||||
|
||||
diff_norm = difficulty / 10.0
|
||||
|
||||
while True:
|
||||
mask = (~picked) & (costs <= budget_remaining)
|
||||
if not mask.any():
|
||||
break
|
||||
|
||||
idxs = np.where(mask)[0]
|
||||
|
||||
# Context row — same for every candidate this step, broadcast via tile.
|
||||
if n_picked > 0:
|
||||
ctx = np.array([
|
||||
n_picked / 30.0,
|
||||
budget_remaining / 50.0,
|
||||
sum_atk / n_picked / _MAX_ATK,
|
||||
sum_def / n_picked / _MAX_DEF,
|
||||
sum_cost / n_picked / 10.0,
|
||||
n_cheap / n_picked,
|
||||
n_high / n_picked,
|
||||
diff_norm,
|
||||
], dtype=np.float32)
|
||||
else:
|
||||
ctx = np.array([
|
||||
0.0, budget_remaining / 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, diff_norm,
|
||||
], dtype=np.float32)
|
||||
|
||||
features = np.concatenate(
|
||||
[static[idxs], np.tile(ctx, (len(idxs), 1))],
|
||||
axis=1,
|
||||
)
|
||||
scores = self.net.forward(features)
|
||||
|
||||
if self.training:
|
||||
probs = _softmax((scores / self.temperature).astype(np.float64))
|
||||
probs = np.clip(probs, 1e-10, None)
|
||||
probs /= probs.sum()
|
||||
local_idx = int(np.random.choice(len(idxs), p=probs))
|
||||
self.trajectory.append((features, local_idx))
|
||||
else:
|
||||
local_idx = int(np.argmax(scores))
|
||||
|
||||
global_idx = idxs[local_idx]
|
||||
card = allowed[global_idx]
|
||||
picked[global_idx] = True
|
||||
selected.append(card)
|
||||
|
||||
# Incremental context update — O(1).
|
||||
budget_remaining -= card.cost
|
||||
n_picked += 1
|
||||
sum_atk += card.attack
|
||||
sum_def += card.defense
|
||||
sum_cost += card.cost
|
||||
if card.cost <= 3: n_cheap += 1
|
||||
if card.cost >= 6: n_high += 1
|
||||
|
||||
return selected
|
||||
|
||||
def compute_grads(self, outcome: float) -> tuple[list, list] | None:
|
||||
"""
|
||||
REINFORCE gradients averaged over the pick trajectory.
|
||||
outcome: centered reward (win/loss minus baseline).
|
||||
Returns (grads_w, grads_b), or None if no picks were made.
|
||||
"""
|
||||
if not self.trajectory:
|
||||
return None
|
||||
|
||||
acc_gw = [np.zeros_like(w) for w in self.net.weights]
|
||||
acc_gb = [np.zeros_like(b) for b in self.net.biases]
|
||||
|
||||
for features, chosen_idx in self.trajectory:
|
||||
scores = self.net.forward(features)
|
||||
probs = _softmax(scores.astype(np.float64)).astype(np.float32)
|
||||
upstream = -probs.copy()
|
||||
upstream[chosen_idx] += 1.0
|
||||
upstream *= outcome
|
||||
gw, gb = self.net.backward(upstream)
|
||||
for i in range(len(acc_gw)):
|
||||
acc_gw[i] += gw[i]
|
||||
acc_gb[i] += gb[i]
|
||||
|
||||
n = len(self.trajectory)
|
||||
for i in range(len(acc_gw)):
|
||||
acc_gw[i] /= n
|
||||
acc_gb[i] /= n
|
||||
|
||||
self.trajectory.clear()
|
||||
return acc_gw, acc_gb
|
||||
1
backend/ai/card_pick_weights.json
Normal file
1
backend/ai/card_pick_weights.json
Normal file
File diff suppressed because one or more lines are too long
@@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
import random
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import combinations, permutations
|
||||
|
||||
import numpy as np
|
||||
from card import Card
|
||||
from game import action_play_card, action_sacrifice, action_end_turn, BOARD_SIZE, STARTING_LIFE, PlayerState
|
||||
|
||||
from game.card import Card
|
||||
from game.rules import action_play_card, action_sacrifice, action_end_turn, BOARD_SIZE, STARTING_LIFE, PlayerState
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
@@ -77,7 +80,21 @@ def choose_cards(cards: list[Card], difficulty: int, personality: AIPersonality)
|
||||
elif personality == AIPersonality.CONTROL:
|
||||
# Small cost_norm keeps flavour without causing severe deck shrinkage at D10
|
||||
scores = 0.85 * pcv_norm + 0.15 * cost_norm
|
||||
elif personality in (AIPersonality.BALANCED, AIPersonality.JEBRASKA):
|
||||
elif personality == AIPersonality.BALANCED:
|
||||
scores = 0.60 * pcv_norm + 0.25 * atk_ratio + 0.15 * (1.0 - atk_ratio)
|
||||
elif personality == AIPersonality.JEBRASKA:
|
||||
# Delegate entirely to the card-pick NN; skip the heuristic scoring path.
|
||||
from ai.card_pick_nn import CardPickPlayer, CARD_PICK_WEIGHTS_PATH
|
||||
from ai.nn import NeuralNet
|
||||
if not hasattr(choose_cards, "_card_pick_net"):
|
||||
choose_cards._card_pick_net = (
|
||||
NeuralNet.load(CARD_PICK_WEIGHTS_PATH)
|
||||
if os.path.exists(CARD_PICK_WEIGHTS_PATH) else None
|
||||
)
|
||||
net = choose_cards._card_pick_net
|
||||
if net is not None:
|
||||
return CardPickPlayer(net, training=False).choose_cards(allowed, difficulty)
|
||||
# Fall through to BALANCED heuristic if weights aren't trained yet.
|
||||
scores = 0.60 * pcv_norm + 0.25 * atk_ratio + 0.15 * (1.0 - atk_ratio)
|
||||
else: # ARBITRARY
|
||||
w = 0.09 * difficulty
|
||||
@@ -97,7 +114,7 @@ def choose_cards(cards: list[Card], difficulty: int, personality: AIPersonality)
|
||||
AIPersonality.DEFENSIVE: 15, # raised: stable cheap-card base across difficulty levels
|
||||
AIPersonality.CONTROL: 8,
|
||||
AIPersonality.BALANCED: 25, # spread the deck across all cost levels
|
||||
AIPersonality.JEBRASKA: 25, # same as balanced
|
||||
AIPersonality.JEBRASKA: 25, # fallback (no trained weights yet)
|
||||
AIPersonality.ARBITRARY: 8,
|
||||
}[personality]
|
||||
|
||||
@@ -320,14 +337,14 @@ def choose_plan(player: PlayerState, opponent: PlayerState, personality: AIPerso
|
||||
plans = generate_plans(player, opponent)
|
||||
|
||||
if personality == AIPersonality.JEBRASKA:
|
||||
from nn import NeuralNet
|
||||
from ai.nn import NeuralNet
|
||||
import os
|
||||
_weights = os.path.join(os.path.dirname(__file__), "nn_weights.json")
|
||||
if not hasattr(choose_plan, "_neural_net"):
|
||||
choose_plan._neural_net = NeuralNet.load(_weights) if os.path.exists(_weights) else None
|
||||
net = choose_plan._neural_net
|
||||
if net is not None:
|
||||
from nn import extract_plan_features
|
||||
from ai.nn import extract_plan_features
|
||||
scores = net.forward(extract_plan_features(plans, player, opponent))
|
||||
else: # fallback to BALANCED if weights not found
|
||||
scores = score_plans_batch(plans, player, opponent, AIPersonality.BALANCED)
|
||||
@@ -339,7 +356,7 @@ def choose_plan(player: PlayerState, opponent: PlayerState, personality: AIPerso
|
||||
return plans[int(np.argmax(scores + noise))]
|
||||
|
||||
async def run_ai_turn(game_id: str):
|
||||
from game_manager import (
|
||||
from game.manager import (
|
||||
active_games, connections, active_deck_ids,
|
||||
serialize_state, record_game_result, calculate_combat_animation_time
|
||||
)
|
||||
@@ -421,7 +438,7 @@ async def run_ai_turn(game_id: str):
|
||||
await send_state(state)
|
||||
|
||||
if state.result:
|
||||
from database import SessionLocal
|
||||
from core.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
record_game_result(state, db)
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Layout: [state(8) | my_board(15) | opp_board(15) | plan(3) | result_board(15) | opp_deck_type(8)]
|
||||
N_FEATURES = 64
|
||||
|
||||
@@ -137,7 +138,7 @@ def extract_plan_features(plans: list, player, opponent) -> np.ndarray:
|
||||
Returns (n_plans, N_FEATURES) float32 array.
|
||||
Layout: [state(8) | my_board(15) | opp_board(15) | plan(3) | result_board(15)]
|
||||
"""
|
||||
from game import BOARD_SIZE, HAND_SIZE, MAX_ENERGY_CAP, STARTING_LIFE
|
||||
from game.rules import BOARD_SIZE, HAND_SIZE, MAX_ENERGY_CAP, STARTING_LIFE
|
||||
|
||||
n = len(plans)
|
||||
|
||||
@@ -217,7 +218,7 @@ class NeuralPlayer:
|
||||
self.trajectory: list[tuple[np.ndarray, int]] = [] # (features, chosen_idx)
|
||||
|
||||
def choose_plan(self, player, opponent):
|
||||
from ai import generate_plans
|
||||
from ai.engine import generate_plans
|
||||
plans = generate_plans(player, opponent)
|
||||
features = extract_plan_features(plans, player, opponent)
|
||||
scores = self.net.forward(features)
|
||||
1
backend/ai/nn_weights.json
Normal file
1
backend/ai/nn_weights.json
Normal file
File diff suppressed because one or more lines are too long
@@ -1,21 +1,21 @@
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
import asyncio
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from card import Card, CardType, CardRarity, generate_cards, compute_deck_type
|
||||
from game import (
|
||||
from game.card import Card, CardType, CardRarity, generate_cards, compute_deck_type
|
||||
from game.rules import (
|
||||
CardInstance, PlayerState, GameState,
|
||||
action_play_card, action_sacrifice, action_end_turn,
|
||||
)
|
||||
from ai import AIPersonality, choose_cards, choose_plan
|
||||
from ai.engine import AIPersonality, choose_cards, choose_plan
|
||||
|
||||
SIMULATION_CARDS_PATH = os.path.join(os.path.dirname(__file__), "simulation_cards.json")
|
||||
SIMULATION_CARD_COUNT = 1000
|
||||
@@ -24,7 +24,7 @@ SIMULATION_CARD_COUNT = 1000
|
||||
def _card_to_dict(card: Card) -> dict:
|
||||
return {
|
||||
"name": card.name,
|
||||
"created_at": card.created_at.isoformat(),
|
||||
"generated_at": card.generated_at.isoformat(),
|
||||
"image_link": card.image_link,
|
||||
"card_rarity": card.card_rarity.name,
|
||||
"card_type": card.card_type.name,
|
||||
@@ -39,7 +39,7 @@ def _card_to_dict(card: Card) -> dict:
|
||||
def _dict_to_card(d: dict) -> Card:
|
||||
return Card(
|
||||
name=d["name"],
|
||||
created_at=datetime.fromisoformat(d["created_at"]),
|
||||
generated_at=datetime.fromisoformat(d["generated_at"]),
|
||||
image_link=d["image_link"],
|
||||
card_rarity=CardRarity[d["card_rarity"]],
|
||||
card_type=CardType[d["card_type"]],
|
||||
@@ -609,7 +609,7 @@ def draw_grid(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
difficulties = list(range(7, 11))
|
||||
difficulties = list(range(8, 11))
|
||||
|
||||
card_pool = get_simulation_cards()
|
||||
players = _all_players(difficulties)
|
||||
@@ -1,27 +1,39 @@
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from card import compute_deck_type
|
||||
from ai import AIPersonality, choose_cards, choose_plan
|
||||
from game import PlayerState, GameState, action_play_card, action_sacrifice, action_end_turn
|
||||
from simulate import get_simulation_cards, _make_instances, MAX_TURNS
|
||||
from nn import NeuralNet, NeuralPlayer
|
||||
from game.card import compute_deck_type
|
||||
from ai.engine import AIPersonality, choose_cards, choose_plan
|
||||
from game.rules import PlayerState, GameState, action_play_card, action_sacrifice, action_end_turn
|
||||
from ai.simulate import get_simulation_cards, _make_instances, MAX_TURNS
|
||||
from ai.nn import NeuralNet, NeuralPlayer
|
||||
from ai.card_pick_nn import CardPickPlayer, N_CARD_FEATURES, CARD_PICK_WEIGHTS_PATH
|
||||
|
||||
NN_WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), "nn_weights.json")
|
||||
|
||||
P1 = "p1"
|
||||
P2 = "p2"
|
||||
|
||||
FIXED_PERSONALITIES = [p for p in AIPersonality if p != AIPersonality.ARBITRARY]
|
||||
FIXED_PERSONALITIES = [
|
||||
p for p in AIPersonality
|
||||
if p not in (
|
||||
AIPersonality.ARBITRARY,
|
||||
AIPersonality.JEBRASKA
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _build_player(pid: str, name: str, cards: list, difficulty: int, personality: AIPersonality) -> PlayerState:
|
||||
deck = choose_cards(cards, difficulty, personality)
|
||||
def _build_player(pid: str, name: str, cards: list, difficulty: int, personality: AIPersonality,
|
||||
deck_pool: dict | None = None) -> PlayerState:
|
||||
if deck_pool and personality in deck_pool:
|
||||
deck = random.choice(deck_pool[personality])
|
||||
else:
|
||||
deck = choose_cards(cards, difficulty, personality)
|
||||
instances = _make_instances(deck)
|
||||
random.shuffle(instances)
|
||||
p = PlayerState(
|
||||
@@ -32,6 +44,21 @@ def _build_player(pid: str, name: str, cards: list, difficulty: int, personality
|
||||
return p
|
||||
|
||||
|
||||
def _build_nn_player(pid: str, name: str, cards: list, difficulty: int,
|
||||
card_pick_player: CardPickPlayer) -> PlayerState:
|
||||
"""Build a PlayerState using the card-pick NN for deck selection."""
|
||||
max_card_cost = difficulty + 1 if difficulty >= 6 else 6
|
||||
allowed = [c for c in cards if c.cost <= max_card_cost] or list(cards)
|
||||
deck = card_pick_player.choose_cards(allowed, difficulty)
|
||||
instances = _make_instances(deck)
|
||||
random.shuffle(instances)
|
||||
return PlayerState(
|
||||
user_id=pid, username=name,
|
||||
deck_type=compute_deck_type(deck) or "Balanced",
|
||||
deck=instances,
|
||||
)
|
||||
|
||||
|
||||
def run_episode(
|
||||
p1_state: PlayerState,
|
||||
p2_state: PlayerState,
|
||||
@@ -81,25 +108,40 @@ def run_episode(
|
||||
|
||||
|
||||
def train(
|
||||
n_episodes: int = 20_000,
|
||||
self_play_start: int = 5_000,
|
||||
self_play_max_frac: float = 0.4,
|
||||
n_episodes: int = 50_000,
|
||||
self_play_start: int = 0,
|
||||
self_play_max_frac: float = 0.9,
|
||||
lr: float = 1e-3,
|
||||
opp_difficulty: int = 10,
|
||||
temperature: float = 1.0,
|
||||
batch_size: int = 50,
|
||||
batch_size: int = 500,
|
||||
save_every: int = 5_000,
|
||||
save_path: str = NN_WEIGHTS_PATH,
|
||||
) -> NeuralNet:
|
||||
cards = get_simulation_cards()
|
||||
|
||||
# Pre-build a pool of opponent decks per personality to avoid rebuilding from scratch each episode.
|
||||
DECK_POOL_SIZE = 100
|
||||
opp_deck_pool: dict[AIPersonality, list] = {
|
||||
p: [choose_cards(cards, opp_difficulty, p) for _ in range(DECK_POOL_SIZE)]
|
||||
for p in FIXED_PERSONALITIES
|
||||
}
|
||||
|
||||
if os.path.exists(save_path):
|
||||
print(f"Resuming from {save_path}")
|
||||
print(f"Resuming plan net from {save_path}")
|
||||
net = NeuralNet.load(save_path)
|
||||
else:
|
||||
print("Initializing new network")
|
||||
print("Initializing new plan network")
|
||||
net = NeuralNet(seed=42)
|
||||
|
||||
cp_path = CARD_PICK_WEIGHTS_PATH
|
||||
if os.path.exists(cp_path):
|
||||
print(f"Resuming card-pick net from {cp_path}")
|
||||
card_pick_net = NeuralNet.load(cp_path)
|
||||
else:
|
||||
print("Initializing new card-pick network")
|
||||
card_pick_net = NeuralNet(n_features=N_CARD_FEATURES, hidden=(32, 16), seed=43)
|
||||
|
||||
recent_outcomes: deque[int] = deque(maxlen=1000) # rolling window for win rate display
|
||||
baseline = 0.0 # EMA of recent outcomes; subtracted before each update
|
||||
baseline_alpha = 0.99 # decay — roughly a 100-episode window
|
||||
@@ -108,6 +150,10 @@ def train(
|
||||
batch_gb = [np.zeros_like(b) for b in net.biases]
|
||||
batch_count = 0
|
||||
|
||||
cp_batch_gw = [np.zeros_like(w) for w in card_pick_net.weights]
|
||||
cp_batch_gb = [np.zeros_like(b) for b in card_pick_net.biases]
|
||||
cp_batch_count = 0
|
||||
|
||||
for episode in range(1, n_episodes + 1):
|
||||
# Ramp self-play fraction linearly from 0 to self_play_max_frac
|
||||
if episode >= self_play_start:
|
||||
@@ -122,9 +168,11 @@ def train(
|
||||
if random.random() < self_play_prob:
|
||||
nn1 = NeuralPlayer(net, training=True, temperature=temperature)
|
||||
nn2 = NeuralPlayer(net, training=True, temperature=temperature)
|
||||
cp1 = CardPickPlayer(card_pick_net, training=True, temperature=temperature)
|
||||
cp2 = CardPickPlayer(card_pick_net, training=True, temperature=temperature)
|
||||
|
||||
p1_state = _build_player(P1, "NN1", cards, 10, AIPersonality.BALANCED)
|
||||
p2_state = _build_player(P2, "NN2", cards, 10, AIPersonality.BALANCED)
|
||||
p1_state = _build_nn_player(P1, "NN1", cards, 10, cp1)
|
||||
p2_state = _build_nn_player(P2, "NN2", cards, 10, cp2)
|
||||
|
||||
if not nn_goes_first:
|
||||
p1_state, p2_state = p2_state, p1_state
|
||||
@@ -142,20 +190,30 @@ def train(
|
||||
batch_gb[i] += gb[i]
|
||||
batch_count += 1
|
||||
|
||||
for cp_grads in [cp1.compute_grads(p1_outcome - baseline),
|
||||
cp2.compute_grads(-p1_outcome - baseline)]:
|
||||
if cp_grads is not None:
|
||||
gw, gb = cp_grads
|
||||
for i in range(len(cp_batch_gw)):
|
||||
cp_batch_gw[i] += gw[i]
|
||||
cp_batch_gb[i] += gb[i]
|
||||
cp_batch_count += 1
|
||||
|
||||
else:
|
||||
opp_personality = random.choice(FIXED_PERSONALITIES)
|
||||
nn_player = NeuralPlayer(net, training=True, temperature=temperature)
|
||||
cp_player = CardPickPlayer(card_pick_net, training=True, temperature=temperature)
|
||||
opp_ctrl = lambda p, o, pers=opp_personality, diff=opp_difficulty: choose_plan(p, o, pers, diff)
|
||||
|
||||
if nn_goes_first:
|
||||
nn_id = P1
|
||||
p1_state = _build_player(P1, "NN", cards, 10, AIPersonality.BALANCED)
|
||||
p2_state = _build_player(P2, "OPP", cards, opp_difficulty, opp_personality)
|
||||
p1_state = _build_nn_player(P1, "NN", cards, 10, cp_player)
|
||||
p2_state = _build_player(P2, "OPP", cards, opp_difficulty, opp_personality, opp_deck_pool)
|
||||
winner = run_episode(p1_state, p2_state, nn_player.choose_plan, opp_ctrl)
|
||||
else:
|
||||
nn_id = P2
|
||||
p1_state = _build_player(P1, "OPP", cards, opp_difficulty, opp_personality)
|
||||
p2_state = _build_player(P2, "NN", cards, 10, AIPersonality.BALANCED)
|
||||
p1_state = _build_player(P1, "OPP", cards, opp_difficulty, opp_personality, opp_deck_pool)
|
||||
p2_state = _build_nn_player(P2, "NN", cards, 10, cp_player)
|
||||
winner = run_episode(p1_state, p2_state, opp_ctrl, nn_player.choose_plan)
|
||||
|
||||
nn_outcome = 1.0 if winner == nn_id else -1.0
|
||||
@@ -169,6 +227,14 @@ def train(
|
||||
batch_gb[i] += gb[i]
|
||||
batch_count += 1
|
||||
|
||||
cp_grads = cp_player.compute_grads(nn_outcome - baseline)
|
||||
if cp_grads is not None:
|
||||
gw, gb = cp_grads
|
||||
for i in range(len(cp_batch_gw)):
|
||||
cp_batch_gw[i] += gw[i]
|
||||
cp_batch_gb[i] += gb[i]
|
||||
cp_batch_count += 1
|
||||
|
||||
recent_outcomes.append(1 if winner == nn_id else 0)
|
||||
|
||||
if batch_count >= batch_size:
|
||||
@@ -180,16 +246,29 @@ def train(
|
||||
batch_gb = [np.zeros_like(b) for b in net.biases]
|
||||
batch_count = 0
|
||||
|
||||
if cp_batch_count >= batch_size:
|
||||
for i in range(len(cp_batch_gw)):
|
||||
cp_batch_gw[i] /= cp_batch_count
|
||||
cp_batch_gb[i] /= cp_batch_count
|
||||
card_pick_net.adam_update(cp_batch_gw, cp_batch_gb, lr=lr)
|
||||
cp_batch_gw = [np.zeros_like(w) for w in card_pick_net.weights]
|
||||
cp_batch_gb = [np.zeros_like(b) for b in card_pick_net.biases]
|
||||
cp_batch_count = 0
|
||||
|
||||
if episode % 1000 == 0 or episode == n_episodes:
|
||||
wr = sum(recent_outcomes) / len(recent_outcomes) if recent_outcomes else 0.0
|
||||
print(f"[{episode:>6}/{n_episodes}] win rate (last {len(recent_outcomes)}): {wr:.1%} "
|
||||
print(f"\r[{episode:>6}/{n_episodes}] win rate (last {len(recent_outcomes)}): {wr:.1%} "
|
||||
f"self-play frac: {self_play_prob:.0%}", flush=True)
|
||||
else:
|
||||
print(f" {episode % 1000}/1000", end="\r", flush=True)
|
||||
|
||||
if episode % save_every == 0:
|
||||
net.save(save_path)
|
||||
print(f" → saved to {save_path}")
|
||||
card_pick_net.save(cp_path)
|
||||
print(f" → saved to {save_path} and {cp_path}")
|
||||
|
||||
net.save(save_path)
|
||||
card_pick_net.save(cp_path)
|
||||
wr = sum(recent_outcomes) / len(recent_outcomes) if recent_outcomes else 0.0
|
||||
print(f"Done. Final win rate (last {len(recent_outcomes)}): {wr:.1%}")
|
||||
return net
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool, create_engine
|
||||
|
||||
from alembic import context
|
||||
from models import Base
|
||||
from core.models import Base
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add trade_wishlist to users
|
||||
|
||||
Revision ID: 0fc168f5970d
|
||||
Revises: e70b992e5d95
|
||||
Create Date: 2026-03-27 23:01:32.739184
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0fc168f5970d'
|
||||
down_revision: Union[str, Sequence[str], None] = 'e70b992e5d95'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('trade_wishlist', sa.Text(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'trade_wishlist')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add_game_challenges_table
|
||||
|
||||
Revision ID: 29da7c818b01
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-03-28 23:20:21.949520
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '29da7c818b01'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a1b2c3d4e5f6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('game_challenges',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('challenger_id', sa.UUID(), nullable=False),
|
||||
sa.Column('challenged_id', sa.UUID(), nullable=False),
|
||||
sa.Column('challenger_deck_id', sa.UUID(), nullable=False),
|
||||
sa.Column('status', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['challenged_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['challenger_deck_id'], ['decks.id'], ),
|
||||
sa.ForeignKeyConstraint(['challenger_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.drop_index(op.f('ix_trade_proposals_proposer_status'), table_name='trade_proposals')
|
||||
op.drop_index(op.f('ix_trade_proposals_recipient_status'), table_name='trade_proposals')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(op.f('ix_trade_proposals_recipient_status'), 'trade_proposals', ['recipient_id', 'status'], unique=False)
|
||||
op.create_index(op.f('ix_trade_proposals_proposer_status'), 'trade_proposals', ['proposer_id', 'status'], unique=False)
|
||||
op.drop_table('game_challenges')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add_processed_webhook_events
|
||||
|
||||
Revision ID: 4603709eb82d
|
||||
Revises: d1e2f3a4b5c6
|
||||
Create Date: 2026-03-30 00:30:05.493030
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4603709eb82d'
|
||||
down_revision: Union[str, Sequence[str], None] = 'd1e2f3a4b5c6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('processed_webhook_events',
|
||||
sa.Column('stripe_event_id', sa.String(), nullable=False),
|
||||
sa.Column('processed_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('stripe_event_id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('processed_webhook_events')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,55 @@
|
||||
"""trade_proposals_multi_requested_cards
|
||||
|
||||
Revision ID: 58fc464be769
|
||||
Revises: cfac344e21b4
|
||||
Create Date: 2026-03-28 22:09:44.129838
|
||||
|
||||
Replace single requested_card_id FK with requested_card_ids JSONB array so proposals
|
||||
can request zero or more cards, mirroring the real-time trade system's flexibility.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '58fc464be769'
|
||||
down_revision: Union[str, Sequence[str], None] = 'cfac344e21b4'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column, migrate existing data, then drop the old column
|
||||
op.add_column('trade_proposals',
|
||||
sa.Column('requested_card_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True)
|
||||
)
|
||||
# Migrate any existing rows: wrap the single FK UUID into a JSON array
|
||||
op.execute("""
|
||||
UPDATE trade_proposals
|
||||
SET requested_card_ids = json_build_array(requested_card_id::text)::jsonb
|
||||
WHERE requested_card_id IS NOT NULL
|
||||
""")
|
||||
op.execute("""
|
||||
UPDATE trade_proposals
|
||||
SET requested_card_ids = '[]'::jsonb
|
||||
WHERE requested_card_ids IS NULL
|
||||
""")
|
||||
op.alter_column('trade_proposals', 'requested_card_ids', nullable=False)
|
||||
op.drop_constraint('trade_proposals_requested_card_id_fkey', 'trade_proposals', type_='foreignkey')
|
||||
op.drop_column('trade_proposals', 'requested_card_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column('trade_proposals',
|
||||
sa.Column('requested_card_id', sa.UUID(), nullable=True)
|
||||
)
|
||||
# Best-effort reverse: take first element of the array if present
|
||||
op.execute("""
|
||||
UPDATE trade_proposals
|
||||
SET requested_card_id = (requested_card_ids->0)::text::uuid
|
||||
WHERE jsonb_array_length(requested_card_ids) > 0
|
||||
""")
|
||||
op.drop_column('trade_proposals', 'requested_card_ids')
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add_fk_cascade_constraints
|
||||
|
||||
Revision ID: 8283acd4cbcc
|
||||
Revises: a2b3c4d5e6f7
|
||||
Create Date: 2026-03-29 13:55:46.488121
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '8283acd4cbcc'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a2b3c4d5e6f7'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.drop_constraint(op.f('cards_user_id_fkey'), 'cards', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('cards_user_id_fkey'), 'cards', 'users', ['user_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('deck_cards_card_id_fkey'), 'deck_cards', type_='foreignkey')
|
||||
op.drop_constraint(op.f('deck_cards_deck_id_fkey'), 'deck_cards', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('deck_cards_deck_id_fkey'), 'deck_cards', 'decks', ['deck_id'], ['id'], ondelete='CASCADE')
|
||||
op.create_foreign_key(op.f('deck_cards_card_id_fkey'), 'deck_cards', 'cards', ['card_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('decks_user_id_fkey'), 'decks', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('decks_user_id_fkey'), 'decks', 'users', ['user_id'], ['id'], ondelete='CASCADE')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_constraint(op.f('decks_user_id_fkey'), 'decks', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('decks_user_id_fkey'), 'decks', 'users', ['user_id'], ['id'])
|
||||
op.drop_constraint(op.f('deck_cards_deck_id_fkey'), 'deck_cards', type_='foreignkey')
|
||||
op.drop_constraint(op.f('deck_cards_card_id_fkey'), 'deck_cards', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('deck_cards_deck_id_fkey'), 'deck_cards', 'decks', ['deck_id'], ['id'])
|
||||
op.create_foreign_key(op.f('deck_cards_card_id_fkey'), 'deck_cards', 'cards', ['card_id'], ['id'])
|
||||
op.drop_constraint(op.f('cards_user_id_fkey'), 'cards', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('cards_user_id_fkey'), 'cards', 'users', ['user_id'], ['id'])
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add_received_at_rename_generated_at_on_cards
|
||||
|
||||
Revision ID: 98e23cab7057
|
||||
Revises: 0fc168f5970d
|
||||
Create Date: 2026-03-28 18:07:12.712311
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '98e23cab7057'
|
||||
down_revision: Union[str, Sequence[str], None] = '0fc168f5970d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.alter_column('cards', 'created_at', new_column_name='generated_at')
|
||||
op.add_column('cards', sa.Column('received_at', sa.DateTime(), nullable=True))
|
||||
op.execute("UPDATE cards SET received_at = generated_at")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_column('cards', 'received_at')
|
||||
op.alter_column('cards', 'generated_at', new_column_name='created_at')
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add last_active_at to users
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 58fc464be769
|
||||
Create Date: 2026-03-28
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
down_revision: Union[str, Sequence[str], None] = '58fc464be769'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('users', sa.Column('last_active_at', sa.DateTime(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('users', 'last_active_at')
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add_unique_constraint_friendship
|
||||
|
||||
Revision ID: a2b3c4d5e6f7
|
||||
Revises: f4e8a1b2c3d9
|
||||
Create Date: 2026-03-29 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a2b3c4d5e6f7'
|
||||
down_revision: Union[str, Sequence[str], None] = 'f4e8a1b2c3d9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove duplicate (requester_id, addressee_id) pairs that already exist,
|
||||
# keeping the earliest row per pair before adding the unique constraint.
|
||||
conn = op.get_bind()
|
||||
conn.execute(text("""
|
||||
DELETE FROM friendships
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY requester_id, addressee_id
|
||||
ORDER BY created_at
|
||||
) AS rn
|
||||
FROM friendships
|
||||
) sub
|
||||
WHERE rn > 1
|
||||
)
|
||||
"""))
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_friendship_requester_addressee",
|
||||
"friendships",
|
||||
["requester_id", "addressee_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("uq_friendship_requester_addressee", "friendships", type_="unique")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_friendships_table
|
||||
|
||||
Revision ID: b989aae3e37d
|
||||
Revises: de721927ff59
|
||||
Create Date: 2026-03-28 19:14:54.623287
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b989aae3e37d'
|
||||
down_revision: Union[str, Sequence[str], None] = 'de721927ff59'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('friendships',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('requester_id', sa.UUID(), nullable=False),
|
||||
sa.Column('addressee_id', sa.UUID(), nullable=False),
|
||||
sa.Column('status', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['addressee_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['requester_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('friendships')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add_check_constraints_on_status_fields
|
||||
|
||||
Revision ID: c1d2e3f4a5b6
|
||||
Revises: 8283acd4cbcc
|
||||
Create Date: 2026-03-29 14:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c1d2e3f4a5b6'
|
||||
down_revision: Union[str, Sequence[str], None] = '8283acd4cbcc'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_check_constraint("ck_friendships_status", "friendships", "status IN ('pending', 'accepted', 'declined')")
|
||||
op.create_check_constraint("ck_trade_proposals_status", "trade_proposals", "status IN ('pending', 'accepted', 'declined', 'expired', 'withdrawn')")
|
||||
op.create_check_constraint("ck_game_challenges_status", "game_challenges", "status IN ('pending', 'accepted', 'declined', 'expired', 'withdrawn')")
|
||||
op.create_check_constraint("ck_notifications_type", "notifications", "type IN ('friend_request', 'trade_offer', 'game_challenge')")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("ck_notifications_type", "notifications", type_="check")
|
||||
op.drop_constraint("ck_game_challenges_status", "game_challenges", type_="check")
|
||||
op.drop_constraint("ck_trade_proposals_status", "trade_proposals", type_="check")
|
||||
op.drop_constraint("ck_friendships_status", "friendships", type_="check")
|
||||
@@ -0,0 +1,49 @@
|
||||
"""add_trade_proposals_table
|
||||
|
||||
Revision ID: cfac344e21b4
|
||||
Revises: b989aae3e37d
|
||||
Create Date: 2026-03-28 22:01:28.188084
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'cfac344e21b4'
|
||||
down_revision: Union[str, Sequence[str], None] = 'b989aae3e37d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('trade_proposals',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('proposer_id', sa.UUID(), nullable=False),
|
||||
sa.Column('recipient_id', sa.UUID(), nullable=False),
|
||||
sa.Column('offered_card_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('requested_card_id', sa.UUID(), nullable=False),
|
||||
sa.Column('status', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['proposer_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['recipient_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['requested_card_id'], ['cards.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_trade_proposals_proposer_status', 'trade_proposals', ['proposer_id', 'status'])
|
||||
op.create_index('ix_trade_proposals_recipient_status', 'trade_proposals', ['recipient_id', 'status'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_trade_proposals_proposer_status', 'trade_proposals')
|
||||
op.drop_index('ix_trade_proposals_recipient_status', 'trade_proposals')
|
||||
op.drop_table('trade_proposals')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add_fk_cascades_friendship_trade_challenge_notification
|
||||
|
||||
Revision ID: d1e2f3a4b5c6
|
||||
Revises: c1d2e3f4a5b6
|
||||
Create Date: 2026-03-29 15:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'd1e2f3a4b5c6'
|
||||
down_revision: Union[str, Sequence[str], None] = 'c1d2e3f4a5b6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# notifications
|
||||
op.drop_constraint(op.f('notifications_user_id_fkey'), 'notifications', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'notifications', 'users', ['user_id'], ['id'], ondelete='CASCADE')
|
||||
|
||||
# friendships
|
||||
op.drop_constraint(op.f('friendships_requester_id_fkey'), 'friendships', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'friendships', 'users', ['requester_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('friendships_addressee_id_fkey'), 'friendships', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'friendships', 'users', ['addressee_id'], ['id'], ondelete='CASCADE')
|
||||
|
||||
# trade_proposals
|
||||
op.drop_constraint(op.f('trade_proposals_proposer_id_fkey'), 'trade_proposals', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'trade_proposals', 'users', ['proposer_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('trade_proposals_recipient_id_fkey'), 'trade_proposals', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'trade_proposals', 'users', ['recipient_id'], ['id'], ondelete='CASCADE')
|
||||
|
||||
# game_challenges
|
||||
op.drop_constraint(op.f('game_challenges_challenger_id_fkey'), 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'game_challenges', 'users', ['challenger_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('game_challenges_challenged_id_fkey'), 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'game_challenges', 'users', ['challenged_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint(op.f('game_challenges_challenger_deck_id_fkey'), 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'game_challenges', 'decks', ['challenger_deck_id'], ['id'], ondelete='CASCADE')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# game_challenges
|
||||
op.drop_constraint(None, 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('game_challenges_challenger_deck_id_fkey'), 'game_challenges', 'decks', ['challenger_deck_id'], ['id'])
|
||||
op.drop_constraint(None, 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('game_challenges_challenged_id_fkey'), 'game_challenges', 'users', ['challenged_id'], ['id'])
|
||||
op.drop_constraint(None, 'game_challenges', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('game_challenges_challenger_id_fkey'), 'game_challenges', 'users', ['challenger_id'], ['id'])
|
||||
|
||||
# trade_proposals
|
||||
op.drop_constraint(None, 'trade_proposals', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('trade_proposals_recipient_id_fkey'), 'trade_proposals', 'users', ['recipient_id'], ['id'])
|
||||
op.drop_constraint(None, 'trade_proposals', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('trade_proposals_proposer_id_fkey'), 'trade_proposals', 'users', ['proposer_id'], ['id'])
|
||||
|
||||
# friendships
|
||||
op.drop_constraint(None, 'friendships', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('friendships_addressee_id_fkey'), 'friendships', 'users', ['addressee_id'], ['id'])
|
||||
op.drop_constraint(None, 'friendships', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('friendships_requester_id_fkey'), 'friendships', 'users', ['requester_id'], ['id'])
|
||||
|
||||
# notifications
|
||||
op.drop_constraint(None, 'notifications', type_='foreignkey')
|
||||
op.create_foreign_key(op.f('notifications_user_id_fkey'), 'notifications', 'users', ['user_id'], ['id'])
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add_notifications_table
|
||||
|
||||
Revision ID: de721927ff59
|
||||
Revises: 98e23cab7057
|
||||
Create Date: 2026-03-28 18:51:11.848830
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'de721927ff59'
|
||||
down_revision: Union[str, Sequence[str], None] = '98e23cab7057'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('notifications',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('type', sa.String(), nullable=False),
|
||||
sa.Column('payload', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('read', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('notifications')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add is_favorite and willing_to_trade to cards
|
||||
|
||||
Revision ID: e70b992e5d95
|
||||
Revises: a9f2d4e7c301
|
||||
Create Date: 2026-03-27 17:41:30.462441
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e70b992e5d95'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a9f2d4e7c301'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('cards', sa.Column('is_favorite', sa.Boolean(), nullable=False, server_default=sa.false()))
|
||||
op.add_column('cards', sa.Column('willing_to_trade', sa.Boolean(), nullable=False, server_default=sa.false()))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('cards', 'willing_to_trade')
|
||||
op.drop_column('cards', 'is_favorite')
|
||||
# ### end Alembic commands ###
|
||||
40
backend/alembic/versions/f4e8a1b2c3d9_add_fk_indices.py
Normal file
40
backend/alembic/versions/f4e8a1b2c3d9_add_fk_indices.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""add fk indices
|
||||
|
||||
Revision ID: f4e8a1b2c3d9
|
||||
Revises: 29da7c818b01
|
||||
Create Date: 2026-03-29 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f4e8a1b2c3d9'
|
||||
down_revision: Union[str, None] = '29da7c818b01'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add indices on FK columns that are missing them."""
|
||||
op.create_index('ix_cards_user_id', 'cards', ['user_id'])
|
||||
op.create_index('ix_decks_user_id', 'decks', ['user_id'])
|
||||
op.create_index('ix_notifications_user_id', 'notifications', ['user_id'])
|
||||
op.create_index('ix_friendships_requester_id', 'friendships', ['requester_id'])
|
||||
op.create_index('ix_friendships_addressee_id', 'friendships', ['addressee_id'])
|
||||
# Composite indices mirror the trade_proposals pattern: filter by owner + status together
|
||||
op.create_index('ix_game_challenges_challenger_status', 'game_challenges', ['challenger_id', 'status'])
|
||||
op.create_index('ix_game_challenges_challenged_status', 'game_challenges', ['challenged_id', 'status'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop FK indices."""
|
||||
op.drop_index('ix_game_challenges_challenged_status', table_name='game_challenges')
|
||||
op.drop_index('ix_game_challenges_challenger_status', table_name='game_challenges')
|
||||
op.drop_index('ix_friendships_addressee_id', table_name='friendships')
|
||||
op.drop_index('ix_friendships_requester_id', table_name='friendships')
|
||||
op.drop_index('ix_notifications_user_id', table_name='notifications')
|
||||
op.drop_index('ix_decks_user_id', table_name='decks')
|
||||
op.drop_index('ix_cards_user_id', table_name='cards')
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add trade_response to notification type check constraint
|
||||
|
||||
Revision ID: f657d45be3ae
|
||||
Revises: 4603709eb82d
|
||||
Create Date: 2026-03-30 12:10:21.112505
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f657d45be3ae'
|
||||
down_revision: Union[str, Sequence[str], None] = '4603709eb82d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint("ck_notifications_type", "notifications", type_="check")
|
||||
op.create_check_constraint("ck_notifications_type", "notifications", "type IN ('friend_request', 'trade_offer', 'trade_response', 'game_challenge')")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("ck_notifications_type", "notifications", type_="check")
|
||||
op.create_check_constraint("ck_notifications_type", "notifications", "type IN ('friend_request', 'trade_offer', 'game_challenge')")
|
||||
0
backend/core/__init__.py
Normal file
0
backend/core/__init__.py
Normal file
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from config import JWT_SECRET_KEY
|
||||
from core.config import JWT_SECRET_KEY
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
@@ -40,6 +41,8 @@ def decode_refresh_token(token: str) -> str | None:
|
||||
def decode_access_token(token: str) -> str | None:
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("type") != "access":
|
||||
return None
|
||||
return payload.get("sub")
|
||||
except JWTError:
|
||||
return None
|
||||
@@ -1,9 +1,14 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
from config import DATABASE_URL
|
||||
from core.config import DATABASE_URL
|
||||
|
||||
engine = create_engine(DATABASE_URL)
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_timeout=30,
|
||||
)
|
||||
SessionLocal = sessionmaker(bind=engine)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
43
backend/core/dependencies.py
Normal file
43
backend/core/dependencies.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.auth import decode_access_token
|
||||
from core.database import get_db
|
||||
from core.models import User as UserModel
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||
|
||||
# Shared rate limiter — registered on app.state in main.py
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserModel:
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
# Throttle to one write per 5 minutes so every authenticated request doesn't hammer the DB
|
||||
now = datetime.now()
|
||||
if not user.last_active_at or (now - user.last_active_at).total_seconds() > 300:
|
||||
user.last_active_at = now
|
||||
db.commit()
|
||||
return user
|
||||
|
||||
|
||||
# Per-user key for rate limiting authenticated endpoints — prevents shared IPs (NAT/VPN)
|
||||
# from having their limits pooled. Falls back to remote IP for unauthenticated requests.
|
||||
def get_user_id_from_request(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
user_id = decode_access_token(auth[7:])
|
||||
if user_id:
|
||||
return f"user:{user_id}"
|
||||
return get_remote_address(request)
|
||||
189
backend/core/models.py
Normal file
189
backend/core/models.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, ForeignKey, DateTime, Text, Boolean, UniqueConstraint, CheckConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from core.database import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
boosters: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
boosters_countdown: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
wins: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
losses: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
shards: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
last_refresh_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
reset_token: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
reset_token_expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
email_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
email_verification_token: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
email_verification_token_expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
trade_wishlist: Mapped[str | None] = mapped_column(Text, nullable=True, default="")
|
||||
last_active_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
cards: Mapped[list["Card"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
decks: Mapped[list["Deck"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
notifications: Mapped[list["Notification"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
friendships_sent: Mapped[list["Friendship"]] = relationship(
|
||||
foreign_keys="Friendship.requester_id", back_populates="requester", cascade="all, delete-orphan"
|
||||
)
|
||||
friendships_received: Mapped[list["Friendship"]] = relationship(
|
||||
foreign_keys="Friendship.addressee_id", back_populates="addressee", cascade="all, delete-orphan"
|
||||
)
|
||||
proposals_sent: Mapped[list["TradeProposal"]] = relationship(
|
||||
foreign_keys="TradeProposal.proposer_id", back_populates="proposer", cascade="all, delete-orphan"
|
||||
)
|
||||
proposals_received: Mapped[list["TradeProposal"]] = relationship(
|
||||
foreign_keys="TradeProposal.recipient_id", back_populates="recipient", cascade="all, delete-orphan"
|
||||
)
|
||||
challenges_sent: Mapped[list["GameChallenge"]] = relationship(
|
||||
foreign_keys="GameChallenge.challenger_id", back_populates="challenger", cascade="all, delete-orphan"
|
||||
)
|
||||
challenges_received: Mapped[list["GameChallenge"]] = relationship(
|
||||
foreign_keys="GameChallenge.challenged_id", back_populates="challenged", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Card(Base):
|
||||
__tablename__ = "cards"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
image_link: Mapped[str] = mapped_column(String, nullable=True)
|
||||
card_rarity: Mapped[str] = mapped_column(String, nullable=False)
|
||||
card_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
attack: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
defense: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
cost: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
generated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
received_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
times_played: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
reported: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_used: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
is_favorite: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
willing_to_trade: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
user: Mapped["User | None"] = relationship(back_populates="cards")
|
||||
deck_cards: Mapped[list["DeckCard"]] = relationship(back_populates="card", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Deck(Base):
|
||||
__tablename__ = "decks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
times_played: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
wins: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
losses: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="decks")
|
||||
deck_cards: Mapped[list["DeckCard"]] = relationship(back_populates="deck", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notifications"
|
||||
__table_args__ = (
|
||||
CheckConstraint("type IN ('friend_request', 'trade_offer', 'trade_response', 'game_challenge')", name="ck_notifications_type"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
# type is one of: friend_request, trade_offer, trade_response, game_challenge
|
||||
type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
payload: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)
|
||||
read: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="notifications")
|
||||
|
||||
|
||||
class Friendship(Base):
|
||||
__tablename__ = "friendships"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("requester_id", "addressee_id", name="uq_friendship_requester_addressee"),
|
||||
CheckConstraint("status IN ('pending', 'accepted', 'declined')", name="ck_friendships_status"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
requester_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
addressee_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
# status: pending / accepted / declined
|
||||
status: Mapped[str] = mapped_column(String, nullable=False, default="pending")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
|
||||
requester: Mapped["User"] = relationship(foreign_keys=[requester_id], back_populates="friendships_sent")
|
||||
addressee: Mapped["User"] = relationship(foreign_keys=[addressee_id], back_populates="friendships_received")
|
||||
|
||||
|
||||
class TradeProposal(Base):
|
||||
__tablename__ = "trade_proposals"
|
||||
__table_args__ = (
|
||||
CheckConstraint("status IN ('pending', 'accepted', 'declined', 'expired')", name="ck_trade_proposals_status"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
proposer_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
recipient_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
# Both sides stored as JSONB lists of UUID strings so either party can offer 0 or more cards,
|
||||
# mirroring the flexibility of the real-time trade system
|
||||
offered_card_ids: Mapped[list] = mapped_column(JSONB, nullable=False, default=list)
|
||||
requested_card_ids: Mapped[list] = mapped_column(JSONB, nullable=False, default=list)
|
||||
# status: pending / accepted / declined / expired
|
||||
status: Mapped[str] = mapped_column(String, nullable=False, default="pending")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
proposer: Mapped["User"] = relationship(foreign_keys=[proposer_id])
|
||||
recipient: Mapped["User"] = relationship(foreign_keys=[recipient_id])
|
||||
|
||||
|
||||
class GameChallenge(Base):
|
||||
__tablename__ = "game_challenges"
|
||||
__table_args__ = (
|
||||
CheckConstraint("status IN ('pending', 'accepted', 'declined', 'expired')", name="ck_game_challenges_status"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
challenger_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
challenged_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
challenger_deck_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("decks.id", ondelete="CASCADE"), nullable=False)
|
||||
# status: pending / accepted / declined / expired
|
||||
status: Mapped[str] = mapped_column(String, nullable=False, default="pending")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
challenger: Mapped["User"] = relationship(foreign_keys=[challenger_id], back_populates="challenges_sent")
|
||||
challenged: Mapped["User"] = relationship(foreign_keys=[challenged_id], back_populates="challenges_received")
|
||||
challenger_deck: Mapped["Deck"] = relationship(foreign_keys=[challenger_deck_id])
|
||||
|
||||
|
||||
class DeckCard(Base):
|
||||
__tablename__ = "deck_cards"
|
||||
|
||||
deck_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("decks.id", ondelete="CASCADE"), primary_key=True)
|
||||
card_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("cards.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
deck: Mapped["Deck"] = relationship(back_populates="deck_cards")
|
||||
card: Mapped["Card"] = relationship(back_populates="deck_cards")
|
||||
|
||||
|
||||
class ProcessedWebhookEvent(Base):
|
||||
__tablename__ = "processed_webhook_events"
|
||||
|
||||
# stripe_event_id is the primary key — acts as unique constraint to prevent duplicate processing
|
||||
stripe_event_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
processed_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, nullable=False)
|
||||
@@ -1,89 +0,0 @@
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from card import _get_cards_async
|
||||
from models import Card as CardModel
|
||||
from models import User as UserModel
|
||||
from database import SessionLocal
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
POOL_MINIMUM = 1000
|
||||
POOL_TARGET = 2000
|
||||
POOL_BATCH_SIZE = 10
|
||||
POOL_SLEEP = 4.0
|
||||
|
||||
pool_filling = False
|
||||
|
||||
async def fill_card_pool():
|
||||
global pool_filling
|
||||
if pool_filling:
|
||||
logger.info("Pool fill already in progress, skipping")
|
||||
return
|
||||
|
||||
db: Session = SessionLocal()
|
||||
while True:
|
||||
try:
|
||||
unassigned = db.query(CardModel).filter(CardModel.user_id == None, CardModel.ai_used == False).count()
|
||||
logger.info(f"Card pool has {unassigned} unassigned cards")
|
||||
if unassigned >= POOL_MINIMUM:
|
||||
logger.info("Pool sufficiently stocked, skipping fill")
|
||||
return
|
||||
|
||||
pool_filling = True
|
||||
needed = POOL_TARGET - unassigned
|
||||
logger.info(f"Filling pool with {needed} cards")
|
||||
|
||||
fetched = 0
|
||||
while fetched < needed:
|
||||
batch_size = min(POOL_BATCH_SIZE, needed - fetched)
|
||||
cards = await _get_cards_async(batch_size)
|
||||
|
||||
for card in cards:
|
||||
db.add(CardModel(
|
||||
name=card.name,
|
||||
image_link=card.image_link,
|
||||
card_rarity=card.card_rarity.name,
|
||||
card_type=card.card_type.name,
|
||||
text=card.text,
|
||||
attack=card.attack,
|
||||
defense=card.defense,
|
||||
cost=card.cost,
|
||||
user_id=None,
|
||||
))
|
||||
db.commit()
|
||||
fetched += batch_size
|
||||
logger.info(f"Pool fill progress: {fetched}/{needed}")
|
||||
await asyncio.sleep(POOL_SLEEP)
|
||||
|
||||
finally:
|
||||
pool_filling = False
|
||||
db.close()
|
||||
|
||||
BOOSTER_MAX = 5
|
||||
BOOSTER_COOLDOWN_HOURS = 5
|
||||
|
||||
def check_boosters(user: UserModel, db: Session) -> tuple[int, datetime|None]:
|
||||
if user.boosters_countdown is None:
|
||||
if user.boosters < BOOSTER_MAX:
|
||||
user.boosters = BOOSTER_MAX
|
||||
db.commit()
|
||||
return (user.boosters, user.boosters_countdown)
|
||||
|
||||
now = datetime.now()
|
||||
countdown = user.boosters_countdown
|
||||
|
||||
while user.boosters < BOOSTER_MAX:
|
||||
next_tick = countdown + timedelta(hours=BOOSTER_COOLDOWN_HOURS)
|
||||
if now >= next_tick:
|
||||
user.boosters += 1
|
||||
countdown = next_tick
|
||||
else:
|
||||
break
|
||||
|
||||
user.boosters_countdown = countdown if user.boosters < BOOSTER_MAX else None
|
||||
db.commit()
|
||||
return (user.boosters, user.boosters_countdown)
|
||||
0
backend/game/__init__.py
Normal file
0
backend/game/__init__.py
Normal file
@@ -6,7 +6,7 @@ from urllib.parse import quote
|
||||
from datetime import datetime, timedelta
|
||||
from time import sleep
|
||||
|
||||
from config import WIKIRANK_USER_AGENT
|
||||
from core.config import WIKIRANK_USER_AGENT
|
||||
HEADERS = {"User-Agent": WIKIRANK_USER_AGENT}
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
@@ -33,7 +33,7 @@ class CardRarity(Enum):
|
||||
|
||||
class Card(NamedTuple):
|
||||
name: str
|
||||
created_at: datetime
|
||||
generated_at: datetime
|
||||
image_link: str
|
||||
card_rarity: CardRarity
|
||||
card_type: CardType
|
||||
@@ -81,7 +81,7 @@ class Card(NamedTuple):
|
||||
return_string += "┃"+f"{l:{' '}<50}"+"┃\n"
|
||||
return_string += "┠"+"─"*50+"┨\n"
|
||||
|
||||
date_text = str(self.created_at.date())
|
||||
date_text = str(self.generated_at.date())
|
||||
stats = f"{self.attack}/{self.defense}"
|
||||
spaces = 50 - (len(date_text) + len(stats))
|
||||
return_string += "┃"+date_text + " "*spaces + stats + "┃\n"
|
||||
@@ -123,6 +123,7 @@ WIKIDATA_INSTANCE_TYPE_MAP = {
|
||||
"Q1446621": CardType.artwork, # recital
|
||||
"Q1868552": CardType.artwork, # local newspaper
|
||||
"Q3244175": CardType.artwork, # tabletop game
|
||||
"Q2031291": CardType.artwork, # musical release
|
||||
"Q63952888": CardType.artwork, # anime television series
|
||||
"Q47461344": CardType.artwork, # written work
|
||||
"Q71631512": CardType.artwork, # tabletop role-playing game supplement
|
||||
@@ -167,6 +168,7 @@ WIKIDATA_INSTANCE_TYPE_MAP = {
|
||||
|
||||
"Q198": CardType.event, # war
|
||||
"Q8465": CardType.event, # civil war
|
||||
"Q844482": CardType.event, # killing
|
||||
"Q141022": CardType.event, # eclipse
|
||||
"Q103495": CardType.event, # world war
|
||||
"Q350604": CardType.event, # armed conflict
|
||||
@@ -180,7 +182,7 @@ WIKIDATA_INSTANCE_TYPE_MAP = {
|
||||
"Q1361229": CardType.event, # conquest
|
||||
"Q2223653": CardType.event, # terrorist attack
|
||||
"Q2672648": CardType.event, # social conflict
|
||||
"Q2627975": CardType.event, # ceremony
|
||||
"Q2627975": CardType.event, # ceremony"
|
||||
"Q16510064": CardType.event, # sporting event
|
||||
"Q10688145": CardType.event, # season
|
||||
"Q13418847": CardType.event, # historical event
|
||||
@@ -275,6 +277,7 @@ WIKIDATA_INSTANCE_TYPE_MAP = {
|
||||
"Q1428357": CardType.vehicle, # submarine class
|
||||
"Q1499623": CardType.vehicle, # destroyer escort
|
||||
"Q4818021": CardType.vehicle, # attack submarine
|
||||
"Q45296117": CardType.vehicle, # aircraft type
|
||||
"Q15141321": CardType.vehicle, # train service
|
||||
"Q19832486": CardType.vehicle, # locomotive class
|
||||
"Q23866334": CardType.vehicle, # motorcycle model
|
||||
@@ -544,7 +547,7 @@ async def _get_card_async(client: httpx.AsyncClient, page_title: str|None = None
|
||||
|
||||
return Card(
|
||||
name=summary["title"],
|
||||
created_at=datetime.now(),
|
||||
generated_at=datetime.now(),
|
||||
image_link=summary.get("thumbnail", {}).get("source", ""),
|
||||
card_rarity=rarity,
|
||||
card_type=card_type,
|
||||
@@ -1,20 +1,21 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import random
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from game import (
|
||||
from game.rules import (
|
||||
GameState, CardInstance, PlayerState, action_play_card, action_sacrifice,
|
||||
action_end_turn, create_game, CombatEvent, GameResult, BOARD_SIZE
|
||||
)
|
||||
from models import Card as CardModel, Deck as DeckModel, DeckCard as DeckCardModel, User as UserModel
|
||||
from card import compute_deck_type
|
||||
from ai import AI_USER_ID, run_ai_turn, get_random_personality, choose_cards
|
||||
from core.models import Card as CardModel, Deck as DeckModel, DeckCard as DeckCardModel, User as UserModel
|
||||
from game.card import compute_deck_type
|
||||
from ai.engine import AI_USER_ID, run_ai_turn, get_random_personality, choose_cards
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
@@ -90,7 +91,9 @@ def serialize_card(card: CardInstance|None) -> dict | None:
|
||||
"card_type": card.card_type,
|
||||
"card_rarity": card.card_rarity,
|
||||
"image_link": card.image_link,
|
||||
"text": card.text
|
||||
"text": card.text,
|
||||
"is_favorite": card.is_favorite,
|
||||
"willing_to_trade": card.willing_to_trade,
|
||||
}
|
||||
|
||||
def serialize_player(player: PlayerState, hide_hand=False) -> dict:
|
||||
@@ -150,8 +153,8 @@ async def broadcast_state(game_id: str):
|
||||
"type": "state",
|
||||
"state": serialize_state(state, user_id),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
if state.active_player_id == AI_USER_ID and not state.result:
|
||||
asyncio.create_task(run_ai_turn(game_id))
|
||||
@@ -221,6 +224,33 @@ async def try_match(db: Session):
|
||||
await broadcast_state(state.game_id)
|
||||
|
||||
|
||||
## Direct challenge game creation (no WebSocket needed at creation time)
|
||||
|
||||
def create_challenge_game(
|
||||
challenger_id: str, challenger_deck_id: str,
|
||||
challenged_id: str, challenged_deck_id: str,
|
||||
db: Session
|
||||
) -> str:
|
||||
challenger = db.query(UserModel).filter(UserModel.id == uuid.UUID(challenger_id)).first()
|
||||
challenged = db.query(UserModel).filter(UserModel.id == uuid.UUID(challenged_id)).first()
|
||||
p1_cards = load_deck_cards(challenger_deck_id, challenger_id, db)
|
||||
p2_cards = load_deck_cards(challenged_deck_id, challenged_id, db)
|
||||
if not p1_cards or not p2_cards or not challenger or not challenged:
|
||||
raise ValueError("Could not load decks or players")
|
||||
p1_deck_type = compute_deck_type(p1_cards)
|
||||
p2_deck_type = compute_deck_type(p2_cards)
|
||||
state = create_game(
|
||||
challenger_id, challenger.username, p1_deck_type or "", p1_cards,
|
||||
challenged_id, challenged.username, p2_deck_type or "", p2_cards,
|
||||
)
|
||||
active_games[state.game_id] = state
|
||||
# Initialize with no websockets; players connect via /ws/game/{game_id} after redirect
|
||||
connections[state.game_id] = {challenger_id: None, challenged_id: None}
|
||||
active_deck_ids[challenger_id] = challenger_deck_id
|
||||
active_deck_ids[challenged_id] = challenged_deck_id
|
||||
return state.game_id
|
||||
|
||||
|
||||
## Action handler
|
||||
|
||||
async def handle_action(game_id: str, user_id: str, message: dict, db: Session):
|
||||
@@ -255,7 +285,7 @@ async def handle_action(game_id: str, user_id: str, message: dict, db: Session):
|
||||
if card:
|
||||
card.times_played += 1
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
except (SQLAlchemyError, ValueError) as e:
|
||||
logger.warning(f"Failed to increment times_played for card {card_instance.card_id}: {e}")
|
||||
db.rollback()
|
||||
elif action == "sacrifice":
|
||||
@@ -275,8 +305,8 @@ async def handle_action(game_id: str, user_id: str, message: dict, db: Session):
|
||||
"type": "sacrifice_animation",
|
||||
"instance_id": card.instance_id,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
await asyncio.sleep(0.65)
|
||||
err = action_sacrifice(state, slot)
|
||||
elif action == "end_turn":
|
||||
@@ -325,7 +355,7 @@ async def handle_disconnect(game_id: str, user_id: str):
|
||||
)
|
||||
state.phase = "end"
|
||||
|
||||
from database import SessionLocal
|
||||
from core.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
record_game_result(state, db)
|
||||
@@ -340,8 +370,8 @@ async def handle_disconnect(game_id: str, user_id: str):
|
||||
"type": "state",
|
||||
"state": serialize_state(state, winner_id),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
active_deck_ids.pop(user_id, None)
|
||||
active_deck_ids.pop(winner_id, None)
|
||||
@@ -1,10 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import random
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from models import Card as CardModel
|
||||
from core.models import Card as CardModel
|
||||
|
||||
STARTING_LIFE = 1000
|
||||
MAX_ENERGY_CAP = 6
|
||||
@@ -24,6 +24,8 @@ class CardInstance:
|
||||
card_rarity: str
|
||||
image_link: str
|
||||
text: str
|
||||
is_favorite: bool = False
|
||||
willing_to_trade: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_db_card(cls, card: CardModel) -> "CardInstance":
|
||||
@@ -38,7 +40,9 @@ class CardInstance:
|
||||
card_type=card.card_type,
|
||||
card_rarity=card.card_rarity,
|
||||
image_link=card.image_link or "",
|
||||
text=card.text
|
||||
text=card.text,
|
||||
is_favorite=card.is_favorite,
|
||||
willing_to_trade=card.willing_to_trade,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -8,15 +8,17 @@ Example:
|
||||
python give_card.py nikolaj "Marie Curie"
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from database import SessionLocal
|
||||
from models import User as UserModel, Card as CardModel
|
||||
from card import _get_specific_card_async
|
||||
import uuid
|
||||
from game.card import _get_specific_card_async
|
||||
from core.database import SessionLocal
|
||||
from core.models import User as UserModel, Card as CardModel
|
||||
|
||||
|
||||
async def main(username: str, page_title: str) -> None:
|
||||
@@ -44,6 +46,7 @@ async def main(username: str, page_title: str) -> None:
|
||||
attack=card.attack,
|
||||
defense=card.defense,
|
||||
cost=card.cost,
|
||||
received_at=datetime.now(),
|
||||
)
|
||||
db.add(db_card)
|
||||
db.commit()
|
||||
|
||||
844
backend/main.py
844
backend/main.py
@@ -1,846 +1,54 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast, Callable
|
||||
import secrets
|
||||
from typing import Callable, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from fastapi import FastAPI, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
from database import get_db
|
||||
from database_functions import fill_card_pool, check_boosters, BOOSTER_MAX
|
||||
from models import Card as CardModel
|
||||
from models import User as UserModel
|
||||
from models import Deck as DeckModel
|
||||
from models import DeckCard as DeckCardModel
|
||||
from auth import (
|
||||
hash_password, verify_password, create_access_token, create_refresh_token,
|
||||
decode_access_token, decode_refresh_token
|
||||
)
|
||||
from game_manager import (
|
||||
queue, queue_lock, QueueEntry, try_match, handle_action, connections, active_games,
|
||||
serialize_state, handle_disconnect, handle_timeout_claim, load_deck_cards, create_solo_game
|
||||
)
|
||||
from trade_manager import (
|
||||
trade_queue, trade_queue_lock, TradeQueueEntry, try_trade_match,
|
||||
handle_trade_action, active_trades, handle_trade_disconnect,
|
||||
serialize_trade,
|
||||
)
|
||||
from card import compute_deck_type, _get_specific_card_async
|
||||
from email_utils import send_password_reset_email, send_verification_email
|
||||
from config import CORS_ORIGINS, STRIPE_SECRET_KEY, STRIPE_PUBLISHABLE_KEY, STRIPE_WEBHOOK_SECRET, FRONTEND_URL
|
||||
import stripe
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
|
||||
from core.config import CORS_ORIGINS, STRIPE_SECRET_KEY
|
||||
from core.dependencies import limiter
|
||||
from services.database_functions import fill_card_pool, run_cleanup_loop
|
||||
|
||||
from routers import auth, cards, decks, games, health, notifications, profile, friends, store, trades
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
# Auth
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
password: str
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserModel:
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
return user
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
class ResetPasswordWithTokenRequest(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
asyncio.create_task(fill_card_pool())
|
||||
asyncio.create_task(run_cleanup_loop())
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Rate limiting
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, cast(Callable, _rate_limit_exceeded_handler))
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=CORS_ORIGINS, # SvelteKit's default dev port
|
||||
allow_origins=CORS_ORIGINS,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
try:
|
||||
from disposable_email_domains import blocklist as _disposable_blocklist
|
||||
except ImportError:
|
||||
_disposable_blocklist: set[str] = set()
|
||||
|
||||
def validate_register(username: str, email: str, password: str) -> str | None:
|
||||
if not username.strip():
|
||||
return "Username is required"
|
||||
if len(username) > 16:
|
||||
return "Username must be 16 characters or fewer"
|
||||
if not re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", email):
|
||||
return "Please enter a valid email"
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain in _disposable_blocklist:
|
||||
return "Disposable email addresses are not allowed"
|
||||
if len(password) < 8:
|
||||
return "Password must be at least 8 characters"
|
||||
if len(password) > 256:
|
||||
return "Password must be 256 characters or fewer"
|
||||
return None
|
||||
|
||||
@app.post("/register")
|
||||
def register(req: RegisterRequest, db: Session = Depends(get_db)):
|
||||
err = validate_register(req.username, req.email, req.password)
|
||||
if err:
|
||||
raise HTTPException(status_code=400, detail=err)
|
||||
if db.query(UserModel).filter(UserModel.username == req.username).first():
|
||||
raise HTTPException(status_code=400, detail="Username already taken")
|
||||
if db.query(UserModel).filter(UserModel.email == req.email).first():
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
user = UserModel(
|
||||
id=uuid.uuid4(),
|
||||
username=req.username,
|
||||
email=req.email,
|
||||
password_hash=hash_password(req.password),
|
||||
email_verified=False,
|
||||
email_verification_token=verification_token,
|
||||
email_verification_token_expires_at=datetime.now() + timedelta(hours=24),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
try:
|
||||
send_verification_email(req.email, req.username, verification_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email: {e}")
|
||||
return {"message": "Account created. Please check your email to verify your account."}
|
||||
|
||||
@app.post("/login")
|
||||
def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.username == form.username).first()
|
||||
if not user or not verify_password(form.password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Invalid username or password")
|
||||
return {
|
||||
"access_token": create_access_token(str(user.id)),
|
||||
"refresh_token": create_refresh_token(str(user.id)),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@app.get("/boosters")
|
||||
def get_boosters(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
count, countdown = check_boosters(user, db)
|
||||
return {"count": count, "countdown": countdown, "email_verified": user.email_verified}
|
||||
|
||||
@app.get("/cards")
|
||||
def get_cards(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
cards = db.query(CardModel).filter(CardModel.user_id == user.id).all()
|
||||
return [
|
||||
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type}
|
||||
for card in cards
|
||||
]
|
||||
|
||||
@app.get("/cards/in-decks")
|
||||
def get_cards_in_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck_ids = [d.id for d in db.query(DeckModel).filter(DeckModel.user_id == user.id, DeckModel.deleted == False).all()]
|
||||
if not deck_ids:
|
||||
return []
|
||||
card_ids = db.query(DeckCardModel.card_id).filter(DeckCardModel.deck_id.in_(deck_ids)).distinct().all()
|
||||
return [str(row.card_id) for row in card_ids]
|
||||
|
||||
@app.post("/open_pack")
|
||||
@limiter.limit("10/minute")
|
||||
async def open_pack(request: Request, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not user.email_verified:
|
||||
raise HTTPException(status_code=403, detail="You must verify your email before opening packs")
|
||||
|
||||
check_boosters(user, db)
|
||||
|
||||
if user.boosters == 0:
|
||||
raise HTTPException(status_code=400, detail="No booster packs available")
|
||||
|
||||
cards = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == None, CardModel.ai_used == False)
|
||||
.limit(5)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(cards) < 5:
|
||||
asyncio.create_task(fill_card_pool())
|
||||
raise HTTPException(status_code=503, detail="Card pool is low, please try again shortly")
|
||||
|
||||
for card in cards:
|
||||
card.user_id = user.id
|
||||
|
||||
was_full = user.boosters == BOOSTER_MAX
|
||||
user.boosters -= 1
|
||||
if was_full:
|
||||
user.boosters_countdown = datetime.now()
|
||||
|
||||
db.commit()
|
||||
|
||||
asyncio.create_task(fill_card_pool())
|
||||
|
||||
return [
|
||||
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type}
|
||||
for card in cards
|
||||
]
|
||||
|
||||
@app.get("/decks")
|
||||
def get_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
decks = db.query(DeckModel).filter(
|
||||
DeckModel.user_id == user.id,
|
||||
DeckModel.deleted == False
|
||||
).order_by(DeckModel.created_at).all()
|
||||
result = []
|
||||
for deck in decks:
|
||||
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
||||
cards = db.query(CardModel).filter(CardModel.id.in_(card_ids)).all()
|
||||
result.append({
|
||||
"id": str(deck.id),
|
||||
"name": deck.name,
|
||||
"card_count": len(cards),
|
||||
"total_cost": sum(card.cost for card in cards),
|
||||
"times_played": deck.times_played,
|
||||
"wins": deck.wins,
|
||||
"losses": deck.losses,
|
||||
"deck_type": compute_deck_type(cards),
|
||||
})
|
||||
return result
|
||||
|
||||
@app.post("/decks")
|
||||
def create_deck(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
count = db.query(DeckModel).filter(DeckModel.user_id == user.id).count()
|
||||
deck = DeckModel(id=uuid.uuid4(), user_id=user.id, name=f"Deck #{count + 1}")
|
||||
db.add(deck)
|
||||
db.commit()
|
||||
return {"id": str(deck.id), "name": deck.name, "card_count": 0}
|
||||
|
||||
@app.patch("/decks/{deck_id}")
|
||||
def update_deck(deck_id: str, body: dict, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
if "name" in body:
|
||||
deck.name = body["name"]
|
||||
if "card_ids" in body:
|
||||
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
|
||||
for card_id in body["card_ids"]:
|
||||
db.add(DeckCardModel(deck_id=deck.id, card_id=uuid.UUID(card_id)))
|
||||
if deck.times_played > 0:
|
||||
deck.wins = 0
|
||||
deck.losses = 0
|
||||
deck.times_played = 0
|
||||
db.commit()
|
||||
return {"id": str(deck.id), "name": deck.name}
|
||||
|
||||
@app.delete("/decks/{deck_id}")
|
||||
def delete_deck(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
if deck.times_played > 0:
|
||||
deck.deleted = True
|
||||
else:
|
||||
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
|
||||
db.delete(deck)
|
||||
db.commit()
|
||||
return {"message": "Deleted"}
|
||||
|
||||
@app.get("/decks/{deck_id}/cards")
|
||||
def get_deck_cards(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
deck_cards = db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()
|
||||
return [str(dc.card_id) for dc in deck_cards]
|
||||
|
||||
@app.websocket("/ws/queue")
|
||||
async def queue_endpoint(websocket: WebSocket, deck_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
deck = db.query(DeckModel).filter(
|
||||
DeckModel.id == uuid.UUID(deck_id),
|
||||
DeckModel.user_id == uuid.UUID(user_id)
|
||||
).first()
|
||||
|
||||
if not deck:
|
||||
await websocket.send_json({"type": "error", "message": "Deck not found"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
||||
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
||||
if total_cost == 0 or total_cost > 50:
|
||||
await websocket.send_json({"type": "error", "message": "Deck total cost must be between 1 and 50"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
entry = QueueEntry(user_id=user_id, deck_id=deck_id, websocket=websocket)
|
||||
|
||||
async with queue_lock:
|
||||
queue.append(entry)
|
||||
|
||||
await websocket.send_json({"type": "queued"})
|
||||
await try_match(db)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keeping socket alive
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
async with queue_lock:
|
||||
queue[:] = [e for e in queue if e.user_id != user_id]
|
||||
|
||||
|
||||
@app.websocket("/ws/game/{game_id}")
|
||||
async def game_endpoint(websocket: WebSocket, game_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
if game_id not in active_games:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Register this connection (handles reconnects)
|
||||
connections[game_id][user_id] = websocket
|
||||
|
||||
# Send current state immediately on connect
|
||||
await websocket.send_json({
|
||||
"type": "state",
|
||||
"state": serialize_state(active_games[game_id], user_id),
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
await handle_action(game_id, user_id, data, db)
|
||||
except WebSocketDisconnect:
|
||||
if game_id in connections:
|
||||
connections[game_id].pop(user_id, None)
|
||||
asyncio.create_task(handle_disconnect(game_id, user_id))
|
||||
|
||||
@app.websocket("/ws/trade/queue")
|
||||
async def trade_queue_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if not user.email_verified:
|
||||
await websocket.send_json({"type": "error", "message": "You must verify your email before trading."})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
entry = TradeQueueEntry(user_id=user_id, username=user.username, websocket=websocket)
|
||||
|
||||
async with trade_queue_lock:
|
||||
trade_queue.append(entry)
|
||||
|
||||
await websocket.send_json({"type": "queued"})
|
||||
await try_trade_match()
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
async with trade_queue_lock:
|
||||
trade_queue[:] = [e for e in trade_queue if e.user_id != user_id]
|
||||
|
||||
|
||||
@app.websocket("/ws/trade/{trade_id}")
|
||||
async def trade_endpoint(websocket: WebSocket, trade_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
session = active_trades.get(trade_id)
|
||||
if not session or user_id not in session.offers:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
session.connections[user_id] = websocket
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "state",
|
||||
"state": serialize_trade(session, user_id),
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
await handle_trade_action(trade_id, user_id, data, db)
|
||||
except WebSocketDisconnect:
|
||||
session.connections.pop(user_id, None)
|
||||
asyncio.create_task(handle_trade_disconnect(trade_id, user_id))
|
||||
|
||||
|
||||
@app.get("/profile")
|
||||
def get_profile(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
total_games = user.wins + user.losses
|
||||
|
||||
most_played_deck = (
|
||||
db.query(DeckModel)
|
||||
.filter(DeckModel.user_id == user.id, DeckModel.times_played > 0)
|
||||
.order_by(DeckModel.times_played.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
most_played_card = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == user.id, CardModel.times_played > 0)
|
||||
.order_by(CardModel.times_played.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"email_verified": user.email_verified,
|
||||
"created_at": user.created_at,
|
||||
"wins": user.wins,
|
||||
"losses": user.losses,
|
||||
"shards": user.shards,
|
||||
"win_rate": round((user.wins / total_games) * 100) if total_games > 0 else None,
|
||||
"most_played_deck": {
|
||||
"name": most_played_deck.name,
|
||||
"times_played": most_played_deck.times_played,
|
||||
} if most_played_deck else None,
|
||||
"most_played_card": {
|
||||
"name": most_played_card.name,
|
||||
"times_played": most_played_card.times_played,
|
||||
"card_type": most_played_card.card_type,
|
||||
"card_rarity": most_played_card.card_rarity,
|
||||
"image_link": most_played_card.image_link,
|
||||
} if most_played_card else None,
|
||||
}
|
||||
|
||||
class ShatterRequest(BaseModel):
|
||||
card_ids: list[str]
|
||||
|
||||
@app.post("/shards/shatter")
|
||||
def shatter_cards(req: ShatterRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not req.card_ids:
|
||||
raise HTTPException(status_code=400, detail="No cards selected")
|
||||
try:
|
||||
parsed_ids = [uuid.UUID(cid) for cid in req.card_ids]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid card IDs")
|
||||
|
||||
cards = db.query(CardModel).filter(
|
||||
CardModel.id.in_(parsed_ids),
|
||||
CardModel.user_id == user.id,
|
||||
).all()
|
||||
|
||||
if len(cards) != len(parsed_ids):
|
||||
raise HTTPException(status_code=400, detail="Some cards are not in your collection")
|
||||
|
||||
total = sum(c.cost for c in cards)
|
||||
|
||||
for card in cards:
|
||||
db.query(DeckCardModel).filter(DeckCardModel.card_id == card.id).delete()
|
||||
db.delete(card)
|
||||
|
||||
user.shards += total
|
||||
db.commit()
|
||||
return {"shards": user.shards, "gained": total}
|
||||
|
||||
# Shard packages sold for real money.
|
||||
# price_oere is in Danish øre (1 DKK = 100 øre). Stripe minimum is 250 øre.
|
||||
SHARD_PACKAGES = {
|
||||
"s1": {"base": 100, "bonus": 0, "shards": 100, "price_oere": 1000, "price_label": "10 DKK"},
|
||||
"s2": {"base": 250, "bonus": 50, "shards": 300, "price_oere": 2500, "price_label": "25 DKK"},
|
||||
"s3": {"base": 500, "bonus": 200, "shards": 700, "price_oere": 5000, "price_label": "50 DKK"},
|
||||
"s4": {"base": 1000, "bonus": 600, "shards": 1600, "price_oere": 10000, "price_label": "100 DKK"},
|
||||
"s5": {"base": 2500, "bonus": 2000, "shards": 4500, "price_oere": 25000, "price_label": "250 DKK"},
|
||||
"s6": {"base": 5000, "bonus": 5000, "shards": 10000, "price_oere": 50000, "price_label": "500 DKK"},
|
||||
}
|
||||
|
||||
class StripeCheckoutRequest(BaseModel):
|
||||
package_id: str
|
||||
|
||||
@app.post("/store/stripe/checkout")
|
||||
def create_stripe_checkout(req: StripeCheckoutRequest, user: UserModel = Depends(get_current_user)):
|
||||
package = SHARD_PACKAGES.get(req.package_id)
|
||||
if not package:
|
||||
raise HTTPException(status_code=400, detail="Invalid package")
|
||||
session = stripe.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
line_items=[{
|
||||
"price_data": {
|
||||
"currency": "dkk",
|
||||
"product_data": {"name": f"WikiTCG Shards — {package['price_label']}"},
|
||||
"unit_amount": package["price_oere"],
|
||||
},
|
||||
"quantity": 1,
|
||||
}],
|
||||
mode="payment",
|
||||
success_url=f"{FRONTEND_URL}/store?payment=success",
|
||||
cancel_url=f"{FRONTEND_URL}/store",
|
||||
metadata={"user_id": str(user.id), "shards": str(package["shards"])},
|
||||
)
|
||||
return {"url": session.url}
|
||||
|
||||
@app.post("/stripe/webhook")
|
||||
async def stripe_webhook(request: Request, db: Session = Depends(get_db)):
|
||||
payload = await request.body()
|
||||
sig = request.headers.get("stripe-signature", "")
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(payload, sig, STRIPE_WEBHOOK_SECRET)
|
||||
except stripe.error.SignatureVerificationError: # type: ignore
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
if event["type"] == "checkout.session.completed":
|
||||
data = event["data"]["object"]
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
shards = data.get("metadata", {}).get("shards")
|
||||
if user_id and shards:
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if user:
|
||||
user.shards += int(shards)
|
||||
db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/store/config")
|
||||
def store_config():
|
||||
return {
|
||||
"publishable_key": STRIPE_PUBLISHABLE_KEY,
|
||||
"shard_packages": SHARD_PACKAGES,
|
||||
}
|
||||
|
||||
STORE_PACKAGES = {
|
||||
1: 15,
|
||||
5: 65,
|
||||
10: 120,
|
||||
25: 260,
|
||||
}
|
||||
|
||||
class StoreBuyRequest(BaseModel):
|
||||
quantity: int
|
||||
|
||||
class BuySpecificCardRequest(BaseModel):
|
||||
wiki_title: str
|
||||
|
||||
SPECIFIC_CARD_COST = 1000
|
||||
|
||||
@app.post("/store/buy-specific-card")
|
||||
@limiter.limit("10/hour")
|
||||
async def buy_specific_card(request: Request, req: BuySpecificCardRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if user.shards < SPECIFIC_CARD_COST:
|
||||
raise HTTPException(status_code=400, detail="Not enough shards")
|
||||
|
||||
card = await _get_specific_card_async(req.wiki_title)
|
||||
if card is None:
|
||||
raise HTTPException(status_code=404, detail="Could not generate a card for that Wikipedia page")
|
||||
|
||||
db_card = CardModel(
|
||||
name=card.name,
|
||||
image_link=card.image_link,
|
||||
card_rarity=card.card_rarity.name,
|
||||
card_type=card.card_type.name,
|
||||
text=card.text,
|
||||
attack=card.attack,
|
||||
defense=card.defense,
|
||||
cost=card.cost,
|
||||
user_id=user.id,
|
||||
)
|
||||
db.add(db_card)
|
||||
user.shards -= SPECIFIC_CARD_COST
|
||||
db.commit()
|
||||
db.refresh(db_card)
|
||||
|
||||
return {
|
||||
**{c.name: getattr(db_card, c.name) for c in db_card.__table__.columns},
|
||||
"card_rarity": db_card.card_rarity,
|
||||
"card_type": db_card.card_type,
|
||||
"shards": user.shards,
|
||||
}
|
||||
|
||||
@app.post("/store/buy")
|
||||
def store_buy(req: StoreBuyRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
cost = STORE_PACKAGES.get(req.quantity)
|
||||
if cost is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid package")
|
||||
if user.shards < cost:
|
||||
raise HTTPException(status_code=400, detail="Not enough shards")
|
||||
user.shards -= cost
|
||||
user.boosters += req.quantity
|
||||
db.commit()
|
||||
return {"shards": user.shards, "boosters": user.boosters}
|
||||
|
||||
@app.post("/cards/{card_id}/report")
|
||||
def report_card(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
card.reported = True
|
||||
db.commit()
|
||||
return {"message": "Card reported"}
|
||||
|
||||
@app.post("/cards/{card_id}/refresh")
|
||||
@limiter.limit("5/hour")
|
||||
async def refresh_card(request: Request, card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
|
||||
if user.last_refresh_at and datetime.now() - user.last_refresh_at < timedelta(hours=2):
|
||||
remaining = (user.last_refresh_at + timedelta(hours=2)) - datetime.now()
|
||||
hours = int(remaining.total_seconds() // 3600)
|
||||
minutes = int((remaining.total_seconds() % 3600) // 60)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You can refresh again in {hours}h {minutes}m"
|
||||
)
|
||||
|
||||
new_card = await _get_specific_card_async(card.name)
|
||||
if not new_card:
|
||||
raise HTTPException(status_code=502, detail="Failed to regenerate card from Wikipedia")
|
||||
|
||||
card.image_link = new_card.image_link
|
||||
card.card_rarity = new_card.card_rarity.name
|
||||
card.card_type = new_card.card_type.name
|
||||
card.text = new_card.text
|
||||
card.attack = new_card.attack
|
||||
card.defense = new_card.defense
|
||||
card.cost = new_card.cost
|
||||
card.reported = False
|
||||
|
||||
user.last_refresh_at = datetime.now()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type,
|
||||
}
|
||||
|
||||
@app.get("/profile/refresh-status")
|
||||
def refresh_status(user: UserModel = Depends(get_current_user)):
|
||||
if not user.last_refresh_at:
|
||||
return {"can_refresh": True, "next_refresh_at": None}
|
||||
next_refresh = user.last_refresh_at + timedelta(hours=2)
|
||||
can_refresh = datetime.now() >= next_refresh
|
||||
return {
|
||||
"can_refresh": can_refresh,
|
||||
"next_refresh_at": next_refresh.isoformat() if not can_refresh else None,
|
||||
}
|
||||
|
||||
@app.post("/game/{game_id}/claim-timeout-win")
|
||||
async def claim_timeout_win(game_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
err = await handle_timeout_claim(game_id, str(user.id), db)
|
||||
if err:
|
||||
raise HTTPException(status_code=400, detail=err)
|
||||
return {"message": "Win claimed"}
|
||||
|
||||
@app.post("/game/solo")
|
||||
async def start_solo_game(deck_id: str, difficulty: int = 5, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if difficulty < 1 or difficulty > 10:
|
||||
raise HTTPException(status_code=400, detail="Difficulty must be between 1 and 10")
|
||||
|
||||
deck = db.query(DeckModel).filter(
|
||||
DeckModel.id == uuid.UUID(deck_id),
|
||||
DeckModel.user_id == user.id
|
||||
).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
|
||||
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
||||
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
||||
if total_cost == 0 or total_cost > 50:
|
||||
raise HTTPException(status_code=400, detail="Deck total cost must be between 1 and 50")
|
||||
|
||||
player_cards = load_deck_cards(deck_id, str(user.id), db)
|
||||
if player_cards is None:
|
||||
raise HTTPException(status_code=503, detail="Couldn't load deck")
|
||||
|
||||
ai_cards = db.query(CardModel).filter(
|
||||
CardModel.user_id == None,
|
||||
).order_by(func.random()).limit(500).all()
|
||||
|
||||
if len(ai_cards) == 0:
|
||||
raise HTTPException(status_code=503, detail="Not enough cards in pool for AI deck")
|
||||
|
||||
for card in ai_cards:
|
||||
card.ai_used = True
|
||||
db.commit()
|
||||
|
||||
game_id = create_solo_game(str(user.id), user.username, player_cards, ai_cards, deck_id, difficulty)
|
||||
asyncio.create_task(fill_card_pool())
|
||||
|
||||
return {"game_id": game_id}
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@app.post("/auth/reset-password")
|
||||
def reset_password(req: ResetPasswordRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not verify_password(req.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if len(req.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if len(req.new_password) > 256:
|
||||
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
||||
if req.current_password == req.new_password:
|
||||
raise HTTPException(status_code=400, detail="New password must be different from current password")
|
||||
user.password_hash = hash_password(req.new_password)
|
||||
db.commit()
|
||||
return {"message": "Password updated"}
|
||||
|
||||
@app.post("/auth/forgot-password")
|
||||
def forgot_password(req: ForgotPasswordRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
||||
# Always return success even if email not found. Prevents user enumeration
|
||||
if user:
|
||||
token = secrets.token_urlsafe(32)
|
||||
user.reset_token = token
|
||||
user.reset_token_expires_at = datetime.now() + timedelta(hours=1)
|
||||
db.commit()
|
||||
try:
|
||||
send_password_reset_email(user.email, user.username, token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset email: {e}")
|
||||
return {"message": "If that email is registered you will receive a reset link shortly"}
|
||||
|
||||
@app.post("/auth/reset-password-with-token")
|
||||
def reset_password_with_token(req: ResetPasswordWithTokenRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.reset_token == req.token).first()
|
||||
if not user or not user.reset_token_expires_at or user.reset_token_expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired reset link")
|
||||
if len(req.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if len(req.new_password) > 256:
|
||||
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
||||
user.password_hash = hash_password(req.new_password)
|
||||
user.reset_token = None
|
||||
user.reset_token_expires_at = None
|
||||
db.commit()
|
||||
return {"message": "Password updated"}
|
||||
|
||||
@app.get("/auth/verify-email")
|
||||
def verify_email(token: str, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email_verification_token == token).first()
|
||||
if not user or not user.email_verification_token_expires_at or user.email_verification_token_expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired verification link")
|
||||
user.email_verified = True
|
||||
user.email_verification_token = None
|
||||
user.email_verification_token_expires_at = None
|
||||
db.commit()
|
||||
return {"message": "Email verified"}
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@app.post("/auth/resend-verification")
|
||||
def resend_verification(req: ResendVerificationRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
||||
# Always return success to prevent user enumeration
|
||||
if user and not user.email_verified:
|
||||
token = secrets.token_urlsafe(32)
|
||||
user.email_verification_token = token
|
||||
user.email_verification_token_expires_at = datetime.now() + timedelta(hours=24)
|
||||
db.commit()
|
||||
try:
|
||||
send_verification_email(user.email, user.username, token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resend verification email: {e}")
|
||||
return {"message": "If that email is registered and unverified, you will receive a new verification link shortly"}
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
@app.post("/auth/refresh")
|
||||
def refresh(req: RefreshRequest, db: Session = Depends(get_db)):
|
||||
user_id = decode_refresh_token(req.refresh_token)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
return {
|
||||
"access_token": create_access_token(str(user.id)),
|
||||
"refresh_token": create_refresh_token(str(user.id)),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ai import AIPersonality, choose_cards
|
||||
from card import generate_cards, Card
|
||||
from time import sleep
|
||||
|
||||
all_cards = generate_cards(500)
|
||||
|
||||
all_cards.sort(key=lambda x: x.cost, reverse=True)
|
||||
|
||||
print(len(all_cards))
|
||||
def write_cards(cards: list[Card], file: str):
|
||||
with open(file, "w") as fp:
|
||||
fp.write('\n'.join([
|
||||
f"{c.name} - {c.attack}/{c.defense} - {c.cost}"
|
||||
for c in cards
|
||||
]))
|
||||
|
||||
write_cards(all_cards, "output/all.txt")
|
||||
|
||||
for personality in AIPersonality:
|
||||
print(personality.value)
|
||||
for difficulty in range(1,11):
|
||||
chosen_cards = choose_cards(all_cards, difficulty, personality)
|
||||
chosen_cards.sort(key=lambda x: x.cost, reverse=True)
|
||||
write_cards(chosen_cards, f"output/{personality.value}-{difficulty}.txt")
|
||||
app.include_router(health.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(cards.router)
|
||||
app.include_router(decks.router)
|
||||
app.include_router(games.router)
|
||||
app.include_router(notifications.router)
|
||||
app.include_router(profile.router)
|
||||
app.include_router(friends.router)
|
||||
app.include_router(store.router)
|
||||
app.include_router(trades.router)
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Integer, ForeignKey, DateTime, Text, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from database import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
boosters: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
boosters_countdown: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
wins: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
losses: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
shards: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
last_refresh_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
reset_token: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
reset_token_expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
email_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
email_verification_token: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
email_verification_token_expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
cards: Mapped[list["Card"]] = relationship(back_populates="user")
|
||||
decks: Mapped[list["Deck"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class Card(Base):
|
||||
__tablename__ = "cards"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
image_link: Mapped[str] = mapped_column(String, nullable=True)
|
||||
card_rarity: Mapped[str] = mapped_column(String, nullable=False)
|
||||
card_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
attack: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
defense: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
cost: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
times_played: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
reported: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_used: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
user: Mapped["User | None"] = relationship(back_populates="cards")
|
||||
deck_cards: Mapped[list["DeckCard"]] = relationship(back_populates="card")
|
||||
|
||||
|
||||
class Deck(Base):
|
||||
__tablename__ = "decks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
times_played: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
wins: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
losses: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="decks")
|
||||
deck_cards: Mapped[list["DeckCard"]] = relationship(back_populates="deck")
|
||||
|
||||
|
||||
class DeckCard(Base):
|
||||
__tablename__ = "deck_cards"
|
||||
|
||||
deck_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("decks.id"), primary_key=True)
|
||||
card_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("cards.id"), primary_key=True)
|
||||
|
||||
deck: Mapped["Deck"] = relationship(back_populates="deck_cards")
|
||||
card: Mapped["Card"] = relationship(back_populates="deck_cards")
|
||||
File diff suppressed because one or more lines are too long
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
224
backend/routers/auth.py
Normal file
224
backend/routers/auth.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.auth import (
|
||||
create_access_token, create_refresh_token,
|
||||
decode_refresh_token, hash_password, verify_password,
|
||||
)
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user, limiter
|
||||
from services.email_utils import send_password_reset_email, send_verification_email
|
||||
from core.models import User as UserModel
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
try:
|
||||
from disposable_email_domains import blocklist as _disposable_blocklist
|
||||
except ImportError:
|
||||
_disposable_blocklist: set[str] = set()
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
password: str
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
class ResetPasswordWithTokenRequest(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
def validate_register(username: str, email: str, password: str) -> str | None:
|
||||
if not username.strip():
|
||||
return "Username is required"
|
||||
if len(username) < 2:
|
||||
return "Username must be at least 2 characters"
|
||||
if len(username) > 16:
|
||||
return "Username must be 16 characters or fewer"
|
||||
if not re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", email):
|
||||
return "Please enter a valid email"
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain in _disposable_blocklist:
|
||||
return "Disposable email addresses are not allowed"
|
||||
if len(password) < 8:
|
||||
return "Password must be at least 8 characters"
|
||||
if len(password) > 256:
|
||||
return "Password must be 256 characters or fewer"
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
@limiter.limit("5/minute")
|
||||
def register(request: Request, req: RegisterRequest, db: Session = Depends(get_db)):
|
||||
err = validate_register(req.username, req.email, req.password)
|
||||
if err:
|
||||
raise HTTPException(status_code=400, detail=err)
|
||||
if db.query(UserModel).filter(UserModel.username.ilike(req.username)).first():
|
||||
raise HTTPException(status_code=400, detail="Username already taken")
|
||||
if db.query(UserModel).filter(UserModel.email == req.email).first():
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
user = UserModel(
|
||||
id=uuid.uuid4(),
|
||||
username=req.username,
|
||||
email=req.email,
|
||||
password_hash=hash_password(req.password),
|
||||
email_verified=False,
|
||||
email_verification_token=verification_token,
|
||||
email_verification_token_expires_at=datetime.now() + timedelta(hours=24),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
try:
|
||||
send_verification_email(req.email, req.username, verification_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Account created but we couldn't send the verification email. Please use 'Resend verification' to try again."
|
||||
)
|
||||
return {"message": "Account created. Please check your email to verify your account."}
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@limiter.limit("10/minute")
|
||||
def login(request: Request, form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.username.ilike(form.username)).first()
|
||||
if not user or not verify_password(form.password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Invalid username or password")
|
||||
user.last_active_at = datetime.now()
|
||||
db.commit()
|
||||
return {
|
||||
"access_token": create_access_token(str(user.id)),
|
||||
"refresh_token": create_refresh_token(str(user.id)),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/auth/reset-password")
|
||||
@limiter.limit("5/minute")
|
||||
def reset_password(request: Request, req: ResetPasswordRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not verify_password(req.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if len(req.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if len(req.new_password) > 256:
|
||||
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
||||
if req.current_password == req.new_password:
|
||||
raise HTTPException(status_code=400, detail="New password must be different from current password")
|
||||
user.password_hash = hash_password(req.new_password)
|
||||
db.commit()
|
||||
return {"message": "Password updated"}
|
||||
|
||||
|
||||
@router.post("/auth/forgot-password")
|
||||
@limiter.limit("5/minute")
|
||||
def forgot_password(request: Request, req: ForgotPasswordRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
||||
# Always return success even if email not found. Prevents user enumeration
|
||||
if user:
|
||||
token = secrets.token_urlsafe(32)
|
||||
user.reset_token = token
|
||||
user.reset_token_expires_at = datetime.now() + timedelta(hours=1)
|
||||
db.commit()
|
||||
try:
|
||||
send_password_reset_email(user.email, user.username, token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset email: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to send the password reset email. Please try again later."
|
||||
)
|
||||
return {"message": "If that email is registered you will receive a reset link shortly"}
|
||||
|
||||
|
||||
@router.post("/auth/reset-password-with-token")
|
||||
@limiter.limit("5/minute")
|
||||
def reset_password_with_token(request: Request, req: ResetPasswordWithTokenRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.reset_token == req.token).first()
|
||||
if not user or not user.reset_token_expires_at or user.reset_token_expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired reset link")
|
||||
if len(req.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if len(req.new_password) > 256:
|
||||
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
||||
user.password_hash = hash_password(req.new_password)
|
||||
user.reset_token = None
|
||||
user.reset_token_expires_at = None
|
||||
db.commit()
|
||||
return {"message": "Password updated"}
|
||||
|
||||
|
||||
@router.get("/auth/verify-email")
|
||||
@limiter.limit("10/minute")
|
||||
def verify_email(request: Request, token: str, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email_verification_token == token).first()
|
||||
if not user or not user.email_verification_token_expires_at or user.email_verification_token_expires_at < datetime.now():
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired verification link")
|
||||
user.email_verified = True
|
||||
user.email_verification_token = None
|
||||
user.email_verification_token_expires_at = None
|
||||
db.commit()
|
||||
return {"message": "Email verified"}
|
||||
|
||||
|
||||
@router.post("/auth/resend-verification")
|
||||
@limiter.limit("5/minute")
|
||||
def resend_verification(request: Request, req: ResendVerificationRequest, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
||||
# Always return success to prevent user enumeration
|
||||
if user and not user.email_verified:
|
||||
token = secrets.token_urlsafe(32)
|
||||
user.email_verification_token = token
|
||||
user.email_verification_token_expires_at = datetime.now() + timedelta(hours=24)
|
||||
db.commit()
|
||||
try:
|
||||
send_verification_email(user.email, user.username, token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resend verification email: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to send the verification email. Please try again later."
|
||||
)
|
||||
return {"message": "If that email is registered and unverified, you will receive a new verification link shortly"}
|
||||
|
||||
|
||||
@router.post("/auth/refresh")
|
||||
@limiter.limit("20/minute")
|
||||
def refresh(request: Request, req: RefreshRequest, db: Session = Depends(get_db)):
|
||||
user_id = decode_refresh_token(req.refresh_token)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
user.last_active_at = datetime.now()
|
||||
db.commit()
|
||||
return {
|
||||
"access_token": create_access_token(str(user.id)),
|
||||
"refresh_token": create_refresh_token(str(user.id)),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
229
backend/routers/cards.py
Normal file
229
backend/routers/cards.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import asc, case, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from game.card import _get_specific_card_async
|
||||
from core.database import get_db
|
||||
from services.database_functions import check_boosters, fill_card_pool, BOOSTER_MAX
|
||||
from core.dependencies import get_current_user, limiter
|
||||
from core.models import Card as CardModel
|
||||
from core.models import Deck as DeckModel
|
||||
from core.models import DeckCard as DeckCardModel
|
||||
from core.models import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/boosters")
|
||||
def get_boosters(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
count, countdown = check_boosters(user, db)
|
||||
return {"count": count, "countdown": countdown, "email_verified": user.email_verified}
|
||||
|
||||
|
||||
@router.get("/cards")
|
||||
def get_cards(
|
||||
skip: int = 0,
|
||||
limit: int = 40,
|
||||
search: str = "",
|
||||
rarities: list[str] = Query(default=[]),
|
||||
types: list[str] = Query(default=[]),
|
||||
cost_min: int = 1,
|
||||
cost_max: int = 10,
|
||||
favorites_only: bool = False,
|
||||
wtt_only: bool = False,
|
||||
sort_by: str = "name",
|
||||
sort_dir: str = "asc",
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
q = db.query(CardModel).filter(CardModel.user_id == user.id)
|
||||
|
||||
if search:
|
||||
q = q.filter(CardModel.name.ilike(f"%{search}%"))
|
||||
if rarities:
|
||||
q = q.filter(CardModel.card_rarity.in_(rarities))
|
||||
if types:
|
||||
q = q.filter(CardModel.card_type.in_(types))
|
||||
q = q.filter(CardModel.cost >= cost_min, CardModel.cost <= cost_max)
|
||||
if favorites_only:
|
||||
q = q.filter(CardModel.is_favorite == True)
|
||||
if wtt_only:
|
||||
q = q.filter(CardModel.willing_to_trade == True)
|
||||
|
||||
total = q.count()
|
||||
|
||||
# case() for rarity ordering matches frontend RARITY_ORDER constant
|
||||
rarity_order_expr = case(
|
||||
(CardModel.card_rarity == 'common', 0),
|
||||
(CardModel.card_rarity == 'uncommon', 1),
|
||||
(CardModel.card_rarity == 'rare', 2),
|
||||
(CardModel.card_rarity == 'super_rare', 3),
|
||||
(CardModel.card_rarity == 'epic', 4),
|
||||
(CardModel.card_rarity == 'legendary', 5),
|
||||
else_=0
|
||||
)
|
||||
# coalesce mirrors frontend: received_at ?? generated_at
|
||||
date_received_expr = func.coalesce(CardModel.received_at, CardModel.generated_at)
|
||||
|
||||
sort_map = {
|
||||
"name": CardModel.name,
|
||||
"cost": CardModel.cost,
|
||||
"attack": CardModel.attack,
|
||||
"defense": CardModel.defense,
|
||||
"rarity": rarity_order_expr,
|
||||
"date_generated": CardModel.generated_at,
|
||||
"date_received": date_received_expr,
|
||||
}
|
||||
sort_col = sort_map.get(sort_by, CardModel.name)
|
||||
order_fn = desc if sort_dir == "desc" else asc
|
||||
# Secondary sort by name keeps pages stable when primary values are tied
|
||||
q = q.order_by(order_fn(sort_col), asc(CardModel.name))
|
||||
|
||||
cards = q.offset(skip).limit(limit).all()
|
||||
return {
|
||||
"cards": [
|
||||
{c.name: getattr(card, c.name) for c in card.__table__.columns}
|
||||
for card in cards
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cards/in-decks")
|
||||
def get_cards_in_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck_ids = [d.id for d in db.query(DeckModel).filter(DeckModel.user_id == user.id, DeckModel.deleted == False).all()]
|
||||
if not deck_ids:
|
||||
return []
|
||||
card_ids = db.query(DeckCardModel.card_id).filter(DeckCardModel.deck_id.in_(deck_ids)).distinct().all()
|
||||
return [str(row.card_id) for row in card_ids]
|
||||
|
||||
|
||||
@router.post("/open_pack")
|
||||
@limiter.limit("10/minute")
|
||||
async def open_pack(request: Request, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not user.email_verified:
|
||||
raise HTTPException(status_code=403, detail="You must verify your email before opening packs")
|
||||
|
||||
check_boosters(user, db)
|
||||
|
||||
if user.boosters == 0:
|
||||
raise HTTPException(status_code=400, detail="No booster packs available")
|
||||
|
||||
cards = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == None, CardModel.ai_used == False)
|
||||
.limit(5)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(cards) < 5:
|
||||
asyncio.create_task(fill_card_pool())
|
||||
raise HTTPException(status_code=503, detail="Card pool is low, please try again shortly")
|
||||
|
||||
now = datetime.now()
|
||||
for card in cards:
|
||||
card.user_id = user.id
|
||||
card.received_at = now
|
||||
|
||||
was_full = user.boosters == BOOSTER_MAX
|
||||
user.boosters -= 1
|
||||
if was_full:
|
||||
user.boosters_countdown = datetime.now()
|
||||
|
||||
db.commit()
|
||||
|
||||
asyncio.create_task(fill_card_pool())
|
||||
|
||||
return [
|
||||
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type}
|
||||
for card in cards
|
||||
]
|
||||
|
||||
|
||||
@router.post("/cards/{card_id}/report")
|
||||
def report_card(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
card.reported = True
|
||||
db.commit()
|
||||
return {"message": "Card reported"}
|
||||
|
||||
|
||||
@router.post("/cards/{card_id}/refresh")
|
||||
@limiter.limit("5/hour")
|
||||
async def refresh_card(request: Request, card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
|
||||
if user.last_refresh_at and datetime.now() - user.last_refresh_at < timedelta(minutes=10):
|
||||
remaining = (user.last_refresh_at + timedelta(minutes=10)) - datetime.now()
|
||||
minutes = int(remaining.total_seconds() // 60)
|
||||
seconds = int(remaining.total_seconds() % 60)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You can refresh again in {minutes}m {seconds}s"
|
||||
)
|
||||
|
||||
new_card = await _get_specific_card_async(card.name)
|
||||
if not new_card:
|
||||
raise HTTPException(status_code=502, detail="Failed to regenerate card from Wikipedia")
|
||||
|
||||
card.image_link = new_card.image_link
|
||||
card.card_rarity = new_card.card_rarity.name
|
||||
card.card_type = new_card.card_type.name
|
||||
card.text = new_card.text
|
||||
card.attack = new_card.attack
|
||||
card.defense = new_card.defense
|
||||
card.cost = new_card.cost
|
||||
card.reported = False
|
||||
card.generated_at = datetime.now()
|
||||
card.received_at = datetime.now()
|
||||
|
||||
user.last_refresh_at = datetime.now()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/cards/{card_id}/favorite")
|
||||
def toggle_favorite(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
card.is_favorite = not card.is_favorite
|
||||
db.commit()
|
||||
return {"is_favorite": card.is_favorite}
|
||||
|
||||
|
||||
@router.post("/cards/{card_id}/willing-to-trade")
|
||||
def toggle_willing_to_trade(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
card = db.query(CardModel).filter(
|
||||
CardModel.id == uuid.UUID(card_id),
|
||||
CardModel.user_id == user.id
|
||||
).first()
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Card not found")
|
||||
card.willing_to_trade = not card.willing_to_trade
|
||||
db.commit()
|
||||
return {"willing_to_trade": card.willing_to_trade}
|
||||
97
backend/routers/decks.py
Normal file
97
backend/routers/decks.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from game.card import compute_deck_type
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user
|
||||
from core.models import Card as CardModel
|
||||
from core.models import Deck as DeckModel
|
||||
from core.models import DeckCard as DeckCardModel
|
||||
from core.models import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class DeckUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=64)
|
||||
card_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
@router.get("/decks")
|
||||
def get_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
decks = db.query(DeckModel).options(
|
||||
selectinload(DeckModel.deck_cards).selectinload(DeckCardModel.card)
|
||||
).filter(
|
||||
DeckModel.user_id == user.id,
|
||||
DeckModel.deleted == False
|
||||
).order_by(DeckModel.created_at).all()
|
||||
result = []
|
||||
for deck in decks:
|
||||
cards = [dc.card for dc in deck.deck_cards]
|
||||
result.append({
|
||||
"id": str(deck.id),
|
||||
"name": deck.name,
|
||||
"card_count": len(cards),
|
||||
"total_cost": sum(card.cost for card in cards),
|
||||
"times_played": deck.times_played,
|
||||
"wins": deck.wins,
|
||||
"losses": deck.losses,
|
||||
"deck_type": compute_deck_type(cards),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/decks")
|
||||
def create_deck(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
count = db.query(DeckModel).filter(DeckModel.user_id == user.id).count()
|
||||
deck = DeckModel(id=uuid.uuid4(), user_id=user.id, name=f"Deck #{count + 1}")
|
||||
db.add(deck)
|
||||
db.commit()
|
||||
return {"id": str(deck.id), "name": deck.name, "card_count": 0}
|
||||
|
||||
|
||||
@router.patch("/decks/{deck_id}")
|
||||
def update_deck(deck_id: str, body: DeckUpdate, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
if body.name is not None:
|
||||
deck.name = body.name
|
||||
if body.card_ids is not None:
|
||||
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
|
||||
for card_id in body.card_ids:
|
||||
db.add(DeckCardModel(deck_id=deck.id, card_id=uuid.UUID(card_id)))
|
||||
if deck.times_played > 0:
|
||||
deck.wins = 0
|
||||
deck.losses = 0
|
||||
deck.times_played = 0
|
||||
db.commit()
|
||||
return {"id": str(deck.id), "name": deck.name}
|
||||
|
||||
|
||||
@router.delete("/decks/{deck_id}")
|
||||
def delete_deck(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
if deck.times_played > 0:
|
||||
deck.deleted = True
|
||||
else:
|
||||
db.delete(deck)
|
||||
db.commit()
|
||||
return {"message": "Deleted"}
|
||||
|
||||
|
||||
@router.get("/decks/{deck_id}/cards")
|
||||
def get_deck_cards(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
deck_cards = db.query(DeckCardModel).options(
|
||||
selectinload(DeckCardModel.card)
|
||||
).filter(DeckCardModel.deck_id == deck.id).all()
|
||||
return [{"id": str(dc.card_id), "cost": dc.card.cost} for dc in deck_cards]
|
||||
134
backend/routers/friends.py
Normal file
134
backend/routers/friends.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from services import notification_manager
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user, get_user_id_from_request, limiter
|
||||
from core.models import Friendship as FriendshipModel
|
||||
from core.models import Notification as NotificationModel
|
||||
from core.models import User as UserModel
|
||||
from routers.notifications import _serialize_notification
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/users/{username}/friend-request")
|
||||
@limiter.limit("10/minute", key_func=get_user_id_from_request)
|
||||
async def send_friend_request(request: Request, username: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
addressee = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if not addressee:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if addressee.id == user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot send friend request to yourself")
|
||||
|
||||
# Check for any existing friendship in either direction
|
||||
existing = db.query(FriendshipModel).filter(
|
||||
((FriendshipModel.requester_id == user.id) & (FriendshipModel.addressee_id == addressee.id)) |
|
||||
((FriendshipModel.requester_id == addressee.id) & (FriendshipModel.addressee_id == user.id)),
|
||||
).first()
|
||||
if existing and existing.status != "declined":
|
||||
raise HTTPException(status_code=400, detail="Friend request already exists or already friends")
|
||||
# Clear stale declined row so the unique constraint allows re-requesting
|
||||
if existing:
|
||||
db.delete(existing)
|
||||
db.flush()
|
||||
|
||||
friendship = FriendshipModel(requester_id=user.id, addressee_id=addressee.id, status="pending")
|
||||
db.add(friendship)
|
||||
db.flush() # get friendship.id before notification
|
||||
|
||||
notif = NotificationModel(
|
||||
user_id=addressee.id,
|
||||
type="friend_request",
|
||||
payload={"friendship_id": str(friendship.id), "from_username": user.username},
|
||||
)
|
||||
db.add(notif)
|
||||
db.commit()
|
||||
|
||||
await notification_manager.send_notification(str(addressee.id), _serialize_notification(notif))
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/friendships/{friendship_id}/accept")
|
||||
def accept_friend_request(friendship_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
friendship = db.query(FriendshipModel).filter(FriendshipModel.id == uuid.UUID(friendship_id)).first()
|
||||
if not friendship:
|
||||
raise HTTPException(status_code=404, detail="Friendship not found")
|
||||
if friendship.addressee_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
if friendship.status != "pending":
|
||||
raise HTTPException(status_code=400, detail="Friendship is not pending")
|
||||
friendship.status = "accepted"
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/friendships/{friendship_id}/decline")
|
||||
def decline_friend_request(friendship_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
friendship = db.query(FriendshipModel).filter(FriendshipModel.id == uuid.UUID(friendship_id)).first()
|
||||
if not friendship:
|
||||
raise HTTPException(status_code=404, detail="Friendship not found")
|
||||
if friendship.addressee_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
if friendship.status != "pending":
|
||||
raise HTTPException(status_code=400, detail="Friendship is not pending")
|
||||
friendship.status = "declined"
|
||||
# Clean up the associated notification so it disappears from the bell
|
||||
db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == user.id,
|
||||
NotificationModel.type == "friend_request",
|
||||
NotificationModel.payload["friendship_id"].astext == friendship_id,
|
||||
).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/friends")
|
||||
def get_friends(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
friendships = db.query(FriendshipModel).options(
|
||||
joinedload(FriendshipModel.requester),
|
||||
joinedload(FriendshipModel.addressee),
|
||||
).filter(
|
||||
(FriendshipModel.requester_id == user.id) | (FriendshipModel.addressee_id == user.id),
|
||||
FriendshipModel.status == "accepted",
|
||||
).all()
|
||||
result = []
|
||||
for f in friendships:
|
||||
other = f.addressee if f.requester_id == user.id else f.requester
|
||||
result.append({"id": str(other.id), "username": other.username, "friendship_id": str(f.id)})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/friendship-status/{username}")
|
||||
def get_friendship_status(username: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
"""Returns the friendship status between the current user and the given username."""
|
||||
other = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if not other:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
friendship = db.query(FriendshipModel).filter(
|
||||
((FriendshipModel.requester_id == user.id) & (FriendshipModel.addressee_id == other.id)) |
|
||||
((FriendshipModel.requester_id == other.id) & (FriendshipModel.addressee_id == user.id)),
|
||||
FriendshipModel.status != "declined",
|
||||
).first()
|
||||
if not friendship:
|
||||
return {"status": "none"}
|
||||
if friendship.status == "accepted":
|
||||
return {"status": "friends", "friendship_id": str(friendship.id)}
|
||||
# pending: distinguish sent vs received
|
||||
if friendship.requester_id == user.id:
|
||||
return {"status": "pending_sent", "friendship_id": str(friendship.id)}
|
||||
return {"status": "pending_received", "friendship_id": str(friendship.id)}
|
||||
|
||||
|
||||
@router.delete("/friendships/{friendship_id}")
|
||||
def remove_friend(friendship_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
friendship = db.query(FriendshipModel).filter(FriendshipModel.id == uuid.UUID(friendship_id)).first()
|
||||
if not friendship:
|
||||
raise HTTPException(status_code=404, detail="Friendship not found")
|
||||
if friendship.requester_id != user.id and friendship.addressee_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
db.delete(friendship)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
404
backend/routers/games.py
Normal file
404
backend/routers/games.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from services import notification_manager
|
||||
from core.auth import decode_access_token
|
||||
from core.database import get_db
|
||||
from services.database_functions import fill_card_pool
|
||||
from core.dependencies import get_current_user, get_user_id_from_request, limiter
|
||||
from game.manager import (
|
||||
QueueEntry, active_games, connections, create_challenge_game, create_solo_game,
|
||||
handle_action, handle_disconnect, handle_timeout_claim, load_deck_cards,
|
||||
queue, queue_lock, serialize_state, try_match,
|
||||
)
|
||||
from core.models import Card as CardModel
|
||||
from core.models import Deck as DeckModel
|
||||
from core.models import DeckCard as DeckCardModel
|
||||
from core.models import GameChallenge as GameChallengeModel
|
||||
from core.models import Notification as NotificationModel
|
||||
from core.models import User as UserModel
|
||||
from routers.notifications import _serialize_notification
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _serialize_challenge(c: GameChallengeModel, current_user_id: uuid.UUID) -> dict:
|
||||
deck = c.challenger_deck
|
||||
return {
|
||||
"id": str(c.id),
|
||||
"status": c.status,
|
||||
"direction": "outgoing" if c.challenger_id == current_user_id else "incoming",
|
||||
"challenger_username": c.challenger.username,
|
||||
"challenged_username": c.challenged.username,
|
||||
"deck_name": deck.name if deck else "Unknown Deck",
|
||||
"deck_id": str(c.challenger_deck_id),
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"expires_at": c.expires_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ── WebSocket game matchmaking ────────────────────────────────────────────────
|
||||
|
||||
@router.websocket("/ws/queue")
|
||||
async def queue_endpoint(websocket: WebSocket, deck_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
deck = db.query(DeckModel).filter(
|
||||
DeckModel.id == uuid.UUID(deck_id),
|
||||
DeckModel.user_id == uuid.UUID(user_id)
|
||||
).first()
|
||||
|
||||
if not deck:
|
||||
await websocket.send_json({"type": "error", "message": "Deck not found"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
||||
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
||||
if total_cost == 0 or total_cost > 50:
|
||||
await websocket.send_json({"type": "error", "message": "Deck total cost must be between 1 and 50"})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
entry = QueueEntry(user_id=user_id, deck_id=deck_id, websocket=websocket)
|
||||
|
||||
async with queue_lock:
|
||||
queue.append(entry)
|
||||
|
||||
await websocket.send_json({"type": "queued"})
|
||||
await try_match(db)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keeping socket alive
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
async with queue_lock:
|
||||
queue[:] = [e for e in queue if e.user_id != user_id]
|
||||
|
||||
|
||||
@router.websocket("/ws/game/{game_id}")
|
||||
async def game_endpoint(websocket: WebSocket, game_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
if game_id not in active_games:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
if user_id not in active_games[game_id].players:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Register this connection (handles reconnects)
|
||||
connections[game_id][user_id] = websocket
|
||||
|
||||
# Send current state immediately on connect
|
||||
await websocket.send_json({
|
||||
"type": "state",
|
||||
"state": serialize_state(active_games[game_id], user_id),
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
await handle_action(game_id, user_id, data, db)
|
||||
except WebSocketDisconnect:
|
||||
if game_id in connections:
|
||||
connections[game_id].pop(user_id, None)
|
||||
asyncio.create_task(handle_disconnect(game_id, user_id))
|
||||
|
||||
|
||||
# ── Game challenges ───────────────────────────────────────────────────────────
|
||||
|
||||
class CreateGameChallengeRequest(BaseModel):
|
||||
deck_id: str
|
||||
|
||||
class AcceptGameChallengeRequest(BaseModel):
|
||||
deck_id: str
|
||||
|
||||
|
||||
@router.post("/users/{username}/challenge")
|
||||
@limiter.limit("10/minute", key_func=get_user_id_from_request)
|
||||
async def create_game_challenge(
|
||||
request: Request,
|
||||
username: str,
|
||||
req: CreateGameChallengeRequest,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
target = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if target.id == user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot challenge yourself")
|
||||
|
||||
try:
|
||||
deck_id = uuid.UUID(req.deck_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid deck_id")
|
||||
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == deck_id, DeckModel.user_id == user.id, DeckModel.deleted == False).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
|
||||
existing = db.query(GameChallengeModel).filter(
|
||||
GameChallengeModel.status == "pending",
|
||||
(
|
||||
((GameChallengeModel.challenger_id == user.id) & (GameChallengeModel.challenged_id == target.id)) |
|
||||
((GameChallengeModel.challenger_id == target.id) & (GameChallengeModel.challenged_id == user.id))
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="A pending challenge already exists between you two")
|
||||
|
||||
now = datetime.now()
|
||||
challenge = GameChallengeModel(
|
||||
challenger_id=user.id,
|
||||
challenged_id=target.id,
|
||||
challenger_deck_id=deck_id,
|
||||
expires_at=now + timedelta(minutes=5),
|
||||
)
|
||||
db.add(challenge)
|
||||
db.flush()
|
||||
|
||||
notif = NotificationModel(
|
||||
user_id=target.id,
|
||||
type="game_challenge",
|
||||
expires_at=challenge.expires_at,
|
||||
payload={
|
||||
"challenge_id": str(challenge.id),
|
||||
"from_username": user.username,
|
||||
"deck_name": deck.name,
|
||||
},
|
||||
)
|
||||
db.add(notif)
|
||||
db.commit()
|
||||
|
||||
await notification_manager.send_notification(str(target.id), _serialize_notification(notif))
|
||||
return {"challenge_id": str(challenge.id)}
|
||||
|
||||
|
||||
@router.post("/challenges/{challenge_id}/accept")
|
||||
async def accept_game_challenge(
|
||||
challenge_id: str,
|
||||
req: AcceptGameChallengeRequest,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
cid = uuid.UUID(challenge_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid challenge_id")
|
||||
|
||||
challenge = db.query(GameChallengeModel).filter(GameChallengeModel.id == cid).with_for_update().first()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
if challenge.challenged_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
now = datetime.now()
|
||||
if challenge.status == "pending" and now > challenge.expires_at:
|
||||
challenge.status = "expired"
|
||||
db.commit()
|
||||
raise HTTPException(status_code=400, detail="Challenge has expired")
|
||||
if challenge.status != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"Challenge is already {challenge.status}")
|
||||
|
||||
try:
|
||||
deck_id = uuid.UUID(req.deck_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid deck_id")
|
||||
|
||||
deck = db.query(DeckModel).filter(DeckModel.id == deck_id, DeckModel.user_id == user.id, DeckModel.deleted == False).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
|
||||
# Verify challenger's deck still exists — it could have been deleted since the challenge was sent
|
||||
challenger_deck = db.query(DeckModel).filter(
|
||||
DeckModel.id == challenge.challenger_deck_id,
|
||||
DeckModel.deleted == False,
|
||||
).first()
|
||||
if not challenger_deck:
|
||||
raise HTTPException(status_code=400, detail="The challenger's deck no longer exists")
|
||||
|
||||
try:
|
||||
game_id = create_challenge_game(
|
||||
str(challenge.challenger_id), str(challenge.challenger_deck_id),
|
||||
str(challenge.challenged_id), str(deck_id),
|
||||
db,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
challenge.status = "accepted"
|
||||
|
||||
# Delete the original challenge notification from the challenged player's bell
|
||||
old_notif = db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == user.id,
|
||||
NotificationModel.type == "game_challenge",
|
||||
NotificationModel.payload["challenge_id"].astext == str(challenge.id),
|
||||
).first()
|
||||
deleted_notif_id = str(old_notif.id) if old_notif else None
|
||||
if old_notif:
|
||||
db.delete(old_notif)
|
||||
|
||||
# Notify the challenger that their challenge was accepted
|
||||
response_notif = NotificationModel(
|
||||
user_id=challenge.challenger_id,
|
||||
type="game_challenge",
|
||||
payload={
|
||||
"challenge_id": str(challenge.id),
|
||||
"status": "accepted",
|
||||
"game_id": game_id,
|
||||
"from_username": user.username,
|
||||
},
|
||||
)
|
||||
db.add(response_notif)
|
||||
db.commit()
|
||||
|
||||
if deleted_notif_id:
|
||||
await notification_manager.send_delete(str(user.id), deleted_notif_id)
|
||||
await notification_manager.send_notification(str(challenge.challenger_id), _serialize_notification(response_notif))
|
||||
|
||||
return {"game_id": game_id}
|
||||
|
||||
|
||||
@router.post("/challenges/{challenge_id}/decline")
|
||||
async def decline_game_challenge(
|
||||
challenge_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
cid = uuid.UUID(challenge_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid challenge_id")
|
||||
|
||||
challenge = db.query(GameChallengeModel).filter(GameChallengeModel.id == cid).first()
|
||||
if not challenge:
|
||||
raise HTTPException(status_code=404, detail="Challenge not found")
|
||||
if challenge.challenger_id != user.id and challenge.challenged_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
now = datetime.now()
|
||||
if challenge.status == "pending" and now > challenge.expires_at:
|
||||
challenge.status = "expired"
|
||||
db.commit()
|
||||
raise HTTPException(status_code=400, detail="Challenge has already expired")
|
||||
if challenge.status != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"Challenge is already {challenge.status}")
|
||||
|
||||
is_withdrawal = challenge.challenger_id == user.id
|
||||
challenge.status = "withdrawn" if is_withdrawal else "declined"
|
||||
|
||||
# Remove the notification from the other party's bell
|
||||
if is_withdrawal:
|
||||
# Challenger withdrawing: remove challenge notif from challenged player's bell
|
||||
notif = db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == challenge.challenged_id,
|
||||
NotificationModel.type == "game_challenge",
|
||||
NotificationModel.payload["challenge_id"].astext == str(challenge.id),
|
||||
).first()
|
||||
recipient_id = str(challenge.challenged_id)
|
||||
else:
|
||||
# Challenged player declining: remove challenge notif from their own bell
|
||||
notif = db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == user.id,
|
||||
NotificationModel.type == "game_challenge",
|
||||
NotificationModel.payload["challenge_id"].astext == str(challenge.id),
|
||||
).first()
|
||||
recipient_id = str(user.id)
|
||||
|
||||
deleted_notif_id = str(notif.id) if notif else None
|
||||
if notif:
|
||||
db.delete(notif)
|
||||
db.commit()
|
||||
|
||||
if deleted_notif_id:
|
||||
await notification_manager.send_delete(recipient_id, deleted_notif_id)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/challenges")
|
||||
def get_challenges(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
now = datetime.now()
|
||||
# Lazy-expire pending challenges past deadline
|
||||
db.query(GameChallengeModel).filter(
|
||||
GameChallengeModel.status == "pending",
|
||||
GameChallengeModel.expires_at < now,
|
||||
(GameChallengeModel.challenger_id == user.id) | (GameChallengeModel.challenged_id == user.id),
|
||||
).update({"status": "expired"})
|
||||
db.commit()
|
||||
|
||||
challenges = db.query(GameChallengeModel).options(
|
||||
joinedload(GameChallengeModel.challenger_deck)
|
||||
).filter(
|
||||
(GameChallengeModel.challenger_id == user.id) | (GameChallengeModel.challenged_id == user.id)
|
||||
).order_by(GameChallengeModel.created_at.desc()).all()
|
||||
|
||||
return [_serialize_challenge(c, user.id) for c in challenges]
|
||||
|
||||
|
||||
@router.post("/game/{game_id}/claim-timeout-win")
|
||||
async def claim_timeout_win(game_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
err = await handle_timeout_claim(game_id, str(user.id), db)
|
||||
if err:
|
||||
raise HTTPException(status_code=400, detail=err)
|
||||
return {"message": "Win claimed"}
|
||||
|
||||
|
||||
@router.post("/game/solo")
|
||||
async def start_solo_game(deck_id: str, difficulty: int = 5, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if difficulty < 1 or difficulty > 10:
|
||||
raise HTTPException(status_code=400, detail="Difficulty must be between 1 and 10")
|
||||
|
||||
deck = db.query(DeckModel).filter(
|
||||
DeckModel.id == uuid.UUID(deck_id),
|
||||
DeckModel.user_id == user.id
|
||||
).first()
|
||||
if not deck:
|
||||
raise HTTPException(status_code=404, detail="Deck not found")
|
||||
|
||||
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
||||
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
||||
if total_cost == 0 or total_cost > 50:
|
||||
raise HTTPException(status_code=400, detail="Deck total cost must be between 1 and 50")
|
||||
|
||||
player_cards = load_deck_cards(deck_id, str(user.id), db)
|
||||
if player_cards is None:
|
||||
raise HTTPException(status_code=503, detail="Couldn't load deck")
|
||||
|
||||
ai_cards = db.query(CardModel).filter(
|
||||
CardModel.user_id == None,
|
||||
).order_by(func.random()).limit(500).all()
|
||||
|
||||
if len(ai_cards) == 0:
|
||||
raise HTTPException(status_code=503, detail="Not enough cards in pool for AI deck")
|
||||
|
||||
for card in ai_cards:
|
||||
card.ai_used = True
|
||||
db.commit()
|
||||
|
||||
game_id = create_solo_game(str(user.id), user.username, player_cards, ai_cards, deck_id, difficulty)
|
||||
asyncio.create_task(fill_card_pool())
|
||||
|
||||
return {"game_id": game_id}
|
||||
17
backend/routers/health.py
Normal file
17
backend/routers/health.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from core.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/health")
|
||||
def health_check(db: Session = Depends(get_db)):
|
||||
# Validates that the DB is reachable, not just that the process is up
|
||||
db.execute(text("SELECT 1"))
|
||||
return {"status": "ok"}
|
||||
|
||||
@router.get("/teapot")
|
||||
def teapot():
|
||||
return JSONResponse(status_code=418, content={"message": "I'm a teapot"})
|
||||
115
backend/routers/notifications.py
Normal file
115
backend/routers/notifications.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services import notification_manager
|
||||
from core.auth import decode_access_token
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user
|
||||
from core.models import Notification as NotificationModel
|
||||
from core.models import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _serialize_notification(n: NotificationModel) -> dict:
|
||||
return {
|
||||
"id": str(n.id),
|
||||
"type": n.type,
|
||||
"payload": n.payload,
|
||||
"read": n.read,
|
||||
"created_at": n.created_at.isoformat(),
|
||||
"expires_at": n.expires_at.isoformat() if n.expires_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws/notifications")
|
||||
async def notifications_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
notification_manager.register(user_id, websocket)
|
||||
|
||||
# Flush all unread (non-expired) notifications on connect
|
||||
now = datetime.now()
|
||||
pending = (
|
||||
db.query(NotificationModel)
|
||||
.filter(
|
||||
NotificationModel.user_id == uuid.UUID(user_id),
|
||||
NotificationModel.read == False,
|
||||
(NotificationModel.expires_at == None) | (NotificationModel.expires_at > now),
|
||||
)
|
||||
.order_by(NotificationModel.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "flush",
|
||||
"notifications": [_serialize_notification(n) for n in pending],
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text() # keep connection alive; server only pushes
|
||||
except WebSocketDisconnect:
|
||||
notification_manager.unregister(user_id)
|
||||
|
||||
|
||||
@router.get("/notifications")
|
||||
def get_notifications(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
now = datetime.now()
|
||||
notifications = (
|
||||
db.query(NotificationModel)
|
||||
.filter(
|
||||
NotificationModel.user_id == user.id,
|
||||
(NotificationModel.expires_at == None) | (NotificationModel.expires_at > now),
|
||||
)
|
||||
.order_by(NotificationModel.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [_serialize_notification(n) for n in notifications]
|
||||
|
||||
|
||||
@router.post("/notifications/{notification_id}/read")
|
||||
def mark_notification_read(
|
||||
notification_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
n = db.query(NotificationModel).filter(
|
||||
NotificationModel.id == uuid.UUID(notification_id),
|
||||
NotificationModel.user_id == user.id,
|
||||
).first()
|
||||
if not n:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
n.read = True
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/notifications/{notification_id}")
|
||||
def delete_notification(
|
||||
notification_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
n = db.query(NotificationModel).filter(
|
||||
NotificationModel.id == uuid.UUID(notification_id),
|
||||
NotificationModel.user_id == user.id,
|
||||
).first()
|
||||
if not n:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
db.delete(n)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
147
backend/routers/profile.py
Normal file
147
backend/routers/profile.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user
|
||||
from core.models import Card as CardModel
|
||||
from core.models import Deck as DeckModel
|
||||
from core.models import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _serialize_card_public(card: CardModel) -> dict:
|
||||
"""Card fields safe to expose on public profiles (no user_id)."""
|
||||
return {
|
||||
"id": str(card.id),
|
||||
"name": card.name,
|
||||
"image_link": card.image_link,
|
||||
"card_rarity": card.card_rarity,
|
||||
"card_type": card.card_type,
|
||||
"text": card.text,
|
||||
"attack": card.attack,
|
||||
"defense": card.defense,
|
||||
"cost": card.cost,
|
||||
"is_favorite": card.is_favorite,
|
||||
"willing_to_trade": card.willing_to_trade,
|
||||
}
|
||||
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
trade_wishlist: str
|
||||
|
||||
|
||||
@router.get("/profile")
|
||||
def get_profile(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
total_games = user.wins + user.losses
|
||||
|
||||
most_played_deck = (
|
||||
db.query(DeckModel)
|
||||
.filter(DeckModel.user_id == user.id, DeckModel.times_played > 0)
|
||||
.order_by(DeckModel.times_played.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
most_played_card = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == user.id, CardModel.times_played > 0)
|
||||
.order_by(CardModel.times_played.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"email_verified": user.email_verified,
|
||||
"created_at": user.created_at,
|
||||
"wins": user.wins,
|
||||
"losses": user.losses,
|
||||
"shards": user.shards,
|
||||
"win_rate": round((user.wins / total_games) * 100) if total_games > 0 else None,
|
||||
"trade_wishlist": user.trade_wishlist or "",
|
||||
"most_played_deck": {
|
||||
"name": most_played_deck.name,
|
||||
"times_played": most_played_deck.times_played,
|
||||
} if most_played_deck else None,
|
||||
"most_played_card": {
|
||||
"name": most_played_card.name,
|
||||
"times_played": most_played_card.times_played,
|
||||
"card_type": most_played_card.card_type,
|
||||
"card_rarity": most_played_card.card_rarity,
|
||||
"image_link": most_played_card.image_link,
|
||||
} if most_played_card else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/profile/refresh-status")
|
||||
def refresh_status(user: UserModel = Depends(get_current_user)):
|
||||
if not user.last_refresh_at:
|
||||
return {"can_refresh": True, "next_refresh_at": None}
|
||||
next_refresh = user.last_refresh_at + timedelta(minutes=10)
|
||||
can_refresh = datetime.now() >= next_refresh
|
||||
return {
|
||||
"can_refresh": can_refresh,
|
||||
"next_refresh_at": next_refresh.isoformat() if not can_refresh else None,
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/profile")
|
||||
def update_profile(req: UpdateProfileRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
user.trade_wishlist = req.trade_wishlist
|
||||
db.commit()
|
||||
return {"trade_wishlist": user.trade_wishlist}
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
def search_users(q: str, current_user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
# Require auth to prevent scraping
|
||||
if len(q) < 2:
|
||||
return []
|
||||
results = (
|
||||
db.query(UserModel)
|
||||
.filter(UserModel.username.ilike(f"%{q}%"))
|
||||
.limit(20)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"username": u.username,
|
||||
"wins": u.wins,
|
||||
"losses": u.losses,
|
||||
"win_rate": round(u.wins / (u.wins + u.losses) * 100) if (u.wins + u.losses) > 0 else 0,
|
||||
}
|
||||
for u in results
|
||||
]
|
||||
|
||||
|
||||
@router.get("/users/{username}")
|
||||
def get_public_profile(username: str, db: Session = Depends(get_db)):
|
||||
user = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
total_games = user.wins + user.losses
|
||||
favorite_cards = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == user.id, CardModel.is_favorite == True)
|
||||
.order_by(CardModel.received_at.desc())
|
||||
.all()
|
||||
)
|
||||
wtt_cards = (
|
||||
db.query(CardModel)
|
||||
.filter(CardModel.user_id == user.id, CardModel.willing_to_trade == True)
|
||||
.order_by(CardModel.received_at.desc())
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"username": user.username,
|
||||
"wins": user.wins,
|
||||
"losses": user.losses,
|
||||
"win_rate": round((user.wins / total_games) * 100) if total_games > 0 else None,
|
||||
"trade_wishlist": user.trade_wishlist or "",
|
||||
"last_active_at": user.last_active_at.isoformat() if user.last_active_at else None,
|
||||
"favorite_cards": [_serialize_card_public(c) for c in favorite_cards],
|
||||
"wtt_cards": [_serialize_card_public(c) for c in wtt_cards],
|
||||
}
|
||||
189
backend/routers/store.py
Normal file
189
backend/routers/store.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import stripe
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from game.card import _get_specific_card_async
|
||||
from core.config import FRONTEND_URL, STRIPE_PUBLISHABLE_KEY, STRIPE_WEBHOOK_SECRET
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user, limiter
|
||||
from core.models import Card as CardModel
|
||||
from core.models import ProcessedWebhookEvent
|
||||
from core.models import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Shard packages sold for real money.
|
||||
# price_oere is in Danish øre (1 DKK = 100 øre). Stripe minimum is 250 øre.
|
||||
SHARD_PACKAGES = {
|
||||
"s1": {"base": 100, "bonus": 0, "shards": 100, "price_oere": 1000, "price_label": "10 DKK"},
|
||||
"s2": {"base": 250, "bonus": 50, "shards": 300, "price_oere": 2500, "price_label": "25 DKK"},
|
||||
"s3": {"base": 500, "bonus": 200, "shards": 700, "price_oere": 5000, "price_label": "50 DKK"},
|
||||
"s4": {"base": 1000, "bonus": 600, "shards": 1600, "price_oere": 10000, "price_label": "100 DKK"},
|
||||
"s5": {"base": 2500, "bonus": 2000, "shards": 4500, "price_oere": 25000, "price_label": "250 DKK"},
|
||||
"s6": {"base": 5000, "bonus": 5000, "shards": 10000, "price_oere": 50000, "price_label": "500 DKK"},
|
||||
}
|
||||
|
||||
STORE_PACKAGES = {
|
||||
1: 15,
|
||||
5: 65,
|
||||
10: 120,
|
||||
25: 260,
|
||||
}
|
||||
|
||||
SPECIFIC_CARD_COST = 1000
|
||||
|
||||
|
||||
class ShatterRequest(BaseModel):
|
||||
card_ids: list[str]
|
||||
|
||||
class StripeCheckoutRequest(BaseModel):
|
||||
package_id: str
|
||||
|
||||
class StoreBuyRequest(BaseModel):
|
||||
quantity: int
|
||||
|
||||
class BuySpecificCardRequest(BaseModel):
|
||||
wiki_title: str
|
||||
|
||||
|
||||
@router.post("/shards/shatter")
|
||||
def shatter_cards(req: ShatterRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if not req.card_ids:
|
||||
raise HTTPException(status_code=400, detail="No cards selected")
|
||||
try:
|
||||
parsed_ids = [uuid.UUID(cid) for cid in req.card_ids]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid card IDs")
|
||||
|
||||
cards = db.query(CardModel).filter(
|
||||
CardModel.id.in_(parsed_ids),
|
||||
CardModel.user_id == user.id,
|
||||
).all()
|
||||
|
||||
if len(cards) != len(parsed_ids):
|
||||
raise HTTPException(status_code=400, detail="Some cards are not in your collection")
|
||||
|
||||
total = sum(c.cost for c in cards)
|
||||
|
||||
for card in cards:
|
||||
db.delete(card)
|
||||
|
||||
user.shards += total
|
||||
db.commit()
|
||||
return {"shards": user.shards, "gained": total}
|
||||
|
||||
|
||||
@router.post("/store/stripe/checkout")
|
||||
def create_stripe_checkout(req: StripeCheckoutRequest, user: UserModel = Depends(get_current_user)):
|
||||
package = SHARD_PACKAGES.get(req.package_id)
|
||||
if not package:
|
||||
raise HTTPException(status_code=400, detail="Invalid package")
|
||||
session = stripe.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
line_items=[{
|
||||
"price_data": {
|
||||
"currency": "dkk",
|
||||
"product_data": {"name": f"WikiTCG Shards — {package['price_label']}"},
|
||||
"unit_amount": package["price_oere"],
|
||||
},
|
||||
"quantity": 1,
|
||||
}],
|
||||
mode="payment",
|
||||
success_url=f"{FRONTEND_URL}/store?payment=success",
|
||||
cancel_url=f"{FRONTEND_URL}/store",
|
||||
metadata={"user_id": str(user.id), "shards": str(package["shards"])},
|
||||
)
|
||||
return {"url": session.url}
|
||||
|
||||
|
||||
@router.post("/stripe/webhook")
|
||||
async def stripe_webhook(request: Request, db: Session = Depends(get_db)):
|
||||
payload = await request.body()
|
||||
sig = request.headers.get("stripe-signature", "")
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(payload, sig, STRIPE_WEBHOOK_SECRET)
|
||||
except stripe.error.SignatureVerificationError: # type: ignore
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Guard against duplicate delivery: Stripe retries on timeout/5xx, so the same
|
||||
# event can arrive more than once. The PK constraint on stripe_event_id is the
|
||||
# arbiter — if the INSERT fails, we've already processed this event.
|
||||
try:
|
||||
db.add(ProcessedWebhookEvent(stripe_event_id=event["id"]))
|
||||
db.flush()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
return {"ok": True}
|
||||
|
||||
if event["type"] == "checkout.session.completed":
|
||||
data = event["data"]["object"]
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
shards = data.get("metadata", {}).get("shards")
|
||||
if user_id and shards:
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if user:
|
||||
user.shards += int(shards)
|
||||
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/store/config")
|
||||
def store_config():
|
||||
return {
|
||||
"publishable_key": STRIPE_PUBLISHABLE_KEY,
|
||||
"shard_packages": SHARD_PACKAGES,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/store/buy-specific-card")
|
||||
@limiter.limit("10/hour")
|
||||
async def buy_specific_card(request: Request, req: BuySpecificCardRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
if user.shards < SPECIFIC_CARD_COST:
|
||||
raise HTTPException(status_code=400, detail="Not enough shards")
|
||||
|
||||
card = await _get_specific_card_async(req.wiki_title)
|
||||
if card is None:
|
||||
raise HTTPException(status_code=404, detail="Could not generate a card for that Wikipedia page")
|
||||
|
||||
db_card = CardModel(
|
||||
name=card.name,
|
||||
image_link=card.image_link,
|
||||
card_rarity=card.card_rarity.name,
|
||||
card_type=card.card_type.name,
|
||||
text=card.text,
|
||||
attack=card.attack,
|
||||
defense=card.defense,
|
||||
cost=card.cost,
|
||||
user_id=user.id,
|
||||
received_at=datetime.now(),
|
||||
)
|
||||
db.add(db_card)
|
||||
user.shards -= SPECIFIC_CARD_COST
|
||||
db.commit()
|
||||
db.refresh(db_card)
|
||||
|
||||
return {
|
||||
**{c.name: getattr(db_card, c.name) for c in db_card.__table__.columns},
|
||||
"card_rarity": db_card.card_rarity,
|
||||
"card_type": db_card.card_type,
|
||||
"shards": user.shards,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/store/buy")
|
||||
def store_buy(req: StoreBuyRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
cost = STORE_PACKAGES.get(req.quantity)
|
||||
if cost is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid package")
|
||||
if user.shards < cost:
|
||||
raise HTTPException(status_code=400, detail="Not enough shards")
|
||||
user.shards -= cost
|
||||
user.boosters += req.quantity
|
||||
db.commit()
|
||||
return {"shards": user.shards, "boosters": user.boosters}
|
||||
411
backend/routers/trades.py
Normal file
411
backend/routers/trades.py
Normal file
@@ -0,0 +1,411 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services import notification_manager
|
||||
from core.auth import decode_access_token
|
||||
from core.database import get_db
|
||||
from core.dependencies import get_current_user, get_user_id_from_request, limiter
|
||||
from core.models import Card as CardModel
|
||||
from core.models import Notification as NotificationModel
|
||||
from core.models import TradeProposal as TradeProposalModel
|
||||
from core.models import User as UserModel
|
||||
from routers.notifications import _serialize_notification
|
||||
from services.trade_manager import (
|
||||
TradeQueueEntry, active_trades, handle_trade_action,
|
||||
handle_trade_disconnect, serialize_trade, trade_queue, trade_queue_lock, try_trade_match,
|
||||
)
|
||||
from services.trade_manager import transfer_cards
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _fetch_cards_for_ids(id_strings: list, db: Session) -> list:
|
||||
"""Fetch CardModel rows for a JSONB list of UUID strings, preserving nothing if list is empty."""
|
||||
if not id_strings:
|
||||
return []
|
||||
uuids = [uuid.UUID(cid) for cid in id_strings]
|
||||
return db.query(CardModel).filter(CardModel.id.in_(uuids)).all()
|
||||
|
||||
|
||||
def _serialize_proposal(p: TradeProposalModel, current_user_id: uuid.UUID, card_map: dict) -> dict:
|
||||
offered_cards = [card_map[cid] for cid in p.offered_card_ids if cid in card_map]
|
||||
requested_cards = [card_map[cid] for cid in p.requested_card_ids if cid in card_map]
|
||||
def card_summary(c: CardModel) -> dict:
|
||||
return {
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"card_rarity": c.card_rarity,
|
||||
"card_type": c.card_type,
|
||||
"image_link": c.image_link,
|
||||
"cost": c.cost,
|
||||
"text": c.text,
|
||||
"attack": c.attack,
|
||||
"defense": c.defense,
|
||||
"generated_at": c.generated_at.isoformat() if c.generated_at else None,
|
||||
}
|
||||
return {
|
||||
"id": str(p.id),
|
||||
"status": p.status,
|
||||
"direction": "outgoing" if p.proposer_id == current_user_id else "incoming",
|
||||
"proposer_username": p.proposer.username,
|
||||
"recipient_username": p.recipient.username,
|
||||
"offered_cards": [card_summary(c) for c in offered_cards],
|
||||
"requested_cards": [card_summary(c) for c in requested_cards],
|
||||
"created_at": p.created_at.isoformat(),
|
||||
"expires_at": p.expires_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ── WebSocket trade matchmaking ───────────────────────────────────────────────
|
||||
|
||||
@router.websocket("/ws/trade/queue")
|
||||
async def trade_queue_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if not user.email_verified:
|
||||
await websocket.send_json({"type": "error", "message": "You must verify your email before trading."})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
entry = TradeQueueEntry(user_id=user_id, username=user.username, websocket=websocket)
|
||||
|
||||
async with trade_queue_lock:
|
||||
trade_queue.append(entry)
|
||||
|
||||
await websocket.send_json({"type": "queued"})
|
||||
await try_trade_match()
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
async with trade_queue_lock:
|
||||
trade_queue[:] = [e for e in trade_queue if e.user_id != user_id]
|
||||
|
||||
|
||||
@router.websocket("/ws/trade/{trade_id}")
|
||||
async def trade_endpoint(websocket: WebSocket, trade_id: str, db: Session = Depends(get_db)):
|
||||
await websocket.accept()
|
||||
|
||||
token = await websocket.receive_text()
|
||||
user_id = decode_access_token(token)
|
||||
if not user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
session = active_trades.get(trade_id)
|
||||
if not session or user_id not in session.offers:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
session.connections[user_id] = websocket
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "state",
|
||||
"state": serialize_trade(session, user_id),
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
await handle_trade_action(trade_id, user_id, data, db)
|
||||
except WebSocketDisconnect:
|
||||
session.connections.pop(user_id, None)
|
||||
import asyncio
|
||||
asyncio.create_task(handle_trade_disconnect(trade_id, user_id))
|
||||
|
||||
|
||||
# ── Trade proposals ───────────────────────────────────────────────────────────
|
||||
|
||||
class CreateTradeProposalRequest(BaseModel):
|
||||
recipient_username: str
|
||||
offered_card_ids: list[str]
|
||||
requested_card_ids: list[str]
|
||||
|
||||
|
||||
@router.post("/trade-proposals")
|
||||
@limiter.limit("10/minute", key_func=get_user_id_from_request)
|
||||
async def create_trade_proposal(
|
||||
request: Request,
|
||||
req: CreateTradeProposalRequest,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# Parse UUIDs early so we give a clear error if malformed
|
||||
try:
|
||||
offered_uuids = [uuid.UUID(cid) for cid in req.offered_card_ids]
|
||||
requested_uuids = [uuid.UUID(cid) for cid in req.requested_card_ids]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid card IDs")
|
||||
|
||||
recipient = db.query(UserModel).filter(UserModel.username == req.recipient_username).first()
|
||||
if not recipient:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if recipient.id == user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot propose a trade with yourself")
|
||||
if not offered_uuids and not requested_uuids:
|
||||
raise HTTPException(status_code=400, detail="At least one side must include cards")
|
||||
|
||||
# Verify proposer owns all offered cards
|
||||
if offered_uuids:
|
||||
owned_count = db.query(CardModel).filter(
|
||||
CardModel.id.in_(offered_uuids),
|
||||
CardModel.user_id == user.id,
|
||||
).count()
|
||||
if owned_count != len(offered_uuids):
|
||||
raise HTTPException(status_code=400, detail="Some offered cards are not in your collection")
|
||||
|
||||
# Verify all requested cards belong to recipient and are marked WTT
|
||||
if requested_uuids:
|
||||
wtt_count = db.query(CardModel).filter(
|
||||
CardModel.id.in_(requested_uuids),
|
||||
CardModel.user_id == recipient.id,
|
||||
CardModel.willing_to_trade == True,
|
||||
).count()
|
||||
if wtt_count != len(requested_uuids):
|
||||
raise HTTPException(status_code=400, detail="Some requested cards are not available for trade")
|
||||
|
||||
# One pending proposal per direction between two users prevents spam
|
||||
duplicate = db.query(TradeProposalModel).filter(
|
||||
TradeProposalModel.proposer_id == user.id,
|
||||
TradeProposalModel.recipient_id == recipient.id,
|
||||
TradeProposalModel.status == "pending",
|
||||
).first()
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="You already have a pending proposal with this user")
|
||||
|
||||
now = datetime.now()
|
||||
proposal = TradeProposalModel(
|
||||
proposer_id=user.id,
|
||||
recipient_id=recipient.id,
|
||||
offered_card_ids=[str(cid) for cid in offered_uuids],
|
||||
requested_card_ids=[str(cid) for cid in requested_uuids],
|
||||
expires_at=now + timedelta(hours=72),
|
||||
)
|
||||
db.add(proposal)
|
||||
db.flush() # get proposal.id before notification
|
||||
|
||||
notif = NotificationModel(
|
||||
user_id=recipient.id,
|
||||
type="trade_offer",
|
||||
payload={
|
||||
"proposal_id": str(proposal.id),
|
||||
"from_username": user.username,
|
||||
"offered_count": len(offered_uuids),
|
||||
"requested_count": len(requested_uuids),
|
||||
},
|
||||
expires_at=proposal.expires_at,
|
||||
)
|
||||
db.add(notif)
|
||||
db.commit()
|
||||
await notification_manager.send_notification(str(recipient.id), _serialize_notification(notif))
|
||||
return {"proposal_id": str(proposal.id)}
|
||||
|
||||
|
||||
@router.get("/trade-proposals/{proposal_id}")
|
||||
def get_trade_proposal(
|
||||
proposal_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
pid = uuid.UUID(proposal_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid proposal ID")
|
||||
proposal = db.query(TradeProposalModel).filter(TradeProposalModel.id == pid).first()
|
||||
if not proposal:
|
||||
raise HTTPException(status_code=404, detail="Proposal not found")
|
||||
if proposal.proposer_id != user.id and proposal.recipient_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
# Lazy-expire before returning so the UI always sees accurate status
|
||||
if proposal.status == "pending" and datetime.now() > proposal.expires_at:
|
||||
proposal.status = "expired"
|
||||
db.commit()
|
||||
all_ids = set(proposal.offered_card_ids + proposal.requested_card_ids)
|
||||
card_map = {str(c.id): c for c in _fetch_cards_for_ids(list(all_ids), db)}
|
||||
return _serialize_proposal(proposal, user.id, card_map)
|
||||
|
||||
|
||||
@router.post("/trade-proposals/{proposal_id}/accept")
|
||||
async def accept_trade_proposal(
|
||||
proposal_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
proposal = db.query(TradeProposalModel).filter(TradeProposalModel.id == uuid.UUID(proposal_id)).with_for_update().first()
|
||||
if not proposal:
|
||||
raise HTTPException(status_code=404, detail="Proposal not found")
|
||||
if proposal.recipient_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Only the recipient can accept a proposal")
|
||||
if proposal.status != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"Proposal is already {proposal.status}")
|
||||
|
||||
now = datetime.now()
|
||||
if now > proposal.expires_at:
|
||||
proposal.status = "expired"
|
||||
db.commit()
|
||||
raise HTTPException(status_code=400, detail="This trade proposal has expired")
|
||||
|
||||
offered_uuids = [uuid.UUID(cid) for cid in proposal.offered_card_ids]
|
||||
requested_uuids = [uuid.UUID(cid) for cid in proposal.requested_card_ids]
|
||||
|
||||
# Re-verify proposer still owns all offered cards at accept time
|
||||
if offered_uuids:
|
||||
owned_count = db.query(CardModel).filter(
|
||||
CardModel.id.in_(offered_uuids),
|
||||
CardModel.user_id == proposal.proposer_id,
|
||||
).count()
|
||||
if owned_count != len(offered_uuids):
|
||||
proposal.status = "expired"
|
||||
db.commit()
|
||||
raise HTTPException(status_code=400, detail="The proposer no longer owns all offered cards")
|
||||
|
||||
# Re-verify all requested cards still belong to recipient and are still WTT
|
||||
if requested_uuids:
|
||||
wtt_count = db.query(CardModel).filter(
|
||||
CardModel.id.in_(requested_uuids),
|
||||
CardModel.user_id == user.id,
|
||||
CardModel.willing_to_trade == True,
|
||||
).count()
|
||||
if wtt_count != len(requested_uuids):
|
||||
raise HTTPException(status_code=400, detail="Some requested cards are no longer available for trade")
|
||||
|
||||
# Execute both sides of the transfer atomically
|
||||
transfer_cards(proposal.proposer_id, user.id, offered_uuids, db, now)
|
||||
transfer_cards(user.id, proposal.proposer_id, requested_uuids, db, now)
|
||||
|
||||
proposal.status = "accepted"
|
||||
|
||||
# Clean up the trade_offer notification from the recipient's bell
|
||||
deleted_notif = db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == proposal.recipient_id,
|
||||
NotificationModel.type == "trade_offer",
|
||||
NotificationModel.payload["proposal_id"].astext == proposal_id,
|
||||
).first()
|
||||
deleted_notif_id = str(deleted_notif.id) if deleted_notif else None
|
||||
if deleted_notif:
|
||||
db.delete(deleted_notif)
|
||||
|
||||
# Notify the proposer that their offer was accepted
|
||||
response_notif = NotificationModel(
|
||||
user_id=proposal.proposer_id,
|
||||
type="trade_response",
|
||||
payload={
|
||||
"proposal_id": proposal_id,
|
||||
"status": "accepted",
|
||||
"from_username": user.username,
|
||||
},
|
||||
)
|
||||
db.add(response_notif)
|
||||
|
||||
# Withdraw any other pending proposals that involve cards that just changed hands.
|
||||
# Both sides are now non-tradeable: offered cards left the proposer, requested cards left the recipient.
|
||||
transferred_strs = {str(c) for c in offered_uuids + requested_uuids}
|
||||
if transferred_strs:
|
||||
for p in db.query(TradeProposalModel).filter(
|
||||
TradeProposalModel.status == "pending",
|
||||
TradeProposalModel.id != proposal.id,
|
||||
(
|
||||
(TradeProposalModel.proposer_id == proposal.proposer_id) |
|
||||
(TradeProposalModel.proposer_id == proposal.recipient_id) |
|
||||
(TradeProposalModel.recipient_id == proposal.proposer_id) |
|
||||
(TradeProposalModel.recipient_id == proposal.recipient_id)
|
||||
),
|
||||
).all():
|
||||
if set(p.offered_card_ids) & transferred_strs or set(p.requested_card_ids) & transferred_strs:
|
||||
p.status = "withdrawn"
|
||||
|
||||
db.commit()
|
||||
|
||||
if deleted_notif_id:
|
||||
await notification_manager.send_delete(str(proposal.recipient_id), deleted_notif_id)
|
||||
await notification_manager.send_notification(str(proposal.proposer_id), _serialize_notification(response_notif))
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/trade-proposals/{proposal_id}/decline")
|
||||
async def decline_trade_proposal(
|
||||
proposal_id: str,
|
||||
user: UserModel = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
proposal = db.query(TradeProposalModel).filter(TradeProposalModel.id == uuid.UUID(proposal_id)).first()
|
||||
if not proposal:
|
||||
raise HTTPException(status_code=404, detail="Proposal not found")
|
||||
if proposal.proposer_id != user.id and proposal.recipient_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
if proposal.status != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"Proposal is already {proposal.status}")
|
||||
|
||||
is_withdrawal = proposal.proposer_id == user.id
|
||||
proposal.status = "withdrawn" if is_withdrawal else "declined"
|
||||
|
||||
# Clean up the trade_offer notification from the recipient's bell
|
||||
deleted_notif = db.query(NotificationModel).filter(
|
||||
NotificationModel.user_id == proposal.recipient_id,
|
||||
NotificationModel.type == "trade_offer",
|
||||
NotificationModel.payload["proposal_id"].astext == proposal_id,
|
||||
).first()
|
||||
deleted_notif_id = str(deleted_notif.id) if deleted_notif else None
|
||||
if deleted_notif:
|
||||
db.delete(deleted_notif)
|
||||
|
||||
# Notify the proposer if the recipient declined (not a withdrawal)
|
||||
response_notif = None
|
||||
if not is_withdrawal:
|
||||
response_notif = NotificationModel(
|
||||
user_id=proposal.proposer_id,
|
||||
type="trade_response",
|
||||
payload={
|
||||
"proposal_id": proposal_id,
|
||||
"status": "declined",
|
||||
"from_username": user.username,
|
||||
},
|
||||
)
|
||||
db.add(response_notif)
|
||||
|
||||
db.commit()
|
||||
|
||||
if deleted_notif_id:
|
||||
await notification_manager.send_delete(str(proposal.recipient_id), deleted_notif_id)
|
||||
if response_notif:
|
||||
await notification_manager.send_notification(str(proposal.proposer_id), _serialize_notification(response_notif))
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/trade-proposals")
|
||||
def get_trade_proposals(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
# Lazy-expire any pending proposals that have passed their deadline
|
||||
now = datetime.now()
|
||||
db.query(TradeProposalModel).filter(
|
||||
TradeProposalModel.status == "pending",
|
||||
TradeProposalModel.expires_at < now,
|
||||
(TradeProposalModel.proposer_id == user.id) | (TradeProposalModel.recipient_id == user.id),
|
||||
).update({"status": "expired"})
|
||||
db.commit()
|
||||
|
||||
proposals = db.query(TradeProposalModel).filter(
|
||||
(TradeProposalModel.proposer_id == user.id) | (TradeProposalModel.recipient_id == user.id)
|
||||
).order_by(TradeProposalModel.created_at.desc()).all()
|
||||
|
||||
# Batch-fetch all cards referenced across all proposals in one query
|
||||
all_ids = {cid for p in proposals for cid in p.offered_card_ids + p.requested_card_ids}
|
||||
card_map = {str(c.id): c for c in _fetch_cards_for_ids(list(all_ids), db)}
|
||||
|
||||
return [_serialize_proposal(p, user.id, card_map) for p in proposals]
|
||||
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
161
backend/services/database_functions.py
Normal file
161
backend/services/database_functions.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete, insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from game.card import _get_cards_async
|
||||
from core.models import Card as CardModel
|
||||
from core.models import GameChallenge as GameChallengeModel
|
||||
from core.models import Notification as NotificationModel
|
||||
from core.models import TradeProposal as TradeProposalModel
|
||||
from core.models import User as UserModel
|
||||
from core.database import SessionLocal
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
## Card pool management
|
||||
|
||||
POOL_MINIMUM = 1000
|
||||
POOL_TARGET = 2000
|
||||
POOL_BATCH_SIZE = 10
|
||||
POOL_SLEEP = 4.0
|
||||
# After this many consecutive empty batches, stop trying and wait for the cooldown.
|
||||
POOL_MAX_CONSECUTIVE_EMPTY = 5
|
||||
POOL_CIRCUIT_BREAKER_COOLDOWN = 600.0 # seconds
|
||||
|
||||
pool_filling = False
|
||||
# asyncio monotonic timestamp; 0 means breaker is closed (no cooldown active)
|
||||
_cb_open_until: float = 0.0
|
||||
|
||||
async def fill_card_pool():
|
||||
global pool_filling, _cb_open_until
|
||||
|
||||
if pool_filling:
|
||||
logger.info("Pool fill already in progress, skipping")
|
||||
return
|
||||
|
||||
loop_time = asyncio.get_event_loop().time()
|
||||
if loop_time < _cb_open_until:
|
||||
remaining = int(_cb_open_until - loop_time)
|
||||
logger.warning(f"Card generation circuit breaker open, skipping fill ({remaining}s remaining)")
|
||||
return
|
||||
|
||||
pool_filling = True
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
unassigned = db.query(CardModel).filter(CardModel.user_id == None, CardModel.ai_used == False).count()
|
||||
logger.info(f"Card pool has {unassigned} unassigned cards")
|
||||
if unassigned >= POOL_MINIMUM:
|
||||
logger.info("Pool sufficiently stocked, skipping fill")
|
||||
return
|
||||
|
||||
needed = POOL_TARGET - unassigned
|
||||
logger.info(f"Filling pool with {needed} cards")
|
||||
|
||||
fetched = 0
|
||||
consecutive_empty = 0
|
||||
while fetched < needed:
|
||||
batch_size = min(POOL_BATCH_SIZE, needed - fetched)
|
||||
cards = await _get_cards_async(batch_size)
|
||||
|
||||
if not cards:
|
||||
consecutive_empty += 1
|
||||
logger.warning(
|
||||
f"Card generation batch returned 0 cards "
|
||||
f"({consecutive_empty}/{POOL_MAX_CONSECUTIVE_EMPTY} consecutive empty batches)"
|
||||
)
|
||||
if consecutive_empty >= POOL_MAX_CONSECUTIVE_EMPTY:
|
||||
_cb_open_until = asyncio.get_event_loop().time() + POOL_CIRCUIT_BREAKER_COOLDOWN
|
||||
logger.error(
|
||||
f"ALERT: Card generation circuit breaker tripped — {consecutive_empty} consecutive "
|
||||
f"empty batches. Wikipedia/Wikirank API may be down. "
|
||||
f"Next retry in {int(POOL_CIRCUIT_BREAKER_COOLDOWN)}s."
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(POOL_SLEEP)
|
||||
continue
|
||||
|
||||
consecutive_empty = 0
|
||||
db.execute(insert(CardModel).values([
|
||||
dict(
|
||||
name=card.name,
|
||||
image_link=card.image_link,
|
||||
card_rarity=card.card_rarity.name,
|
||||
card_type=card.card_type.name,
|
||||
text=card.text,
|
||||
attack=card.attack,
|
||||
defense=card.defense,
|
||||
cost=card.cost,
|
||||
user_id=None,
|
||||
)
|
||||
for card in cards
|
||||
]))
|
||||
db.commit()
|
||||
fetched += len(cards)
|
||||
logger.info(f"Pool fill progress: {fetched}/{needed}")
|
||||
await asyncio.sleep(POOL_SLEEP)
|
||||
|
||||
finally:
|
||||
pool_filling = False
|
||||
db.close()
|
||||
|
||||
## Booster management
|
||||
|
||||
BOOSTER_MAX = 5
|
||||
BOOSTER_COOLDOWN_HOURS = 5
|
||||
|
||||
def check_boosters(user: UserModel, db: Session) -> tuple[int, datetime|None]:
|
||||
if user.boosters_countdown is None:
|
||||
if user.boosters < BOOSTER_MAX:
|
||||
user.boosters = BOOSTER_MAX
|
||||
db.commit()
|
||||
return (user.boosters, user.boosters_countdown)
|
||||
|
||||
now = datetime.now()
|
||||
countdown = user.boosters_countdown
|
||||
|
||||
while user.boosters < BOOSTER_MAX:
|
||||
next_tick = countdown + timedelta(hours=BOOSTER_COOLDOWN_HOURS)
|
||||
if now >= next_tick:
|
||||
user.boosters += 1
|
||||
countdown = next_tick
|
||||
else:
|
||||
break
|
||||
|
||||
user.boosters_countdown = countdown if user.boosters < BOOSTER_MAX else None
|
||||
db.commit()
|
||||
return (user.boosters, user.boosters_countdown)
|
||||
|
||||
## Periodic cleanup
|
||||
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
async def run_cleanup_loop():
|
||||
# Brief startup delay so the DB is fully ready before first run
|
||||
await asyncio.sleep(60)
|
||||
while True:
|
||||
try:
|
||||
_delete_expired_records()
|
||||
except Exception:
|
||||
logger.exception("Periodic cleanup job failed")
|
||||
await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def _delete_expired_records():
|
||||
now = datetime.now()
|
||||
with SessionLocal() as db:
|
||||
for model in (NotificationModel, TradeProposalModel, GameChallengeModel):
|
||||
# Notification.expires_at is nullable — skip rows without an expiry.
|
||||
# TradeProposal and GameChallenge always have expires_at, but the
|
||||
# guard is harmless and makes the intent explicit.
|
||||
result = db.execute(
|
||||
delete(model).where(
|
||||
model.expires_at != None, # noqa: E711
|
||||
model.expires_at < now,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
logger.info("Cleanup: deleted %d expired %s rows", result.rowcount, model.__tablename__)
|
||||
@@ -1,6 +1,8 @@
|
||||
import resend
|
||||
import os
|
||||
from config import RESEND_API_KEY, EMAIL_FROM, FRONTEND_URL
|
||||
|
||||
import resend
|
||||
|
||||
from core.config import RESEND_API_KEY, EMAIL_FROM, FRONTEND_URL
|
||||
|
||||
def send_verification_email(to_email: str, username: str, token: str):
|
||||
resend.api_key = RESEND_API_KEY
|
||||
41
backend/services/notification_manager.py
Normal file
41
backend/services/notification_manager.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Manages persistent per-user WebSocket connections for the notification channel.
|
||||
The DB is the source of truth — this layer just delivers live pushes to connected clients.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
# user_id (str) -> active WebSocket; replaced on reconnect
|
||||
connections: dict[str, WebSocket] = {}
|
||||
|
||||
|
||||
def register(user_id: str, ws: WebSocket) -> None:
|
||||
connections[user_id] = ws
|
||||
|
||||
|
||||
def unregister(user_id: str) -> None:
|
||||
connections.pop(user_id, None)
|
||||
|
||||
|
||||
async def send_notification(user_id: str, notification: dict) -> None:
|
||||
"""Push a single notification to the user if they're connected. No-op otherwise."""
|
||||
ws = connections.get(user_id)
|
||||
if ws:
|
||||
try:
|
||||
await ws.send_json({"type": "push", "notification": notification})
|
||||
except Exception as e:
|
||||
# Stale connection — the disconnect handler will clean it up
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
|
||||
async def send_delete(user_id: str, notification_id: str) -> None:
|
||||
"""Tell the client to remove a notification from its local list."""
|
||||
ws = connections.get(user_id)
|
||||
if ws:
|
||||
try:
|
||||
await ws.send_json({"type": "delete", "notification_id": notification_id})
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
@@ -1,14 +1,54 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Card as CardModel, DeckCard as DeckCardModel
|
||||
from core.models import Card as CardModel, DeckCard as DeckCardModel
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
## Card transfer
|
||||
|
||||
def transfer_cards(
|
||||
from_user_id: uuid.UUID,
|
||||
to_user_id: uuid.UUID,
|
||||
card_ids: list[uuid.UUID],
|
||||
db: Session,
|
||||
now: datetime,
|
||||
) -> None:
|
||||
"""
|
||||
Reassigns card ownership, stamps received_at, removes deck memberships, and clears WTT.
|
||||
Does NOT commit — caller owns the transaction.
|
||||
Clearing WTT on transfer prevents a card from auto-appearing as tradeable on the new owner's
|
||||
profile without them explicitly opting in.
|
||||
"""
|
||||
if not card_ids:
|
||||
return
|
||||
|
||||
matched_cards = db.query(CardModel).filter(
|
||||
CardModel.id.in_(card_ids),
|
||||
CardModel.user_id == from_user_id,
|
||||
).all()
|
||||
|
||||
# Bail out if any card is missing or no longer owned by the sender — a partial
|
||||
# transfer would silently give the receiver fewer cards than agreed upon.
|
||||
if len(matched_cards) != len(card_ids):
|
||||
raise ValueError(
|
||||
f"Expected {len(card_ids)} cards owned by {from_user_id}, "
|
||||
f"found {len(matched_cards)}"
|
||||
)
|
||||
|
||||
for card in matched_cards:
|
||||
card.user_id = to_user_id
|
||||
card.received_at = now
|
||||
card.willing_to_trade = False
|
||||
db.query(DeckCardModel).filter(DeckCardModel.card_id == card.id).delete(synchronize_session=False)
|
||||
|
||||
|
||||
## Storage
|
||||
|
||||
@dataclass
|
||||
@@ -47,7 +87,10 @@ def serialize_card_model(card: CardModel) -> dict:
|
||||
"defense": card.defense,
|
||||
"cost": card.cost,
|
||||
"text": card.text,
|
||||
"created_at": card.created_at.isoformat() if card.created_at else None,
|
||||
"generated_at": card.generated_at.isoformat() if card.generated_at else None,
|
||||
"received_at": card.received_at.isoformat() if card.received_at else None,
|
||||
"is_favorite": card.is_favorite,
|
||||
"willing_to_trade": card.willing_to_trade,
|
||||
}
|
||||
|
||||
def serialize_trade(session: TradeSession, perspective_user_id: str) -> dict:
|
||||
@@ -76,8 +119,8 @@ async def broadcast_trade(session: TradeSession) -> None:
|
||||
"type": "state",
|
||||
"state": serialize_trade(session, user_id),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
## Matchmaking
|
||||
|
||||
@@ -108,8 +151,8 @@ async def try_trade_match() -> None:
|
||||
for entry in [p1, p2]:
|
||||
try:
|
||||
await entry.websocket.send_json({"type": "trade_start", "trade_id": trade_id})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
## Action handling
|
||||
|
||||
@@ -230,28 +273,17 @@ async def _complete_trade(trade_id: str, db: Session) -> None:
|
||||
"type": "error",
|
||||
"message": "Trade failed: ownership check failed. Offers have been reset.",
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
for offer in session.offers.values():
|
||||
offer.accepted = False
|
||||
await broadcast_trade(session)
|
||||
return
|
||||
|
||||
# Transfer ownership and clear deck relationships
|
||||
for cid_str in [c["id"] for c in cards_u1]:
|
||||
cid = uuid.UUID(cid_str)
|
||||
card = db.query(CardModel).filter(CardModel.id == cid).first()
|
||||
if card:
|
||||
card.user_id = uuid.UUID(u2)
|
||||
db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete()
|
||||
|
||||
for cid_str in [c["id"] for c in cards_u2]:
|
||||
cid = uuid.UUID(cid_str)
|
||||
card = db.query(CardModel).filter(CardModel.id == cid).first()
|
||||
if card:
|
||||
card.user_id = uuid.UUID(u1)
|
||||
db.query(DeckCardModel).filter(DeckCardModel.card_id == cid).delete()
|
||||
|
||||
now = datetime.now()
|
||||
transfer_cards(uuid.UUID(u1), uuid.UUID(u2), [uuid.UUID(c["id"]) for c in cards_u1], db, now)
|
||||
transfer_cards(uuid.UUID(u2), uuid.UUID(u1), [uuid.UUID(c["id"]) for c in cards_u2], db, now)
|
||||
db.commit()
|
||||
|
||||
active_trades.pop(trade_id, None)
|
||||
@@ -259,8 +291,8 @@ async def _complete_trade(trade_id: str, db: Session) -> None:
|
||||
for ws in list(session.connections.values()):
|
||||
try:
|
||||
await ws.send_json({"type": "trade_complete"})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
|
||||
## Disconnect handling
|
||||
|
||||
@@ -279,5 +311,5 @@ async def handle_trade_disconnect(trade_id: str, user_id: str) -> None:
|
||||
"type": "error",
|
||||
"message": "Your trade partner disconnected. Trade cancelled.",
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket send failed (stale connection): {e}")
|
||||
@@ -1,13 +1,14 @@
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from game import (
|
||||
from game.rules import (
|
||||
GameState, PlayerState, CardInstance, CombatEvent, GameResult,
|
||||
create_game, resolve_combat, check_win_condition,
|
||||
action_play_card, action_sacrifice, action_end_turn,
|
||||
BOARD_SIZE, HAND_SIZE, STARTING_LIFE, MAX_ENERGY_CAP,
|
||||
)
|
||||
import uuid
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -79,6 +80,8 @@ class TestCreateGame:
|
||||
card_rarity = "common"
|
||||
image_link = ""
|
||||
text = ""
|
||||
is_favorite = False
|
||||
willing_to_trade = False
|
||||
|
||||
cards = [FakeCard() for _ in range(20)]
|
||||
state = create_game("p1", "player 1", "test", cards, "p2", "player 2", "test", cards)
|
||||
@@ -96,6 +99,8 @@ class TestCreateGame:
|
||||
card_rarity = "common"
|
||||
image_link = ""
|
||||
text = ""
|
||||
is_favorite = False
|
||||
willing_to_trade = False
|
||||
|
||||
cards = [FakeCard() for _ in range(20)]
|
||||
state = create_game("p1", "player 1", "test", cards, "p2", "player 2", "test", cards)
|
||||
@@ -113,6 +118,8 @@ class TestCreateGame:
|
||||
card_rarity = "common"
|
||||
image_link = ""
|
||||
text = ""
|
||||
is_favorite = False
|
||||
willing_to_trade = False
|
||||
|
||||
cards = [FakeCard() for _ in range(20)]
|
||||
state = create_game("p1", "player 1", "test", cards, "p2", "player 2", "test", cards)
|
||||
@@ -131,6 +138,8 @@ class TestCreateGame:
|
||||
card_rarity = "common"
|
||||
image_link = ""
|
||||
text = ""
|
||||
is_favorite = False
|
||||
willing_to_trade = False
|
||||
|
||||
cards = [FakeCard() for _ in range(20)]
|
||||
state = create_game("p1", "player 1", "test", cards, "p2", "player 2", "test", cards)
|
||||
|
||||
Reference in New Issue
Block a user