Compare commits

..

1 Commits

Author SHA1 Message Date
Cal Corum
d3b9355f26 fix: batch standings updates to eliminate N+1 queries in recalculate (#75)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m54s
Replace per-game update_standings() calls with pre-fetched dicts and
in-memory accumulation, then a single bulk_update at the end.
Reduces ~1,100+ queries for a full season to ~5 queries.

Closes #75

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 03:37:45 -05:00
29 changed files with 796 additions and 1316 deletions

File diff suppressed because it is too large Load Diff

View File

@ -57,9 +57,6 @@ priv_help = (
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
MAX_LIMIT = 500
DEFAULT_LIMIT = 200
def valid_token(token):
return token == os.environ.get("API_TOKEN")
@ -379,14 +376,14 @@ def update_season_pitching_stats(player_ids, season, db_connection):
-- RBI allowed (excluding HR) per runner opportunity
CASE
WHEN (SUM(CASE WHEN sp.on_first_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_id IS NOT NULL THEN 1 ELSE 0 END)) > 0
WHEN (SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)) > 0
THEN ROUND(
(SUM(sp.rbi) - SUM(sp.homerun))::DECIMAL /
(SUM(CASE WHEN sp.on_first_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_id IS NOT NULL THEN 1 ELSE 0 END)),
(SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)),
3
)
ELSE 0.000
@ -807,10 +804,6 @@ def handle_db_errors(func):
return result
except HTTPException:
# Let intentional HTTP errors (401, 404, etc.) pass through unchanged
raise
except Exception as e:
elapsed_time = time.time() - start_time
error_trace = traceback.format_exc()

View File

