from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional, Literal import logging import pydantic from ..db_engine import ( db, PitchingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks, ) from ..dependencies import ( oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors, MAX_LIMIT, DEFAULT_LIMIT, ) logger = logging.getLogger("discord_app") router = APIRouter(prefix="/api/v3/pitchingstats", tags=["pitchingstats"]) class PitStatModel(pydantic.BaseModel): player_id: int team_id: int ip: Optional[float] = 0.0 hit: Optional[int] = 0 run: Optional[int] = 0 erun: Optional[int] = 0 so: Optional[int] = 0 bb: Optional[int] = 0 hbp: Optional[int] = 0 wp: Optional[int] = 0 balk: Optional[int] = 0 hr: Optional[int] = 0 gs: Optional[int] = 0 win: Optional[int] = 0 loss: Optional[int] = 0 hold: Optional[int] = 0 sv: Optional[int] = 0 bsv: Optional[int] = 0 ir: Optional[int] = 0 irs: Optional[int] = 0 week: int game: int season: int class PitStatList(pydantic.BaseModel): count: int stats: List[PitStatModel] @router.get("") @handle_db_errors async def get_pitstats( season: int, s_type: Optional[str] = "regular", team_abbrev: list = Query(default=None), player_name: list = Query(default=None), player_id: list = Query(default=None), week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT), ip_min: Optional[float] = None, sort: Optional[str] = None, short_output: Optional[bool] = True, ): if "post" in s_type.lower(): all_stats = PitchingStat.post_season(season) if all_stats.count() == 0: db.close() return {"count": 0, "stats": []} elif s_type.lower() in ["combined", "total", "all"]: all_stats = PitchingStat.combined_season(season) if all_stats.count() == 0: db.close() return {"count": 0, "stats": []} else: all_stats = PitchingStat.regular_season(season) if all_stats.count() == 0: db.close() return {"count": 0, "stats": []} if team_abbrev is not None: t_query = Team.select().where(Team.abbrev << [x.upper() for x in team_abbrev]) all_stats = all_stats.where(PitchingStat.team << t_query) if player_name is not None or player_id is not None: if player_id: all_stats = all_stats.where(PitchingStat.player_id << player_id) else: p_query = Player.select_season(season).where( fn.Lower(Player.name) << [x.lower() for x in player_name] ) all_stats = all_stats.where(PitchingStat.player << p_query) if game_num: all_stats = all_stats.where(PitchingStat.game == game_num) if ip_min is not None: all_stats = all_stats.where(PitchingStat.ip >= ip_min) start = 1 end = Current.get(Current.season == season).week if week_start is not None: start = week_start if week_end is not None: end = min(week_end, end) if start > end: db.close() raise HTTPException( status_code=404, detail=f"Start week {start} is after end week {end} - cannot pull stats", ) all_stats = all_stats.where( (PitchingStat.week >= start) & (PitchingStat.week <= end) ) all_stats = all_stats.limit(limit) if sort: if sort == "newest": all_stats = all_stats.order_by(-PitchingStat.week, -PitchingStat.game) return_stats = { "count": all_stats.count(), "stats": [model_to_dict(x, recurse=not short_output) for x in all_stats], } db.close() return return_stats @router.get("/totals") @handle_db_errors async def get_totalstats( season: int, s_type: Literal["regular", "post", "total", None] = None, team_abbrev: list = Query(default=None), team_id: list = Query(default=None), player_name: list = Query(default=None), week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), is_sp: Optional[bool] = None, ip_min: Optional[float] = 0.25, sort: Optional[str] = None, player_id: list = Query(default=None), short_output: Optional[bool] = False, group_by: Literal["team", "player", "playerteam"] = "player", week: list = Query(default=None), limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT), offset: int = Query(default=0, ge=0), ): if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: raise HTTPException( status_code=400, detail=f"Only one of s_type, week_start/week_end, or week may be used.", ) # Build SELECT fields conditionally based on group_by to match GROUP BY exactly select_fields = [] if group_by == "player": select_fields = [PitchingStat.player] elif group_by == "team": select_fields = [PitchingStat.team] elif group_by == "playerteam": select_fields = [PitchingStat.player, PitchingStat.team] else: # Default case select_fields = [PitchingStat.player] all_stats = ( PitchingStat.select( *select_fields, fn.SUM(PitchingStat.ip).alias("sum_ip"), fn.SUM(PitchingStat.hit).alias("sum_hit"), fn.SUM(PitchingStat.run).alias("sum_run"), fn.SUM(PitchingStat.erun).alias("sum_erun"), fn.SUM(PitchingStat.so).alias("sum_so"), fn.SUM(PitchingStat.bb).alias("sum_bb"), fn.SUM(PitchingStat.hbp).alias("sum_hbp"), fn.SUM(PitchingStat.wp).alias("sum_wp"), fn.SUM(PitchingStat.balk).alias("sum_balk"), fn.SUM(PitchingStat.hr).alias("sum_hr"), fn.SUM(PitchingStat.ir).alias("sum_ir"), fn.SUM(PitchingStat.win).alias("sum_win"), fn.SUM(PitchingStat.loss).alias("sum_loss"), fn.SUM(PitchingStat.hold).alias("sum_hold"), fn.SUM(PitchingStat.sv).alias("sum_sv"), fn.SUM(PitchingStat.bsv).alias("sum_bsv"), fn.SUM(PitchingStat.irs).alias("sum_irs"), fn.SUM(PitchingStat.gs).alias("sum_gs"), fn.COUNT(PitchingStat.game).alias("sum_games"), ) .where(PitchingStat.season == season) .having(fn.SUM(PitchingStat.ip) >= ip_min) ) if True in [s_type is not None, week_start is not None, week_end is not None]: weeks = {} if s_type is not None: weeks = per_season_weeks(season, s_type) elif week_start is not None or week_end is not None: if week_start is None or week_end is None: raise HTTPException( status_code=400, detail="Both week_start and week_end must be included if either is used.", ) weeks["start"] = week_start if week_end < weeks["start"]: raise HTTPException( status_code=400, detail="week_end must be greater than or equal to week_start", ) else: weeks["end"] = week_end all_stats = all_stats.where( (PitchingStat.week >= weeks["start"]) & (PitchingStat.week <= weeks["end"]) ) elif week is not None: all_stats = all_stats.where(PitchingStat.week << week) if game_num is not None: all_stats = all_stats.where(PitchingStat.game << game_num) if is_sp is not None: if is_sp: all_stats = all_stats.where(PitchingStat.gs == 1) if not is_sp: all_stats = all_stats.where(PitchingStat.gs == 0) if sort is not None: if sort == "player": all_stats = all_stats.order_by(PitchingStat.player) elif sort == "team": all_stats = all_stats.order_by(PitchingStat.team) if group_by is not None: # Use the same fields for GROUP BY as we used for SELECT all_stats = all_stats.group_by(*select_fields) if team_id is not None: all_teams = Team.select().where(Team.id << team_id) all_stats = all_stats.where(PitchingStat.team << all_teams) elif team_abbrev is not None: all_teams = Team.select().where( fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev] ) all_stats = all_stats.where(PitchingStat.team << all_teams) if player_name is not None: all_players = Player.select().where( fn.Lower(Player.name) << [x.lower() for x in player_name] ) all_stats = all_stats.where(PitchingStat.player << all_players) elif player_id is not None: all_players = Player.select().where(Player.id << player_id) all_stats = all_stats.where(PitchingStat.player << all_players) total_count = all_stats.count() all_stats = all_stats.offset(offset).limit(limit) return_stats = {"count": total_count, "stats": []} for x in all_stats: # Handle player field based on grouping with safe access this_player = "TOT" if "player" in group_by and hasattr(x, "player"): this_player = ( x.player_id if short_output else model_to_dict(x.player, recurse=False) ) # Handle team field based on grouping with safe access this_team = "TOT" if "team" in group_by and hasattr(x, "team"): this_team = ( x.team_id if short_output else model_to_dict(x.team, recurse=False) ) return_stats["stats"].append( { "player": this_player, "team": this_team, "ip": x.sum_ip, "hit": x.sum_hit, "run": x.sum_run, "erun": x.sum_erun, "so": x.sum_so, "bb": x.sum_bb, "hbp": x.sum_hbp, "wp": x.sum_wp, "balk": x.sum_balk, "hr": x.sum_hr, "ir": x.sum_ir, "irs": x.sum_irs, "gs": x.sum_gs, "games": x.sum_games, "win": x.sum_win, "loss": x.sum_loss, "hold": x.sum_hold, "sv": x.sum_sv, "bsv": x.sum_bsv, } ) db.close() return return_stats @router.patch("/{stat_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_pitstats( stat_id: int, new_stats: PitStatModel, token: str = Depends(oauth2_scheme) ): if not valid_token(token): logger.warning(f"patch_pitstats - Bad Token: {token}") raise HTTPException(status_code=401, detail="Unauthorized") if PitchingStat.get_or_none(PitchingStat.id == stat_id) is None: raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found") PitchingStat.update(**new_stats.dict()).where(PitchingStat.id == stat_id).execute() r_stat = model_to_dict(PitchingStat.get_by_id(stat_id)) db.close() return r_stat @router.post("/", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_pitstats(s_list: PitStatList, token: str = Depends(oauth2_scheme)): if not valid_token(token): logger.warning(f"post_pitstats - Bad Token: {token}") raise HTTPException(status_code=401, detail="Unauthorized") all_stats = [] for x in s_list.stats: team = Team.get_or_none(Team.id == x.team_id) this_player = Player.get_or_none(Player.id == x.player_id) if team is None: raise HTTPException( status_code=404, detail=f"Team ID {x.team_id} not found" ) if this_player is None: raise HTTPException( status_code=404, detail=f"Player ID {x.player_id} not found" ) all_stats.append(PitchingStat(**x.dict())) with db.atomic(): for batch in chunked(all_stats, 15): PitchingStat.insert_many(batch).on_conflict_ignore().execute() db.close() return f"Added {len(all_stats)} batting lines"