Guard bulk ID queries against empty lists to prevent PostgreSQL syntax error (WHERE id IN ()) when batch POST endpoints receive empty request bodies. Affected endpoints: - POST /api/v3/transactions - POST /api/v3/results - POST /api/v3/schedules - POST /api/v3/battingstats Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
417 lines
14 KiB
Python
417 lines
14 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from typing import List, Optional, Literal
|
|
import logging
|
|
import pydantic
|
|
|
|
from ..db_engine import (
|
|
db,
|
|
BattingStat,
|
|
Team,
|
|
Player,
|
|
Current,
|
|
model_to_dict,
|
|
chunked,
|
|
fn,
|
|
per_season_weeks,
|
|
)
|
|
from ..dependencies import (
|
|
oauth2_scheme,
|
|
valid_token,
|
|
PRIVATE_IN_SCHEMA,
|
|
handle_db_errors,
|
|
)
|
|
|
|
logger = logging.getLogger("discord_app")
|
|
|
|
router = APIRouter(prefix="/api/v3/battingstats", tags=["battingstats"])
|
|
|
|
|
|
class BatStatModel(pydantic.BaseModel):
|
|
player_id: int
|
|
team_id: int
|
|
pos: str
|
|
pa: Optional[int] = 0
|
|
ab: Optional[int] = 0
|
|
run: Optional[int] = 0
|
|
hit: Optional[int] = 0
|
|
rbi: Optional[int] = 0
|
|
double: Optional[int] = 0
|
|
triple: Optional[int] = 0
|
|
hr: Optional[int] = 0
|
|
bb: Optional[int] = 0
|
|
so: Optional[int] = 0
|
|
hbp: Optional[int] = 0
|
|
sac: Optional[int] = 0
|
|
ibb: Optional[int] = 0
|
|
gidp: Optional[int] = 0
|
|
sb: Optional[int] = 0
|
|
cs: Optional[int] = 0
|
|
bphr: Optional[int] = 0
|
|
bpfo: Optional[int] = 0
|
|
bp1b: Optional[int] = 0
|
|
bplo: Optional[int] = 0
|
|
xba: Optional[int] = 0
|
|
xbt: Optional[int] = 0
|
|
xch: Optional[int] = 0
|
|
xhit: Optional[int] = 0
|
|
error: Optional[int] = 0
|
|
pb: Optional[int] = 0
|
|
sbc: Optional[int] = 0
|
|
csc: Optional[int] = 0
|
|
roba: Optional[int] = 0
|
|
robs: Optional[int] = 0
|
|
raa: Optional[int] = 0
|
|
rto: Optional[int] = 0
|
|
week: int
|
|
game: int
|
|
season: int
|
|
|
|
|
|
class BatStatList(pydantic.BaseModel):
|
|
count: int
|
|
stats: List[BatStatModel]
|
|
|
|
|
|
@router.get("")
|
|
@handle_db_errors
|
|
async def get_batstats(
|
|
season: int,
|
|
s_type: Optional[str] = "regular",
|
|
team_abbrev: list = Query(default=None),
|
|
player_name: list = Query(default=None),
|
|
player_id: list = Query(default=None),
|
|
week_start: Optional[int] = None,
|
|
week_end: Optional[int] = None,
|
|
game_num: list = Query(default=None),
|
|
position: list = Query(default=None),
|
|
limit: Optional[int] = None,
|
|
sort: Optional[str] = None,
|
|
short_output: Optional[bool] = True,
|
|
):
|
|
if "post" in s_type.lower():
|
|
all_stats = BattingStat.post_season(season)
|
|
if all_stats.count() == 0:
|
|
db.close()
|
|
return {"count": 0, "stats": []}
|
|
elif s_type.lower() in ["combined", "total", "all"]:
|
|
all_stats = BattingStat.combined_season(season)
|
|
if all_stats.count() == 0:
|
|
db.close()
|
|
return {"count": 0, "stats": []}
|
|
else:
|
|
all_stats = BattingStat.regular_season(season)
|
|
if all_stats.count() == 0:
|
|
db.close()
|
|
return {"count": 0, "stats": []}
|
|
|
|
if position is not None:
|
|
all_stats = all_stats.where(BattingStat.pos << [x.upper() for x in position])
|
|
if team_abbrev is not None:
|
|
t_query = Team.select().where(Team.abbrev << [x.upper() for x in team_abbrev])
|
|
all_stats = all_stats.where(BattingStat.team << t_query)
|
|
if player_name is not None or player_id is not None:
|
|
if player_id:
|
|
all_stats = all_stats.where(BattingStat.player_id << player_id)
|
|
else:
|
|
p_query = Player.select_season(season).where(
|
|
fn.Lower(Player.name) << [x.lower() for x in player_name]
|
|
)
|
|
all_stats = all_stats.where(BattingStat.player << p_query)
|
|
if game_num:
|
|
all_stats = all_stats.where(BattingStat.game == game_num)
|
|
|
|
start = 1
|
|
end = Current.get(Current.season == season).week
|
|
if week_start is not None:
|
|
start = week_start
|
|
if week_end is not None:
|
|
end = min(week_end, end)
|
|
if start > end:
|
|
db.close()
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Start week {start} is after end week {end} - cannot pull stats",
|
|
)
|
|
all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
|
|
|
|
if limit:
|
|
all_stats = all_stats.limit(limit)
|
|
if sort:
|
|
if sort == "newest":
|
|
all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game)
|
|
|
|
return_stats = {
|
|
"count": all_stats.count(),
|
|
"stats": [model_to_dict(x, recurse=not short_output) for x in all_stats],
|
|
# 'stats': [{'id': x.id} for x in all_stats]
|
|
}
|
|
|
|
db.close()
|
|
return return_stats
|
|
|
|
|
|
@router.get("/totals")
|
|
@handle_db_errors
|
|
async def get_totalstats(
|
|
season: int,
|
|
s_type: Literal["regular", "post", "total", None] = None,
|
|
team_abbrev: list = Query(default=None),
|
|
team_id: list = Query(default=None),
|
|
player_name: list = Query(default=None),
|
|
week_start: Optional[int] = None,
|
|
week_end: Optional[int] = None,
|
|
game_num: list = Query(default=None),
|
|
position: list = Query(default=None),
|
|
sort: Optional[str] = None,
|
|
player_id: list = Query(default=None),
|
|
group_by: Literal["team", "player", "playerteam"] = "player",
|
|
short_output: Optional[bool] = False,
|
|
min_pa: Optional[int] = 1,
|
|
week: list = Query(default=None),
|
|
):
|
|
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Only one of s_type, week_start/week_end, or week may be used.",
|
|
)
|
|
|
|
# Build SELECT fields conditionally based on group_by to match GROUP BY exactly
|
|
select_fields = []
|
|
|
|
if group_by == "player":
|
|
select_fields = [BattingStat.player]
|
|
elif group_by == "team":
|
|
select_fields = [BattingStat.team]
|
|
elif group_by == "playerteam":
|
|
select_fields = [BattingStat.player, BattingStat.team]
|
|
else:
|
|
# Default case
|
|
select_fields = [BattingStat.player]
|
|
|
|
all_stats = (
|
|
BattingStat.select(
|
|
*select_fields,
|
|
fn.SUM(BattingStat.pa).alias("sum_pa"),
|
|
fn.SUM(BattingStat.ab).alias("sum_ab"),
|
|
fn.SUM(BattingStat.run).alias("sum_run"),
|
|
fn.SUM(BattingStat.hit).alias("sum_hit"),
|
|
fn.SUM(BattingStat.rbi).alias("sum_rbi"),
|
|
fn.SUM(BattingStat.double).alias("sum_double"),
|
|
fn.SUM(BattingStat.triple).alias("sum_triple"),
|
|
fn.SUM(BattingStat.hr).alias("sum_hr"),
|
|
fn.SUM(BattingStat.bb).alias("sum_bb"),
|
|
fn.SUM(BattingStat.so).alias("sum_so"),
|
|
fn.SUM(BattingStat.hbp).alias("sum_hbp"),
|
|
fn.SUM(BattingStat.sac).alias("sum_sac"),
|
|
fn.SUM(BattingStat.ibb).alias("sum_ibb"),
|
|
fn.SUM(BattingStat.gidp).alias("sum_gidp"),
|
|
fn.SUM(BattingStat.sb).alias("sum_sb"),
|
|
fn.SUM(BattingStat.cs).alias("sum_cs"),
|
|
fn.SUM(BattingStat.bphr).alias("sum_bphr"),
|
|
fn.SUM(BattingStat.bpfo).alias("sum_bpfo"),
|
|
fn.SUM(BattingStat.bp1b).alias("sum_bp1b"),
|
|
fn.SUM(BattingStat.bplo).alias("sum_bplo"),
|
|
fn.SUM(BattingStat.xba).alias("sum_xba"),
|
|
fn.SUM(BattingStat.xbt).alias("sum_xbt"),
|
|
fn.SUM(BattingStat.xch).alias("sum_xch"),
|
|
fn.SUM(BattingStat.xhit).alias("sum_xhit"),
|
|
fn.SUM(BattingStat.error).alias("sum_error"),
|
|
fn.SUM(BattingStat.pb).alias("sum_pb"),
|
|
fn.SUM(BattingStat.sbc).alias("sum_sbc"),
|
|
fn.SUM(BattingStat.csc).alias("sum_csc"),
|
|
fn.SUM(BattingStat.roba).alias("sum_roba"),
|
|
fn.SUM(BattingStat.robs).alias("sum_robs"),
|
|
fn.SUM(BattingStat.raa).alias("sum_raa"),
|
|
fn.SUM(BattingStat.rto).alias("sum_rto"),
|
|
)
|
|
.where(BattingStat.season == season)
|
|
.having(fn.SUM(BattingStat.pa) >= min_pa)
|
|
)
|
|
|
|
if True in [s_type is not None, week_start is not None, week_end is not None]:
|
|
weeks = {}
|
|
if s_type is not None:
|
|
weeks = per_season_weeks(season, s_type)
|
|
elif week_start is not None or week_end is not None:
|
|
if week_start is None or week_end is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Both week_start and week_end must be included if either is used.",
|
|
)
|
|
weeks["start"] = week_start
|
|
if week_end < weeks["start"]:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="week_end must be greater than or equal to week_start",
|
|
)
|
|
else:
|
|
weeks["end"] = week_end
|
|
|
|
all_stats = all_stats.where(
|
|
(BattingStat.week >= weeks["start"]) & (BattingStat.week <= weeks["end"])
|
|
)
|
|
elif week is not None:
|
|
all_stats = all_stats.where(BattingStat.week << week)
|
|
|
|
if game_num is not None:
|
|
all_stats = all_stats.where(BattingStat.game << game_num)
|
|
if position is not None:
|
|
p_list = [x.upper() for x in position]
|
|
all_players = Player.select().where(
|
|
(Player.pos_1 << p_list)
|
|
| (Player.pos_2 << p_list)
|
|
| (Player.pos_3 << p_list)
|
|
| (Player.pos_4 << p_list)
|
|
| (Player.pos_5 << p_list)
|
|
| (Player.pos_6 << p_list)
|
|
| (Player.pos_7 << p_list)
|
|
| (Player.pos_8 << p_list)
|
|
)
|
|
all_stats = all_stats.where(BattingStat.player << all_players)
|
|
if sort is not None:
|
|
if sort == "player":
|
|
all_stats = all_stats.order_by(BattingStat.player)
|
|
elif sort == "team":
|
|
all_stats = all_stats.order_by(BattingStat.team)
|
|
if group_by is not None:
|
|
# Use the same fields for GROUP BY as we used for SELECT
|
|
all_stats = all_stats.group_by(*select_fields)
|
|
|
|
# if team_abbrev is None and team_id is None and player_name is None and player_id is None:
|
|
# raise HTTPException(
|
|
# status_code=400,
|
|
# detail=f'Must include team_id/team_abbrev and/or player_name/player_id'
|
|
# )
|
|
|
|
if team_id is not None:
|
|
all_teams = Team.select().where(Team.id << team_id)
|
|
all_stats = all_stats.where(BattingStat.team << all_teams)
|
|
elif team_abbrev is not None:
|
|
all_teams = Team.select().where(
|
|
fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]
|
|
)
|
|
all_stats = all_stats.where(BattingStat.team << all_teams)
|
|
|
|
if player_name is not None:
|
|
all_players = Player.select().where(
|
|
fn.Lower(Player.name) << [x.lower() for x in player_name]
|
|
)
|
|
all_stats = all_stats.where(BattingStat.player << all_players)
|
|
elif player_id is not None:
|
|
all_players = Player.select().where(Player.id << player_id)
|
|
all_stats = all_stats.where(BattingStat.player << all_players)
|
|
|
|
return_stats = {"count": all_stats.count(), "stats": []}
|
|
|
|
for x in all_stats:
|
|
# Handle player field based on grouping with safe access
|
|
this_player = "TOT"
|
|
if "player" in group_by and hasattr(x, "player"):
|
|
this_player = (
|
|
x.player_id if short_output else model_to_dict(x.player, recurse=False)
|
|
)
|
|
|
|
# Handle team field based on grouping with safe access
|
|
this_team = "TOT"
|
|
if "team" in group_by and hasattr(x, "team"):
|
|
this_team = (
|
|
x.team_id if short_output else model_to_dict(x.team, recurse=False)
|
|
)
|
|
|
|
return_stats["stats"].append(
|
|
{
|
|
"player": this_player,
|
|
"team": this_team,
|
|
"pa": x.sum_pa,
|
|
"ab": x.sum_ab,
|
|
"run": x.sum_run,
|
|
"hit": x.sum_hit,
|
|
"rbi": x.sum_rbi,
|
|
"double": x.sum_double,
|
|
"triple": x.sum_triple,
|
|
"hr": x.sum_hr,
|
|
"bb": x.sum_bb,
|
|
"so": x.sum_so,
|
|
"hbp": x.sum_hbp,
|
|
"sac": x.sum_sac,
|
|
"ibb": x.sum_ibb,
|
|
"gidp": x.sum_gidp,
|
|
"sb": x.sum_sb,
|
|
"cs": x.sum_cs,
|
|
"bphr": x.sum_bphr,
|
|
"bpfo": x.sum_bpfo,
|
|
"bp1b": x.sum_bp1b,
|
|
"bplo": x.sum_bplo,
|
|
}
|
|
)
|
|
db.close()
|
|
return return_stats
|
|
|
|
|
|
# @router.get('/career/{player_name}')
|
|
# async def get_careerstats(
|
|
# s_type: Literal['regular', 'post', 'total'] = 'regular', player_name: list = Query(default=None)):
|
|
# pass # Keep Career Stats table and recalculate after posting stats
|
|
|
|
|
|
@router.patch("/{stat_id}", include_in_schema=PRIVATE_IN_SCHEMA)
|
|
@handle_db_errors
|
|
async def patch_batstats(
|
|
stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)
|
|
):
|
|
if not valid_token(token):
|
|
logger.warning(f"patch_batstats - Bad Token: {token}")
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
if BattingStat.get_or_none(BattingStat.id == stat_id) is None:
|
|
raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found")
|
|
|
|
BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute()
|
|
r_stat = model_to_dict(BattingStat.get_by_id(stat_id))
|
|
db.close()
|
|
return r_stat
|
|
|
|
|
|
@router.post("/", include_in_schema=PRIVATE_IN_SCHEMA)
|
|
@handle_db_errors
|
|
async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)):
|
|
if not valid_token(token):
|
|
logger.warning(f"post_batstats - Bad Token: {token}")
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
all_stats = []
|
|
|
|
all_team_ids = list(set(x.team_id for x in s_list.stats))
|
|
all_player_ids = list(set(x.player_id for x in s_list.stats))
|
|
found_team_ids = (
|
|
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
|
|
if all_team_ids
|
|
else set()
|
|
)
|
|
found_player_ids = (
|
|
set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids))
|
|
if all_player_ids
|
|
else set()
|
|
)
|
|
|
|
for x in s_list.stats:
|
|
if x.team_id not in found_team_ids:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Team ID {x.team_id} not found"
|
|
)
|
|
if x.player_id not in found_player_ids:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Player ID {x.player_id} not found"
|
|
)
|
|
|
|
all_stats.append(BattingStat(**x.dict()))
|
|
|
|
with db.atomic():
|
|
for batch in chunked(all_stats, 15):
|
|
BattingStat.insert_many(batch).on_conflict_ignore().execute()
|
|
|
|
# Update career stats
|
|
|
|
db.close()
|
|
return f"Added {len(all_stats)} batting lines"
|