@ -2,7 +2,6 @@ import datetime
import logging
from logging.handlers import RotatingFileHandler
import os
from urllib.parse import parse_qsl, urlencode
from fastapi import Depends, FastAPI, Request
from fastapi.openapi.docs import get_swagger_ui_html
@ -71,19 +70,6 @@ app = FastAPI(
logger.info(f"Starting up now...")
@app.middleware("http")
async def strip_empty_query_params(request: Request, call_next):
qs = request.scope.get("query_string", b"")
if qs:
pairs = parse_qsl(qs.decode(), keep_blank_values=True)
filtered = [(k, v) for k, v in pairs if v != ""]
new_qs = urlencode(filtered).encode()
request.scope["query_string"] = new_qs
if hasattr(request, "_query_params"):
del request._query_params
return await call_next(request)
app.include_router(current.router)
app.include_router(players.router)
app.include_router(results.router)

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -45,8 +43,6 @@ async def get_awards(
team_id: list = Query(default=None),
short_output: Optional[bool] = False,
player_name: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_awards = Award.select()
@ -71,11 +67,8 @@ async def get_awards(
all_players = Player.select().where(fn.Lower(Player.name) << pname_list)
all_awards = all_awards.where(Award.player << all_players)
total_count = all_awards.count()
all_awards = all_awards.offset(offset).limit(limit)
return_awards = {
"count": total_count,
"count": all_awards.count(),
"awards": [model_to_dict(x, recurse=not short_output) for x in all_awards],
}
db.close()

View File

@ -19,8 +19,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -86,7 +84,7 @@ async def get_batstats(
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = None,
sort: Optional[str] = None,
short_output: Optional[bool] = True,
):
@ -136,7 +134,8 @@ async def get_batstats(
)
all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
all_stats = all_stats.limit(limit)
if limit:
all_stats = all_stats.limit(limit)
if sort:
if sort == "newest":
all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game)
@ -169,8 +168,6 @@ async def get_totalstats(
short_output: Optional[bool] = False,
min_pa: Optional[int] = 1,
week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
raise HTTPException(
@ -304,10 +301,7 @@ async def get_totalstats(
all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(BattingStat.player << all_players)
total_count = all_stats.count()
all_stats = all_stats.offset(offset).limit(limit)
return_stats = {"count": total_count, "stats": []}
return_stats = {"count": all_stats.count(), "stats": []}
for x in all_stats:
# Handle player field based on grouping with safe access

View File

@ -19,8 +19,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -75,7 +73,7 @@ async def get_decisions(
irunners_scored: list = Query(default=None),
game_id: list = Query(default=None),
player_id: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = None,
short_output: Optional[bool] = False,
):
all_dec = Decision.select().order_by(
@ -137,7 +135,10 @@ async def get_decisions(
if irunners_scored is not None:
all_dec = all_dec.where(Decision.irunners_scored << irunners_scored)
all_dec = all_dec.limit(limit)
if limit is not None:
if limit < 1:
limit = 1
all_dec = all_dec.limit(limit)
return_dec = {
"count": all_dec.count(),

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -34,8 +32,6 @@ async def get_divisions(
div_abbrev: Optional[str] = None,
lg_name: Optional[str] = None,
lg_abbrev: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_divisions = Division.select().where(Division.season == season)
@ -48,11 +44,8 @@ async def get_divisions(
if lg_abbrev is not None:
all_divisions = all_divisions.where(Division.league_abbrev == lg_abbrev)
total_count = all_divisions.count()
all_divisions = all_divisions.offset(offset).limit(limit)
return_div = {
"count": total_count,
"count": all_divisions.count(),
"divisions": [model_to_dict(x) for x in all_divisions],
}
db.close()

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -36,8 +34,6 @@ async def get_draftlist(
season: Optional[int],
team_id: list = Query(default=None),
token: str = Depends(oauth2_scheme),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
if not valid_token(token):
logger.warning(f"get_draftlist - Bad Token: {token}")
@ -50,10 +46,7 @@ async def get_draftlist(
if team_id is not None:
all_list = all_list.where(DraftList.team_id << team_id)
total_count = all_list.count()
all_list = all_list.offset(offset).limit(limit)
r_list = {"count": total_count, "picks": [model_to_dict(x) for x in all_list]}
r_list = {"count": all_list.count(), "picks": [model_to_dict(x) for x in all_list]}
db.close()
return r_list

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -52,7 +50,7 @@ async def get_picks(
overall_end: Optional[int] = None,
short_output: Optional[bool] = False,
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = None,
player_id: list = Query(default=None),
player_taken: Optional[bool] = None,
):
@ -112,7 +110,8 @@ async def get_picks(
all_picks = all_picks.where(DraftPick.overall <= overall_end)
if player_taken is not None:
all_picks = all_picks.where(DraftPick.player.is_null(not player_taken))
all_picks = all_picks.limit(limit)
if limit is not None:
all_picks = all_picks.limit(limit)
if sort is not None:
if sort == "order-asc":

View File

@ -3,61 +3,40 @@ from typing import List, Optional, Literal
import logging
import pydantic
from ..db_engine import (
db,
BattingStat,
Team,
Player,
Current,
model_to_dict,
chunked,
fn,
per_season_weeks,
)
from ..dependencies import (
oauth2_scheme,
valid_token,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
from ..db_engine import db, BattingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks
from ..dependencies import oauth2_scheme, valid_token, handle_db_errors
logger = logging.getLogger('discord_app')
router = APIRouter(
prefix='/api/v3/fieldingstats',
tags=['fieldingstats']
)
logger = logging.getLogger("discord_app")
router = APIRouter(prefix="/api/v3/fieldingstats", tags=["fieldingstats"])
@router.get("")
@router.get('')
@handle_db_errors
async def get_fieldingstats(
season: int,
s_type: Optional[str] = "regular",
team_abbrev: list = Query(default=None),
player_name: list = Query(default=None),
player_id: list = Query(default=None),
week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
sort: Optional[str] = None,
short_output: Optional[bool] = True,
):
if "post" in s_type.lower():
season: int, s_type: Optional[str] = 'regular', team_abbrev: list = Query(default=None),
player_name: list = Query(default=None), player_id: list = Query(default=None),
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None),
position: list = Query(default=None), limit: Optional[int] = None, sort: Optional[str] = None,
short_output: Optional[bool] = True):
if 'post' in s_type.lower():
all_stats = BattingStat.post_season(season)
if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []}
elif s_type.lower() in ["combined", "total", "all"]:
return {'count': 0, 'stats': []}
elif s_type.lower() in ['combined', 'total', 'all']:
all_stats = BattingStat.combined_season(season)
if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []}
return {'count': 0, 'stats': []}
else:
all_stats = BattingStat.regular_season(season)
if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []}
return {'count': 0, 'stats': []}
all_stats = all_stats.where(
(BattingStat.xch > 0) | (BattingStat.pb > 0) | (BattingStat.sbc > 0)
@ -72,9 +51,7 @@ async def get_fieldingstats(
if player_id:
all_stats = all_stats.where(BattingStat.player_id << player_id)
else:
p_query = Player.select_season(season).where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
p_query = Player.select_season(season).where(fn.Lower(Player.name) << [x.lower() for x in player_name])
all_stats = all_stats.where(BattingStat.player << p_query)
if game_num:
all_stats = all_stats.where(BattingStat.game == game_num)
@ -89,92 +66,72 @@ async def get_fieldingstats(
db.close()
raise HTTPException(
status_code=404,
detail=f"Start week {start} is after end week {end} - cannot pull stats",
detail=f'Start week {start} is after end week {end} - cannot pull stats'
)
all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
all_stats = all_stats.where(
(BattingStat.week >= start) & (BattingStat.week <= end)
)
total_count = all_stats.count()
all_stats = all_stats.limit(limit)
if limit:
all_stats = all_stats.limit(limit)
if sort:
if sort == "newest":
if sort == 'newest':
all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game)
return_stats = {
"count": total_count,
"stats": [
{
"player": x.player_id
if short_output
else model_to_dict(x.player, recurse=False),
"team": x.team_id
if short_output
else model_to_dict(x.team, recurse=False),
"pos": x.pos,
"xch": x.xch,
"xhit": x.xhit,
"error": x.error,
"pb": x.pb,
"sbc": x.sbc,
"csc": x.csc,
"week": x.week,
"game": x.game,
"season": x.season,
}
for x in all_stats
],
'count': all_stats.count(),
'stats': [{
'player': x.player_id if short_output else model_to_dict(x.player, recurse=False),
'team': x.team_id if short_output else model_to_dict(x.team, recurse=False),
'pos': x.pos,
'xch': x.xch,
'xhit': x.xhit,
'error': x.error,
'pb': x.pb,
'sbc': x.sbc,
'csc': x.csc,
'week': x.week,
'game': x.game,
'season': x.season
} for x in all_stats]
}
db.close()
return return_stats
@router.get("/totals")
@router.get('/totals')
@handle_db_errors
async def get_totalstats(
season: int,
s_type: Literal["regular", "post", "total", None] = None,
team_abbrev: list = Query(default=None),
team_id: list = Query(default=None),
player_name: list = Query(default=None),
week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
sort: Optional[str] = None,
player_id: list = Query(default=None),
group_by: Literal["team", "player", "playerteam"] = "player",
short_output: Optional[bool] = False,
min_ch: Optional[int] = 1,
week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
season: int, s_type: Literal['regular', 'post', 'total', None] = None, team_abbrev: list = Query(default=None),
team_id: list = Query(default=None), player_name: list = Query(default=None),
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None),
position: list = Query(default=None), sort: Optional[str] = None, player_id: list = Query(default=None),
group_by: Literal['team', 'player', 'playerteam'] = 'player', short_output: Optional[bool] = False,
min_ch: Optional[int] = 1, week: list = Query(default=None)):
# Build SELECT fields conditionally based on group_by to match GROUP BY exactly
select_fields = []
if group_by == "player":
if group_by == 'player':
select_fields = [BattingStat.player, BattingStat.pos]
elif group_by == "team":
elif group_by == 'team':
select_fields = [BattingStat.team, BattingStat.pos]
elif group_by == "playerteam":
elif group_by == 'playerteam':
select_fields = [BattingStat.player, BattingStat.team, BattingStat.pos]
else:
# Default case
select_fields = [BattingStat.player, BattingStat.pos]
all_stats = (
BattingStat.select(
*select_fields,
fn.SUM(BattingStat.xch).alias("sum_xch"),
fn.SUM(BattingStat.xhit).alias("sum_xhit"),
fn.SUM(BattingStat.error).alias("sum_error"),
fn.SUM(BattingStat.pb).alias("sum_pb"),
fn.SUM(BattingStat.sbc).alias("sum_sbc"),
fn.SUM(BattingStat.csc).alias("sum_csc"),
)
.where(BattingStat.season == season)
.having(fn.SUM(BattingStat.xch) >= min_ch)
BattingStat
.select(*select_fields,
fn.SUM(BattingStat.xch).alias('sum_xch'),
fn.SUM(BattingStat.xhit).alias('sum_xhit'), fn.SUM(BattingStat.error).alias('sum_error'),
fn.SUM(BattingStat.pb).alias('sum_pb'), fn.SUM(BattingStat.sbc).alias('sum_sbc'),
fn.SUM(BattingStat.csc).alias('sum_csc'))
.where(BattingStat.season == season)
.having(fn.SUM(BattingStat.xch) >= min_ch)
)
if True in [s_type is not None, week_start is not None, week_end is not None]:
@ -184,20 +141,16 @@ async def get_totalstats(
elif week_start is not None or week_end is not None:
if week_start is None or week_end is None:
raise HTTPException(
status_code=400,
detail="Both week_start and week_end must be included if either is used.",
)
weeks["start"] = week_start
if week_end < weeks["start"]:
raise HTTPException(
status_code=400,
detail="week_end must be greater than or equal to week_start",
status_code=400, detail='Both week_start and week_end must be included if either is used.'
)
weeks['start'] = week_start
if week_end < weeks['start']:
raise HTTPException(status_code=400, detail='week_end must be greater than or equal to week_start')
else:
weeks["end"] = week_end
weeks['end'] = week_end
all_stats = all_stats.where(
(BattingStat.week >= weeks["start"]) & (BattingStat.week <= weeks["end"])
(BattingStat.week >= weeks['start']) & (BattingStat.week <= weeks['end'])
)
elif week is not None:
@ -208,20 +161,14 @@ async def get_totalstats(
if position is not None:
p_list = [x.upper() for x in position]
all_players = Player.select().where(
(Player.pos_1 << p_list)
| (Player.pos_2 << p_list)
| (Player.pos_3 << p_list)
| (Player.pos_4 << p_list)
| (Player.pos_5 << p_list)
| (Player.pos_6 << p_list)
| (Player.pos_7 << p_list)
| (Player.pos_8 << p_list)
(Player.pos_1 << p_list) | (Player.pos_2 << p_list) | (Player.pos_3 << p_list) | (Player.pos_4 << p_list) |
(Player.pos_5 << p_list) | (Player.pos_6 << p_list) | (Player.pos_7 << p_list) | (Player.pos_8 << p_list)
)
all_stats = all_stats.where(BattingStat.player << all_players)
if sort is not None:
if sort == "player":
if sort == 'player':
all_stats = all_stats.order_by(BattingStat.player)
elif sort == "team":
elif sort == 'team':
all_stats = all_stats.order_by(BattingStat.team)
if group_by is not None:
# Use the same fields for GROUP BY as we used for SELECT
@ -230,57 +177,47 @@ async def get_totalstats(
all_teams = Team.select().where(Team.id << team_id)
all_stats = all_stats.where(BattingStat.team << all_teams)
elif team_abbrev is not None:
all_teams = Team.select().where(
fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]
)
all_teams = Team.select().where(fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev])
all_stats = all_stats.where(BattingStat.team << all_teams)
if player_name is not None:
all_players = Player.select().where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
all_players = Player.select().where(fn.Lower(Player.name) << [x.lower() for x in player_name])
all_stats = all_stats.where(BattingStat.player << all_players)
elif player_id is not None:
all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(BattingStat.player << all_players)
total_count = all_stats.count()
all_stats = all_stats.offset(offset).limit(limit)
return_stats = {"count": total_count, "stats": []}
return_stats = {
'count': 0,
'stats': []
}
for x in all_stats:
if x.sum_xch + x.sum_sbc <= 0:
continue
# Handle player field based on grouping with safe access
this_player = "TOT"
if "player" in group_by and hasattr(x, "player"):
this_player = (
x.player_id if short_output else model_to_dict(x.player, recurse=False)
)
this_player = 'TOT'
if 'player' in group_by and hasattr(x, 'player'):
this_player = x.player_id if short_output else model_to_dict(x.player, recurse=False)
# Handle team field based on grouping with safe access
this_team = "TOT"
if "team" in group_by and hasattr(x, "team"):
this_team = (
x.team_id if short_output else model_to_dict(x.team, recurse=False)
)
return_stats["stats"].append(
{
"player": this_player,
"team": this_team,
"pos": x.pos,
"xch": x.sum_xch,
"xhit": x.sum_xhit,
"error": x.sum_error,
"pb": x.sum_pb,
"sbc": x.sum_sbc,
"csc": x.sum_csc,
}
)
return_stats["count"] = len(return_stats["stats"])
# Handle team field based on grouping with safe access
this_team = 'TOT'
if 'team' in group_by and hasattr(x, 'team'):
this_team = x.team_id if short_output else model_to_dict(x.team, recurse=False)
return_stats['stats'].append({
'player': this_player,
'team': this_team,
'pos': x.pos,
'xch': x.sum_xch,
'xhit': x.sum_xhit,
'error': x.sum_error,
'pb': x.sum_pb,
'sbc': x.sum_sbc,
'csc': x.sum_csc
})
return_stats['count'] = len(return_stats['stats'])
db.close()
return return_stats

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -40,8 +38,6 @@ async def get_injuries(
is_active: bool = None,
short_output: bool = False,
sort: Optional[str] = "start-asc",
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_injuries = Injury.select()
@ -68,11 +64,8 @@ async def get_injuries(
elif sort == "start-desc":
all_injuries = all_injuries.order_by(-Injury.start_week, -Injury.start_game)
total_count = all_injuries.count()
all_injuries = all_injuries.offset(offset).limit(limit)
return_injuries = {
"count": total_count,
"count": all_injuries.count(),
"injuries": [model_to_dict(x, recurse=not short_output) for x in all_injuries],
}
db.close()

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -36,8 +34,6 @@ async def get_keepers(
team_id: list = Query(default=None),
player_id: list = Query(default=None),
short_output: bool = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_keepers = Keeper.select()
@ -48,11 +44,8 @@ async def get_keepers(
if player_id is not None:
all_keepers = all_keepers.where(Keeper.player_id << player_id)
total_count = all_keepers.count()
all_keepers = all_keepers.offset(offset).limit(limit)
return_keepers = {
"count": total_count,
"count": all_keepers.count(),
"keepers": [model_to_dict(x, recurse=not short_output) for x in all_keepers],
}
db.close()

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -31,8 +29,6 @@ async def get_managers(
name: list = Query(default=None),
active: Optional[bool] = None,
short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
if active is not None:
current = Current.latest()
@ -65,9 +61,7 @@ async def get_managers(
i_mgr.append(z)
final_mgrs = [model_to_dict(y, recurse=not short_output) for y in i_mgr]
total_count = len(final_mgrs)
final_mgrs = final_mgrs[offset : offset + limit]
return_managers = {"count": total_count, "managers": final_mgrs}
return_managers = {"count": len(final_mgrs), "managers": final_mgrs}
else:
all_managers = Manager.select()
@ -75,10 +69,8 @@ async def get_managers(
name_list = [x.lower() for x in name]
all_managers = all_managers.where(fn.Lower(Manager.name) << name_list)
total_count = all_managers.count()
all_managers = all_managers.offset(offset).limit(limit)
return_managers = {
"count": total_count,
"count": all_managers.count(),
"managers": [
model_to_dict(x, recurse=not short_output) for x in all_managers
],

View File

@ -19,8 +19,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -70,7 +68,7 @@ async def get_pitstats(
week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = None,
ip_min: Optional[float] = None,
sort: Optional[str] = None,
short_output: Optional[bool] = True,
@ -123,7 +121,8 @@ async def get_pitstats(
(PitchingStat.week >= start) & (PitchingStat.week <= end)
)
all_stats = all_stats.limit(limit)
if limit:
all_stats = all_stats.limit(limit)
if sort:
if sort == "newest":
all_stats = all_stats.order_by(-PitchingStat.week, -PitchingStat.game)
@ -155,8 +154,6 @@ async def get_totalstats(
short_output: Optional[bool] = False,
group_by: Literal["team", "player", "playerteam"] = "player",
week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
raise HTTPException(
@ -262,10 +259,7 @@ async def get_totalstats(
all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(PitchingStat.player << all_players)
total_count = all_stats.count()
all_stats = all_stats.offset(offset).limit(limit)
return_stats = {"count": total_count, "stats": []}
return_stats = {"count": all_stats.count(), "stats": []}
for x in all_stats:
# Handle player field based on grouping with safe access

View File

@ -6,11 +6,7 @@ Thin HTTP layer using PlayerService for business logic.
from fastapi import APIRouter, Query, Response, Depends
from typing import Optional, List
from ..dependencies import (
oauth2_scheme,
cache_result,
handle_db_errors,
)
from ..dependencies import oauth2_scheme, cache_result, handle_db_errors
from ..services.base import BaseService
from ..services.player_service import PlayerService
@ -28,7 +24,9 @@ async def get_players(
strat_code: list = Query(default=None),
is_injured: Optional[bool] = None,
sort: Optional[str] = None,
limit: Optional[int] = Query(default=None, ge=1),
limit: Optional[int] = Query(
default=None, ge=1, description="Maximum number of results to return"
),
offset: Optional[int] = Query(
default=None, ge=0, description="Number of results to skip for pagination"
),

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -44,8 +42,6 @@ async def get_results(
away_abbrev: list = Query(default=None),
home_abbrev: list = Query(default=None),
short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_results = Result.select_season(season)
@ -78,11 +74,8 @@ async def get_results(
if week_end is not None:
all_results = all_results.where(Result.week <= week_end)
total_count = all_results.count()
all_results = all_results.offset(offset).limit(limit)
return_results = {
"count": total_count,
"count": all_results.count(),
"results": [model_to_dict(x, recurse=not short_output) for x in all_results],
}
db.close()

View File

@ -12,8 +12,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -46,8 +44,6 @@ async def get_players(
key_mlbam: list = Query(default=None),
sort: Optional[str] = None,
csv: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_players = SbaPlayer.select()
@ -105,11 +101,8 @@ async def get_players(
db.close()
return Response(content=return_val, media_type="text/csv")
total_count = all_players.count()
all_players = all_players.offset(offset).limit(limit)
return_val = {
"count": total_count,
"count": all_players.count(),
"players": [model_to_dict(x) for x in all_players],
}
db.close()

View File

@ -9,8 +9,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -40,8 +38,6 @@ async def get_schedules(
week_start: Optional[int] = None,
week_end: Optional[int] = None,
short_output: Optional[bool] = True,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
all_sched = Schedule.select_season(season)
@ -73,11 +69,8 @@ async def get_schedules(
all_sched = all_sched.order_by(Schedule.id)
total_count = all_sched.count()
all_sched = all_sched.offset(offset).limit(limit)
return_sched = {
"count": total_count,
"count": all_sched.count(),
"schedules": [model_to_dict(x, recurse=not short_output) for x in all_sched],
}
db.close()

View File

@ -8,8 +8,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -25,8 +23,6 @@ async def get_standings(
league_abbrev: Optional[str] = None,
division_abbrev: Optional[str] = None,
short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
standings = Standings.select_season(season)
@ -61,11 +57,8 @@ async def get_standings(
div_teams = [x for x in standings]
div_teams.sort(key=lambda team: win_pct(team), reverse=True)
total_count = len(div_teams)
div_teams = div_teams[offset : offset + limit]
return_standings = {
"count": total_count,
"count": len(div_teams),
"standings": [model_to_dict(x, recurse=not short_output) for x in div_teams],
}

View File

@ -13,8 +13,6 @@ from ..dependencies import (
PRIVATE_IN_SCHEMA,
handle_db_errors,
update_season_batting_stats,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -61,8 +59,6 @@ async def get_games(
division_id: Optional[int] = None,
short_output: Optional[bool] = False,
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
) -> Any:
all_games = StratGame.select()
@ -123,11 +119,8 @@ async def get_games(
StratGame.season, StratGame.week, StratGame.game_num
)
total_count = all_games.count()
all_games = all_games.offset(offset).limit(limit)
return_games = {
"count": total_count,
"count": all_games.count(),
"games": [model_to_dict(x, recurse=not short_output) for x in all_games],
}
db.close()

View File

@ -13,13 +13,7 @@ from ...db_engine import (
fn,
model_to_dict,
)
from ...dependencies import (
add_cache_headers,
cache_result,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from ...dependencies import add_cache_headers, cache_result, handle_db_errors
from .common import build_season_games
router = APIRouter()
@ -58,7 +52,7 @@ async def get_batting_totals(
risp: Optional[bool] = None,
inning: list = Query(default=None),
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
short_output: Optional[bool] = False,
page_num: Optional[int] = 1,
week_start: Optional[int] = None,
@ -429,6 +423,8 @@ async def get_batting_totals(
run_plays = run_plays.order_by(StratPlay.game.asc())
# For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY
if limit < 1:
limit = 1
bat_plays = bat_plays.paginate(page_num, limit)
logger.info(f"bat_plays query: {bat_plays}")

View File

@ -13,13 +13,7 @@ from ...db_engine import (
fn,
SQL,
)
from ...dependencies import (
handle_db_errors,
add_cache_headers,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from ...dependencies import handle_db_errors, add_cache_headers, cache_result
from .common import build_season_games
logger = logging.getLogger("discord_app")
@ -57,7 +51,7 @@ async def get_fielding_totals(
team_id: list = Query(default=None),
manager_id: list = Query(default=None),
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
short_output: Optional[bool] = False,
page_num: Optional[int] = 1,
):
@ -243,6 +237,8 @@ async def get_fielding_totals(
def_plays = def_plays.order_by(StratPlay.game.asc())
# For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY
if limit < 1:
limit = 1
def_plays = def_plays.paginate(page_num, limit)
logger.info(f"def_plays query: {def_plays}")

View File

@ -16,13 +16,7 @@ from ...db_engine import (
SQL,
complex_data_to_csv,
)
from ...dependencies import (
handle_db_errors,
add_cache_headers,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from ...dependencies import handle_db_errors, add_cache_headers, cache_result
from .common import build_season_games
router = APIRouter()
@ -57,7 +51,7 @@ async def get_pitching_totals(
risp: Optional[bool] = None,
inning: list = Query(default=None),
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
short_output: Optional[bool] = False,
csv: Optional[bool] = False,
page_num: Optional[int] = 1,
@ -170,6 +164,8 @@ async def get_pitching_totals(
if group_by in ["playergame", "teamgame"]:
pitch_plays = pitch_plays.order_by(StratPlay.game.asc())
if limit < 1:
limit = 1
pitch_plays = pitch_plays.paginate(page_num, limit)
# Execute the Peewee query

View File

@ -16,8 +16,6 @@ from ...dependencies import (
handle_db_errors,
add_cache_headers,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -72,7 +70,7 @@ async def get_plays(
pitcher_team_id: list = Query(default=None),
short_output: Optional[bool] = False,
sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
page_num: Optional[int] = 1,
s_type: Literal["regular", "post", "total", None] = None,
):
@ -187,6 +185,8 @@ async def get_plays(
season_games = season_games.where(StratGame.week > 18)
all_plays = all_plays.where(StratPlay.game << season_games)
if limit < 1:
limit = 1
bat_plays = all_plays.paginate(page_num, limit)
if sort == "wpa-desc":

View File

@ -11,8 +11,6 @@ from ..dependencies import (
PRIVATE_IN_SCHEMA,
handle_db_errors,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from ..services.base import BaseService
from ..services.team_service import TeamService

View File

@ -10,8 +10,6 @@ from ..dependencies import (
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -38,7 +36,7 @@ class TransactionList(pydantic.BaseModel):
@router.get("")
@handle_db_errors
async def get_transactions(
season: int,
season,
team_abbrev: list = Query(default=None),
week_start: Optional[int] = 0,
week_end: Optional[int] = None,
@ -47,9 +45,8 @@ async def get_transactions(
player_name: list = Query(default=None),
player_id: list = Query(default=None),
move_id: Optional[str] = None,
is_trade: Optional[bool] = None,
short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
if season:
transactions = Transaction.select_season(season)
@ -87,15 +84,15 @@ async def get_transactions(
else:
transactions = transactions.where(Transaction.frozen == 0)
if is_trade is not None:
raise HTTPException(
status_code=501, detail="The is_trade parameter is not implemented, yet"
)
transactions = transactions.order_by(-Transaction.week, Transaction.moveid)
total_count = transactions.count()
transactions = transactions.offset(offset).limit(limit)
return_trans = {
"count": total_count,
"limit": limit,
"offset": offset,
"count": transactions.count(),
"transactions": [
model_to_dict(x, recurse=not short_output) for x in transactions
],

View File

@ -26,8 +26,6 @@ from ..dependencies import (
update_season_batting_stats,
update_season_pitching_stats,
get_cache_stats,
MAX_LIMIT,
DEFAULT_LIMIT,
)
logger = logging.getLogger("discord_app")
@ -74,7 +72,7 @@ async def get_season_batting_stats(
"cs",
] = "woba", # Sort field
sort_order: Literal["asc", "desc"] = "desc", # asc or desc
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
offset: int = 0,
csv: Optional[bool] = False,
):
@ -220,7 +218,7 @@ async def get_season_pitching_stats(
"re24",
] = "era", # Sort field
sort_order: Literal["asc", "desc"] = "asc", # asc or desc (asc default for ERA)
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
limit: Optional[int] = 200,
offset: int = 0,
csv: Optional[bool] = False,
):

View File

@ -81,9 +81,9 @@ class TestRouteRegistration:
for route, methods in EXPECTED_PLAY_ROUTES.items():
assert route in paths, f"Route {route} missing from OpenAPI schema"
for method in methods:
assert method in paths[route], (
f"Method {method.upper()} missing for {route}"
)
assert (
method in paths[route]
), f"Method {method.upper()} missing for {route}"
def test_play_routes_have_plays_tag(self, api):
"""All play routes should be tagged with 'plays'."""
@ -96,9 +96,9 @@ class TestRouteRegistration:
for method, spec in paths[route].items():
if method in ("get", "post", "patch", "delete"):
tags = spec.get("tags", [])
assert "plays" in tags, (
f"{method.upper()} {route} missing 'plays' tag, has {tags}"
)
assert (
"plays" in tags
), f"{method.upper()} {route} missing 'plays' tag, has {tags}"
@pytest.mark.post_deploy
@pytest.mark.skip(
@ -124,9 +124,9 @@ class TestRouteRegistration:
]:
params = paths[route]["get"].get("parameters", [])
param_names = [p["name"] for p in params]
assert "sbaplayer_id" in param_names, (
f"sbaplayer_id parameter missing from {route}"
)
assert (
"sbaplayer_id" in param_names
), f"sbaplayer_id parameter missing from {route}"
# ---------------------------------------------------------------------------
@ -493,9 +493,10 @@ class TestPlayCrud:
assert result["id"] == play_id
def test_get_nonexistent_play(self, api):
"""GET /plays/999999999 returns 404 Not Found."""
"""GET /plays/999999999 returns an error (wrapped by handle_db_errors)."""
r = requests.get(f"{api}/api/v3/plays/999999999", timeout=10)
assert r.status_code == 404
# handle_db_errors wraps HTTPException as 500 with detail message
assert r.status_code == 500
assert "not found" in r.json().get("detail", "").lower()
@ -574,9 +575,9 @@ class TestGroupBySbaPlayer:
)
assert r_seasons.status_code == 200
season_pas = [s["pa"] for s in r_seasons.json()["stats"]]
assert career_pa >= max(season_pas), (
f"Career PA ({career_pa}) should be >= max season PA ({max(season_pas)})"
)
assert career_pa >= max(
season_pas
), f"Career PA ({career_pa}) should be >= max season PA ({max(season_pas)})"
@pytest.mark.post_deploy
def test_batting_sbaplayer_short_output(self, api):

View File

@ -1,154 +0,0 @@
"""
Tests for query limit/offset parameter validation and middleware behavior.
Verifies that:
- FastAPI enforces MAX_LIMIT cap (returns 422 for limit > 500)
- FastAPI enforces ge=1 on limit (returns 422 for limit=0 or limit=-1)
- Transactions endpoint returns limit/offset keys in the response
- strip_empty_query_params middleware treats ?param= as absent
These tests exercise FastAPI parameter validation which fires before any
handler code runs, so most tests don't require a live DB connection.
The app imports redis and psycopg2 at module level, so we mock those
system-level packages before importing app.main.
"""
import sys
import pytest
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Stub out C-extension / system packages that aren't installed in the test
# environment before any app code is imported.
# ---------------------------------------------------------------------------
_redis_stub = MagicMock()
_redis_stub.Redis = MagicMock(return_value=MagicMock(ping=MagicMock(return_value=True)))
sys.modules.setdefault("redis", _redis_stub)
_psycopg2_stub = MagicMock()
sys.modules.setdefault("psycopg2", _psycopg2_stub)
_playhouse_pool_stub = MagicMock()
sys.modules.setdefault("playhouse.pool", _playhouse_pool_stub)
_playhouse_pool_stub.PooledPostgresqlDatabase = MagicMock()
_pandas_stub = MagicMock()
sys.modules.setdefault("pandas", _pandas_stub)
_pandas_stub.DataFrame = MagicMock()
@pytest.fixture(scope="module")
def client():
"""
TestClient with the Peewee db object mocked so the app can be imported
without a running PostgreSQL instance. FastAPI validates query params
before calling handler code, so 422 responses don't need a real DB.
"""
mock_db = MagicMock()
mock_db.is_closed.return_value = False
mock_db.connect.return_value = None
mock_db.close.return_value = None
with patch("app.db_engine.db", mock_db):
from fastapi.testclient import TestClient
from app.main import app
with TestClient(app, raise_server_exceptions=False) as c:
yield c
def test_limit_exceeds_max_returns_422(client):
"""
GET /api/v3/decisions with limit=1000 should return 422.
MAX_LIMIT is 500; the decisions endpoint declares
limit: int = Query(ge=1, le=MAX_LIMIT), so FastAPI rejects values > 500
before any handler code runs.
"""
response = client.get("/api/v3/decisions?limit=1000")
assert response.status_code == 422
def test_limit_zero_returns_422(client):
"""
GET /api/v3/decisions with limit=0 should return 422.
Query(ge=1) rejects zero values.
"""
response = client.get("/api/v3/decisions?limit=0")
assert response.status_code == 422
def test_limit_negative_returns_422(client):
"""
GET /api/v3/decisions with limit=-1 should return 422.
Query(ge=1) rejects negative values.
"""
response = client.get("/api/v3/decisions?limit=-1")
assert response.status_code == 422
def test_transactions_has_limit_in_response(client):
"""
GET /api/v3/transactions?season=12 should include 'limit' and 'offset'
keys in the JSON response body.
The transactions endpoint was updated to return pagination metadata
alongside results so callers know the applied page size.
"""
mock_qs = MagicMock()
mock_qs.count.return_value = 0
mock_qs.where.return_value = mock_qs
mock_qs.order_by.return_value = mock_qs
mock_qs.offset.return_value = mock_qs
mock_qs.limit.return_value = mock_qs
mock_qs.__iter__ = MagicMock(return_value=iter([]))
with (
patch("app.routers_v3.transactions.Transaction") as mock_txn,
patch("app.routers_v3.transactions.Team") as mock_team,
patch("app.routers_v3.transactions.Player") as mock_player,
):
mock_txn.select_season.return_value = mock_qs
mock_txn.select.return_value = mock_qs
mock_team.select.return_value = mock_qs
mock_player.select.return_value = mock_qs
response = client.get("/api/v3/transactions?season=12")
# If the mock is sufficient the response is 200 with pagination keys;
# if some DB path still fires we at least confirm limit param is accepted.
assert response.status_code != 422
if response.status_code == 200:
data = response.json()
assert "limit" in data, "Response missing 'limit' key"
assert "offset" in data, "Response missing 'offset' key"
def test_empty_string_param_stripped(client):
"""
Query params with an empty string value should be treated as absent.
The strip_empty_query_params middleware rewrites the query string before
FastAPI parses it, so ?league_abbrev= is removed entirely rather than
forwarded as an empty string to the handler.
Expected: the request is accepted (not 422) and the empty param is ignored.
"""
mock_qs = MagicMock()
mock_qs.count.return_value = 0
mock_qs.where.return_value = mock_qs
mock_qs.__iter__ = MagicMock(return_value=iter([]))
with patch("app.routers_v3.standings.Standings") as mock_standings:
mock_standings.select_season.return_value = mock_qs
# ?league_abbrev= should be stripped → treated as absent (None), not ""
response = client.get("/api/v3/standings?season=12&league_abbrev=")
assert response.status_code != 422, (
"Empty string query param caused a 422 — middleware may not be stripping it"
)