Files
wiki-tcg/backend/main.py
2026-03-19 22:34:02 +01:00

565 lines
20 KiB
Python

import asyncio
import logging
import uuid
import re
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import cast, Callable
import secrets
from dotenv import load_dotenv
load_dotenv()
from sqlalchemy.orm import Session
from sqlalchemy import func
from fastapi import FastAPI, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from database import get_db
from database_functions import fill_card_pool, check_boosters, BOOSTER_MAX
from models import Card as CardModel
from models import User as UserModel
from models import Deck as DeckModel
from models import DeckCard as DeckCardModel
from auth import (
hash_password, verify_password, create_access_token, create_refresh_token,
decode_access_token, decode_refresh_token
)
from game_manager import (
queue, queue_lock, QueueEntry, try_match, handle_action, connections, active_games,
serialize_state, handle_disconnect, handle_timeout_claim, load_deck_cards, create_solo_game
)
from card import compute_deck_type, _get_specific_card_async
from email_utils import send_password_reset_email
from config import CORS_ORIGINS
logger = logging.getLogger("app")
# Auth
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
class RegisterRequest(BaseModel):
username: str
email: str
password: str
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserModel:
user_id = decode_access_token(token)
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user
class ForgotPasswordRequest(BaseModel):
email: str
class ResetPasswordWithTokenRequest(BaseModel):
token: str
new_password: str
@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(fill_card_pool())
yield
app = FastAPI(lifespan=lifespan)
# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, cast(Callable, _rate_limit_exceeded_handler))
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS, # SvelteKit's default dev port
allow_methods=["*"],
allow_headers=["*"],
)
def validate_register(username: str, email: str, password: str) -> str | None:
if not username.strip():
return "Username is required"
if len(username) > 16:
return "Username must be 16 characters or fewer"
if not re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", email):
return "Please enter a valid email"
if len(password) < 8:
return "Password must be at least 8 characters"
if len(password) > 256:
return "Password must be 256 characters or fewer"
return None
@app.post("/register")
def register(req: RegisterRequest, db: Session = Depends(get_db)):
err = validate_register(req.username, req.email, req.password)
if err:
raise HTTPException(status_code=400, detail=err)
if db.query(UserModel).filter(UserModel.username == req.username).first():
raise HTTPException(status_code=400, detail="Username already taken")
if db.query(UserModel).filter(UserModel.email == req.email).first():
raise HTTPException(status_code=400, detail="Email already registered")
user = UserModel(
id=uuid.uuid4(),
username=req.username,
email=req.email,
password_hash=hash_password(req.password),
)
db.add(user)
db.commit()
return {"message": "User created"}
@app.post("/login")
def login(form: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = db.query(UserModel).filter(UserModel.username == form.username).first()
if not user or not verify_password(form.password, user.password_hash):
raise HTTPException(status_code=400, detail="Invalid username or password")
return {
"access_token": create_access_token(str(user.id)),
"refresh_token": create_refresh_token(str(user.id)),
"token_type": "bearer",
}
@app.get("/boosters")
def get_boosters(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)) -> tuple[int,datetime|None]:
return check_boosters(user, db)
@app.get("/cards")
def get_cards(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
cards = db.query(CardModel).filter(CardModel.user_id == user.id).all()
return [
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
"card_rarity": card.card_rarity,
"card_type": card.card_type}
for card in cards
]
@app.post("/open_pack")
@limiter.limit("10/minute")
async def open_pack(request: Request, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
check_boosters(user, db)
if user.boosters == 0:
raise HTTPException(status_code=400, detail="No booster packs available")
cards = (
db.query(CardModel)
.filter(CardModel.user_id == None, CardModel.ai_used == False)
.limit(5)
.all()
)
if len(cards) < 5:
asyncio.create_task(fill_card_pool())
raise HTTPException(status_code=503, detail="Card pool is low, please try again shortly")
for card in cards:
card.user_id = user.id
was_full = user.boosters == BOOSTER_MAX
user.boosters -= 1
if was_full:
user.boosters_countdown = datetime.now()
db.commit()
asyncio.create_task(fill_card_pool())
return [
{**{c.name: getattr(card, c.name) for c in card.__table__.columns},
"card_rarity": card.card_rarity,
"card_type": card.card_type}
for card in cards
]
@app.get("/decks")
def get_decks(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
decks = db.query(DeckModel).filter(
DeckModel.user_id == user.id,
DeckModel.deleted == False
).order_by(DeckModel.created_at).all()
result = []
for deck in decks:
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
cards = db.query(CardModel).filter(CardModel.id.in_(card_ids)).all()
result.append({
"id": str(deck.id),
"name": deck.name,
"card_count": len(cards),
"total_cost": sum(card.cost for card in cards),
"times_played": deck.times_played,
"wins": deck.wins,
"losses": deck.losses,
"deck_type": compute_deck_type(cards),
})
return result
@app.post("/decks")
def create_deck(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
count = db.query(DeckModel).filter(DeckModel.user_id == user.id).count()
deck = DeckModel(id=uuid.uuid4(), user_id=user.id, name=f"Deck #{count + 1}")
db.add(deck)
db.commit()
return {"id": str(deck.id), "name": deck.name, "card_count": 0}
@app.patch("/decks/{deck_id}")
def update_deck(deck_id: str, body: dict, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
if not deck:
raise HTTPException(status_code=404, detail="Deck not found")
if "name" in body:
deck.name = body["name"]
if "card_ids" in body:
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
for card_id in body["card_ids"]:
db.add(DeckCardModel(deck_id=deck.id, card_id=uuid.UUID(card_id)))
if deck.times_played > 0:
deck.wins = 0
deck.losses = 0
deck.times_played = 0
db.commit()
return {"id": str(deck.id), "name": deck.name}
@app.delete("/decks/{deck_id}")
def delete_deck(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
if not deck:
raise HTTPException(status_code=404, detail="Deck not found")
if deck.times_played > 0:
deck.deleted = True
else:
db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).delete()
db.delete(deck)
db.commit()
return {"message": "Deleted"}
@app.get("/decks/{deck_id}/cards")
def get_deck_cards(deck_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
deck = db.query(DeckModel).filter(DeckModel.id == uuid.UUID(deck_id), DeckModel.user_id == user.id).first()
if not deck:
raise HTTPException(status_code=404, detail="Deck not found")
deck_cards = db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()
return [str(dc.card_id) for dc in deck_cards]
@app.websocket("/ws/queue")
async def queue_endpoint(websocket: WebSocket, deck_id: str, db: Session = Depends(get_db)):
await websocket.accept()
token = await websocket.receive_text()
user_id = decode_access_token(token)
if not user_id:
await websocket.close(code=1008)
return
deck = db.query(DeckModel).filter(
DeckModel.id == uuid.UUID(deck_id),
DeckModel.user_id == uuid.UUID(user_id)
).first()
if not deck:
await websocket.send_json({"type": "error", "message": "Deck not found"})
await websocket.close(code=1008)
return
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
if total_cost == 0 or total_cost > 50:
await websocket.send_json({"type": "error", "message": "Deck total cost must be between 1 and 50"})
await websocket.close(code=1008)
return
entry = QueueEntry(user_id=user_id, deck_id=deck_id, websocket=websocket)
async with queue_lock:
queue.append(entry)
await websocket.send_json({"type": "queued"})
await try_match(db)
try:
while True:
# Keeping socket alive
await websocket.receive_text()
except WebSocketDisconnect:
async with queue_lock:
queue[:] = [e for e in queue if e.user_id != user_id]
@app.websocket("/ws/game/{game_id}")
async def game_endpoint(websocket: WebSocket, game_id: str, db: Session = Depends(get_db)):
await websocket.accept()
token = await websocket.receive_text()
user_id = decode_access_token(token)
if not user_id:
await websocket.close(code=1008)
return
if game_id not in active_games:
await websocket.close(code=1008)
return
# Register this connection (handles reconnects)
connections[game_id][user_id] = websocket
# Send current state immediately on connect
await websocket.send_json({
"type": "state",
"state": serialize_state(active_games[game_id], user_id),
})
try:
while True:
data = await websocket.receive_json()
await handle_action(game_id, user_id, data, db)
except WebSocketDisconnect:
if game_id in connections:
connections[game_id].pop(user_id, None)
asyncio.create_task(handle_disconnect(game_id, user_id))
@app.get("/profile")
def get_profile(user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
total_games = user.wins + user.losses
most_played_deck = (
db.query(DeckModel)
.filter(DeckModel.user_id == user.id, DeckModel.times_played > 0)
.order_by(DeckModel.times_played.desc())
.first()
)
most_played_card = (
db.query(CardModel)
.filter(CardModel.user_id == user.id, CardModel.times_played > 0)
.order_by(CardModel.times_played.desc())
.first()
)
return {
"username": user.username,
"email": user.email,
"created_at": user.created_at,
"wins": user.wins,
"losses": user.losses,
"win_rate": round((user.wins / total_games) * 100) if total_games > 0 else None,
"most_played_deck": {
"name": most_played_deck.name,
"times_played": most_played_deck.times_played,
} if most_played_deck else None,
"most_played_card": {
"name": most_played_card.name,
"times_played": most_played_card.times_played,
"card_type": most_played_card.card_type,
"card_rarity": most_played_card.card_rarity,
"image_link": most_played_card.image_link,
} if most_played_card else None,
}
@app.post("/cards/{card_id}/report")
def report_card(card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
card = db.query(CardModel).filter(
CardModel.id == uuid.UUID(card_id),
CardModel.user_id == user.id
).first()
if not card:
raise HTTPException(status_code=404, detail="Card not found")
card.reported = True
db.commit()
return {"message": "Card reported"}
@app.post("/cards/{card_id}/refresh")
@limiter.limit("5/hour")
async def refresh_card(request: Request, card_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
card = db.query(CardModel).filter(
CardModel.id == uuid.UUID(card_id),
CardModel.user_id == user.id
).first()
if not card:
raise HTTPException(status_code=404, detail="Card not found")
if user.last_refresh_at and datetime.now() - user.last_refresh_at < timedelta(hours=48):
remaining = (user.last_refresh_at + timedelta(hours=48)) - datetime.now()
hours = int(remaining.total_seconds() // 3600)
minutes = int((remaining.total_seconds() % 3600) // 60)
raise HTTPException(
status_code=429,
detail=f"You can refresh again in {hours}h {minutes}m"
)
new_card = await _get_specific_card_async(card.name)
if not new_card:
raise HTTPException(status_code=502, detail="Failed to regenerate card from Wikipedia")
card.image_link = new_card.image_link
card.card_rarity = new_card.card_rarity.name
card.card_type = new_card.card_type.name
card.text = new_card.text
card.attack = new_card.attack
card.defense = new_card.defense
card.cost = new_card.cost
card.reported = False
user.last_refresh_at = datetime.now()
db.commit()
return {
**{c.name: getattr(card, c.name) for c in card.__table__.columns},
"card_rarity": card.card_rarity,
"card_type": card.card_type,
}
@app.get("/profile/refresh-status")
def refresh_status(user: UserModel = Depends(get_current_user)):
if not user.last_refresh_at:
return {"can_refresh": True, "next_refresh_at": None}
next_refresh = user.last_refresh_at + timedelta(hours=48)
can_refresh = datetime.now() >= next_refresh
return {
"can_refresh": can_refresh,
"next_refresh_at": next_refresh.isoformat() if not can_refresh else None,
}
@app.post("/game/{game_id}/claim-timeout-win")
async def claim_timeout_win(game_id: str, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
err = await handle_timeout_claim(game_id, str(user.id), db)
if err:
raise HTTPException(status_code=400, detail=err)
return {"message": "Win claimed"}
@app.post("/game/solo")
async def start_solo_game(deck_id: str, difficulty: int = 5, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
if difficulty < 1 or difficulty > 10:
raise HTTPException(status_code=400, detail="Difficulty must be between 1 and 10")
deck = db.query(DeckModel).filter(
DeckModel.id == uuid.UUID(deck_id),
DeckModel.user_id == user.id
).first()
if not deck:
raise HTTPException(status_code=404, detail="Deck not found")
card_ids = [dc.card_id for dc in db.query(DeckCardModel).filter(DeckCardModel.deck_id == deck.id).all()]
total_cost = db.query(func.sum(CardModel.cost)).filter(CardModel.id.in_(card_ids)).scalar() or 0
if total_cost == 0 or total_cost > 50:
raise HTTPException(status_code=400, detail="Deck total cost must be between 1 and 50")
player_cards = load_deck_cards(deck_id, str(user.id), db)
if player_cards is None:
raise HTTPException(status_code=503, detail="Couldn't load deck")
ai_cards = db.query(CardModel).filter(
CardModel.user_id == None,
).order_by(func.random()).limit(500).all()
if len(ai_cards) == 0:
raise HTTPException(status_code=503, detail="Not enough cards in pool for AI deck")
for card in ai_cards:
card.ai_used = True
db.commit()
game_id = create_solo_game(str(user.id), user.username, player_cards, ai_cards, deck_id, difficulty)
asyncio.create_task(fill_card_pool())
return {"game_id": game_id}
class ResetPasswordRequest(BaseModel):
current_password: str
new_password: str
@app.post("/auth/reset-password")
def reset_password(req: ResetPasswordRequest, user: UserModel = Depends(get_current_user), db: Session = Depends(get_db)):
if not verify_password(req.current_password, user.password_hash):
raise HTTPException(status_code=400, detail="Current password is incorrect")
if len(req.new_password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
if len(req.new_password) > 256:
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
if req.current_password == req.new_password:
raise HTTPException(status_code=400, detail="New password must be different from current password")
user.password_hash = hash_password(req.new_password)
db.commit()
return {"message": "Password updated"}
@app.post("/auth/forgot-password")
def forgot_password(req: ForgotPasswordRequest, db: Session = Depends(get_db)):
user = db.query(UserModel).filter(UserModel.email == req.email).first()
# Always return success even if email not found. Prevents user enumeration
if user:
token = secrets.token_urlsafe(32)
user.reset_token = token
user.reset_token_expires_at = datetime.now() + timedelta(hours=1)
db.commit()
try:
send_password_reset_email(user.email, user.username, token)
except Exception as e:
logger.error(f"Failed to send reset email: {e}")
return {"message": "If that email is registered you will receive a reset link shortly"}
@app.post("/auth/reset-password-with-token")
def reset_password_with_token(req: ResetPasswordWithTokenRequest, db: Session = Depends(get_db)):
user = db.query(UserModel).filter(UserModel.reset_token == req.token).first()
if not user or not user.reset_token_expires_at or user.reset_token_expires_at < datetime.now():
raise HTTPException(status_code=400, detail="Invalid or expired reset link")
if len(req.new_password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
if len(req.new_password) > 256:
raise HTTPException(status_code=400, detail="Password must be 256 characters or fewer")
user.password_hash = hash_password(req.new_password)
user.reset_token = None
user.reset_token_expires_at = None
db.commit()
return {"message": "Password updated"}
class RefreshRequest(BaseModel):
refresh_token: str
@app.post("/auth/refresh")
def refresh(req: RefreshRequest, db: Session = Depends(get_db)):
user_id = decode_refresh_token(req.refresh_token)
if not user_id:
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
user = db.query(UserModel).filter(UserModel.id == uuid.UUID(user_id)).first()
if not user:
raise HTTPException(status_code=401, detail="User not found")
return {
"access_token": create_access_token(str(user.id)),
"refresh_token": create_refresh_token(str(user.id)),
"token_type": "bearer",
}
if __name__ == "__main__":
from ai import AIPersonality, choose_cards
from card import generate_cards, Card
from time import sleep
all_cards: list[Card] = []
for i in range(30):
print(i)
all_cards += generate_cards(10)
sleep(5)
all_cards.sort(key=lambda x: x.cost, reverse=True)
print(len(all_cards))
def write_cards(cards: list[Card], file: str):
with open(file, "w") as fp:
fp.write('\n'.join([
f"{c.name} - {c.attack}/{c.defense} - {c.cost}"
for c in cards
]))
write_cards(all_cards, "output/all.txt")
for personality in AIPersonality:
print(personality.value)
for difficulty in range(1,11):
chosen_cards = choose_cards(all_cards, difficulty, personality)
chosen_cards.sort(key=lambda x: x.cost, reverse=True)
write_cards(chosen_cards, f"output/{personality.value}-{difficulty}.txt")