feat: add initial test suite with pytest (#28)

- Add SQLITE_DB_PATH env var to db_engine.py for test isolation
- Create tests/conftest.py with in-memory SQLite fixture and sample data helpers
- Add tests/test_dependencies.py: unit tests for valid_token, mround, param_char, get_req_url
- Add tests/test_card_pricing.py: tests for Player.change_on_sell/buy and get_all_pos
- Add tests/test_api_packs.py: integration tests for GET/POST/DELETE /api/v2/packs
- Add requirements-test.txt with pytest and httpx
- Add test job to CI workflow (build now requires tests to pass first)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-04 16:35:53 -06:00
parent 761c0a6dab
commit 64be5eabdc
8 changed files with 580 additions and 2 deletions

View File

@ -17,7 +17,31 @@ on:
- main - main
jobs: jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: https://github.com/actions/checkout@v4
- name: Set up Python
uses: https://github.com/actions/setup-python@v5
with:
python-version: "3.12"
- name: Install test dependencies
run: pip install -r requirements-test.txt
- name: Create required directories
run: mkdir -p logs/database storage
- name: Run tests
env:
API_TOKEN: test_token_12345
run: pytest tests/ -v
build: build:
needs: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View File

@ -31,8 +31,9 @@ if DATABASE_TYPE.lower() == "postgresql":
) )
else: else:
# Default SQLite configuration for local development # Default SQLite configuration for local development
_sqlite_path = os.environ.get("SQLITE_DB_PATH", "storage/pd_master.db")
db = SqliteDatabase( db = SqliteDatabase(
"storage/pd_master.db", _sqlite_path,
pragmas={"journal_mode": "wal", "cache_size": -1 * 64000, "synchronous": 0}, pragmas={"journal_mode": "wal", "cache_size": -1 * 64000, "synchronous": 0},
) )
@ -925,7 +926,13 @@ CardPosition.add_index(pos_index)
if not SKIP_TABLE_CREATION: if not SKIP_TABLE_CREATION:
db.create_tables( db.create_tables(
[BattingCard, BattingCardRatings, PitchingCard, PitchingCardRatings, CardPosition], [
BattingCard,
BattingCardRatings,
PitchingCard,
PitchingCardRatings,
CardPosition,
],
safe=True, safe=True,
) )

3
requirements-test.txt Normal file
View File

@ -0,0 +1,3 @@
-r requirements.txt
pytest
httpx

0
tests/__init__.py Normal file
View File

202
tests/conftest.py Normal file
View File

