from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional, Literal import logging import pydantic from ..db_engine import db, BattingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors logger = logging.getLogger('discord_app') router = APIRouter( prefix='/api/v3/battingstats', tags=['battingstats'] ) class BatStatModel(pydantic.BaseModel): player_id: int team_id: int pos: str pa: Optional[int] = 0 ab: Optional[int] = 0 run: Optional[int] = 0 hit: Optional[int] = 0 rbi: Optional[int] = 0 double: Optional[int] = 0 triple: Optional[int] = 0 hr: Optional[int] = 0 bb: Optional[int] = 0 so: Optional[int] = 0 hbp: Optional[int] = 0 sac: Optional[int] = 0 ibb: Optional[int] = 0 gidp: Optional[int] = 0 sb: Optional[int] = 0 cs: Optional[int] = 0 bphr: Optional[int] = 0 bpfo: Optional[int] = 0 bp1b: Optional[int] = 0 bplo: Optional[int] = 0 xba: Optional[int] = 0 xbt: Optional[int] = 0 xch: Optional[int] = 0 xhit: Optional[int] = 0 error: Optional[int] = 0 pb: Optional[int] = 0 sbc: Optional[int] = 0 csc: Optional[int] = 0 roba: Optional[int] = 0 robs: Optional[int] = 0 raa: Optional[int] = 0 rto: Optional[int] = 0 week: int game: int season: int class BatStatList(pydantic.BaseModel): count: int stats: List[BatStatModel] @router.get('') @handle_db_errors async def get_batstats( season: int, s_type: Optional[str] = 'regular', team_abbrev: list = Query(default=None), player_name: list = Query(default=None), player_id: list = Query(default=None), week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), position: list = Query(default=None), limit: Optional[int] = None, sort: Optional[str] = None, short_output: Optional[bool] = True): if 'post' in s_type.lower(): all_stats = BattingStat.post_season(season) if all_stats.count() == 0: db.close() return {'count': 0, 'stats': []} elif s_type.lower() in ['combined', 'total', 'all']: all_stats = BattingStat.combined_season(season) if all_stats.count() == 0: db.close() return {'count': 0, 'stats': []} else: all_stats = BattingStat.regular_season(season) if all_stats.count() == 0: db.close() return {'count': 0, 'stats': []} if position is not None: all_stats = all_stats.where(BattingStat.pos << [x.upper() for x in position]) if team_abbrev is not None: t_query = Team.select().where(Team.abbrev << [x.upper() for x in team_abbrev]) all_stats = all_stats.where(BattingStat.team << t_query) if player_name is not None or player_id is not None: if player_id: all_stats = all_stats.where(BattingStat.player_id << player_id) else: p_query = Player.select_season(season).where(fn.Lower(Player.name) << [x.lower() for x in player_name]) all_stats = all_stats.where(BattingStat.player << p_query) if game_num: all_stats = all_stats.where(BattingStat.game == game_num) start = 1 end = Current.get(Current.season == season).week if week_start is not None: start = week_start if week_end is not None: end = min(week_end, end) if start > end: db.close() raise HTTPException( status_code=404, detail=f'Start week {start} is after end week {end} - cannot pull stats' ) all_stats = all_stats.where( (BattingStat.week >= start) & (BattingStat.week <= end) ) if limit: all_stats = all_stats.limit(limit) if sort: if sort == 'newest': all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game) return_stats = { 'count': all_stats.count(), 'stats': [model_to_dict(x, recurse=not short_output) for x in all_stats], # 'stats': [{'id': x.id} for x in all_stats] } db.close() return return_stats @router.get('/totals') @handle_db_errors async def get_totalstats( season: int, s_type: Literal['regular', 'post', 'total', None] = None, team_abbrev: list = Query(default=None), team_id: list = Query(default=None), player_name: list = Query(default=None), week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), position: list = Query(default=None), sort: Optional[str] = None, player_id: list = Query(default=None), group_by: Literal['team', 'player', 'playerteam'] = 'player', short_output: Optional[bool] = False, min_pa: Optional[int] = 1, week: list = Query(default=None)): if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: raise HTTPException(status_code=400, detail=f'Only one of s_type, week_start/week_end, or week may be used.') # Build SELECT fields conditionally based on group_by to match GROUP BY exactly select_fields = [] if group_by == 'player': select_fields = [BattingStat.player] elif group_by == 'team': select_fields = [BattingStat.team] elif group_by == 'playerteam': select_fields = [BattingStat.player, BattingStat.team] else: # Default case select_fields = [BattingStat.player] all_stats = ( BattingStat .select(*select_fields, fn.SUM(BattingStat.pa).alias('sum_pa'), fn.SUM(BattingStat.ab).alias('sum_ab'), fn.SUM(BattingStat.run).alias('sum_run'), fn.SUM(BattingStat.hit).alias('sum_hit'), fn.SUM(BattingStat.rbi).alias('sum_rbi'), fn.SUM(BattingStat.double).alias('sum_double'), fn.SUM(BattingStat.triple).alias('sum_triple'), fn.SUM(BattingStat.hr).alias('sum_hr'), fn.SUM(BattingStat.bb).alias('sum_bb'), fn.SUM(BattingStat.so).alias('sum_so'), fn.SUM(BattingStat.hbp).alias('sum_hbp'), fn.SUM(BattingStat.sac).alias('sum_sac'), fn.SUM(BattingStat.ibb).alias('sum_ibb'), fn.SUM(BattingStat.gidp).alias('sum_gidp'), fn.SUM(BattingStat.sb).alias('sum_sb'), fn.SUM(BattingStat.cs).alias('sum_cs'), fn.SUM(BattingStat.bphr).alias('sum_bphr'), fn.SUM(BattingStat.bpfo).alias('sum_bpfo'), fn.SUM(BattingStat.bp1b).alias('sum_bp1b'), fn.SUM(BattingStat.bplo).alias('sum_bplo'), fn.SUM(BattingStat.xba).alias('sum_xba'), fn.SUM(BattingStat.xbt).alias('sum_xbt'), fn.SUM(BattingStat.xch).alias('sum_xch'), fn.SUM(BattingStat.xhit).alias('sum_xhit'), fn.SUM(BattingStat.error).alias('sum_error'), fn.SUM(BattingStat.pb).alias('sum_pb'), fn.SUM(BattingStat.sbc).alias('sum_sbc'), fn.SUM(BattingStat.csc).alias('sum_csc'), fn.SUM(BattingStat.roba).alias('sum_roba'), fn.SUM(BattingStat.robs).alias('sum_robs'), fn.SUM(BattingStat.raa).alias('sum_raa'), fn.SUM(BattingStat.rto).alias('sum_rto')) .where(BattingStat.season == season) .having(fn.SUM(BattingStat.pa) >= min_pa) ) if True in [s_type is not None, week_start is not None, week_end is not None]: weeks = {} if s_type is not None: weeks = per_season_weeks(season, s_type) elif week_start is not None or week_end is not None: if week_start is None or week_end is None: raise HTTPException( status_code=400, detail='Both week_start and week_end must be included if either is used.' ) weeks['start'] = week_start if week_end < weeks['start']: raise HTTPException(status_code=400, detail='week_end must be greater than or equal to week_start') else: weeks['end'] = week_end all_stats = all_stats.where( (BattingStat.week >= weeks['start']) & (BattingStat.week <= weeks['end']) ) elif week is not None: all_stats = all_stats.where(BattingStat.week << week) if game_num is not None: all_stats = all_stats.where(BattingStat.game << game_num) if position is not None: p_list = [x.upper() for x in position] all_players = Player.select().where( (Player.pos_1 << p_list) | (Player.pos_2 << p_list) | (Player.pos_3 << p_list) | ( Player.pos_4 << p_list) | (Player.pos_5 << p_list) | (Player.pos_6 << p_list) | (Player.pos_7 << p_list) | ( Player.pos_8 << p_list) ) all_stats = all_stats.where(BattingStat.player << all_players) if sort is not None: if sort == 'player': all_stats = all_stats.order_by(BattingStat.player) elif sort == 'team': all_stats = all_stats.order_by(BattingStat.team) if group_by is not None: # Use the same fields for GROUP BY as we used for SELECT all_stats = all_stats.group_by(*select_fields) # if team_abbrev is None and team_id is None and player_name is None and player_id is None: # raise HTTPException( # status_code=400, # detail=f'Must include team_id/team_abbrev and/or player_name/player_id' # ) if team_id is not None: all_teams = Team.select().where(Team.id << team_id) all_stats = all_stats.where(BattingStat.team << all_teams) elif team_abbrev is not None: all_teams = Team.select().where(fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]) all_stats = all_stats.where(BattingStat.team << all_teams) if player_name is not None: all_players = Player.select().where(fn.Lower(Player.name) << [x.lower() for x in player_name]) all_stats = all_stats.where(BattingStat.player << all_players) elif player_id is not None: all_players = Player.select().where(Player.id << player_id) all_stats = all_stats.where(BattingStat.player << all_players) return_stats = { 'count': all_stats.count(), 'stats': [] } for x in all_stats: # Handle player field based on grouping with safe access this_player = 'TOT' if 'player' in group_by and hasattr(x, 'player'): this_player = x.player_id if short_output else model_to_dict(x.player, recurse=False) # Handle team field based on grouping with safe access this_team = 'TOT' if 'team' in group_by and hasattr(x, 'team'): this_team = x.team_id if short_output else model_to_dict(x.team, recurse=False) return_stats['stats'].append({ 'player': this_player, 'team': this_team, 'pa': x.sum_pa, 'ab': x.sum_ab, 'run': x.sum_run, 'hit': x.sum_hit, 'rbi': x.sum_rbi, 'double': x.sum_double, 'triple': x.sum_triple, 'hr': x.sum_hr, 'bb': x.sum_bb, 'so': x.sum_so, 'hbp': x.sum_hbp, 'sac': x.sum_sac, 'ibb': x.sum_ibb, 'gidp': x.sum_gidp, 'sb': x.sum_sb, 'cs': x.sum_cs, 'bphr': x.sum_bphr, 'bpfo': x.sum_bpfo, 'bp1b': x.sum_bp1b, 'bplo': x.sum_bplo }) db.close() return return_stats # @router.get('/career/{player_name}') # async def get_careerstats( # s_type: Literal['regular', 'post', 'total'] = 'regular', player_name: list = Query(default=None)): # pass # Keep Career Stats table and recalculate after posting stats @router.patch('/{stat_id}', include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_batstats(stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)): if not valid_token(token): logger.warning(f'patch_batstats - Bad Token: {token}') raise HTTPException(status_code=401, detail='Unauthorized') if BattingStat.get_or_none(BattingStat.id == stat_id) is None: raise HTTPException(status_code=404, detail=f'Stat ID {stat_id} not found') BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute() r_stat = model_to_dict(BattingStat.get_by_id(stat_id)) db.close() return r_stat @router.post('', include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)): if not valid_token(token): logger.warning(f'post_batstats - Bad Token: {token}') raise HTTPException(status_code=401, detail='Unauthorized') all_stats = [] for x in s_list.stats: team = Team.get_or_none(Team.id == x.team_id) this_player = Player.get_or_none(Player.id == x.player_id) if team is None: raise HTTPException(status_code=404, detail=f'Team ID {x.team_id} not found') if this_player is None: raise HTTPException(status_code=404, detail=f'Player ID {x.player_id} not found') all_stats.append(BattingStat(**x.dict())) with db.atomic(): for batch in chunked(all_stats, 15): BattingStat.insert_many(batch).on_conflict_ignore().execute() # Update career stats db.close() return f'Added {len(all_stats)} batting lines'