From 580b22ec6fbd68c2a938b395dcceb6e198f79b84 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Thu, 5 Mar 2026 17:33:32 -0600 Subject: [PATCH 1/2] fix: eliminate N+1 queries in batch POST endpoints (#25) Replace per-row Team/Player lookups with bulk IN-list queries before the validation loop in post_transactions, post_results, post_schedules, and post_batstats. A 50-move batch now uses 2 queries instead of 150. Co-Authored-By: Claude Sonnet 4.6 --- app/routers_v3/battingstats.py | 326 ++++++++++++++++++++------------- app/routers_v3/results.py | 109 +++++++---- app/routers_v3/schedules.py | 105 +++++++---- app/routers_v3/transactions.py | 116 ++++++++---- 4 files changed, 419 insertions(+), 237 deletions(-) diff --git a/app/routers_v3/battingstats.py b/app/routers_v3/battingstats.py index 7471cf9..334beb4 100644 --- a/app/routers_v3/battingstats.py +++ b/app/routers_v3/battingstats.py @@ -3,15 +3,27 @@ from typing import List, Optional, Literal import logging import pydantic -from ..db_engine import db, BattingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks -from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/battingstats', - tags=['battingstats'] +from ..db_engine import ( + db, + BattingStat, + Team, + Player, + Current, + model_to_dict, + chunked, + fn, + per_season_weeks, ) +from ..dependencies import ( + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, +) + +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/battingstats", tags=["battingstats"]) class BatStatModel(pydantic.BaseModel): @@ -60,29 +72,37 @@ class BatStatList(pydantic.BaseModel): stats: List[BatStatModel] -@router.get('') +@router.get("") @handle_db_errors async def get_batstats( - season: int, s_type: Optional[str] = 'regular', team_abbrev: list = Query(default=None), - player_name: list = Query(default=None), player_id: list = Query(default=None), - week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), - position: list = Query(default=None), limit: Optional[int] = None, sort: Optional[str] = None, - short_output: Optional[bool] = True): - if 'post' in s_type.lower(): + season: int, + s_type: Optional[str] = "regular", + team_abbrev: list = Query(default=None), + player_name: list = Query(default=None), + player_id: list = Query(default=None), + week_start: Optional[int] = None, + week_end: Optional[int] = None, + game_num: list = Query(default=None), + position: list = Query(default=None), + limit: Optional[int] = None, + sort: Optional[str] = None, + short_output: Optional[bool] = True, +): + if "post" in s_type.lower(): all_stats = BattingStat.post_season(season) if all_stats.count() == 0: db.close() - return {'count': 0, 'stats': []} - elif s_type.lower() in ['combined', 'total', 'all']: + return {"count": 0, "stats": []} + elif s_type.lower() in ["combined", "total", "all"]: all_stats = BattingStat.combined_season(season) if all_stats.count() == 0: db.close() - return {'count': 0, 'stats': []} + return {"count": 0, "stats": []} else: all_stats = BattingStat.regular_season(season) if all_stats.count() == 0: db.close() - return {'count': 0, 'stats': []} + return {"count": 0, "stats": []} if position is not None: all_stats = all_stats.where(BattingStat.pos << [x.upper() for x in position]) @@ -93,7 +113,9 @@ async def get_batstats( if player_id: all_stats = all_stats.where(BattingStat.player_id << player_id) else: - p_query = Player.select_season(season).where(fn.Lower(Player.name) << [x.lower() for x in player_name]) + p_query = Player.select_season(season).where( + fn.Lower(Player.name) << [x.lower() for x in player_name] + ) all_stats = all_stats.where(BattingStat.player << p_query) if game_num: all_stats = all_stats.where(BattingStat.game == game_num) @@ -108,21 +130,19 @@ async def get_batstats( db.close() raise HTTPException( status_code=404, - detail=f'Start week {start} is after end week {end} - cannot pull stats' + detail=f"Start week {start} is after end week {end} - cannot pull stats", ) - all_stats = all_stats.where( - (BattingStat.week >= start) & (BattingStat.week <= end) - ) + all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end)) if limit: all_stats = all_stats.limit(limit) if sort: - if sort == 'newest': + if sort == "newest": all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game) return_stats = { - 'count': all_stats.count(), - 'stats': [model_to_dict(x, recurse=not short_output) for x in all_stats], + "count": all_stats.count(), + "stats": [model_to_dict(x, recurse=not short_output) for x in all_stats], # 'stats': [{'id': x.id} for x in all_stats] } @@ -130,52 +150,82 @@ async def get_batstats( return return_stats -@router.get('/totals') +@router.get("/totals") @handle_db_errors async def get_totalstats( - season: int, s_type: Literal['regular', 'post', 'total', None] = None, team_abbrev: list = Query(default=None), - team_id: list = Query(default=None), player_name: list = Query(default=None), - week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), - position: list = Query(default=None), sort: Optional[str] = None, player_id: list = Query(default=None), - group_by: Literal['team', 'player', 'playerteam'] = 'player', short_output: Optional[bool] = False, - min_pa: Optional[int] = 1, week: list = Query(default=None)): + season: int, + s_type: Literal["regular", "post", "total", None] = None, + team_abbrev: list = Query(default=None), + team_id: list = Query(default=None), + player_name: list = Query(default=None), + week_start: Optional[int] = None, + week_end: Optional[int] = None, + game_num: list = Query(default=None), + position: list = Query(default=None), + sort: Optional[str] = None, + player_id: list = Query(default=None), + group_by: Literal["team", "player", "playerteam"] = "player", + short_output: Optional[bool] = False, + min_pa: Optional[int] = 1, + week: list = Query(default=None), +): if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: - raise HTTPException(status_code=400, detail=f'Only one of s_type, week_start/week_end, or week may be used.') + raise HTTPException( + status_code=400, + detail=f"Only one of s_type, week_start/week_end, or week may be used.", + ) # Build SELECT fields conditionally based on group_by to match GROUP BY exactly select_fields = [] - - if group_by == 'player': + + if group_by == "player": select_fields = [BattingStat.player] - elif group_by == 'team': + elif group_by == "team": select_fields = [BattingStat.team] - elif group_by == 'playerteam': + elif group_by == "playerteam": select_fields = [BattingStat.player, BattingStat.team] else: # Default case select_fields = [BattingStat.player] all_stats = ( - BattingStat - .select(*select_fields, - fn.SUM(BattingStat.pa).alias('sum_pa'), fn.SUM(BattingStat.ab).alias('sum_ab'), - fn.SUM(BattingStat.run).alias('sum_run'), fn.SUM(BattingStat.hit).alias('sum_hit'), - fn.SUM(BattingStat.rbi).alias('sum_rbi'), fn.SUM(BattingStat.double).alias('sum_double'), - fn.SUM(BattingStat.triple).alias('sum_triple'), fn.SUM(BattingStat.hr).alias('sum_hr'), - fn.SUM(BattingStat.bb).alias('sum_bb'), fn.SUM(BattingStat.so).alias('sum_so'), - fn.SUM(BattingStat.hbp).alias('sum_hbp'), fn.SUM(BattingStat.sac).alias('sum_sac'), - fn.SUM(BattingStat.ibb).alias('sum_ibb'), fn.SUM(BattingStat.gidp).alias('sum_gidp'), - fn.SUM(BattingStat.sb).alias('sum_sb'), fn.SUM(BattingStat.cs).alias('sum_cs'), - fn.SUM(BattingStat.bphr).alias('sum_bphr'), fn.SUM(BattingStat.bpfo).alias('sum_bpfo'), - fn.SUM(BattingStat.bp1b).alias('sum_bp1b'), fn.SUM(BattingStat.bplo).alias('sum_bplo'), - fn.SUM(BattingStat.xba).alias('sum_xba'), fn.SUM(BattingStat.xbt).alias('sum_xbt'), - fn.SUM(BattingStat.xch).alias('sum_xch'), fn.SUM(BattingStat.xhit).alias('sum_xhit'), - fn.SUM(BattingStat.error).alias('sum_error'), fn.SUM(BattingStat.pb).alias('sum_pb'), - fn.SUM(BattingStat.sbc).alias('sum_sbc'), fn.SUM(BattingStat.csc).alias('sum_csc'), - fn.SUM(BattingStat.roba).alias('sum_roba'), fn.SUM(BattingStat.robs).alias('sum_robs'), - fn.SUM(BattingStat.raa).alias('sum_raa'), fn.SUM(BattingStat.rto).alias('sum_rto')) - .where(BattingStat.season == season) - .having(fn.SUM(BattingStat.pa) >= min_pa) + BattingStat.select( + *select_fields, + fn.SUM(BattingStat.pa).alias("sum_pa"), + fn.SUM(BattingStat.ab).alias("sum_ab"), + fn.SUM(BattingStat.run).alias("sum_run"), + fn.SUM(BattingStat.hit).alias("sum_hit"), + fn.SUM(BattingStat.rbi).alias("sum_rbi"), + fn.SUM(BattingStat.double).alias("sum_double"), + fn.SUM(BattingStat.triple).alias("sum_triple"), + fn.SUM(BattingStat.hr).alias("sum_hr"), + fn.SUM(BattingStat.bb).alias("sum_bb"), + fn.SUM(BattingStat.so).alias("sum_so"), + fn.SUM(BattingStat.hbp).alias("sum_hbp"), + fn.SUM(BattingStat.sac).alias("sum_sac"), + fn.SUM(BattingStat.ibb).alias("sum_ibb"), + fn.SUM(BattingStat.gidp).alias("sum_gidp"), + fn.SUM(BattingStat.sb).alias("sum_sb"), + fn.SUM(BattingStat.cs).alias("sum_cs"), + fn.SUM(BattingStat.bphr).alias("sum_bphr"), + fn.SUM(BattingStat.bpfo).alias("sum_bpfo"), + fn.SUM(BattingStat.bp1b).alias("sum_bp1b"), + fn.SUM(BattingStat.bplo).alias("sum_bplo"), + fn.SUM(BattingStat.xba).alias("sum_xba"), + fn.SUM(BattingStat.xbt).alias("sum_xbt"), + fn.SUM(BattingStat.xch).alias("sum_xch"), + fn.SUM(BattingStat.xhit).alias("sum_xhit"), + fn.SUM(BattingStat.error).alias("sum_error"), + fn.SUM(BattingStat.pb).alias("sum_pb"), + fn.SUM(BattingStat.sbc).alias("sum_sbc"), + fn.SUM(BattingStat.csc).alias("sum_csc"), + fn.SUM(BattingStat.roba).alias("sum_roba"), + fn.SUM(BattingStat.robs).alias("sum_robs"), + fn.SUM(BattingStat.raa).alias("sum_raa"), + fn.SUM(BattingStat.rto).alias("sum_rto"), + ) + .where(BattingStat.season == season) + .having(fn.SUM(BattingStat.pa) >= min_pa) ) if True in [s_type is not None, week_start is not None, week_end is not None]: @@ -185,16 +235,20 @@ async def get_totalstats( elif week_start is not None or week_end is not None: if week_start is None or week_end is None: raise HTTPException( - status_code=400, detail='Both week_start and week_end must be included if either is used.' + status_code=400, + detail="Both week_start and week_end must be included if either is used.", + ) + weeks["start"] = week_start + if week_end < weeks["start"]: + raise HTTPException( + status_code=400, + detail="week_end must be greater than or equal to week_start", ) - weeks['start'] = week_start - if week_end < weeks['start']: - raise HTTPException(status_code=400, detail='week_end must be greater than or equal to week_start') else: - weeks['end'] = week_end + weeks["end"] = week_end all_stats = all_stats.where( - (BattingStat.week >= weeks['start']) & (BattingStat.week <= weeks['end']) + (BattingStat.week >= weeks["start"]) & (BattingStat.week <= weeks["end"]) ) elif week is not None: all_stats = all_stats.where(BattingStat.week << week) @@ -204,14 +258,20 @@ async def get_totalstats( if position is not None: p_list = [x.upper() for x in position] all_players = Player.select().where( - (Player.pos_1 << p_list) | (Player.pos_2 << p_list) | (Player.pos_3 << p_list) | ( Player.pos_4 << p_list) | - (Player.pos_5 << p_list) | (Player.pos_6 << p_list) | (Player.pos_7 << p_list) | ( Player.pos_8 << p_list) + (Player.pos_1 << p_list) + | (Player.pos_2 << p_list) + | (Player.pos_3 << p_list) + | (Player.pos_4 << p_list) + | (Player.pos_5 << p_list) + | (Player.pos_6 << p_list) + | (Player.pos_7 << p_list) + | (Player.pos_8 << p_list) ) all_stats = all_stats.where(BattingStat.player << all_players) if sort is not None: - if sort == 'player': + if sort == "player": all_stats = all_stats.order_by(BattingStat.player) - elif sort == 'team': + elif sort == "team": all_stats = all_stats.order_by(BattingStat.team) if group_by is not None: # Use the same fields for GROUP BY as we used for SELECT @@ -227,56 +287,63 @@ async def get_totalstats( all_teams = Team.select().where(Team.id << team_id) all_stats = all_stats.where(BattingStat.team << all_teams) elif team_abbrev is not None: - all_teams = Team.select().where(fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]) + all_teams = Team.select().where( + fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev] + ) all_stats = all_stats.where(BattingStat.team << all_teams) if player_name is not None: - all_players = Player.select().where(fn.Lower(Player.name) << [x.lower() for x in player_name]) + all_players = Player.select().where( + fn.Lower(Player.name) << [x.lower() for x in player_name] + ) all_stats = all_stats.where(BattingStat.player << all_players) elif player_id is not None: all_players = Player.select().where(Player.id << player_id) all_stats = all_stats.where(BattingStat.player << all_players) - return_stats = { - 'count': all_stats.count(), - 'stats': [] - } - + return_stats = {"count": all_stats.count(), "stats": []} + for x in all_stats: # Handle player field based on grouping with safe access - this_player = 'TOT' - if 'player' in group_by and hasattr(x, 'player'): - this_player = x.player_id if short_output else model_to_dict(x.player, recurse=False) + this_player = "TOT" + if "player" in group_by and hasattr(x, "player"): + this_player = ( + x.player_id if short_output else model_to_dict(x.player, recurse=False) + ) - # Handle team field based on grouping with safe access - this_team = 'TOT' - if 'team' in group_by and hasattr(x, 'team'): - this_team = x.team_id if short_output else model_to_dict(x.team, recurse=False) - - return_stats['stats'].append({ - 'player': this_player, - 'team': this_team, - 'pa': x.sum_pa, - 'ab': x.sum_ab, - 'run': x.sum_run, - 'hit': x.sum_hit, - 'rbi': x.sum_rbi, - 'double': x.sum_double, - 'triple': x.sum_triple, - 'hr': x.sum_hr, - 'bb': x.sum_bb, - 'so': x.sum_so, - 'hbp': x.sum_hbp, - 'sac': x.sum_sac, - 'ibb': x.sum_ibb, - 'gidp': x.sum_gidp, - 'sb': x.sum_sb, - 'cs': x.sum_cs, - 'bphr': x.sum_bphr, - 'bpfo': x.sum_bpfo, - 'bp1b': x.sum_bp1b, - 'bplo': x.sum_bplo - }) + # Handle team field based on grouping with safe access + this_team = "TOT" + if "team" in group_by and hasattr(x, "team"): + this_team = ( + x.team_id if short_output else model_to_dict(x.team, recurse=False) + ) + + return_stats["stats"].append( + { + "player": this_player, + "team": this_team, + "pa": x.sum_pa, + "ab": x.sum_ab, + "run": x.sum_run, + "hit": x.sum_hit, + "rbi": x.sum_rbi, + "double": x.sum_double, + "triple": x.sum_triple, + "hr": x.sum_hr, + "bb": x.sum_bb, + "so": x.sum_so, + "hbp": x.sum_hbp, + "sac": x.sum_sac, + "ibb": x.sum_ibb, + "gidp": x.sum_gidp, + "sb": x.sum_sb, + "cs": x.sum_cs, + "bphr": x.sum_bphr, + "bpfo": x.sum_bpfo, + "bp1b": x.sum_bp1b, + "bplo": x.sum_bplo, + } + ) db.close() return return_stats @@ -287,15 +354,17 @@ async def get_totalstats( # pass # Keep Career Stats table and recalculate after posting stats -@router.patch('/{stat_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.patch("/{stat_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors -async def patch_batstats(stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)): +async def patch_batstats( + stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme) +): if not valid_token(token): - logger.warning(f'patch_batstats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_batstats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") if BattingStat.get_or_none(BattingStat.id == stat_id) is None: - raise HTTPException(status_code=404, detail=f'Stat ID {stat_id} not found') + raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found") BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute() r_stat = model_to_dict(BattingStat.get_by_id(stat_id)) @@ -303,22 +372,33 @@ async def patch_batstats(stat_id: int, new_stats: BatStatModel, token: str = Dep return r_stat -@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'post_batstats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"post_batstats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") all_stats = [] + all_team_ids = list(set(x.team_id for x in s_list.stats)) + all_player_ids = list(set(x.player_id for x in s_list.stats)) + found_team_ids = set( + t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + ) + found_player_ids = set( + p.id for p in Player.select(Player.id).where(Player.id << all_player_ids) + ) + for x in s_list.stats: - team = Team.get_or_none(Team.id == x.team_id) - this_player = Player.get_or_none(Player.id == x.player_id) - if team is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.team_id} not found') - if this_player is None: - raise HTTPException(status_code=404, detail=f'Player ID {x.player_id} not found') + if x.team_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.team_id} not found" + ) + if x.player_id not in found_player_ids: + raise HTTPException( + status_code=404, detail=f"Player ID {x.player_id} not found" + ) all_stats.append(BattingStat(**x.dict())) @@ -329,4 +409,4 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme) # Update career stats db.close() - return f'Added {len(all_stats)} batting lines' + return f"Added {len(all_stats)} batting lines" diff --git a/app/routers_v3/results.py b/app/routers_v3/results.py index 53cc085..1279fd3 100644 --- a/app/routers_v3/results.py +++ b/app/routers_v3/results.py @@ -4,15 +4,17 @@ import logging import pydantic from ..db_engine import db, Result, Team, model_to_dict, chunked -from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/results', - tags=['results'] +from ..dependencies import ( + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, ) +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/results", tags=["results"]) + class ResultModel(pydantic.BaseModel): week: int @@ -29,13 +31,18 @@ class ResultList(pydantic.BaseModel): results: List[ResultModel] -@router.get('') +@router.get("") @handle_db_errors async def get_results( - season: int, team_abbrev: list = Query(default=None), week_start: Optional[int] = None, - week_end: Optional[int] = None, game_num: list = Query(default=None), - away_abbrev: list = Query(default=None), home_abbrev: list = Query(default=None), - short_output: Optional[bool] = False): + season: int, + team_abbrev: list = Query(default=None), + week_start: Optional[int] = None, + week_end: Optional[int] = None, + game_num: list = Query(default=None), + away_abbrev: list = Query(default=None), + home_abbrev: list = Query(default=None), + short_output: Optional[bool] = False, +): all_results = Result.select_season(season) if team_abbrev is not None: @@ -68,14 +75,14 @@ async def get_results( all_results = all_results.where(Result.week <= week_end) return_results = { - 'count': all_results.count(), - 'results': [model_to_dict(x, recurse=not short_output) for x in all_results] + "count": all_results.count(), + "results": [model_to_dict(x, recurse=not short_output) for x in all_results], } db.close() return return_results -@router.get('/{result_id}') +@router.get("/{result_id}") @handle_db_errors async def get_one_result(result_id: int, short_output: Optional[bool] = False): this_result = Result.get_or_none(Result.id == result_id) @@ -87,20 +94,27 @@ async def get_one_result(result_id: int, short_output: Optional[bool] = False): return r_result -@router.patch('/{result_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.patch("/{result_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_result( - result_id: int, week_num: Optional[int] = None, game_num: Optional[int] = None, - away_team_id: Optional[int] = None, home_team_id: Optional[int] = None, away_score: Optional[int] = None, - home_score: Optional[int] = None, season: Optional[int] = None, scorecard_url: Optional[str] = None, - token: str = Depends(oauth2_scheme)): + result_id: int, + week_num: Optional[int] = None, + game_num: Optional[int] = None, + away_team_id: Optional[int] = None, + home_team_id: Optional[int] = None, + away_score: Optional[int] = None, + home_score: Optional[int] = None, + season: Optional[int] = None, + scorecard_url: Optional[str] = None, + token: str = Depends(oauth2_scheme), +): if not valid_token(token): - logger.warning(f'patch_player - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_player - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") this_result = Result.get_or_none(Result.id == result_id) if this_result is None: - raise HTTPException(status_code=404, detail=f'Result ID {result_id} not found') + raise HTTPException(status_code=404, detail=f"Result ID {result_id} not found") if week_num is not None: this_result.week = week_num @@ -132,22 +146,37 @@ async def patch_result( return r_result else: db.close() - raise HTTPException(status_code=500, detail=f'Unable to patch result {result_id}') + raise HTTPException( + status_code=500, detail=f"Unable to patch result {result_id}" + ) -@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_results(result_list: ResultList, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'patch_player - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_player - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") new_results = [] + + all_team_ids = list( + set(x.awayteam_id for x in result_list.results) + | set(x.hometeam_id for x in result_list.results) + ) + found_team_ids = set( + t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + ) + for x in result_list.results: - if Team.get_or_none(Team.id == x.awayteam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.awayteam_id} not found') - if Team.get_or_none(Team.id == x.hometeam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.hometeam_id} not found') + if x.awayteam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.awayteam_id} not found" + ) + if x.hometeam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.hometeam_id} not found" + ) new_results.append(x.dict()) @@ -156,27 +185,27 @@ async def post_results(result_list: ResultList, token: str = Depends(oauth2_sche Result.insert_many(batch).on_conflict_ignore().execute() db.close() - return f'Inserted {len(new_results)} results' + return f"Inserted {len(new_results)} results" -@router.delete('/{result_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.delete("/{result_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def delete_result(result_id: int, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'delete_result - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"delete_result - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") this_result = Result.get_or_none(Result.id == result_id) if not this_result: db.close() - raise HTTPException(status_code=404, detail=f'Result ID {result_id} not found') + raise HTTPException(status_code=404, detail=f"Result ID {result_id} not found") count = this_result.delete_instance() db.close() if count == 1: - return f'Result {result_id} has been deleted' + return f"Result {result_id} has been deleted" else: - raise HTTPException(status_code=500, detail=f'Result {result_id} could not be deleted') - - + raise HTTPException( + status_code=500, detail=f"Result {result_id} could not be deleted" + ) diff --git a/app/routers_v3/schedules.py b/app/routers_v3/schedules.py index 5b76a67..ce5ba2a 100644 --- a/app/routers_v3/schedules.py +++ b/app/routers_v3/schedules.py @@ -4,15 +4,17 @@ import logging import pydantic from ..db_engine import db, Schedule, Team, model_to_dict, chunked -from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/schedules', - tags=['schedules'] +from ..dependencies import ( + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, ) +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/schedules", tags=["schedules"]) + class ScheduleModel(pydantic.BaseModel): week: int @@ -26,12 +28,17 @@ class ScheduleList(pydantic.BaseModel): schedules: List[ScheduleModel] -@router.get('') +@router.get("") @handle_db_errors async def get_schedules( - season: int, team_abbrev: list = Query(default=None), away_abbrev: list = Query(default=None), - home_abbrev: list = Query(default=None), week_start: Optional[int] = None, week_end: Optional[int] = None, - short_output: Optional[bool] = True): + season: int, + team_abbrev: list = Query(default=None), + away_abbrev: list = Query(default=None), + home_abbrev: list = Query(default=None), + week_start: Optional[int] = None, + week_end: Optional[int] = None, + short_output: Optional[bool] = True, +): all_sched = Schedule.select_season(season) if team_abbrev is not None: @@ -63,14 +70,14 @@ async def get_schedules( all_sched = all_sched.order_by(Schedule.id) return_sched = { - 'count': all_sched.count(), - 'schedules': [model_to_dict(x, recurse=not short_output) for x in all_sched] + "count": all_sched.count(), + "schedules": [model_to_dict(x, recurse=not short_output) for x in all_sched], } db.close() return return_sched -@router.get('/{schedule_id}') +@router.get("/{schedule_id}") @handle_db_errors async def get_one_schedule(schedule_id: int): this_sched = Schedule.get_or_none(Schedule.id == schedule_id) @@ -82,19 +89,26 @@ async def get_one_schedule(schedule_id: int): return r_sched -@router.patch('/{schedule_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.patch("/{schedule_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_schedule( - schedule_id: int, week: list = Query(default=None), awayteam_id: Optional[int] = None, - hometeam_id: Optional[int] = None, gamecount: Optional[int] = None, season: Optional[int] = None, - token: str = Depends(oauth2_scheme)): + schedule_id: int, + week: list = Query(default=None), + awayteam_id: Optional[int] = None, + hometeam_id: Optional[int] = None, + gamecount: Optional[int] = None, + season: Optional[int] = None, + token: str = Depends(oauth2_scheme), +): if not valid_token(token): - logger.warning(f'patch_schedule - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_schedule - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") this_sched = Schedule.get_or_none(Schedule.id == schedule_id) if this_sched is None: - raise HTTPException(status_code=404, detail=f'Schedule ID {schedule_id} not found') + raise HTTPException( + status_code=404, detail=f"Schedule ID {schedule_id} not found" + ) if week is not None: this_sched.week = week @@ -117,22 +131,37 @@ async def patch_schedule( return r_sched else: db.close() - raise HTTPException(status_code=500, detail=f'Unable to patch schedule {schedule_id}') + raise HTTPException( + status_code=500, detail=f"Unable to patch schedule {schedule_id}" + ) -@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'post_schedules - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"post_schedules - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") new_sched = [] + + all_team_ids = list( + set(x.awayteam_id for x in sched_list.schedules) + | set(x.hometeam_id for x in sched_list.schedules) + ) + found_team_ids = set( + t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + ) + for x in sched_list.schedules: - if Team.get_or_none(Team.id == x.awayteam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.awayteam_id} not found') - if Team.get_or_none(Team.id == x.hometeam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.hometeam_id} not found') + if x.awayteam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.awayteam_id} not found" + ) + if x.hometeam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.hometeam_id} not found" + ) new_sched.append(x.dict()) @@ -141,24 +170,28 @@ async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_s Schedule.insert_many(batch).on_conflict_ignore().execute() db.close() - return f'Inserted {len(new_sched)} schedules' + return f"Inserted {len(new_sched)} schedules" -@router.delete('/{schedule_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.delete("/{schedule_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def delete_schedule(schedule_id: int, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'delete_schedule - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"delete_schedule - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") this_sched = Schedule.get_or_none(Schedule.id == schedule_id) if this_sched is None: - raise HTTPException(status_code=404, detail=f'Schedule ID {schedule_id} not found') + raise HTTPException( + status_code=404, detail=f"Schedule ID {schedule_id} not found" + ) count = this_sched.delete_instance() db.close() if count == 1: - return f'Schedule {this_sched} has been deleted' + return f"Schedule {this_sched} has been deleted" else: - raise HTTPException(status_code=500, detail=f'Schedule {this_sched} could not be deleted') + raise HTTPException( + status_code=500, detail=f"Schedule {this_sched} could not be deleted" + ) diff --git a/app/routers_v3/transactions.py b/app/routers_v3/transactions.py index 79411d9..e0e58b8 100644 --- a/app/routers_v3/transactions.py +++ b/app/routers_v3/transactions.py @@ -5,15 +5,17 @@ import logging import pydantic from ..db_engine import db, Transaction, Team, Player, model_to_dict, chunked, fn -from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/transactions', - tags=['transactions'] +from ..dependencies import ( + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, ) +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/transactions", tags=["transactions"]) + class TransactionModel(pydantic.BaseModel): week: int @@ -31,13 +33,21 @@ class TransactionList(pydantic.BaseModel): moves: List[TransactionModel] -@router.get('') +@router.get("") @handle_db_errors async def get_transactions( - season, team_abbrev: list = Query(default=None), week_start: Optional[int] = 0, - week_end: Optional[int] = None, cancelled: Optional[bool] = None, frozen: Optional[bool] = None, - player_name: list = Query(default=None), player_id: list = Query(default=None), move_id: Optional[str] = None, - is_trade: Optional[bool] = None, short_output: Optional[bool] = False): + season, + team_abbrev: list = Query(default=None), + week_start: Optional[int] = 0, + week_end: Optional[int] = None, + cancelled: Optional[bool] = None, + frozen: Optional[bool] = None, + player_name: list = Query(default=None), + player_id: list = Query(default=None), + move_id: Optional[str] = None, + is_trade: Optional[bool] = None, + short_output: Optional[bool] = False, +): if season: transactions = Transaction.select_season(season) else: @@ -75,31 +85,39 @@ async def get_transactions( transactions = transactions.where(Transaction.frozen == 0) if is_trade is not None: - raise HTTPException(status_code=501, detail='The is_trade parameter is not implemented, yet') + raise HTTPException( + status_code=501, detail="The is_trade parameter is not implemented, yet" + ) transactions = transactions.order_by(-Transaction.week, Transaction.moveid) return_trans = { - 'count': transactions.count(), - 'transactions': [model_to_dict(x, recurse=not short_output) for x in transactions] + "count": transactions.count(), + "transactions": [ + model_to_dict(x, recurse=not short_output) for x in transactions + ], } db.close() return return_trans -@router.patch('/{move_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.patch("/{move_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_transactions( - move_id, token: str = Depends(oauth2_scheme), frozen: Optional[bool] = None, cancelled: Optional[bool] = None): + move_id, + token: str = Depends(oauth2_scheme), + frozen: Optional[bool] = None, + cancelled: Optional[bool] = None, +): if not valid_token(token): - logger.warning(f'patch_transactions - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_transactions - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") these_moves = Transaction.select().where(Transaction.moveid == move_id) if these_moves.count() == 0: db.close() - raise HTTPException(status_code=404, detail=f'Move ID {move_id} not found') + raise HTTPException(status_code=404, detail=f"Move ID {move_id} not found") if frozen is not None: for x in these_moves: @@ -111,25 +129,44 @@ async def patch_transactions( x.save() db.close() - return f'Updated {these_moves.count()} transactions' + return f"Updated {these_moves.count()} transactions" -@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors -async def post_transactions(moves: TransactionList, token: str = Depends(oauth2_scheme)): +async def post_transactions( + moves: TransactionList, token: str = Depends(oauth2_scheme) +): if not valid_token(token): - logger.warning(f'post_transactions - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"post_transactions - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") all_moves = [] + all_team_ids = list( + set(x.oldteam_id for x in moves.moves) | set(x.newteam_id for x in moves.moves) + ) + all_player_ids = list(set(x.player_id for x in moves.moves)) + found_team_ids = set( + t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + ) + found_player_ids = set( + p.id for p in Player.select(Player.id).where(Player.id << all_player_ids) + ) + for x in moves.moves: - if Team.get_or_none(Team.id == x.oldteam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.oldteam_id} not found') - if Team.get_or_none(Team.id == x.newteam_id) is None: - raise HTTPException(status_code=404, detail=f'Team ID {x.newteam_id} not found') - if Player.get_or_none(Player.id == x.player_id) is None: - raise HTTPException(status_code=404, detail=f'Player ID {x.player_id} not found') + if x.oldteam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.oldteam_id} not found" + ) + if x.newteam_id not in found_team_ids: + raise HTTPException( + status_code=404, detail=f"Team ID {x.newteam_id} not found" + ) + if x.player_id not in found_player_ids: + raise HTTPException( + status_code=404, detail=f"Player ID {x.player_id} not found" + ) all_moves.append(x.dict()) @@ -138,22 +175,25 @@ async def post_transactions(moves: TransactionList, token: str = Depends(oauth2_ Transaction.insert_many(batch).on_conflict_ignore().execute() db.close() - return f'{len(all_moves)} transactions have been added' + return f"{len(all_moves)} transactions have been added" -@router.delete('/{move_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.delete("/{move_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def delete_transactions(move_id, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'delete_transactions - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"delete_transactions - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") delete_query = Transaction.delete().where(Transaction.moveid == move_id) count = delete_query.execute() db.close() if count > 0: - return f'Removed {count} transactions' + return f"Removed {count} transactions" else: - raise HTTPException(status_code=418, detail=f'Well slap my ass and call me a teapot; ' - f'I did not delete any records') + raise HTTPException( + status_code=418, + detail=f"Well slap my ass and call me a teapot; " + f"I did not delete any records", + ) -- 2.25.1 From 17f67ff358e72b71256288400f719da4be99662f Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sat, 7 Mar 2026 01:51:58 -0600 Subject: [PATCH 2/2] fix: address review feedback (#52) Guard bulk ID queries against empty lists to prevent PostgreSQL syntax error (WHERE id IN ()) when batch POST endpoints receive empty request bodies. Affected endpoints: - POST /api/v3/transactions - POST /api/v3/results - POST /api/v3/schedules - POST /api/v3/battingstats Co-Authored-By: Claude Sonnet 4.6 --- app/routers_v3/battingstats.py | 12 ++++++++---- app/routers_v3/results.py | 6 ++++-- app/routers_v3/schedules.py | 6 ++++-- app/routers_v3/transactions.py | 12 ++++++++---- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/app/routers_v3/battingstats.py b/app/routers_v3/battingstats.py index 334beb4..1ae0ade 100644 --- a/app/routers_v3/battingstats.py +++ b/app/routers_v3/battingstats.py @@ -383,11 +383,15 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme) all_team_ids = list(set(x.team_id for x in s_list.stats)) all_player_ids = list(set(x.player_id for x in s_list.stats)) - found_team_ids = set( - t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() ) - found_player_ids = set( - p.id for p in Player.select(Player.id).where(Player.id << all_player_ids) + found_player_ids = ( + set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids)) + if all_player_ids + else set() ) for x in s_list.stats: diff --git a/app/routers_v3/results.py b/app/routers_v3/results.py index 1279fd3..176fe72 100644 --- a/app/routers_v3/results.py +++ b/app/routers_v3/results.py @@ -164,8 +164,10 @@ async def post_results(result_list: ResultList, token: str = Depends(oauth2_sche set(x.awayteam_id for x in result_list.results) | set(x.hometeam_id for x in result_list.results) ) - found_team_ids = set( - t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() ) for x in result_list.results: diff --git a/app/routers_v3/schedules.py b/app/routers_v3/schedules.py index ce5ba2a..706b0d8 100644 --- a/app/routers_v3/schedules.py +++ b/app/routers_v3/schedules.py @@ -149,8 +149,10 @@ async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_s set(x.awayteam_id for x in sched_list.schedules) | set(x.hometeam_id for x in sched_list.schedules) ) - found_team_ids = set( - t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() ) for x in sched_list.schedules: diff --git a/app/routers_v3/transactions.py b/app/routers_v3/transactions.py index e0e58b8..f407bfd 100644 --- a/app/routers_v3/transactions.py +++ b/app/routers_v3/transactions.py @@ -147,11 +147,15 @@ async def post_transactions( set(x.oldteam_id for x in moves.moves) | set(x.newteam_id for x in moves.moves) ) all_player_ids = list(set(x.player_id for x in moves.moves)) - found_team_ids = set( - t.id for t in Team.select(Team.id).where(Team.id << all_team_ids) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() ) - found_player_ids = set( - p.id for p in Player.select(Player.id).where(Player.id << all_player_ids) + found_player_ids = ( + set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids)) + if all_player_ids + else set() ) for x in moves.moves: -- 2.25.1