From 697152808b1aad22bb9e2c16e9360646ee52e398 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Thu, 5 Mar 2026 11:03:17 -0600 Subject: [PATCH] fix: validate sort_by parameter with Literal type in views.py (#36) Co-Authored-By: Claude Sonnet 4.6 --- app/routers_v3/views.py | 312 ++++++++++++++++++++++++++-------------- 1 file changed, 202 insertions(+), 110 deletions(-) diff --git a/app/routers_v3/views.py b/app/routers_v3/views.py index ba77a0d..c658262 100644 --- a/app/routers_v3/views.py +++ b/app/routers_v3/views.py @@ -3,239 +3,331 @@ from typing import List, Literal, Optional import logging import pydantic -from ..db_engine import SeasonBattingStats, SeasonPitchingStats, db, Manager, Team, Current, model_to_dict, fn, query_to_csv, StratPlay, StratGame -from ..dependencies import add_cache_headers, cache_result, oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors, update_season_batting_stats, update_season_pitching_stats, get_cache_stats - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/views', - tags=['views'] +from ..db_engine import ( + SeasonBattingStats, + SeasonPitchingStats, + db, + Manager, + Team, + Current, + model_to_dict, + fn, + query_to_csv, + StratPlay, + StratGame, +) +from ..dependencies import ( + add_cache_headers, + cache_result, + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, + update_season_batting_stats, + update_season_pitching_stats, + get_cache_stats, ) -@router.get('/season-stats/batting') +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/views", tags=["views"]) + + +@router.get("/season-stats/batting") @handle_db_errors -@add_cache_headers(max_age=10*60) -@cache_result(ttl=5*60, key_prefix='season-batting') +@add_cache_headers(max_age=10 * 60) +@cache_result(ttl=5 * 60, key_prefix="season-batting") async def get_season_batting_stats( season: Optional[int] = None, team_id: Optional[int] = None, player_id: Optional[int] = None, sbaplayer_id: Optional[int] = None, min_pa: Optional[int] = None, # Minimum plate appearances - sort_by: str = "woba", # Default sort field - sort_order: Literal['asc', 'desc'] = 'desc', # asc or desc + sort_by: Literal[ + "pa", + "ab", + "run", + "hit", + "double", + "triple", + "homerun", + "rbi", + "bb", + "so", + "bphr", + "bpfo", + "bp1b", + "bplo", + "gidp", + "hbp", + "sac", + "ibb", + "avg", + "obp", + "slg", + "ops", + "woba", + "k_pct", + "sb", + "cs", + ] = "woba", # Sort field + sort_order: Literal["asc", "desc"] = "desc", # asc or desc limit: Optional[int] = 200, offset: int = 0, - csv: Optional[bool] = False + csv: Optional[bool] = False, ): - logger.info(f'Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') + logger.info( + f"Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}" + ) # Use the enhanced get_top_hitters method query = SeasonBattingStats.get_top_hitters( season=season, stat=sort_by, limit=limit if limit != 0 else None, - desc=(sort_order.lower() == 'desc'), + desc=(sort_order.lower() == "desc"), team_id=team_id, player_id=player_id, sbaplayer_id=sbaplayer_id, min_pa=min_pa, - offset=offset + offset=offset, ) - + # Build applied filters for response applied_filters = {} if season is not None: - applied_filters['season'] = season + applied_filters["season"] = season if team_id is not None: - applied_filters['team_id'] = team_id + applied_filters["team_id"] = team_id if player_id is not None: - applied_filters['player_id'] = player_id + applied_filters["player_id"] = player_id if min_pa is not None: - applied_filters['min_pa'] = min_pa - + applied_filters["min_pa"] = min_pa + if csv: return_val = query_to_csv(query) - return Response(content=return_val, media_type='text/csv') + return Response(content=return_val, media_type="text/csv") else: stat_list = [model_to_dict(stat) for stat in query] - return { - 'count': len(stat_list), - 'filters': applied_filters, - 'stats': stat_list - } + return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list} -@router.post('/season-stats/batting/refresh', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/season-stats/batting/refresh", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def refresh_season_batting_stats( - season: int, - token: str = Depends(oauth2_scheme) + season: int, token: str = Depends(oauth2_scheme) ) -> dict: """ Refresh batting stats for all players in a specific season. Useful for full season updates. """ if not valid_token(token): - logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"refresh_season_batting_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info(f"Refreshing all batting stats for season {season}") - logger.info(f'Refreshing all batting stats for season {season}') - try: # Get all player IDs who have stratplay records in this season - batter_ids = [row.batter_id for row in - StratPlay.select(StratPlay.batter_id.distinct()) - .join(StratGame).where(StratGame.season == season)] - + batter_ids = [ + row.batter_id + for row in StratPlay.select(StratPlay.batter_id.distinct()) + .join(StratGame) + .where(StratGame.season == season) + ] + if batter_ids: update_season_batting_stats(batter_ids, season, db) - logger.info(f'Successfully refreshed {len(batter_ids)} players for season {season}') - + logger.info( + f"Successfully refreshed {len(batter_ids)} players for season {season}" + ) + return { - 'message': f'Season {season} batting stats refreshed', - 'players_updated': len(batter_ids) + "message": f"Season {season} batting stats refreshed", + "players_updated": len(batter_ids), } else: - logger.warning(f'No batting data found for season {season}') + logger.warning(f"No batting data found for season {season}") return { - 'message': f'No batting data found for season {season}', - 'players_updated': 0 + "message": f"No batting data found for season {season}", + "players_updated": 0, } - + except Exception as e: - logger.error(f'Error refreshing season {season}: {e}') - raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') + logger.error(f"Error refreshing season {season}: {e}") + raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}") -@router.get('/season-stats/pitching') +@router.get("/season-stats/pitching") @handle_db_errors -@add_cache_headers(max_age=10*60) -@cache_result(ttl=5*60, key_prefix='season-pitching') +@add_cache_headers(max_age=10 * 60) +@cache_result(ttl=5 * 60, key_prefix="season-pitching") async def get_season_pitching_stats( season: Optional[int] = None, team_id: Optional[int] = None, player_id: Optional[int] = None, sbaplayer_id: Optional[int] = None, min_outs: Optional[int] = None, # Minimum outs pitched - sort_by: str = "era", # Default sort field - sort_order: Literal['asc', 'desc'] = 'asc', # asc or desc (asc default for ERA) + sort_by: Literal[ + "tbf", + "outs", + "games", + "gs", + "win", + "loss", + "hold", + "saves", + "bsave", + "ir", + "irs", + "ab", + "run", + "e_run", + "hits", + "double", + "triple", + "homerun", + "bb", + "so", + "hbp", + "sac", + "ibb", + "gidp", + "sb", + "cs", + "bphr", + "bpfo", + "bp1b", + "bplo", + "wp", + "balk", + "wpa", + "era", + "whip", + "avg", + "obp", + "slg", + "ops", + "woba", + "hper9", + "kper9", + "bbper9", + "kperbb", + "lob_2outs", + "rbipercent", + "re24", + ] = "era", # Sort field + sort_order: Literal["asc", "desc"] = "asc", # asc or desc (asc default for ERA) limit: Optional[int] = 200, offset: int = 0, - csv: Optional[bool] = False + csv: Optional[bool] = False, ): - logger.info(f'Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') + logger.info( + f"Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}" + ) # Use the get_top_pitchers method query = SeasonPitchingStats.get_top_pitchers( season=season, stat=sort_by, limit=limit if limit != 0 else None, - desc=(sort_order.lower() == 'desc'), + desc=(sort_order.lower() == "desc"), team_id=team_id, player_id=player_id, sbaplayer_id=sbaplayer_id, min_outs=min_outs, - offset=offset + offset=offset, ) - + # Build applied filters for response applied_filters = {} if season is not None: - applied_filters['season'] = season + applied_filters["season"] = season if team_id is not None: - applied_filters['team_id'] = team_id + applied_filters["team_id"] = team_id if player_id is not None: - applied_filters['player_id'] = player_id + applied_filters["player_id"] = player_id if min_outs is not None: - applied_filters['min_outs'] = min_outs - + applied_filters["min_outs"] = min_outs + if csv: return_val = query_to_csv(query) - return Response(content=return_val, media_type='text/csv') + return Response(content=return_val, media_type="text/csv") else: stat_list = [model_to_dict(stat) for stat in query] - return { - 'count': len(stat_list), - 'filters': applied_filters, - 'stats': stat_list - } + return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list} -@router.post('/season-stats/pitching/refresh', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/season-stats/pitching/refresh", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def refresh_season_pitching_stats( - season: int, - token: str = Depends(oauth2_scheme) + season: int, token: str = Depends(oauth2_scheme) ) -> dict: """ Refresh pitching statistics for a specific season by aggregating from individual games. Private endpoint - not included in public API documentation. """ if not valid_token(token): - logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"refresh_season_batting_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info(f"Refreshing season {season} pitching stats") - logger.info(f'Refreshing season {season} pitching stats') - try: # Get all pitcher IDs for this season pitcher_query = ( - StratPlay - .select(StratPlay.pitcher_id) + StratPlay.select(StratPlay.pitcher_id) .join(StratGame, on=(StratPlay.game_id == StratGame.id)) .where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False))) .distinct() ) pitcher_ids = [row.pitcher_id for row in pitcher_query] - + if not pitcher_ids: - logger.warning(f'No pitchers found for season {season}') + logger.warning(f"No pitchers found for season {season}") return { - 'status': 'success', - 'message': f'No pitchers found for season {season}', - 'players_updated': 0 + "status": "success", + "message": f"No pitchers found for season {season}", + "players_updated": 0, } - + # Use the dependency function to update pitching stats update_season_pitching_stats(pitcher_ids, season, db) - - logger.info(f'Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated') + + logger.info( + f"Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated" + ) return { - 'status': 'success', - 'message': f'Season {season} pitching stats refreshed', - 'players_updated': len(pitcher_ids) + "status": "success", + "message": f"Season {season} pitching stats refreshed", + "players_updated": len(pitcher_ids), } - + except Exception as e: - logger.error(f'Error refreshing season {season} pitching stats: {e}') - raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') + logger.error(f"Error refreshing season {season} pitching stats: {e}") + raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}") -@router.get('/admin/cache', include_in_schema=PRIVATE_IN_SCHEMA) +@router.get("/admin/cache", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors -async def get_admin_cache_stats( - token: str = Depends(oauth2_scheme) -) -> dict: +async def get_admin_cache_stats(token: str = Depends(oauth2_scheme)) -> dict: """ Get Redis cache statistics and status. Private endpoint - requires authentication. """ if not valid_token(token): - logger.warning(f'get_admin_cache_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"get_admin_cache_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info("Getting cache statistics") - logger.info('Getting cache statistics') - try: cache_stats = get_cache_stats() - logger.info(f'Cache stats retrieved: {cache_stats}') - return { - 'status': 'success', - 'cache_info': cache_stats - } - + logger.info(f"Cache stats retrieved: {cache_stats}") + return {"status": "success", "cache_info": cache_stats} + except Exception as e: - logger.error(f'Error getting cache stats: {e}') - raise HTTPException(status_code=500, detail=f'Failed to get cache stats: {str(e)}') + logger.error(f"Error getting cache stats: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to get cache stats: {str(e)}" + )