@ -0,0 +1,202 @@
"""
Test configuration for Paper Dynasty Database.
Sets up an isolated SQLite test database before any app modules are imported.
Uses a temporary file (not :memory:) because db_engine.py closes the connection
at module load time, which would destroy an in-memory database.
"""
import logging
import os
import tempfile
import pytest
# --- Must be set BEFORE any app imports ---
_db_fd, _db_path = tempfile.mkstemp(suffix=".db")
os.close(_db_fd)
os.environ["SQLITE_DB_PATH"] = _db_path
os.environ.setdefault("API_TOKEN", "test_token_12345")
# Suppress file-based logging during tests
logging.disable(logging.CRITICAL)
# --- App imports (after env vars are set) ---
from app.db_engine import ( # noqa: E402
db,
Rarity,
Cardset,
Event,
MlbPlayer,
Player,
Team,
PackType,
Pack,
Card,
Roster,
Current,
BattingCard,
BattingCardRatings,
PitchingCard,
PitchingCardRatings,
CardPosition,
Result,
BattingStat,
PitchingStat,
Award,
Paperdex,
Reward,
GameRewards,
Notification,
GauntletReward,
GauntletRun,
StratGame,
StratPlay,
Decision,
)
ALL_MODELS = [
Current,
Rarity,
Event,
Cardset,
MlbPlayer,
Player,
Team,
PackType,
Pack,
Card,
Roster,
Result,
BattingStat,
PitchingStat,
Award,
Paperdex,
Reward,
GameRewards,
Notification,
GauntletReward,
GauntletRun,
BattingCard,
BattingCardRatings,
PitchingCard,
PitchingCardRatings,
CardPosition,
StratGame,
StratPlay,
Decision,
]
@pytest.fixture(autouse=True)
def db_transaction():
"""
Wrap each test in a transaction that rolls back afterwards.
This ensures tests are isolated no data bleeds between tests.
The fixture reconnects if the db was closed by the middleware.
"""
db.connect(reuse_if_open=True)
with db.atomic() as txn:
yield
txn.rollback()
if not db.is_closed():
db.close()
@pytest.fixture
def sample_rarity(db_transaction):
"""A saved Rarity row for use in tests."""
return Rarity.create(value=1, name="Common", color="#ffffff")
@pytest.fixture
def sample_cardset(db_transaction, sample_event):
"""A saved Cardset row for use in tests."""
return Cardset.create(
name="Test Set 2025",
description="Test cardset",
event=None,
for_purchase=True,
total_cards=100,
in_packs=True,
ranked_legal=True,
)
@pytest.fixture
def sample_event(db_transaction):
"""A saved Event row for use in tests."""
return Event.create(name="Test Event", active=False)
@pytest.fixture
def sample_player(db_transaction, sample_rarity, sample_cardset):
"""A saved Player row for use in tests."""
return Player.create(
player_id=99001,
p_name="Test Player",
cost=100,
image="test.png",
mlbclub="TST",
franchise="Test Franchise",
cardset=sample_cardset,
set_num=1,
rarity=sample_rarity,
pos_1="1B",
description="A test player",
)
@pytest.fixture
def sample_pack_type(db_transaction):
"""A saved PackType row for use in tests."""
return PackType.create(
name="Standard Pack",
card_count=5,
description="A standard pack",
cost=500,
available=True,
)
@pytest.fixture
def sample_current(db_transaction):
"""A saved Current row for use in tests."""
return Current.create(
season=1,
week=1,
gsheet_template="template",
gsheet_version="1.0",
live_scoreboard=123456789,
)
@pytest.fixture
def sample_team(db_transaction, sample_current, sample_event):
"""A saved Team row for use in tests."""
return Team.create(
abbrev="TST",
sname="Testers",
lname="The Test Team",
gmid=111222333,
gmname="testgm",
gsheet="https://example.com",
wallet=5000,
team_value=10000,
collection_value=8000,
season=1,
event=None,
)
@pytest.fixture
def sample_pack(db_transaction, sample_team, sample_pack_type):
"""A saved Pack row for use in tests."""
return Pack.create(
team=sample_team,
pack_type=sample_pack_type,
pack_team=None,
pack_cardset=None,
open_time=None,
)

130
tests/test_api_packs.py Normal file
View File

