diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml index 6c32bfc..659e28f 100644 --- a/.gitea/workflows/build.yml +++ b/.gitea/workflows/build.yml @@ -17,7 +17,31 @@ on: - main jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: https://github.com/actions/checkout@v4 + + - name: Set up Python + uses: https://github.com/actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install test dependencies + run: pip install -r requirements-test.txt + + - name: Create required directories + run: mkdir -p logs/database storage + + - name: Run tests + env: + API_TOKEN: test_token_12345 + run: pytest tests/ -v + build: + needs: test runs-on: ubuntu-latest steps: diff --git a/app/db_engine.py b/app/db_engine.py index b7849b0..22a063f 100644 --- a/app/db_engine.py +++ b/app/db_engine.py @@ -31,8 +31,9 @@ if DATABASE_TYPE.lower() == "postgresql": ) else: # Default SQLite configuration for local development + _sqlite_path = os.environ.get("SQLITE_DB_PATH", "storage/pd_master.db") db = SqliteDatabase( - "storage/pd_master.db", + _sqlite_path, pragmas={"journal_mode": "wal", "cache_size": -1 * 64000, "synchronous": 0}, ) @@ -925,7 +926,13 @@ CardPosition.add_index(pos_index) if not SKIP_TABLE_CREATION: db.create_tables( - [BattingCard, BattingCardRatings, PitchingCard, PitchingCardRatings, CardPosition], + [ + BattingCard, + BattingCardRatings, + PitchingCard, + PitchingCardRatings, + CardPosition, + ], safe=True, ) diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..7af6bb0 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,3 @@ +-r requirements.txt +pytest +httpx diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9534f7c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,202 @@ +""" +Test configuration for Paper Dynasty Database. + +Sets up an isolated SQLite test database before any app modules are imported. +Uses a temporary file (not :memory:) because db_engine.py closes the connection +at module load time, which would destroy an in-memory database. +""" + +import logging +import os +import tempfile + +import pytest + +# --- Must be set BEFORE any app imports --- +_db_fd, _db_path = tempfile.mkstemp(suffix=".db") +os.close(_db_fd) +os.environ["SQLITE_DB_PATH"] = _db_path +os.environ.setdefault("API_TOKEN", "test_token_12345") + +# Suppress file-based logging during tests +logging.disable(logging.CRITICAL) + +# --- App imports (after env vars are set) --- +from app.db_engine import ( # noqa: E402 + db, + Rarity, + Cardset, + Event, + MlbPlayer, + Player, + Team, + PackType, + Pack, + Card, + Roster, + Current, + BattingCard, + BattingCardRatings, + PitchingCard, + PitchingCardRatings, + CardPosition, + Result, + BattingStat, + PitchingStat, + Award, + Paperdex, + Reward, + GameRewards, + Notification, + GauntletReward, + GauntletRun, + StratGame, + StratPlay, + Decision, +) + +ALL_MODELS = [ + Current, + Rarity, + Event, + Cardset, + MlbPlayer, + Player, + Team, + PackType, + Pack, + Card, + Roster, + Result, + BattingStat, + PitchingStat, + Award, + Paperdex, + Reward, + GameRewards, + Notification, + GauntletReward, + GauntletRun, + BattingCard, + BattingCardRatings, + PitchingCard, + PitchingCardRatings, + CardPosition, + StratGame, + StratPlay, + Decision, +] + + +@pytest.fixture(autouse=True) +def db_transaction(): + """ + Wrap each test in a transaction that rolls back afterwards. + + This ensures tests are isolated — no data bleeds between tests. + The fixture reconnects if the db was closed by the middleware. + """ + db.connect(reuse_if_open=True) + with db.atomic() as txn: + yield + txn.rollback() + if not db.is_closed(): + db.close() + + +@pytest.fixture +def sample_rarity(db_transaction): + """A saved Rarity row for use in tests.""" + return Rarity.create(value=1, name="Common", color="#ffffff") + + +@pytest.fixture +def sample_cardset(db_transaction, sample_event): + """A saved Cardset row for use in tests.""" + return Cardset.create( + name="Test Set 2025", + description="Test cardset", + event=None, + for_purchase=True, + total_cards=100, + in_packs=True, + ranked_legal=True, + ) + + +@pytest.fixture +def sample_event(db_transaction): + """A saved Event row for use in tests.""" + return Event.create(name="Test Event", active=False) + + +@pytest.fixture +def sample_player(db_transaction, sample_rarity, sample_cardset): + """A saved Player row for use in tests.""" + return Player.create( + player_id=99001, + p_name="Test Player", + cost=100, + image="test.png", + mlbclub="TST", + franchise="Test Franchise", + cardset=sample_cardset, + set_num=1, + rarity=sample_rarity, + pos_1="1B", + description="A test player", + ) + + +@pytest.fixture +def sample_pack_type(db_transaction): + """A saved PackType row for use in tests.""" + return PackType.create( + name="Standard Pack", + card_count=5, + description="A standard pack", + cost=500, + available=True, + ) + + +@pytest.fixture +def sample_current(db_transaction): + """A saved Current row for use in tests.""" + return Current.create( + season=1, + week=1, + gsheet_template="template", + gsheet_version="1.0", + live_scoreboard=123456789, + ) + + +@pytest.fixture +def sample_team(db_transaction, sample_current, sample_event): + """A saved Team row for use in tests.""" + return Team.create( + abbrev="TST", + sname="Testers", + lname="The Test Team", + gmid=111222333, + gmname="testgm", + gsheet="https://example.com", + wallet=5000, + team_value=10000, + collection_value=8000, + season=1, + event=None, + ) + + +@pytest.fixture +def sample_pack(db_transaction, sample_team, sample_pack_type): + """A saved Pack row for use in tests.""" + return Pack.create( + team=sample_team, + pack_type=sample_pack_type, + pack_team=None, + pack_cardset=None, + open_time=None, + ) diff --git a/tests/test_api_packs.py b/tests/test_api_packs.py new file mode 100644 index 0000000..1c50b76 --- /dev/null +++ b/tests/test_api_packs.py @@ -0,0 +1,130 @@ +""" +Integration tests for the /api/v2/packs endpoints. + +Uses FastAPI's TestClient (starlette) with an isolated SQLite test database. +The db_transaction fixture in conftest.py rolls back all changes after each test. +""" + +import pytest +from starlette.testclient import TestClient + +from app.main import app + +TEST_TOKEN = "test_token_12345" +AUTH_HEADERS = {"Authorization": f"Bearer {TEST_TOKEN}"} + + +@pytest.fixture +def client(): + """Return a TestClient for the FastAPI app.""" + return TestClient(app, raise_server_exceptions=True) + + +class TestGetPacks: + """Tests for GET /api/v2/packs.""" + + def test_empty_database_returns_404(self, client): + """With no packs in the DB, the endpoint should return 404.""" + resp = client.get("/api/v2/packs") + assert resp.status_code == 404 + + def test_returns_pack_list(self, client, sample_pack): + """With at least one pack, GET /packs returns count and list.""" + resp = client.get("/api/v2/packs") + assert resp.status_code == 200 + data = resp.json() + assert "count" in data + assert data["count"] >= 1 + assert "packs" in data + assert len(data["packs"]) >= 1 + + def test_filter_by_team_id(self, client, sample_pack, sample_team): + """team_id filter should return only packs for that team.""" + resp = client.get(f"/api/v2/packs?team_id={sample_team.id}") + assert resp.status_code == 200 + data = resp.json() + for pack in data["packs"]: + assert pack["team"]["id"] == sample_team.id + + def test_filter_by_invalid_team_returns_404(self, client, sample_pack): + """Filtering by a non-existent team_id should return 404.""" + resp = client.get("/api/v2/packs?team_id=999999") + assert resp.status_code == 404 + + def test_filter_opened_false(self, client, sample_pack): + """opened=false should return only packs with no open_time.""" + resp = client.get("/api/v2/packs?opened=false") + assert resp.status_code == 200 + data = resp.json() + for pack in data["packs"]: + assert pack["open_time"] is None + + +class TestGetOnePack: + """Tests for GET /api/v2/packs/{pack_id}.""" + + def test_get_existing_pack(self, client, sample_pack): + """GET with a valid pack_id should return that pack's data.""" + resp = client.get(f"/api/v2/packs/{sample_pack.id}") + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == sample_pack.id + + def test_get_nonexistent_pack_returns_404(self, client): + """GET with an unknown pack_id should return 404.""" + resp = client.get("/api/v2/packs/999999") + assert resp.status_code == 404 + + +class TestPostPack: + """Tests for POST /api/v2/packs — requires auth token.""" + + def test_post_without_token_returns_401( + self, client, sample_team, sample_pack_type + ): + """POST without Authorization header should return 422 (missing field).""" + payload = { + "packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}] + } + resp = client.post("/api/v2/packs", json=payload) + # FastAPI returns 422 when required OAuth2 token is missing + assert resp.status_code == 422 + + def test_post_with_invalid_token_returns_401( + self, client, sample_team, sample_pack_type + ): + """POST with a wrong token should return 401.""" + payload = { + "packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}] + } + resp = client.post( + "/api/v2/packs", + json=payload, + headers={"Authorization": "Bearer wrong_token"}, + ) + assert resp.status_code == 401 + + def test_post_with_valid_token_creates_packs( + self, client, sample_team, sample_pack_type + ): + """POST with valid token and payload should create the packs (returns 200).""" + payload = { + "packs": [{"team_id": sample_team.id, "pack_type_id": sample_pack_type.id}] + } + resp = client.post("/api/v2/packs", json=payload, headers=AUTH_HEADERS) + # The router raises HTTPException(200) on success + assert resp.status_code == 200 + + +class TestDeletePack: + """Tests for DELETE /api/v2/packs/{pack_id} — requires auth token.""" + + def test_delete_nonexistent_pack_returns_404(self, client): + """Deleting a pack that doesn't exist should return 404.""" + resp = client.delete("/api/v2/packs/999999", headers=AUTH_HEADERS) + assert resp.status_code == 404 + + def test_delete_without_token_returns_401_or_422(self, client, sample_pack): + """DELETE without a token should fail auth.""" + resp = client.delete(f"/api/v2/packs/{sample_pack.id}") + assert resp.status_code in (401, 422) diff --git a/tests/test_card_pricing.py b/tests/test_card_pricing.py new file mode 100644 index 0000000..dde6a54 --- /dev/null +++ b/tests/test_card_pricing.py @@ -0,0 +1,116 @@ +""" +Unit tests for card pricing logic in Player model. + +Player.change_on_sell() and Player.change_on_buy() are critical business logic +used whenever cards are traded. These tests verify the price update math and +the floor/ceiling behaviour. +""" + +import math + +import pytest + + +class TestPlayerChangeOnSell: + """Tests for Player.change_on_sell() — price decreases 5% on each sale.""" + + def test_sell_reduces_cost_by_5_percent(self, sample_player): + """Selling a card should reduce its cost to floor(cost * 0.95).""" + sample_player.cost = 100 + sample_player.change_on_sell() + assert sample_player.cost == math.floor(100 * 0.95) # 95 + + def test_sell_saves_to_db(self, sample_player): + """change_on_sell() should persist the new price to the database.""" + from app.db_engine import Player + + sample_player.cost = 200 + sample_player.change_on_sell() + refreshed = Player.get_by_id(sample_player.pk) + assert refreshed.cost == math.floor(200 * 0.95) # 190 + + def test_sell_floors_at_1(self, sample_player): + """Price should never drop below 1, even at very low starting values.""" + sample_player.cost = 1 + sample_player.change_on_sell() + assert sample_player.cost == 1 + + def test_sell_large_price(self, sample_player): + """Large prices should still apply the 5% reduction correctly.""" + sample_player.cost = 10000 + sample_player.change_on_sell() + assert sample_player.cost == math.floor(10000 * 0.95) # 9500 + + def test_sell_rounds_down(self, sample_player): + """floor() means fractional results are rounded down, not up.""" + sample_player.cost = 21 # 21 * 0.95 = 19.95 → floor → 19 + sample_player.change_on_sell() + assert sample_player.cost == 19 + + +class TestPlayerChangeOnBuy: + """Tests for Player.change_on_buy() — price increases 10% on each purchase.""" + + def test_buy_increases_cost_by_10_percent(self, sample_player): + """Buying a card should increase its cost to ceil(cost * 1.1).""" + sample_player.cost = 100 + sample_player.change_on_buy() + assert sample_player.cost == math.ceil(100 * 1.1) # 110 + + def test_buy_saves_to_db(self, sample_player): + """change_on_buy() should persist the new price to the database.""" + from app.db_engine import Player + + sample_player.cost = 200 + sample_player.change_on_buy() + refreshed = Player.get_by_id(sample_player.pk) + assert refreshed.cost == math.ceil(200 * 1.1) # 220 + + def test_buy_from_low_price(self, sample_player): + """Low prices should still apply the 10% increase.""" + sample_player.cost = 1 + sample_player.change_on_buy() + assert sample_player.cost == math.ceil(1 * 1.1) # 2 (ceil of 1.1) + + def test_buy_rounds_up(self, sample_player): + """ceil() means fractional results are rounded up.""" + sample_player.cost = 9 # 9 * 1.1 = 9.9 → ceil → 10 + sample_player.change_on_buy() + assert sample_player.cost == 10 + + def test_buy_large_price(self, sample_player): + """Large prices should still apply the 10% increase correctly.""" + sample_player.cost = 5000 + sample_player.change_on_buy() + assert sample_player.cost == math.ceil(5000 * 1.1) # 5500 + + +class TestPlayerGetAllPos: + """Tests for Player.get_all_pos() — returns non-null, non-CP position list.""" + + def test_returns_primary_position(self, sample_player): + """A player with only pos_1 set should return a list with one position.""" + sample_player.pos_1 = "1B" + assert sample_player.get_all_pos() == ["1B"] + + def test_excludes_cp_position(self, sample_player): + """CP (closing pitcher) is excluded from the position list.""" + sample_player.pos_1 = "SP" + sample_player.pos_2 = "CP" + positions = sample_player.get_all_pos() + assert "CP" not in positions + assert "SP" in positions + + def test_excludes_null_positions(self, sample_player): + """None positions should not appear in the result.""" + sample_player.pos_1 = "CF" + sample_player.pos_2 = None + assert sample_player.get_all_pos() == ["CF"] + + def test_multiple_positions(self, sample_player): + """Players with multiple eligible positions should return all of them.""" + sample_player.pos_1 = "1B" + sample_player.pos_2 = "OF" + sample_player.pos_3 = "DH" + positions = sample_player.get_all_pos() + assert positions == ["1B", "OF", "DH"] diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py new file mode 100644 index 0000000..a825ac3 --- /dev/null +++ b/tests/test_dependencies.py @@ -0,0 +1,96 @@ +""" +Unit tests for app/dependencies.py utility functions. + +Tests pure functions that have no database or HTTP dependencies. +These run fast and verify foundational logic used across all API routes. +""" + +import os + +import pytest + +from app.dependencies import valid_token, mround, param_char, get_req_url + + +class TestValidToken: + """Tests for valid_token() — verifies API bearer token auth.""" + + def test_matching_token_returns_true(self): + """valid_token should return True when the token matches API_TOKEN env var.""" + expected = os.environ.get("API_TOKEN", "") + assert valid_token(expected) is True + + def test_wrong_token_returns_false(self): + """valid_token should return False for any non-matching string.""" + assert valid_token("wrong_token") is False + + def test_empty_token_returns_false(self): + """valid_token should return False for an empty string.""" + assert valid_token("") is False + + +class TestMround: + """Tests for mround() — rounds a float to the nearest multiple of `base`.""" + + def test_rounds_to_nearest_0_05(self): + """Default base=0.05 should round 0.06 to 0.05.""" + assert mround(0.06) == 0.05 + + def test_rounds_up_to_nearest_0_05(self): + """0.08 should round up to 0.10 with base=0.05.""" + assert mround(0.08) == 0.10 + + def test_exact_multiple_unchanged(self): + """A value that is already a multiple of base should be unchanged.""" + assert mround(0.25) == 0.25 + + def test_custom_base(self): + """Custom base=0.25 should round to nearest quarter.""" + assert mround(0.3, base=0.25) == 0.25 + + def test_custom_precision(self): + """Custom prec=4 should return more decimal places.""" + result = mround(0.123456, prec=4, base=0.01) + assert result == 0.12 + + +class TestParamChar: + """Tests for param_char() — returns ? or & for URL query string building.""" + + def test_returns_question_mark_when_no_other_params(self): + """First parameter in a URL should use ?.""" + assert param_char(False) == "?" + + def test_returns_ampersand_when_other_params_exist(self): + """Subsequent parameters should use &.""" + assert param_char(True) == "&" + + +class TestGetReqUrl: + """Tests for get_req_url() — builds API URLs with optional params.""" + + def test_basic_endpoint(self): + """Endpoint with no object_id or params should produce a clean URL.""" + url = get_req_url("players") + assert url.endswith("/v2/players") + + def test_endpoint_with_object_id(self): + """object_id should be appended to the URL path.""" + url = get_req_url("players", object_id=42) + assert url.endswith("/v2/players/42") + + def test_endpoint_with_params(self): + """Params list should be appended as query string.""" + url = get_req_url("players", params=[("season", "1")]) + assert "?season=1" in url + + def test_endpoint_with_multiple_params(self): + """Multiple params should be joined with &.""" + url = get_req_url("players", params=[("season", "1"), ("limit", "10")]) + assert "?season=1" in url + assert "&limit=10" in url + + def test_api_version_override(self): + """api_ver parameter controls the version segment.""" + url = get_req_url("players", api_ver=1) + assert "/v1/players" in url