839 lines
30 KiB
Python
839 lines
30 KiB
Python
import asyncio
|
|
import logging
|
|
import uuid
|
|
import re
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timedelta
|
|
from typing import cast, Callable
|
|
import secrets
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func
|
|
from fastapi import FastAPI, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Request
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
|
|
from database import get_db
|
|
from database_functions import fill_card_pool, check_boosters, BOOSTER_MAX
|
|
from models import Card as CardModel
|
|
from models import User as UserModel
|
|
from models import Deck as DeckModel
|
|
from models import DeckCard as DeckCardModel
|
|
from auth import (
|
|
hash_password, verify_password, create_access_token, create_refresh_token,
|
|
decode_access_token, decode_refresh_token
|
|
)
|
|
from game_manager import (
|
|
queue, queue_lock, QueueEntry, try_match, handle_action, connections, active_games,
|
|
serialize_state, handle_disconnect, handle_timeout_claim, load_deck_cards, create_solo_game
|
|
)
|
|
from trade_manager import (
|
|
trade_queue, trade_queue_lock, TradeQueueEntry, try_trade_match,
|
|
handle_trade_action, active_trades, handle_trade_disconnect,
|
|
serialize_trade,
|
|
)
|
|
from card import compute_deck_type, _get_specific_card_async
|
|
from email_utils import send_password_reset_email, send_verification_email
|
|
from config import CORS_ORIGINS, STRIPE_SECRET_KEY, STRIPE_PUBLISHABLE_KEY, STRIPE_WEBHOOK_SECRET, FRONTEND_URL
|
|
import stripe
|
|
stripe.api_key = STRIPE_SECRET_KEY
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
# Auth
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
|
|
|
class RegisterRequest(BaseModel):
|
|
username: str
|
|
email: str
|
|
password: str
|
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserModel:
|
|
user_id = decode_access_token(token)
|
|
if not user_id:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
return user
|
|
|
|
class ForgotPasswordRequest(BaseModel):
|
|
email: str
|
|
|
|
class ResetPasswordWithTokenRequest(BaseModel):
|
|
token: str
|
|
new_password: str
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
asyncio.create_task(fill_card_pool())
|
|
yield
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Rate limiting
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, cast(Callable, _rate_limit_exceeded_handler))
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=CORS_ORIGINS, # SvelteKit's default dev port
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
try:
|
|
from disposable_email_domains import blocklist as _disposable_blocklist
|
|
except ImportError:
|
|
_disposable_blocklist: set[str] = set()
|
|
|
|
def validate_register(username: str, email: str, password: str) -> str | None:
|
|
if not username.strip():
|
|
return "Username is required"
|
|
if len(username) > 16:
|
|
return "Username must be 16 characters or fewer"
|
|
if not re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", email):
|
|
return "Please enter a valid email"
|
|
domain = email.split("@")[-1].lower()
|
|
if domain in _disposable_blocklist:
|
|
return "Disposable email addresses are not allowed"
|
|
if len(password) < 8:
|
|
return "Password must be at least 8 characters"
|
|
if len(password) > 256:
|
|
return "Password must be 256 characters or fewer"
|
|
return None
|
|
|
|
@app.post("/register")
|
|
def register(req: RegisterRequest, db: Session = Depends(get_db)):
|
|
err = validate_register(req.username, req.email, req.password)
|
|
if err:
|
|
raise HTTPException(status_code=400, detail=err)
|
|
if db.query(UserModel).filter(UserModel.username == req.username).first():
|
|
raise HTTPException(status_code=400, detail="Username already taken")
|
|
if db.query(UserModel).filter(UserModel.email == req.email).first():
|
|
raise HTTPException(status_code=400, detail="Email already registered")
|
|
verification_token = secrets.token_urlsafe(32)
|
|
user = UserModel(
|
|
id=uuid.uuid4(),
|
|
username=req.username,
|
|
email=req.email,
|
|
password_hash=hash_password(req.password),
|
|
email_verified=False,
|
|
email_verification_token=verification_token,
|
|
email_verification_token_expires_at=datetime.now() + timedelta(hours=24),
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
try:
|
|
send_verification_email(req.email, req.username, verification_token)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send verification email: {e}")
|
|
return {"message": "Account created. Please check your email to verify your account."}
|
|
|
|
@app.post("/login")
|
|
def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
|
user = db.query(UserModel).filter(UserModel.username == form.username).first()
|
|
if not user or not verify_password(form.password, user.password_hash):
|
|
raise HTTPException(status_code=400, detail="Invalid username or password")
|
|
return {
|
|
"access_token": create_access_token(str(user.id)),
|
|
"refresh_token": create_refresh_token(str(user.id)),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
@app.get("/boosters")
|
|
def get_boosters(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
count, countdown = check_boosters(user, db)
|
|
return {"count": count, "countdown": countdown, "email_verified": user.email_verified}
|
|
|
|
@app.get("/cards")
|
|
def get_cards(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
cards = db.query(CardModel).filter(CardModel.user_id == user.id).all()
|
|
return [
|
|
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
|
"card_rarity": card.card_rarity,
|
|
"card_type": card.card_type}
|
|
for card in cards
|
|
]
|
|
|
|
@app.post("/open_pack")
|
|
@limiter.limit("10/minute")
|
|
async def open_pack(request: Request, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
if not user.email_verified:
|
|
raise HTTPException(status_code=403, detail="You must verify your email before opening packs")
|
|
|
|
check_boosters(user, db)
|
|
|
|
if user.boosters == 0:
|
|
raise HTTPException(status_code=400, detail="No booster packs available")
|
|
|
|
cards = (
|
|
db.query(CardModel)
|
|
.filter(CardModel.user_id == None, CardModel.ai_used == False)
|
|
.limit(5)
|
|
.all()
|
|
)
|
|
|
|
if len(cards) < 5:
|
|
asyncio.create_task(fill_card_pool())
|
|
raise HTTPException(status_code=503, detail="Card pool is low, please try again shortly")
|
|
|
|
for card in cards:
|
|
card.user_id = user.id
|
|
|
|
was_full = user.boosters == BOOSTER_MAX
|
|
user.boosters -= 1
|
|
if was_full:
|
|
user.boosters_countdown = datetime.now()
|
|
|
|
db.commit()
|
|
|
|
asyncio.create_task(fill_card_pool())
|
|
|
|
return [
|
|
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
|
"card_rarity": card.card_rarity,
|
|
"card_type": card.card_type}
|
|
for card in cards
|
|
]
|
|
|
|
@app.get("/decks")
|
|
def get_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
decks = db.query(DeckModel).filter(
|
|
DeckModel.user_id == user.id,
|
|
DeckModel.deleted == False
|
|
).order_by(DeckModel.created_at).all()
|
|
result = []
|
|
for deck in decks:
|
|
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
|
cards = db.query(CardModel).filter(CardModel.id.in_(card_ids)).all()
|
|
result.append({
|
|
"id": str(deck.id),
|
|
"name": deck.name,
|
|
"card_count": len(cards),
|
|
"total_cost": sum(card.cost for card in cards),
|
|
"times_played": deck.times_played,
|
|
"wins": deck.wins,
|
|
"losses": deck.losses,
|
|
"deck_type": compute_deck_type(cards),
|
|
})
|
|
return result
|
|
|
|
@app.post("/decks")
|
|
def create_deck(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
count = db.query(DeckModel).filter(DeckModel.user_id == user.id).count()
|
|
deck = DeckModel(id=uuid.uuid4(), user_id=user.id, name=f"Deck #{count + 1}")
|
|
db.add(deck)
|
|
db.commit()
|
|
return {"id": str(deck.id), "name": deck.name, "card_count": 0}
|
|
|
|
@app.patch("/decks/{deck_id}")
|
|
def update_deck(deck_id: str, body: dict, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
|
if not deck:
|
|
raise HTTPException(status_code=404, detail="Deck not found")
|
|
if "name" in body:
|
|
deck.name = body["name"]
|
|
if "card_ids" in body:
|
|
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
|
|
for card_id in body["card_ids"]:
|
|
db.add(DeckCardModel(deck_id=deck.id, card_id=uuid.UUID(card_id)))
|
|
if deck.times_played > 0:
|
|
deck.wins = 0
|
|
deck.losses = 0
|
|
deck.times_played = 0
|
|
db.commit()
|
|
return {"id": str(deck.id), "name": deck.name}
|
|
|
|
@app.delete("/decks/{deck_id}")
|
|
def delete_deck(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
|
if not deck:
|
|
raise HTTPException(status_code=404, detail="Deck not found")
|
|
if deck.times_played > 0:
|
|
deck.deleted = True
|
|
else:
|
|
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
|
|
db.delete(deck)
|
|
db.commit()
|
|
return {"message": "Deleted"}
|
|
|
|
@app.get("/decks/{deck_id}/cards")
|
|
def get_deck_cards(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
|
|
if not deck:
|
|
raise HTTPException(status_code=404, detail="Deck not found")
|
|
deck_cards = db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()
|
|
return [str(dc.card_id) for dc in deck_cards]
|
|
|
|
@app.websocket("/ws/queue")
|
|
async def queue_endpoint(websocket: WebSocket, deck_id: str, db: Session = Depends(get_db)):
|
|
await websocket.accept()
|
|
|
|
token = await websocket.receive_text()
|
|
user_id = decode_access_token(token)
|
|
if not user_id:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
deck = db.query(DeckModel).filter(
|
|
DeckModel.id == uuid.UUID(deck_id),
|
|
DeckModel.user_id == uuid.UUID(user_id)
|
|
).first()
|
|
|
|
if not deck:
|
|
await websocket.send_json({"type": "error", "message": "Deck not found"})
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
|
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
|
if total_cost == 0 or total_cost > 50:
|
|
await websocket.send_json({"type": "error", "message": "Deck total cost must be between 1 and 50"})
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
entry = QueueEntry(user_id=user_id, deck_id=deck_id, websocket=websocket)
|
|
|
|
async with queue_lock:
|
|
queue.append(entry)
|
|
|
|
await websocket.send_json({"type": "queued"})
|
|
await try_match(db)
|
|
|
|
try:
|
|
while True:
|
|
# Keeping socket alive
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
async with queue_lock:
|
|
queue[:] = [e for e in queue if e.user_id != user_id]
|
|
|
|
|
|
@app.websocket("/ws/game/{game_id}")
|
|
async def game_endpoint(websocket: WebSocket, game_id: str, db: Session = Depends(get_db)):
|
|
await websocket.accept()
|
|
|
|
token = await websocket.receive_text()
|
|
user_id = decode_access_token(token)
|
|
if not user_id:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
if game_id not in active_games:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
# Register this connection (handles reconnects)
|
|
connections[game_id][user_id] = websocket
|
|
|
|
# Send current state immediately on connect
|
|
await websocket.send_json({
|
|
"type": "state",
|
|
"state": serialize_state(active_games[game_id], user_id),
|
|
})
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
await handle_action(game_id, user_id, data, db)
|
|
except WebSocketDisconnect:
|
|
if game_id in connections:
|
|
connections[game_id].pop(user_id, None)
|
|
asyncio.create_task(handle_disconnect(game_id, user_id))
|
|
|
|
@app.websocket("/ws/trade/queue")
|
|
async def trade_queue_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
|
|
await websocket.accept()
|
|
|
|
token = await websocket.receive_text()
|
|
user_id = decode_access_token(token)
|
|
if not user_id:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
|
if not user:
|
|
await websocket.close(code=1008)
|
|
return
|
|
if not user.email_verified:
|
|
await websocket.send_json({"type": "error", "message": "You must verify your email before trading."})
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
entry = TradeQueueEntry(user_id=user_id, username=user.username, websocket=websocket)
|
|
|
|
async with trade_queue_lock:
|
|
trade_queue.append(entry)
|
|
|
|
await websocket.send_json({"type": "queued"})
|
|
await try_trade_match()
|
|
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
async with trade_queue_lock:
|
|
trade_queue[:] = [e for e in trade_queue if e.user_id != user_id]
|
|
|
|
|
|
@app.websocket("/ws/trade/{trade_id}")
|
|
async def trade_endpoint(websocket: WebSocket, trade_id: str, db: Session = Depends(get_db)):
|
|
await websocket.accept()
|
|
|
|
token = await websocket.receive_text()
|
|
user_id = decode_access_token(token)
|
|
if not user_id:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
session = active_trades.get(trade_id)
|
|
if not session or user_id not in session.offers:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
session.connections[user_id] = websocket
|
|
|
|
await websocket.send_json({
|
|
"type": "state",
|
|
"state": serialize_trade(session, user_id),
|
|
})
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
await handle_trade_action(trade_id, user_id, data, db)
|
|
except WebSocketDisconnect:
|
|
session.connections.pop(user_id, None)
|
|
asyncio.create_task(handle_trade_disconnect(trade_id, user_id))
|
|
|
|
|
|
@app.get("/profile")
|
|
def get_profile(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
total_games = user.wins + user.losses
|
|
|
|
most_played_deck = (
|
|
db.query(DeckModel)
|
|
.filter(DeckModel.user_id == user.id, DeckModel.times_played > 0)
|
|
.order_by(DeckModel.times_played.desc())
|
|
.first()
|
|
)
|
|
|
|
most_played_card = (
|
|
db.query(CardModel)
|
|
.filter(CardModel.user_id == user.id, CardModel.times_played > 0)
|
|
.order_by(CardModel.times_played.desc())
|
|
.first()
|
|
)
|
|
|
|
return {
|
|
"username": user.username,
|
|
"email": user.email,
|
|
"email_verified": user.email_verified,
|
|
"created_at": user.created_at,
|
|
"wins": user.wins,
|
|
"losses": user.losses,
|
|
"shards": user.shards,
|
|
"win_rate": round((user.wins / total_games) * 100) if total_games > 0 else None,
|
|
"most_played_deck": {
|
|
"name": most_played_deck.name,
|
|
"times_played": most_played_deck.times_played,
|
|
} if most_played_deck else None,
|
|
"most_played_card": {
|
|
"name": most_played_card.name,
|
|
"times_played": most_played_card.times_played,
|
|
"card_type": most_played_card.card_type,
|
|
"card_rarity": most_played_card.card_rarity,
|
|
"image_link": most_played_card.image_link,
|
|
} if most_played_card else None,
|
|
}
|
|
|
|
class ShatterRequest(BaseModel):
|
|
card_ids: list[str]
|
|
|
|
@app.post("/shards/shatter")
|
|
def shatter_cards(req: ShatterRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
if not req.card_ids:
|
|
raise HTTPException(status_code=400, detail="No cards selected")
|
|
try:
|
|
parsed_ids = [uuid.UUID(cid) for cid in req.card_ids]
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid card IDs")
|
|
|
|
cards = db.query(CardModel).filter(
|
|
CardModel.id.in_(parsed_ids),
|
|
CardModel.user_id == user.id,
|
|
).all()
|
|
|
|
if len(cards) != len(parsed_ids):
|
|
raise HTTPException(status_code=400, detail="Some cards are not in your collection")
|
|
|
|
total = sum(c.cost for c in cards)
|
|
|
|
for card in cards:
|
|
db.query(DeckCardModel).filter(DeckCardModel.card_id == card.id).delete()
|
|
db.delete(card)
|
|
|
|
user.shards += total
|
|
db.commit()
|
|
return {"shards": user.shards, "gained": total}
|
|
|
|
# Shard packages sold for real money.
|
|
# price_oere is in Danish øre (1 DKK = 100 øre). Stripe minimum is 250 øre.
|
|
SHARD_PACKAGES = {
|
|
"s1": {"base": 100, "bonus": 0, "shards": 100, "price_oere": 1000, "price_label": "10 DKK"},
|
|
"s2": {"base": 250, "bonus": 50, "shards": 300, "price_oere": 2500, "price_label": "25 DKK"},
|
|
"s3": {"base": 500, "bonus": 200, "shards": 700, "price_oere": 5000, "price_label": "50 DKK"},
|
|
"s4": {"base": 1000, "bonus": 600, "shards": 1600, "price_oere": 10000, "price_label": "100 DKK"},
|
|
"s5": {"base": 2500, "bonus": 2000, "shards": 4500, "price_oere": 25000, "price_label": "250 DKK"},
|
|
"s6": {"base": 5000, "bonus": 5000, "shards": 10000, "price_oere": 50000, "price_label": "500 DKK"},
|
|
}
|
|
|
|
class StripeCheckoutRequest(BaseModel):
|
|
package_id: str
|
|
|
|
@app.post("/store/stripe/checkout")
|
|
def create_stripe_checkout(req: StripeCheckoutRequest, user: UserModel = Depends(get_current_user)):
|
|
package = SHARD_PACKAGES.get(req.package_id)
|
|
if not package:
|
|
raise HTTPException(status_code=400, detail="Invalid package")
|
|
session = stripe.checkout.Session.create(
|
|
payment_method_types=["card"],
|
|
line_items=[{
|
|
"price_data": {
|
|
"currency": "dkk",
|
|
"product_data": {"name": f"WikiTCG Shards — {package['price_label']}"},
|
|
"unit_amount": package["price_oere"],
|
|
},
|
|
"quantity": 1,
|
|
}],
|
|
mode="payment",
|
|
success_url=f"{FRONTEND_URL}/store?payment=success",
|
|
cancel_url=f"{FRONTEND_URL}/store",
|
|
metadata={"user_id": str(user.id), "shards": str(package["shards"])},
|
|
)
|
|
return {"url": session.url}
|
|
|
|
@app.post("/stripe/webhook")
|
|
async def stripe_webhook(request: Request, db: Session = Depends(get_db)):
|
|
payload = await request.body()
|
|
sig = request.headers.get("stripe-signature", "")
|
|
try:
|
|
event = stripe.Webhook.construct_event(payload, sig, STRIPE_WEBHOOK_SECRET)
|
|
except stripe.error.SignatureVerificationError: # type: ignore
|
|
raise HTTPException(status_code=400, detail="Invalid signature")
|
|
|
|
if event["type"] == "checkout.session.completed":
|
|
data = event["data"]["object"]
|
|
user_id = data.get("metadata", {}).get("user_id")
|
|
shards = data.get("metadata", {}).get("shards")
|
|
if user_id and shards:
|
|
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
|
if user:
|
|
user.shards += int(shards)
|
|
db.commit()
|
|
|
|
return {"ok": True}
|
|
|
|
@app.get("/store/config")
|
|
def store_config():
|
|
return {
|
|
"publishable_key": STRIPE_PUBLISHABLE_KEY,
|
|
"shard_packages": SHARD_PACKAGES,
|
|
}
|
|
|
|
STORE_PACKAGES = {
|
|
1: 15,
|
|
5: 65,
|
|
10: 120,
|
|
25: 260,
|
|
}
|
|
|
|
class StoreBuyRequest(BaseModel):
|
|
quantity: int
|
|
|
|
class BuySpecificCardRequest(BaseModel):
|
|
wiki_title: str
|
|
|
|
SPECIFIC_CARD_COST = 1000
|
|
|
|
@app.post("/store/buy-specific-card")
|
|
@limiter.limit("10/hour")
|
|
async def buy_specific_card(request: Request, req: BuySpecificCardRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
if user.shards < SPECIFIC_CARD_COST:
|
|
raise HTTPException(status_code=400, detail="Not enough shards")
|
|
|
|
card = await _get_specific_card_async(req.wiki_title)
|
|
if card is None:
|
|
raise HTTPException(status_code=404, detail="Could not generate a card for that Wikipedia page")
|
|
|
|
db_card = CardModel(
|
|
name=card.name,
|
|
image_link=card.image_link,
|
|
card_rarity=card.card_rarity.name,
|
|
card_type=card.card_type.name,
|
|
text=card.text,
|
|
attack=card.attack,
|
|
defense=card.defense,
|
|
cost=card.cost,
|
|
user_id=user.id,
|
|
)
|
|
db.add(db_card)
|
|
user.shards -= SPECIFIC_CARD_COST
|
|
db.commit()
|
|
db.refresh(db_card)
|
|
|
|
return {
|
|
**{c.name: getattr(db_card, c.name) for c in db_card.__table__.columns},
|
|
"card_rarity": db_card.card_rarity,
|
|
"card_type": db_card.card_type,
|
|
"shards": user.shards,
|
|
}
|
|
|
|
@app.post("/store/buy")
|
|
def store_buy(req: StoreBuyRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
cost = STORE_PACKAGES.get(req.quantity)
|
|
if cost is None:
|
|
raise HTTPException(status_code=400, detail="Invalid package")
|
|
if user.shards < cost:
|
|
raise HTTPException(status_code=400, detail="Not enough shards")
|
|
user.shards -= cost
|
|
user.boosters += req.quantity
|
|
db.commit()
|
|
return {"shards": user.shards, "boosters": user.boosters}
|
|
|
|
@app.post("/cards/{card_id}/report")
|
|
def report_card(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
card = db.query(CardModel).filter(
|
|
CardModel.id == uuid.UUID(card_id),
|
|
CardModel.user_id == user.id
|
|
).first()
|
|
if not card:
|
|
raise HTTPException(status_code=404, detail="Card not found")
|
|
card.reported = True
|
|
db.commit()
|
|
return {"message": "Card reported"}
|
|
|
|
@app.post("/cards/{card_id}/refresh")
|
|
@limiter.limit("5/hour")
|
|
async def refresh_card(request: Request, card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
card = db.query(CardModel).filter(
|
|
CardModel.id == uuid.UUID(card_id),
|
|
CardModel.user_id == user.id
|
|
).first()
|
|
if not card:
|
|
raise HTTPException(status_code=404, detail="Card not found")
|
|
|
|
if user.last_refresh_at and datetime.now() - user.last_refresh_at < timedelta(hours=2):
|
|
remaining = (user.last_refresh_at + timedelta(hours=2)) - datetime.now()
|
|
hours = int(remaining.total_seconds() // 3600)
|
|
minutes = int((remaining.total_seconds() % 3600) // 60)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"You can refresh again in {hours}h {minutes}m"
|
|
)
|
|
|
|
new_card = await _get_specific_card_async(card.name)
|
|
if not new_card:
|
|
raise HTTPException(status_code=502, detail="Failed to regenerate card from Wikipedia")
|
|
|
|
card.image_link = new_card.image_link
|
|
card.card_rarity = new_card.card_rarity.name
|
|
card.card_type = new_card.card_type.name
|
|
card.text = new_card.text
|
|
card.attack = new_card.attack
|
|
card.defense = new_card.defense
|
|
card.cost = new_card.cost
|
|
card.reported = False
|
|
|
|
user.last_refresh_at = datetime.now()
|
|
db.commit()
|
|
|
|
return {
|
|
**{c.name: getattr(card, c.name) for c in card.__table__.columns},
|
|
"card_rarity": card.card_rarity,
|
|
"card_type": card.card_type,
|
|
}
|
|
|
|
@app.get("/profile/refresh-status")
|
|
def refresh_status(user: UserModel = Depends(get_current_user)):
|
|
if not user.last_refresh_at:
|
|
return {"can_refresh": True, "next_refresh_at": None}
|
|
next_refresh = user.last_refresh_at + timedelta(hours=2)
|
|
can_refresh = datetime.now() >= next_refresh
|
|
return {
|
|
"can_refresh": can_refresh,
|
|
"next_refresh_at": next_refresh.isoformat() if not can_refresh else None,
|
|
}
|
|
|
|
@app.post("/game/{game_id}/claim-timeout-win")
|
|
async def claim_timeout_win(game_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
err = await handle_timeout_claim(game_id, str(user.id), db)
|
|
if err:
|
|
raise HTTPException(status_code=400, detail=err)
|
|
return {"message": "Win claimed"}
|
|
|
|
@app.post("/game/solo")
|
|
async def start_solo_game(deck_id: str, difficulty: int = 5, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
if difficulty < 1 or difficulty > 10:
|
|
raise HTTPException(status_code=400, detail="Difficulty must be between 1 and 10")
|
|
|
|
deck = db.query(DeckModel).filter(
|
|
DeckModel.id == uuid.UUID(deck_id),
|
|
DeckModel.user_id == user.id
|
|
).first()
|
|
if not deck:
|
|
raise HTTPException(status_code=404, detail="Deck not found")
|
|
|
|
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
|
|
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
|
|
if total_cost == 0 or total_cost > 50:
|
|
raise HTTPException(status_code=400, detail="Deck total cost must be between 1 and 50")
|
|
|
|
player_cards = load_deck_cards(deck_id, str(user.id), db)
|
|
if player_cards is None:
|
|
raise HTTPException(status_code=503, detail="Couldn't load deck")
|
|
|
|
ai_cards = db.query(CardModel).filter(
|
|
CardModel.user_id == None,
|
|
).order_by(func.random()).limit(500).all()
|
|
|
|
if len(ai_cards) == 0:
|
|
raise HTTPException(status_code=503, detail="Not enough cards in pool for AI deck")
|
|
|
|
for card in ai_cards:
|
|
card.ai_used = True
|
|
db.commit()
|
|
|
|
game_id = create_solo_game(str(user.id), user.username, player_cards, ai_cards, deck_id, difficulty)
|
|
asyncio.create_task(fill_card_pool())
|
|
|
|
return {"game_id": game_id}
|
|
|
|
class ResetPasswordRequest(BaseModel):
|
|
current_password: str
|
|
new_password: str
|
|
|
|
@app.post("/auth/reset-password")
|
|
def reset_password(req: ResetPasswordRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
|
|
if not verify_password(req.current_password, user.password_hash):
|
|
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
|
if len(req.new_password) < 8:
|
|
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
|
if len(req.new_password) > 256:
|
|
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
|
if req.current_password == req.new_password:
|
|
raise HTTPException(status_code=400, detail="New password must be different from current password")
|
|
user.password_hash = hash_password(req.new_password)
|
|
db.commit()
|
|
return {"message": "Password updated"}
|
|
|
|
@app.post("/auth/forgot-password")
|
|
def forgot_password(req: ForgotPasswordRequest, db: Session = Depends(get_db)):
|
|
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
|
# Always return success even if email not found. Prevents user enumeration
|
|
if user:
|
|
token = secrets.token_urlsafe(32)
|
|
user.reset_token = token
|
|
user.reset_token_expires_at = datetime.now() + timedelta(hours=1)
|
|
db.commit()
|
|
try:
|
|
send_password_reset_email(user.email, user.username, token)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send reset email: {e}")
|
|
return {"message": "If that email is registered you will receive a reset link shortly"}
|
|
|
|
@app.post("/auth/reset-password-with-token")
|
|
def reset_password_with_token(req: ResetPasswordWithTokenRequest, db: Session = Depends(get_db)):
|
|
user = db.query(UserModel).filter(UserModel.reset_token == req.token).first()
|
|
if not user or not user.reset_token_expires_at or user.reset_token_expires_at < datetime.now():
|
|
raise HTTPException(status_code=400, detail="Invalid or expired reset link")
|
|
if len(req.new_password) < 8:
|
|
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
|
if len(req.new_password) > 256:
|
|
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
|
|
user.password_hash = hash_password(req.new_password)
|
|
user.reset_token = None
|
|
user.reset_token_expires_at = None
|
|
db.commit()
|
|
return {"message": "Password updated"}
|
|
|
|
@app.get("/auth/verify-email")
|
|
def verify_email(token: str, db: Session = Depends(get_db)):
|
|
user = db.query(UserModel).filter(UserModel.email_verification_token == token).first()
|
|
if not user or not user.email_verification_token_expires_at or user.email_verification_token_expires_at < datetime.now():
|
|
raise HTTPException(status_code=400, detail="Invalid or expired verification link")
|
|
user.email_verified = True
|
|
user.email_verification_token = None
|
|
user.email_verification_token_expires_at = None
|
|
db.commit()
|
|
return {"message": "Email verified"}
|
|
|
|
class ResendVerificationRequest(BaseModel):
|
|
email: str
|
|
|
|
@app.post("/auth/resend-verification")
|
|
def resend_verification(req: ResendVerificationRequest, db: Session = Depends(get_db)):
|
|
user = db.query(UserModel).filter(UserModel.email == req.email).first()
|
|
# Always return success to prevent user enumeration
|
|
if user and not user.email_verified:
|
|
token = secrets.token_urlsafe(32)
|
|
user.email_verification_token = token
|
|
user.email_verification_token_expires_at = datetime.now() + timedelta(hours=24)
|
|
db.commit()
|
|
try:
|
|
send_verification_email(user.email, user.username, token)
|
|
except Exception as e:
|
|
logger.error(f"Failed to resend verification email: {e}")
|
|
return {"message": "If that email is registered and unverified, you will receive a new verification link shortly"}
|
|
|
|
class RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
@app.post("/auth/refresh")
|
|
def refresh(req: RefreshRequest, db: Session = Depends(get_db)):
|
|
user_id = decode_refresh_token(req.refresh_token)
|
|
if not user_id:
|
|
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
|
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
|
|
if not user:
|
|
raise HTTPException(status_code=401, detail="User not found")
|
|
return {
|
|
"access_token": create_access_token(str(user.id)),
|
|
"refresh_token": create_refresh_token(str(user.id)),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
from ai import AIPersonality, choose_cards
|
|
from card import generate_cards, Card
|
|
from time import sleep
|
|
|
|
all_cards = generate_cards(500)
|
|
|
|
all_cards.sort(key=lambda x: x.cost, reverse=True)
|
|
|
|
print(len(all_cards))
|
|
def write_cards(cards: list[Card], file: str):
|
|
with open(file, "w") as fp:
|
|
fp.write('\n'.join([
|
|
f"{c.name} - {c.attack}/{c.defense} - {c.cost}"
|
|
for c in cards
|
|
]))
|
|
|
|
write_cards(all_cards, "output/all.txt")
|
|
|
|
for personality in AIPersonality:
|
|
print(personality.value)
|
|
for difficulty in range(1,11):
|
|
chosen_cards = choose_cards(all_cards, difficulty, personality)
|
|
chosen_cards.sort(key=lambda x: x.cost, reverse=True)
|
|
write_cards(chosen_cards, f"output/{personality.value}-{difficulty}.txt")
|