@ -0,0 +1,130 @@
"""
Integration tests for the /api/v2/packs endpoints.
Uses FastAPI's TestClient (starlette) with an isolated SQLite test database.
The db_transaction fixture in conftest.py rolls back all changes after each test.
"""
import pytest
from starlette.testclient import TestClient
from app.main import app
TEST_TOKEN = "test_token_12345"
AUTH_HEADERS = {"Authorization": f"Bearer {TEST_TOKEN}"}
@pytest.fixture
def client():
"""Return a TestClient for the FastAPI app."""
return TestClient(app, raise_server_exceptions=True)
class TestGetPacks:
"""Tests for GET /api/v2/packs."""
def test_empty_database_returns_404(self, client):
"""With no packs in the DB, the endpoint should return 404."""
resp = client.get("/api/v2/packs")
assert resp.status_code == 404
def test_returns_pack_list(self, client, sample_pack):
"""With at least one pack, GET /packs returns count and list."""
resp = client.get("/api/v2/packs")
assert resp.status_code == 200
data = resp.json()
assert "count" in data
assert data["count"] >= 1
assert "packs" in data
assert len(data["packs"]) >= 1
def test_filter_by_team_id(self, client, sample_pack, sample_team):
"""team_id filter should return only packs for that team."""
resp = client.get(f"/api/v2/packs?team_id={sample_team.id}")
assert resp.status_code == 200
data = resp.json()
for pack in data["packs"]:
assert pack["team"]["id"] == sample_team.id
def test_filter_by_invalid_team_returns_404(self, client, sample_pack):
"""Filtering by a non-existent team_id should return 404."""
resp = client.get("/api/v2/packs?team_id=999999")
assert resp.status_code == 404
def test_filter_opened_false(self, client, sample_pack):
"""opened=false should return only packs with no open_time."""
resp = client.get("/api/v2/packs?opened=false")
assert resp.status_code == 200
data = resp.json()
for pack in data["packs"]:
assert pack["open_time"] is None
class TestGetOnePack:
"""Tests for GET /api/v2/packs/{pack_id}."""
def test_get_existing_pack(self, client, sample_pack):
"""GET with a valid pack_id should return that pack's data."""
resp = client.get(f"/api/v2/packs/{sample_pack.id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == sample_pack.id
def test_get_nonexistent_pack_returns_404(self, client):
"""GET with an unknown pack_id should return 404."""
resp = client.get("/api/v2/packs/999999")
assert resp.status_code == 404
class TestPostPack:
"""Tests for POST /api/v2/packs — requires auth token."""
def test_post_without_token_returns_401(
self, client, sample_team, sample_pack_type
):
"""POST without Authorization header should return 422 (missing field)."""
payload = {
"packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}]
}
resp = client.post("/api/v2/packs", json=payload)
# FastAPI returns 422 when required OAuth2 token is missing
assert resp.status_code == 422
def test_post_with_invalid_token_returns_401(
self, client, sample_team, sample_pack_type
):
"""POST with a wrong token should return 401."""
payload = {
"packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}]
}
resp = client.post(
"/api/v2/packs",
json=payload,
headers={"Authorization": "Bearer wrong_token"},
)
assert resp.status_code == 401
def test_post_with_valid_token_creates_packs(
self, client, sample_team, sample_pack_type
):
"""POST with valid token and payload should create the packs (returns 200)."""
payload = {
"packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}]
}
resp = client.post("/api/v2/packs", json=payload, headers=AUTH_HEADERS)
# The router raises HTTPException(200) on success
assert resp.status_code == 200
class TestDeletePack:
"""Tests for DELETE /api/v2/packs/{pack_id} — requires auth token."""
def test_delete_nonexistent_pack_returns_404(self, client):
"""Deleting a pack that doesn't exist should return 404."""
resp = client.delete("/api/v2/packs/999999", headers=AUTH_HEADERS)
assert resp.status_code == 404
def test_delete_without_token_returns_401_or_422(self, client, sample_pack):
"""DELETE without a token should fail auth."""
resp = client.delete(f"/api/v2/packs/{sample_pack.id}")
assert resp.status_code in (401, 422)

116
tests/test_card_pricing.py Normal file
View File

@ -0,0 +1,116 @@
"""
Unit tests for card pricing logic in Player model.
Player.change_on_sell() and Player.change_on_buy() are critical business logic
used whenever cards are traded. These tests verify the price update math and
the floor/ceiling behaviour.
"""
import math
import pytest
class TestPlayerChangeOnSell:
"""Tests for Player.change_on_sell() — price decreases 5% on each sale."""
def test_sell_reduces_cost_by_5_percent(self, sample_player):
"""Selling a card should reduce its cost to floor(cost * 0.95)."""
sample_player.cost = 100
sample_player.change_on_sell()
assert sample_player.cost == math.floor(100 * 0.95) # 95
def test_sell_saves_to_db(self, sample_player):
"""change_on_sell() should persist the new price to the database."""
from app.db_engine import Player
sample_player.cost = 200
sample_player.change_on_sell()
refreshed = Player.get_by_id(sample_player.pk)
assert refreshed.cost == math.floor(200 * 0.95) # 190
def test_sell_floors_at_1(self, sample_player):
"""Price should never drop below 1, even at very low starting values."""
sample_player.cost = 1
sample_player.change_on_sell()
assert sample_player.cost == 1
def test_sell_large_price(self, sample_player):
"""Large prices should still apply the 5% reduction correctly."""
sample_player.cost = 10000
sample_player.change_on_sell()
assert sample_player.cost == math.floor(10000 * 0.95) # 9500
def test_sell_rounds_down(self, sample_player):
"""floor() means fractional results are rounded down, not up."""
sample_player.cost = 21 # 21 * 0.95 = 19.95 → floor → 19
sample_player.change_on_sell()
assert sample_player.cost == 19
class TestPlayerChangeOnBuy:
"""Tests for Player.change_on_buy() — price increases 10% on each purchase."""
def test_buy_increases_cost_by_10_percent(self, sample_player):
"""Buying a card should increase its cost to ceil(cost * 1.1)."""
sample_player.cost = 100
sample_player.change_on_buy()
assert sample_player.cost == math.ceil(100 * 1.1) # 110
def test_buy_saves_to_db(self, sample_player):
"""change_on_buy() should persist the new price to the database."""
from app.db_engine import Player
sample_player.cost = 200
sample_player.change_on_buy()
refreshed = Player.get_by_id(sample_player.pk)
assert refreshed.cost == math.ceil(200 * 1.1) # 220
def test_buy_from_low_price(self, sample_player):
"""Low prices should still apply the 10% increase."""
sample_player.cost = 1
sample_player.change_on_buy()
assert sample_player.cost == math.ceil(1 * 1.1) # 2 (ceil of 1.1)
def test_buy_rounds_up(self, sample_player):
"""ceil() means fractional results are rounded up."""
sample_player.cost = 9 # 9 * 1.1 = 9.9 → ceil → 10
sample_player.change_on_buy()
assert sample_player.cost == 10
def test_buy_large_price(self, sample_player):
"""Large prices should still apply the 10% increase correctly."""
sample_player.cost = 5000
sample_player.change_on_buy()
assert sample_player.cost == math.ceil(5000 * 1.1) # 5500
class TestPlayerGetAllPos:
"""Tests for Player.get_all_pos() — returns non-null, non-CP position list."""
def test_returns_primary_position(self, sample_player):
"""A player with only pos_1 set should return a list with one position."""
sample_player.pos_1 = "1B"
assert sample_player.get_all_pos() == ["1B"]
def test_excludes_cp_position(self, sample_player):
"""CP (closing pitcher) is excluded from the position list."""
sample_player.pos_1 = "SP"
sample_player.pos_2 = "CP"
positions = sample_player.get_all_pos()
assert "CP" not in positions
assert "SP" in positions
def test_excludes_null_positions(self, sample_player):
"""None positions should not appear in the result."""
sample_player.pos_1 = "CF"
sample_player.pos_2 = None
assert sample_player.get_all_pos() == ["CF"]
def test_multiple_positions(self, sample_player):
"""Players with multiple eligible positions should return all of them."""
sample_player.pos_1 = "1B"
sample_player.pos_2 = "OF"
sample_player.pos_3 = "DH"
positions = sample_player.get_all_pos()
assert positions == ["1B", "OF", "DH"]

View File

@ -0,0 +1,96 @@
"""
Unit tests for app/dependencies.py utility functions.
Tests pure functions that have no database or HTTP dependencies.
These run fast and verify foundational logic used across all API routes.
"""
import os
import pytest
from app.dependencies import valid_token, mround, param_char, get_req_url
class TestValidToken:
"""Tests for valid_token() — verifies API bearer token auth."""
def test_matching_token_returns_true(self):
"""valid_token should return True when the token matches API_TOKEN env var."""
expected = os.environ.get("API_TOKEN", "")
assert valid_token(expected) is True
def test_wrong_token_returns_false(self):
"""valid_token should return False for any non-matching string."""
assert valid_token("wrong_token") is False
def test_empty_token_returns_false(self):
"""valid_token should return False for an empty string."""
assert valid_token("") is False
class TestMround:
"""Tests for mround() — rounds a float to the nearest multiple of `base`."""
def test_rounds_to_nearest_0_05(self):
"""Default base=0.05 should round 0.06 to 0.05."""
assert mround(0.06) == 0.05
def test_rounds_up_to_nearest_0_05(self):
"""0.08 should round up to 0.10 with base=0.05."""
assert mround(0.08) == 0.10
def test_exact_multiple_unchanged(self):
"""A value that is already a multiple of base should be unchanged."""
assert mround(0.25) == 0.25
def test_custom_base(self):
"""Custom base=0.25 should round to nearest quarter."""
assert mround(0.3, base=0.25) == 0.25
def test_custom_precision(self):
"""Custom prec=4 should return more decimal places."""
result = mround(0.123456, prec=4, base=0.01)
assert result == 0.12
class TestParamChar:
"""Tests for param_char() — returns ? or & for URL query string building."""
def test_returns_question_mark_when_no_other_params(self):
"""First parameter in a URL should use ?."""
assert param_char(False) == "?"
def test_returns_ampersand_when_other_params_exist(self):
"""Subsequent parameters should use &."""
assert param_char(True) == "&"
class TestGetReqUrl:
"""Tests for get_req_url() — builds API URLs with optional params."""
def test_basic_endpoint(self):
"""Endpoint with no object_id or params should produce a clean URL."""
url = get_req_url("players")
assert url.endswith("/v2/players")
def test_endpoint_with_object_id(self):
"""object_id should be appended to the URL path."""
url = get_req_url("players", object_id=42)
assert url.endswith("/v2/players/42")
def test_endpoint_with_params(self):
"""Params list should be appended as query string."""
url = get_req_url("players", params=[("season", "1")])
assert "?season=1" in url
def test_endpoint_with_multiple_params(self):
"""Multiple params should be joined with &."""
url = get_req_url("players", params=[("season", "1"), ("limit", "10")])
assert "?season=1" in url
assert "&limit=10" in url
def test_api_version_override(self):
"""api_ver parameter controls the version segment."""
url = get_req_url("players", api_ver=1)
assert "/v1/players" in url