Compare commits

..

67 Commits

Author SHA1 Message Date
cal
5b19bd486a Merge pull request 'fix: preserve total_count in get_totalstats instead of overwriting with page length (#101)' (#102) from issue/101-fieldingstats-get-totalstats-total-count-overwritt into main 2026-04-08 04:08:40 +00:00
Cal Corum
718abc0096 fix: preserve total_count in get_totalstats instead of overwriting with page length (#101)
Closes #101

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 23:08:10 -05:00
cal
52d88ae950 Merge pull request 'fix: add missing indexes on FK columns in stratplay and stratgame (#74)' (#95) from issue/74-add-missing-indexes-on-foreign-key-columns-in-high into main 2026-04-08 04:06:06 +00:00
Cal Corum
9165419ed0 fix: add missing indexes on FK columns in stratplay and stratgame (#74)
Closes #74

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 23:05:25 -05:00
cal
d23d6520c3 Merge pull request 'fix: batch standings updates to eliminate N+1 queries in recalculate (#75)' (#93) from issue/75-fix-n-1-query-pattern-in-standings-recalculation into main 2026-04-08 04:03:39 +00:00
Cal Corum
c23ca9a721 fix: batch standings updates to eliminate N+1 queries in recalculate (#75)
Replace per-game update_standings() calls with pre-fetched dicts and
in-memory accumulation, then a single bulk_update at the end.
Reduces ~1,100+ queries for a full season to ~5 queries.

Closes #75

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 23:03:11 -05:00
cal
1db06576cc Merge pull request 'fix: replace integer comparisons on boolean fields with True/False (#69)' (#94) from issue/69-boolean-fields-compared-as-integers-sqlite-pattern into main 2026-04-08 03:57:35 +00:00
Cal Corum
7a5327f490 fix: replace integer comparisons on boolean fields with True/False (#69)
Closes #69

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:57:01 -05:00
cal
a2889751da Merge pull request 'fix: remove SQLite fallback code from db_engine.py (#70)' (#89) from issue/70-remove-sqlite-fallback-code-from-db-engine-py into main 2026-04-08 03:56:11 +00:00
Cal Corum
eb886a4690 fix: remove SQLite fallback code from db_engine.py (#70)
Removes DATABASE_TYPE conditional entirely. PostgreSQL is now the only
supported backend. Moves PooledPostgresqlDatabase import to top-level
and raises RuntimeError at startup if POSTGRES_PASSWORD is unset,
preventing silent misconnection with misleading errors.

Closes #70

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:55:44 -05:00
cal
0ee7367bc0 Merge pull request 'fix: disable autoconnect and set pool timeout on PooledPostgresqlDatabase (#80)' (#87) from issue/80-disable-autoconnect-and-set-pool-timeout-on-pooled into main 2026-04-08 03:55:05 +00:00
Cal Corum
6637f6e9eb fix: disable autoconnect and set pool timeout on PooledPostgresqlDatabase (#80)
- Set timeout=5 so pool exhaustion surfaces as an error instead of hanging forever
- Set autoconnect=False to require explicit connection acquisition
- Add HTTP middleware in main.py to open/close connections per request

Closes #80

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:54:27 -05:00
cal
fa176c9b05 Merge pull request 'fix: enforce Literal validation on sort parameter in GET /api/v3/players (#66)' (#68) from ai/major-domo-database-66 into main 2026-04-08 03:54:02 +00:00
Cal Corum
ece25ec22c fix: enforce Literal validation on sort parameter in GET /api/v3/players (#66)
Closes #66

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:53:33 -05:00
cal
580e8ea031 Merge pull request 'fix: align CustomCommandCreator.discord_id model with BIGINT column (#78)' (#88) from issue/78-fix-discord-id-type-mismatch-between-model-charfie into main 2026-04-08 03:49:26 +00:00
Cal Corum
18394aa74e fix: align CustomCommandCreator.discord_id model with BIGINT column (#78)
Closes #78

Change CharField(max_length=20) to BigIntegerField to match the BIGINT
column created by the migration. Remove the str() workaround in
get_creator_by_discord_id() that was compensating for the type mismatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:48:55 -05:00
cal
8e7466abd7 Merge pull request 'fix: remove token value from Bad Token log warnings (#79)' (#85) from issue/79-stop-logging-raw-auth-tokens-in-warning-messages into main 2026-04-08 03:43:13 +00:00
Cal Corum
d61bc31daa fix: remove token value from Bad Token log warnings (#79)
Closes #79

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 22:42:49 -05:00
cal
ca15dfe380 Merge pull request 'fix: replace row-by-row DELETE with bulk DELETE in career recalculation (#77)' (#92) from issue/77-replace-row-by-row-delete-with-bulk-delete-in-care into main 2026-04-08 03:24:58 +00:00
cal
1575be8260 Merge pull request 'fix: update Docker base image from Python 3.11 to 3.12 (#82)' (#91) from issue/82-align-python-version-between-docker-image-3-11-and into main 2026-04-08 03:24:46 +00:00
cal
7c7405cd1d Merge pull request 'feat: add migration tracking system (#81)' (#96) from issue/81-add-migration-tracking-system into main 2026-04-08 03:23:41 +00:00
cal
0cc0cba6a9 Merge pull request 'fix: replace deprecated Pydantic .dict() with .model_dump() (#76)' (#90) from issue/76-replace-deprecated-pydantic-dict-with-model-dump into main 2026-04-08 03:23:30 +00:00
cal
41fe4f6ce2 Merge pull request 'fix: add type annotations to untyped query parameters (#73)' (#86) from issue/73-add-type-annotations-to-untyped-query-parameters into main 2026-04-08 03:22:17 +00:00
cal
14234385fe Merge pull request 'fix: add combined_season classmethod to PitchingStat (#65)' (#67) from ai/major-domo-database-65 into main 2026-04-08 03:22:15 +00:00
cal
07aeaa8f3e Merge pull request 'fix: replace manual db.close() calls with middleware-based connection management (#71)' (#97) from issue/71-refactor-manual-db-close-calls-to-middleware-based into main 2026-04-08 02:42:09 +00:00
Cal Corum
701f790868 ci: retrigger build after transient Docker Hub push failure 2026-04-07 21:30:36 -05:00
Cal Corum
b46d8d33ef fix: remove empty finally clauses in custom_commands and help_commands
After removing db.close() calls, 22 finally: blocks were left empty
(12 in custom_commands.py, 10 in help_commands.py), causing
IndentationError at import time. Removed the finally: clause entirely
since connection lifecycle is now handled by the middleware.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 21:30:36 -05:00
Cal Corum
cfa6da06b7 fix: replace manual db.close() calls with middleware-based connection management (#71)
Closes #71

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 21:30:36 -05:00
cal
40897f1cc8 Merge pull request 'fix: move hardcoded Discord webhook URL to env var' (#83) from fix/remove-hardcoded-webhook into main
Reviewed-on: #83
Reviewed-by: Claude <cal.corum+openclaw@gmail.com>
2026-04-08 02:28:20 +00:00
12a76c2bb5 Merge branch 'main' into fix/remove-hardcoded-webhook 2026-04-08 02:24:10 +00:00
cal
aac4bf50d5 Merge pull request 'chore: switch CI to tag-triggered builds' (#107) from chore/tag-triggered-ci into main
Reviewed-on: #107
2026-04-06 16:59:02 +00:00
Cal Corum
4ad445b0da chore: switch CI to tag-triggered builds
Match the discord bot's CI pattern — trigger on CalVer tag push
instead of branch push/PR. Removes auto-CalVer generation and
simplifies to a single build step.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-06 16:58:45 +00:00
cal
8d9bbdd7a0 Merge pull request 'fix: increase get_games limit to 1000' (#106) from fix/increase-get-game-limit into main
All checks were successful
Build Docker Image / build (push) Successful in 1m6s
Reviewed-on: #106
2026-04-06 15:30:47 +00:00
cal
c95459fa5d Update app/routers_v3/stratgame.py
All checks were successful
Build Docker Image / build (pull_request) Successful in 4m51s
2026-04-06 14:58:36 +00:00
cal
d809590f0e Merge pull request 'fix: correct column references in season pitching stats SQL' (#105) from fix/pitching-stats-column-name into main
All checks were successful
Build Docker Image / build (push) Successful in 2m11s
2026-04-02 16:57:30 +00:00
cal
0d8e666a75 Merge pull request 'fix: let HTTPException pass through @handle_db_errors' (#104) from fix/handle-db-errors-passthrough-http into main
Some checks failed
Build Docker Image / build (push) Has been cancelled
2026-04-02 16:57:12 +00:00
Cal Corum
bd19b7d913 fix: correct column references in season pitching stats view
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m4s
sp.on_first/on_second/on_third don't exist — the actual columns are
on_first_id/on_second_id/on_third_id. This caused failures when
updating season pitching stats after games.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 11:54:56 -05:00
Cal Corum
c49f91cc19 test: update test_get_nonexistent_play to expect 404 after HTTPException fix
All checks were successful
Build Docker Image / build (pull_request) Successful in 1m3s
After handle_db_errors no longer catches HTTPException, GET /plays/999999999
correctly returns 404 instead of 500. Update the assertion and docstring
to reflect the fixed behavior.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-02 09:30:39 -05:00
Cal Corum
215085b326 fix: let HTTPException pass through @handle_db_errors unchanged
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m34s
The decorator was catching all exceptions including intentional
HTTPException (401, 404, etc.) and re-wrapping them as 500 "Database
error". This masked auth failures and other deliberate HTTP errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 08:30:22 -05:00
cal
c063f5c4ef Merge pull request 'hotfix: remove output caps from GET /players' (#103) from hotfix/remove-players-output-caps into main
All checks were successful
Build Docker Image / build (push) Successful in 1m3s
2026-04-02 01:19:51 +00:00
Cal Corum
d92f571960 hotfix: remove output caps from GET /players endpoint
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m29s
The MAX_LIMIT/DEFAULT_LIMIT caps added in 16f3f8d are too restrictive
for the /players endpoint — bot and website consumers need full player
lists without pagination. Reverts limit param to Optional[int] with no
ceiling while keeping caps on all other endpoints.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 20:14:35 -05:00
cal
81baa54681 Merge pull request 'Fix unbounded API queries causing worker timeouts' (#99) from bugfix/limit-caps into main
All checks were successful
Build Docker Image / build (push) Successful in 1m9s
Reviewed-on: #99
2026-04-01 22:44:38 +00:00
Cal Corum
67e87a893a Fix fieldingstats count computed after limit applied
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m9s
Capture total_count before .limit() so the response count reflects
all matching rows, not just the capped page size. Resolves #100.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 17:40:02 -05:00
Cal Corum
16f3f8d8de Fix unbounded API queries causing Gunicorn worker timeouts
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m32s
Add MAX_LIMIT=500 cap across all list endpoints, empty string
stripping middleware, and limit/offset to /transactions. Resolves #98.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 17:23:25 -05:00
Cal Corum
b35b68a88f Merge remote-tracking branch 'origin/main' into fix/remove-hardcoded-webhook
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m58s
2026-03-28 02:27:14 -05:00
cal
a1fa54c416 Merge pull request 'fix: remove hardcoded fallback password from DB connection' (#84) from fix/remove-default-db-password into main
All checks were successful
Build Docker Image / build (push) Successful in 2m46s
2026-03-28 07:26:55 +00:00
Cal Corum
eccf4d1441 feat: add migration tracking system (#81)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m11s
Adds schema_versions table and migrations.py runner to prevent
double-application and missed migrations across dev/prod environments.

- migrations/2026-03-27_add_schema_versions_table.sql: creates tracking table
- migrations.py: applies pending .sql files in sorted order, records each in schema_versions
- .gitignore: untrack migrations.py (was incorrectly ignored as legacy root file)

First run on an existing DB will apply all migrations (safe — all use IF NOT EXISTS).

Closes #81

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 05:34:13 -05:00
Cal Corum
d8c6ce2a5e fix: replace row-by-row DELETE with bulk DELETE in career recalculation (#77)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m4s
Closes #77

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 03:02:55 -05:00
Cal Corum
665f275546 fix: update Docker base image from Python 3.11 to 3.12 (#82)
Some checks failed
Build Docker Image / build (pull_request) Failing after 49s
Closes #82

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 02:32:01 -05:00
Cal Corum
75a8fc8505 fix: replace deprecated Pydantic .dict() with .model_dump() (#76)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m26s
Closes #76

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 02:02:49 -05:00
Cal Corum
dcaf184ad3 fix: add type annotations to untyped query parameters (#73)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m11s
Closes #73

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 00:02:57 -05:00
Cal Corum
1bcde424c6 Address PR review feedback for DISCORD_WEBHOOK_URL env var
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m32s
- Add DISCORD_WEBHOOK_URL to docker-compose.yml api service environment block
- Add empty placeholder entry in .env for discoverability
- Move DISCORD_WEBHOOK_URL constant to the env-var constants section at top of dependencies.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-26 23:23:26 -05:00
Cal Corum
3be4f71e22 fix: move hardcoded Discord webhook URL to environment variable
All checks were successful
Build Docker Image / build (pull_request) Successful in 3m42s
Replace inline webhook URL+token with DISCORD_WEBHOOK_URL env var.
Logs a warning and returns False gracefully if the var is unset.

The exposed webhook token should be rotated in Discord.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:15:21 -05:00
Cal Corum
c451e02c52 fix: remove hardcoded fallback password from PostgreSQL connection
All checks were successful
Build Docker Image / build (pull_request) Successful in 18m25s
Raise RuntimeError on startup if POSTGRES_PASSWORD env var is not set,
instead of silently falling back to a known password in source code.

Closes #C2 from postgres migration review.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:15:07 -05:00
Cal Corum
a21bb2a380 fix: add combined_season classmethod to PitchingStat (#65)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m17s
Closes #65

`PitchingStat.combined_season()` was referenced in the `get_pitstats`
handler but never defined, causing a 500 on `s_type=combined/total/all`.

Added `combined_season` as a `@staticmethod` matching the pattern of
`BattingStat.combined_season` — returns all rows for the given season
with no week filter (both regular and postseason).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 17:33:47 -05:00
cal
da679b6d1a Merge pull request 'Release: merge next-release into main' (#64) from next-release into main
All checks were successful
Build Docker Image / build (push) Successful in 1m6s
Reviewed-on: #64
2026-03-17 21:43:36 +00:00
Cal Corum
697152808b fix: validate sort_by parameter with Literal type in views.py (#36)
All checks were successful
Build Docker Image / build (pull_request) Successful in 4m13s
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:58 -05:00
Cal Corum
c40426d175 fix: remove unimplementable skipped caching tests (#33)
The three skipped tests in TestPlayerServiceCache required caching
in get_players() (read-through cache) and cache propagation through
the cls() pattern in write methods — neither is implemented and the
architecture does not support it without significant refactoring.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:58 -05:00
Cal Corum
95ff5eeaf9 fix: replace print(req.scope) with logger.debug in /api/docs (#21)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:58 -05:00
Cal Corum
a0d5d49724 fix: address review feedback (#52)
Guard bulk ID queries against empty lists to prevent PostgreSQL
syntax error (WHERE id IN ()) when batch POST endpoints receive
empty request bodies.

Affected endpoints:
- POST /api/v3/transactions
- POST /api/v3/results
- POST /api/v3/schedules
- POST /api/v3/battingstats

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:58 -05:00
Cal Corum
b0fd1d89ea fix: eliminate N+1 queries in batch POST endpoints (#25)
Replace per-row Team/Player lookups with bulk IN-list queries before
the validation loop in post_transactions, post_results, post_schedules,
and post_batstats. A 50-move batch now uses 2 queries instead of 150.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:58 -05:00
Cal Corum
5ac9cce7f0 fix: replace bare except: with except Exception: (#29)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:29:21 -05:00
Cal Corum
0e132e602f fix: remove unused imports in standings.py and pitchingstats.py (#30)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:28:25 -05:00
Cal Corum
d92bb263f1 fix: invalidate cache after PlayerService write operations (#32)
Add finally blocks to update_player, patch_player, create_players, and
delete_player in PlayerService to call invalidate_related_cache() using
the existing cache_patterns. Matches the pattern already used in
TeamService.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:28:07 -05:00
Cal Corum
9558da6ace fix: remove empty WEEK_NUMS dict from db_engine.py (#34)
Dead code - module-level constant defined but never referenced.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:28:07 -05:00
Cal Corum
6d1b0ac747 perf: push limit/offset to DB in PlayerService.get_players (#37)
Apply .offset() and .limit() on the Peewee query before materializing
results, instead of fetching all rows into memory and slicing in Python.
Total count is obtained via query.count() before pagination is applied.
In-memory (mock) queries continue to use Python-level slicing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:28:07 -05:00
Cal Corum
a351010c3c fix: calculate lob_2outs and rbipercent in SeasonPitchingStats (#28)
Both fields were hardcoded to 0.0 in the INSERT. Added SQL expressions
to the pitching_stats CTE to calculate them from stratplay data, using
the same logic as the batting stats endpoint.

- lob_2outs: count of runners stranded when pitcher recorded the 3rd out
- rbipercent: RBI allowed (excluding HR) per runner opportunity

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 16:28:07 -05:00
45 changed files with 2667 additions and 1666 deletions

3
.env
View File

@ -6,6 +6,9 @@ SBA_DB_USER_PASSWORD=your_production_password
# SBa API # SBa API
API_TOKEN=Tp3aO3jhYve5NJF1IqOmJTmk API_TOKEN=Tp3aO3jhYve5NJF1IqOmJTmk
# Integrations
DISCORD_WEBHOOK_URL=
# Universal # Universal
TZ=America/Chicago TZ=America/Chicago
LOG_LEVEL=INFO LOG_LEVEL=INFO

View File

@ -1,20 +1,18 @@
# Gitea Actions: Docker Build, Push, and Notify # Gitea Actions: Docker Build, Push, and Notify
# #
# CI/CD pipeline for Major Domo Database API: # CI/CD pipeline for Major Domo Database API:
# - Builds Docker images on every push/PR # - Triggered by pushing a CalVer tag (e.g., 2026.4.5)
# - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges # - Builds Docker image and pushes to Docker Hub with version + latest tags
# - Pushes to Docker Hub and creates git tag on main
# - Sends Discord notifications on success/failure # - Sends Discord notifications on success/failure
#
# To release: git tag -a 2026.4.5 -m "description" && git push origin 2026.4.5
name: Build Docker Image name: Build Docker Image
on: on:
push: push:
branches: tags:
- main - '20*' # matches CalVer tags like 2026.4.5
pull_request:
branches:
- main
jobs: jobs:
build: build:
@ -24,7 +22,16 @@ jobs:
- name: Checkout code - name: Checkout code
uses: https://github.com/actions/checkout@v4 uses: https://github.com/actions/checkout@v4
with: with:
fetch-depth: 0 # Full history for tag counting fetch-depth: 0
- name: Extract version from tag
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/}
SHA_SHORT=$(git rev-parse --short HEAD)
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "sha_short=$SHA_SHORT" >> $GITHUB_OUTPUT
echo "timestamp=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: https://github.com/docker/setup-buildx-action@v3 uses: https://github.com/docker/setup-buildx-action@v3
@ -35,80 +42,47 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate CalVer version - name: Build and push Docker image
id: calver
uses: cal/gitea-actions/calver@main
# Dev build: push with dev + dev-SHA tags (PR/feature branches)
- name: Build Docker image (dev)
if: github.ref != 'refs/heads/main'
uses: https://github.com/docker/build-push-action@v5
with:
context: .
push: true
tags: |
manticorum67/major-domo-database:dev
manticorum67/major-domo-database:dev-${{ steps.calver.outputs.sha_short }}
cache-from: type=registry,ref=manticorum67/major-domo-database:buildcache
cache-to: type=registry,ref=manticorum67/major-domo-database:buildcache,mode=max
# Production build: push with latest + CalVer tags (main only)
- name: Build Docker image (production)
if: github.ref == 'refs/heads/main'
uses: https://github.com/docker/build-push-action@v5 uses: https://github.com/docker/build-push-action@v5
with: with:
context: . context: .
push: true push: true
tags: | tags: |
manticorum67/major-domo-database:${{ steps.version.outputs.version }}
manticorum67/major-domo-database:latest manticorum67/major-domo-database:latest
manticorum67/major-domo-database:${{ steps.calver.outputs.version }}
manticorum67/major-domo-database:${{ steps.calver.outputs.version_sha }}
cache-from: type=registry,ref=manticorum67/major-domo-database:buildcache cache-from: type=registry,ref=manticorum67/major-domo-database:buildcache
cache-to: type=registry,ref=manticorum67/major-domo-database:buildcache,mode=max cache-to: type=registry,ref=manticorum67/major-domo-database:buildcache,mode=max
- name: Tag release
if: success() && github.ref == 'refs/heads/main'
uses: cal/gitea-actions/gitea-tag@main
with:
version: ${{ steps.calver.outputs.version }}
token: ${{ github.token }}
- name: Build Summary - name: Build Summary
run: | run: |
echo "## Docker Build Successful" >> $GITHUB_STEP_SUMMARY echo "## Docker Build Successful" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
echo "**Version:** \`${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Image Tags:**" >> $GITHUB_STEP_SUMMARY echo "**Image Tags:**" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-database:${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-database:latest\`" >> $GITHUB_STEP_SUMMARY echo "- \`manticorum67/major-domo-database:latest\`" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-database:${{ steps.calver.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-database:${{ steps.calver.outputs.version_sha }}\`" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
echo "**Build Details:**" >> $GITHUB_STEP_SUMMARY echo "**Build Details:**" >> $GITHUB_STEP_SUMMARY
echo "- Branch: \`${{ steps.calver.outputs.branch }}\`" >> $GITHUB_STEP_SUMMARY echo "- Commit: \`${{ steps.version.outputs.sha_short }}\`" >> $GITHUB_STEP_SUMMARY
echo "- Commit: \`${{ github.sha }}\`" >> $GITHUB_STEP_SUMMARY echo "- Timestamp: \`${{ steps.version.outputs.timestamp }}\`" >> $GITHUB_STEP_SUMMARY
echo "- Timestamp: \`${{ steps.calver.outputs.timestamp }}\`" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY
if [ "${{ github.ref }}" == "refs/heads/main" ]; then echo "Pull with: \`docker pull manticorum67/major-domo-database:${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
echo "Pushed to Docker Hub!" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Pull with: \`docker pull manticorum67/major-domo-database:latest\`" >> $GITHUB_STEP_SUMMARY
else
echo "_PR build - image not pushed to Docker Hub_" >> $GITHUB_STEP_SUMMARY
fi
- name: Discord Notification - Success - name: Discord Notification - Success
if: success() && github.ref == 'refs/heads/main' if: success()
uses: cal/gitea-actions/discord-notify@main uses: cal/gitea-actions/discord-notify@main
with: with:
webhook_url: ${{ secrets.DISCORD_WEBHOOK }} webhook_url: ${{ secrets.DISCORD_WEBHOOK }}
title: "Major Domo Database" title: "Major Domo Database"
status: success status: success
version: ${{ steps.calver.outputs.version }} version: ${{ steps.version.outputs.version }}
image_tag: ${{ steps.calver.outputs.version_sha }} image_tag: ${{ steps.version.outputs.version }}
commit_sha: ${{ steps.calver.outputs.sha_short }} commit_sha: ${{ steps.version.outputs.sha_short }}
timestamp: ${{ steps.calver.outputs.timestamp }} timestamp: ${{ steps.version.outputs.timestamp }}
- name: Discord Notification - Failure - name: Discord Notification - Failure
if: failure() && github.ref == 'refs/heads/main' if: failure()
uses: cal/gitea-actions/discord-notify@main uses: cal/gitea-actions/discord-notify@main
with: with:
webhook_url: ${{ secrets.DISCORD_WEBHOOK }} webhook_url: ${{ secrets.DISCORD_WEBHOOK }}

1
.gitignore vendored
View File

@ -55,7 +55,6 @@ Include/
pyvenv.cfg pyvenv.cfg
db_engine.py db_engine.py
main.py main.py
migrations.py
db_engine.py db_engine.py
sba_master.db sba_master.db
db_engine.py db_engine.py

View File

@ -40,7 +40,7 @@ python migrations.py # Run migrations (SQL files in migrat
- **Bot container**: `dev_sba_postgres` (PostgreSQL) + `dev_sba_db_api` (API) — check with `docker ps` - **Bot container**: `dev_sba_postgres` (PostgreSQL) + `dev_sba_db_api` (API) — check with `docker ps`
- **Image**: `manticorum67/major-domo-database:dev` (Docker Hub) - **Image**: `manticorum67/major-domo-database:dev` (Docker Hub)
- **CI/CD**: Gitea Actions on PR to `main` — builds Docker image, auto-generates CalVer version (`YYYY.MM.BUILD`) on merge - **CI/CD**: Gitea Actions — tag-triggered Docker builds. Push a CalVer tag to release: `git tag -a 2026.4.5 -m "description" && git push origin 2026.4.5`
## Important ## Important

View File

@ -1,5 +1,5 @@
# Use specific version for reproducible builds # Use specific version for reproducible builds
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11 FROM tiangolo/uvicorn-gunicorn-fastapi:python3.12
# Set Python optimizations # Set Python optimizations
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1

File diff suppressed because it is too large Load Diff

View File

@ -11,8 +11,8 @@ from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis import Redis from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
@ -22,11 +22,14 @@ logger = logging.getLogger('discord_app')
# level=log_level # level=log_level
# ) # )
# Discord integration
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
# Redis configuration # Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
REDIS_DB = int(os.environ.get('REDIS_DB', '0')) REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
# Initialize Redis client with connection error handling # Initialize Redis client with connection error handling
if not CACHE_ENABLED: if not CACHE_ENABLED:
@ -40,7 +43,7 @@ else:
db=REDIS_DB, db=REDIS_DB,
decode_responses=True, decode_responses=True,
socket_connect_timeout=5, socket_connect_timeout=5,
socket_timeout=5 socket_timeout=5,
) )
# Test connection # Test connection
redis_client.ping() redis_client.ping()
@ -50,12 +53,19 @@ else:
redis_client = None redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() priv_help = (
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False False
if not os.environ.get("PRIVATE_IN_SCHEMA")
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
MAX_LIMIT = 500
DEFAULT_LIMIT = 200
def valid_token(token): def valid_token(token):
return token == os.environ.get('API_TOKEN') return token == os.environ.get("API_TOKEN")
def update_season_batting_stats(player_ids, season, db_connection): def update_season_batting_stats(player_ids, season, db_connection):
@ -72,7 +82,9 @@ def update_season_batting_stats(player_ids, season, db_connection):
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season batting stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert batting stats # SQL query to recalculate and upsert batting stats
@ -221,7 +233,9 @@ def update_season_batting_stats(player_ids, season, db_connection):
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season batting stats: {e}") logger.error(f"Error updating season batting stats: {e}")
@ -242,7 +256,9 @@ def update_season_pitching_stats(player_ids, season, db_connection):
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert pitching stats # SQL query to recalculate and upsert pitching stats
@ -357,7 +373,27 @@ def update_season_pitching_stats(player_ids, season, db_connection):
WHEN SUM(sp.bb) > 0 WHEN SUM(sp.bb) > 0
THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2) THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2)
ELSE 0.0 ELSE 0.0
END AS kperbb END AS kperbb,
-- Runners left on base when pitcher recorded the 3rd out
SUM(CASE WHEN sp.on_first_final IS NOT NULL AND sp.on_first_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_final IS NOT NULL AND sp.on_second_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_final IS NOT NULL AND sp.on_third_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) AS lob_2outs,
-- RBI allowed (excluding HR) per runner opportunity
CASE
WHEN (SUM(CASE WHEN sp.on_first_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_id IS NOT NULL THEN 1 ELSE 0 END)) > 0
THEN ROUND(
(SUM(sp.rbi) - SUM(sp.homerun))::DECIMAL /
(SUM(CASE WHEN sp.on_first_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_id IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_id IS NOT NULL THEN 1 ELSE 0 END)),
3
)
ELSE 0.000
END AS rbipercent
FROM stratplay sp FROM stratplay sp
JOIN stratgame sg ON sg.id = sp.game_id JOIN stratgame sg ON sg.id = sp.game_id
@ -402,7 +438,7 @@ def update_season_pitching_stats(player_ids, season, db_connection):
ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk, ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk,
ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba, ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba,
ps.hper9, ps.kper9, ps.bbper9, ps.kperbb, ps.hper9, ps.kper9, ps.bbper9, ps.kperbb,
0.0, 0.0, COALESCE(ps.re24 * -1, 0.0) ps.lob_2outs, ps.rbipercent, COALESCE(ps.re24 * -1, 0.0)
FROM pitching_stats ps FROM pitching_stats ps
LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season
ON CONFLICT (player_id, season) ON CONFLICT (player_id, season)
@ -464,7 +500,9 @@ def update_season_pitching_stats(player_ids, season, db_connection):
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season pitching stats: {e}") logger.error(f"Error updating season pitching stats: {e}")
@ -481,12 +519,15 @@ def send_webhook_message(message: str) -> bool:
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" webhook_url = DISCORD_WEBHOOK_URL
if not webhook_url:
logger.warning(
"DISCORD_WEBHOOK_URL env var is not set — skipping webhook message"
)
return False
try: try:
payload = { payload = {"content": message}
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10) response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status() response.raise_for_status()
@ -502,7 +543,9 @@ def send_webhook_message(message: str) -> bool:
return False return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): def cache_result(
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
):
""" """
Decorator to cache function results in Redis with parameter normalization. Decorator to cache function results in Redis with parameter normalization.
@ -520,6 +563,7 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
# These will use the same cache entry when normalize_params=True: # These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123) # get_player_stats(123, None) and get_player_stats(123)
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -533,15 +577,16 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
if normalize_params: if normalize_params:
# Remove None values and empty collections # Remove None values and empty collections
normalized_kwargs = { normalized_kwargs = {
k: v for k, v in kwargs.items() k: v
for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {} if v is not None and v != [] and v != "" and v != {}
} }
# Generate more readable cache key # Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None) args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([ kwargs_str = "_".join(
f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
]) )
# Combine args and kwargs for cache key # Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__] key_parts = [key_prefix, func.__name__]
@ -572,10 +617,12 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
redis_client.setex( redis_client.setex(
cache_key, cache_key,
ttl, ttl,
json.dumps(result, default=str, ensure_ascii=False) json.dumps(result, default=str, ensure_ascii=False),
) )
else: else:
logger.debug(f"Skipping cache for Response object from {func.__name__}") logger.debug(
f"Skipping cache for Response object from {func.__name__}"
)
return result return result
@ -585,6 +632,7 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -607,7 +655,9 @@ def invalidate_cache(pattern: str = "*"):
keys = redis_client.keys(pattern) keys = redis_client.keys(pattern)
if keys: if keys:
deleted = redis_client.delete(*keys) deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") logger.info(
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
)
return deleted return deleted
else: else:
logger.debug(f"No cache entries found matching pattern: {pattern}") logger.debug(f"No cache entries found matching pattern: {pattern}")
@ -634,7 +684,7 @@ def get_cache_stats() -> dict:
"memory_used": info.get("used_memory_human", "unknown"), "memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(), "total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0) "uptime_seconds": info.get("uptime_in_seconds", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting cache stats: {e}") logger.error(f"Error getting cache stats: {e}")
@ -645,7 +695,7 @@ def add_cache_headers(
max_age: int = 300, max_age: int = 300,
cache_type: str = "public", cache_type: str = "public",
vary: Optional[str] = None, vary: Optional[str] = None,
etag: bool = False etag: bool = False,
): ):
""" """
Decorator to add HTTP cache headers to FastAPI responses. Decorator to add HTTP cache headers to FastAPI responses.
@ -665,6 +715,7 @@ def add_cache_headers(
async def get_user_data(): async def get_user_data():
return {"data": "user specific"} return {"data": "user specific"}
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -677,7 +728,7 @@ def add_cache_headers(
# Convert to Response with JSON content # Convert to Response with JSON content
response = Response( response = Response(
content=json.dumps(result, default=str, ensure_ascii=False), content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json" media_type="application/json",
) )
else: else:
# Handle other response types # Handle other response types
@ -695,20 +746,23 @@ def add_cache_headers(
response.headers["Vary"] = vary response.headers["Vary"] = vary
# Add ETag if requested # Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): if etag and (
hasattr(result, "__dict__") or isinstance(result, (dict, list))
):
content_hash = hashlib.md5( content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode() json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest() ).hexdigest()
response.headers["ETag"] = f'"{content_hash}"' response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content # Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( response.headers["Last-Modified"] = datetime.datetime.now(
"%a, %d %b %Y %H:%M:%S GMT" datetime.timezone.utc
) ).strftime("%a, %d %b %Y %H:%M:%S GMT")
return response return response
return wrapper return wrapper
return decorator return decorator
@ -718,6 +772,7 @@ def handle_db_errors(func):
Ensures proper cleanup of database connections and provides consistent error handling. Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces. Includes comprehensive logging with function context, timing, and stack traces.
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
import time import time
@ -734,18 +789,24 @@ def handle_db_errors(func):
try: try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.) # Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args: for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object if hasattr(arg, "__dict__") and hasattr(
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") arg, "url"
): # FastAPI Request object
safe_args.append(
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
)
else: else:
safe_args.append(str(arg)[:100]) # Truncate long values safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']: if key.lower() in ["token", "password", "secret", "key"]:
safe_kwargs[key] = '[REDACTED]' safe_kwargs[key] = "[REDACTED]"
else: else:
safe_kwargs[key] = str(value)[:100] # Truncate long values safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") logger.info(
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
)
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
@ -754,6 +815,10 @@ def handle_db_errors(func):
return result return result
except HTTPException:
# Let intentional HTTP errors (401, 404, etc.) pass through unchanged
raise
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
error_trace = traceback.format_exc() error_trace = traceback.format_exc()
@ -775,8 +840,12 @@ def handle_db_errors(func):
db.close() db.close()
logger.info(f"Database connection closed for {func_name}") logger.info(f"Database connection closed for {func_name}")
except Exception as close_error: except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}") logger.error(
f"Error closing database connection in {func_name}: {close_error}"
)
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') raise HTTPException(
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
)
return wrapper return wrapper

View File

@ -2,46 +2,112 @@ import datetime
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
import os import os
from urllib.parse import parse_qsl, urlencode
from fastapi import Depends, FastAPI, Request from fastapi import Depends, FastAPI, Request
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from .db_engine import db
# from fastapi.openapi.docs import get_swagger_ui_html # from fastapi.openapi.docs import get_swagger_ui_html
# from fastapi.openapi.utils import get_openapi # from fastapi.openapi.utils import get_openapi
from .routers_v3 import current, players, results, schedules, standings, teams, transactions, battingstats, pitchingstats, fieldingstats, draftpicks, draftlist, managers, awards, draftdata, keepers, stratgame, stratplay, injuries, decisions, divisions, sbaplayers, custom_commands, help_commands, views from .routers_v3 import (
current,
players,
results,
schedules,
standings,
teams,
transactions,
battingstats,
pitchingstats,
fieldingstats,
draftpicks,
draftlist,
managers,
awards,
draftdata,
keepers,
stratgame,
stratplay,
injuries,
decisions,
divisions,
sbaplayers,
custom_commands,
help_commands,
views,
)
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
log_level = logging.INFO if os.environ.get('LOG_LEVEL') == 'INFO' else logging.WARNING log_level = logging.INFO if os.environ.get("LOG_LEVEL") == "INFO" else logging.WARNING
# logging.basicConfig( # logging.basicConfig(
# filename=f'logs/database/{date}.log', # filename=f'logs/database/{date}.log',
# format='%(asctime)s - sba-database - %(levelname)s - %(message)s', # format='%(asctime)s - sba-database - %(levelname)s - %(message)s',
# level=log_level # level=log_level
# ) # )
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
logger.setLevel(log_level) logger.setLevel(log_level)
handler = RotatingFileHandler( handler = RotatingFileHandler(
filename='./logs/sba-database.log', filename="./logs/sba-database.log",
# encoding='utf-8', # encoding='utf-8',
maxBytes=8 * 1024 * 1024, # 8 MiB maxBytes=8 * 1024 * 1024, # 8 MiB
backupCount=5, # Rotate through 5 files backupCount=5, # Rotate through 5 files
) )
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
app = FastAPI( app = FastAPI(
# root_path='/api', # root_path='/api',
responses={404: {'description': 'Not found'}}, responses={404: {"description": "Not found"}},
docs_url='/api/docs', docs_url="/api/docs",
redoc_url='/api/redoc' redoc_url="/api/redoc",
) )
logger.info(f'Starting up now...') @app.middleware("http")
async def db_connection_middleware(request: Request, call_next):
db.connect(reuse_if_open=True)
try:
response = await call_next(request)
finally:
if not db.is_closed():
db.close()
return response
logger.info(f"Starting up now...")
@app.middleware("http")
async def db_connection_middleware(request: Request, call_next):
from .db_engine import db
db.connect(reuse_if_open=True)
try:
response = await call_next(request)
return response
finally:
if not db.is_closed():
db.close()
@app.middleware("http")
async def strip_empty_query_params(request: Request, call_next):
qs = request.scope.get("query_string", b"")
if qs:
pairs = parse_qsl(qs.decode(), keep_blank_values=True)
filtered = [(k, v) for k, v in pairs if v != ""]
new_qs = urlencode(filtered).encode()
request.scope["query_string"] = new_qs
if hasattr(request, "_query_params"):
del request._query_params
return await call_next(request)
app.include_router(current.router) app.include_router(current.router)
@ -70,18 +136,20 @@ app.include_router(custom_commands.router)
app.include_router(help_commands.router) app.include_router(help_commands.router)
app.include_router(views.router) app.include_router(views.router)
logger.info(f'Loaded all routers.') logger.info(f"Loaded all routers.")
@app.get("/api/docs", include_in_schema=False) @app.get("/api/docs", include_in_schema=False)
async def get_docs(req: Request): async def get_docs(req: Request):
print(req.scope) logger.debug(req.scope)
return get_swagger_ui_html(openapi_url=req.scope.get('root_path')+'/openapi.json', title='Swagger') return get_swagger_ui_html(
openapi_url=req.scope.get("root_path") + "/openapi.json", title="Swagger"
)
@app.get("/api/openapi.json", include_in_schema=False) @app.get("/api/openapi.json", include_in_schema=False)
async def openapi(): async def openapi():
return get_openapi(title='SBa API Docs', version=f'0.1.1', routes=app.routes) return get_openapi(title="SBa API Docs", version=f"0.1.1", routes=app.routes)
# @app.get("/api") # @app.get("/api")

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -43,6 +45,8 @@ async def get_awards(
team_id: list = Query(default=None), team_id: list = Query(default=None),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
player_name: list = Query(default=None), player_name: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_awards = Award.select() all_awards = Award.select()
@ -67,11 +71,13 @@ async def get_awards(
all_players = Player.select().where(fn.Lower(Player.name) << pname_list) all_players = Player.select().where(fn.Lower(Player.name) << pname_list)
all_awards = all_awards.where(Award.player << all_players) all_awards = all_awards.where(Award.player << all_players)
total_count = all_awards.count()
all_awards = all_awards.offset(offset).limit(limit)
return_awards = { return_awards = {
"count": all_awards.count(), "count": total_count,
"awards": [model_to_dict(x, recurse=not short_output) for x in all_awards], "awards": [model_to_dict(x, recurse=not short_output) for x in all_awards],
} }
db.close()
return return_awards return return_awards
@ -80,10 +86,8 @@ async def get_awards(
async def get_one_award(award_id: int, short_output: Optional[bool] = False): async def get_one_award(award_id: int, short_output: Optional[bool] = False):
this_award = Award.get_or_none(Award.id == award_id) this_award = Award.get_or_none(Award.id == award_id)
if this_award is None: if this_award is None:
db.close()
raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found") raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found")
db.close()
return model_to_dict(this_award, recurse=not short_output) return model_to_dict(this_award, recurse=not short_output)
@ -102,12 +106,11 @@ async def patch_award(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_player - Bad Token: {token}") logger.warning("patch_player - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_award = Award.get_or_none(Award.id == award_id) this_award = Award.get_or_none(Award.id == award_id)
if this_award is None: if this_award is None:
db.close()
raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found") raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found")
if name is not None: if name is not None:
@ -129,10 +132,8 @@ async def patch_award(
if this_award.save() == 1: if this_award.save() == 1:
r_award = model_to_dict(this_award) r_award = model_to_dict(this_award)
db.close()
return r_award return r_award
else: else:
db.close()
raise HTTPException(status_code=500, detail=f"Unable to patch award {award_id}") raise HTTPException(status_code=500, detail=f"Unable to patch award {award_id}")
@ -140,7 +141,7 @@ async def patch_award(
@handle_db_errors @handle_db_errors
async def post_award(award_list: AwardList, token: str = Depends(oauth2_scheme)): async def post_award(award_list: AwardList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_player - Bad Token: {token}") logger.warning("patch_player - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_awards = [] new_awards = []
@ -171,12 +172,11 @@ async def post_award(award_list: AwardList, token: str = Depends(oauth2_scheme))
status_code=404, detail=f"Team ID {x.team_id} not found" status_code=404, detail=f"Team ID {x.team_id} not found"
) )
new_awards.append(x.dict()) new_awards.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_awards, 15): for batch in chunked(new_awards, 15):
Award.insert_many(batch).on_conflict_ignore().execute() Award.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_awards)} awards" return f"Inserted {len(new_awards)} awards"
@ -185,16 +185,14 @@ async def post_award(award_list: AwardList, token: str = Depends(oauth2_scheme))
@handle_db_errors @handle_db_errors
async def delete_award(award_id: int, token: str = Depends(oauth2_scheme)): async def delete_award(award_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_player - Bad Token: {token}") logger.warning("patch_player - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_award = Award.get_or_none(Award.id == award_id) this_award = Award.get_or_none(Award.id == award_id)
if this_award is None: if this_award is None:
db.close()
raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found") raise HTTPException(status_code=404, detail=f"Award ID {award_id} not found")
count = this_award.delete_instance() count = this_award.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Award {award_id} has been deleted" return f"Award {award_id} has been deleted"

View File

@ -19,6 +19,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -84,24 +86,21 @@ async def get_batstats(
week_end: Optional[int] = None, week_end: Optional[int] = None,
game_num: list = Query(default=None), game_num: list = Query(default=None),
position: list = Query(default=None), position: list = Query(default=None),
limit: Optional[int] = None, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
sort: Optional[str] = None, sort: Optional[str] = None,
short_output: Optional[bool] = True, short_output: Optional[bool] = True,
): ):
if "post" in s_type.lower(): if "post" in s_type.lower():
all_stats = BattingStat.post_season(season) all_stats = BattingStat.post_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
elif s_type.lower() in ["combined", "total", "all"]: elif s_type.lower() in ["combined", "total", "all"]:
all_stats = BattingStat.combined_season(season) all_stats = BattingStat.combined_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
else: else:
all_stats = BattingStat.regular_season(season) all_stats = BattingStat.regular_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
if position is not None: if position is not None:
@ -127,14 +126,12 @@ async def get_batstats(
if week_end is not None: if week_end is not None:
end = min(week_end, end) end = min(week_end, end)
if start > end: if start > end:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Start week {start} is after end week {end} - cannot pull stats", detail=f"Start week {start} is after end week {end} - cannot pull stats",
) )
all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end)) all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
if limit:
all_stats = all_stats.limit(limit) all_stats = all_stats.limit(limit)
if sort: if sort:
if sort == "newest": if sort == "newest":
@ -146,7 +143,6 @@ async def get_batstats(
# 'stats': [{'id': x.id} for x in all_stats] # 'stats': [{'id': x.id} for x in all_stats]
} }
db.close()
return return_stats return return_stats
@ -168,6 +164,8 @@ async def get_totalstats(
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
min_pa: Optional[int] = 1, min_pa: Optional[int] = 1,
week: list = Query(default=None), week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
raise HTTPException( raise HTTPException(
@ -301,7 +299,10 @@ async def get_totalstats(
all_players = Player.select().where(Player.id << player_id) all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
return_stats = {"count": all_stats.count(), "stats": []} total_count = all_stats.count()
all_stats = all_stats.offset(offset).limit(limit)
return_stats = {"count": total_count, "stats": []}
for x in all_stats: for x in all_stats:
# Handle player field based on grouping with safe access # Handle player field based on grouping with safe access
@ -344,7 +345,6 @@ async def get_totalstats(
"bplo": x.sum_bplo, "bplo": x.sum_bplo,
} }
) )
db.close()
return return_stats return return_stats
@ -360,15 +360,16 @@ async def patch_batstats(
stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme) stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_batstats - Bad Token: {token}") logger.warning("patch_batstats - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
if BattingStat.get_or_none(BattingStat.id == stat_id) is None: if BattingStat.get_or_none(BattingStat.id == stat_id) is None:
raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found") raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found")
BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute() BattingStat.update(**new_stats.model_dump()).where(
BattingStat.id == stat_id
).execute()
r_stat = model_to_dict(BattingStat.get_by_id(stat_id)) r_stat = model_to_dict(BattingStat.get_by_id(stat_id))
db.close()
return r_stat return r_stat
@ -376,24 +377,35 @@ async def patch_batstats(
@handle_db_errors @handle_db_errors
async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)): async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_batstats - Bad Token: {token}") logger.warning("post_batstats - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
all_stats = [] all_stats = []
all_team_ids = list(set(x.team_id for x in s_list.stats))
all_player_ids = list(set(x.player_id for x in s_list.stats))
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
found_player_ids = (
set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids))
if all_player_ids
else set()
)
for x in s_list.stats: for x in s_list.stats:
team = Team.get_or_none(Team.id == x.team_id) if x.team_id not in found_team_ids:
this_player = Player.get_or_none(Player.id == x.player_id)
if team is None:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.team_id} not found" status_code=404, detail=f"Team ID {x.team_id} not found"
) )
if this_player is None: if x.player_id not in found_player_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Player ID {x.player_id} not found" status_code=404, detail=f"Player ID {x.player_id} not found"
) )
all_stats.append(BattingStat(**x.dict())) all_stats.append(BattingStat(**x.model_dump()))
with db.atomic(): with db.atomic():
for batch in chunked(all_stats, 15): for batch in chunked(all_stats, 15):
@ -401,5 +413,4 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)
# Update career stats # Update career stats
db.close()
return f"Added {len(all_stats)} batting lines" return f"Added {len(all_stats)} batting lines"

View File

@ -41,7 +41,6 @@ async def get_current(season: Optional[int] = None):
if current is not None: if current is not None:
r_curr = model_to_dict(current) r_curr = model_to_dict(current)
db.close()
return r_curr return r_curr
else: else:
return None return None
@ -65,7 +64,7 @@ async def patch_current(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_current - Bad Token: {token}") logger.warning("patch_current - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -100,10 +99,8 @@ async def patch_current(
if current.save(): if current.save():
r_curr = model_to_dict(current) r_curr = model_to_dict(current)
db.close()
return r_curr return r_curr
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch current {current_id}" status_code=500, detail=f"Unable to patch current {current_id}"
) )
@ -113,17 +110,15 @@ async def patch_current(
@handle_db_errors @handle_db_errors
async def post_current(new_current: CurrentModel, token: str = Depends(oauth2_scheme)): async def post_current(new_current: CurrentModel, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_current - Bad Token: {token}") logger.warning("patch_current - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_current = Current(**new_current.dict()) this_current = Current(**new_current.model_dump())
if this_current.save(): if this_current.save():
r_curr = model_to_dict(this_current) r_curr = model_to_dict(this_current)
db.close()
return r_curr return r_curr
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Unable to post season {new_current.season} current", detail=f"Unable to post season {new_current.season} current",
@ -134,7 +129,7 @@ async def post_current(new_current: CurrentModel, token: str = Depends(oauth2_sc
@handle_db_errors @handle_db_errors
async def delete_current(current_id: int, token: str = Depends(oauth2_scheme)): async def delete_current(current_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_current - Bad Token: {token}") logger.warning("patch_current - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
if Current.delete_by_id(current_id) == 1: if Current.delete_by_id(current_id) == 1:

View File

@ -175,7 +175,7 @@ def delete_custom_command(command_id: int):
def get_creator_by_discord_id(discord_id: int): def get_creator_by_discord_id(discord_id: int):
"""Get a creator by Discord ID""" """Get a creator by Discord ID"""
creator = CustomCommandCreator.get_or_none( creator = CustomCommandCreator.get_or_none(
CustomCommandCreator.discord_id == str(discord_id) CustomCommandCreator.discord_id == discord_id
) )
if creator: if creator:
return model_to_dict(creator) return model_to_dict(creator)
@ -296,9 +296,8 @@ async def get_custom_commands(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
# Get full creator information # Get full creator information
creator_id = command_dict["creator_id"] creator_id = command_dict["creator_id"]
creator_cursor = db.execute_sql( creator_cursor = db.execute_sql(
@ -364,8 +363,6 @@ async def get_custom_commands(
except Exception as e: except Exception as e:
logger.error(f"Error getting custom commands: {e}") logger.error(f"Error getting custom commands: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# Move this route to after the specific string routes # Move this route to after the specific string routes
@ -378,7 +375,7 @@ async def create_custom_command_endpoint(
): ):
"""Create a new custom command""" """Create a new custom command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"create_custom_command - Bad Token: {token}") logger.warning("create_custom_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -406,7 +403,7 @@ async def create_custom_command_endpoint(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -430,8 +427,6 @@ async def create_custom_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error creating custom command: {e}") logger.error(f"Error creating custom command: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.put("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA) @router.put("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@ -441,7 +436,7 @@ async def update_custom_command_endpoint(
): ):
"""Update an existing custom command""" """Update an existing custom command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"update_custom_command - Bad Token: {token}") logger.warning("update_custom_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -467,7 +462,7 @@ async def update_custom_command_endpoint(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -491,8 +486,6 @@ async def update_custom_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error updating custom command {command_id}: {e}") logger.error(f"Error updating custom command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.patch("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@ -509,7 +502,7 @@ async def patch_custom_command(
): ):
"""Partially update a custom command""" """Partially update a custom command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_custom_command - Bad Token: {token}") logger.warning("patch_custom_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -552,7 +545,7 @@ async def patch_custom_command(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -576,8 +569,6 @@ async def patch_custom_command(
except Exception as e: except Exception as e:
logger.error(f"Error patching custom command {command_id}: {e}") logger.error(f"Error patching custom command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.delete("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA) @router.delete("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@ -587,7 +578,7 @@ async def delete_custom_command_endpoint(
): ):
"""Delete a custom command""" """Delete a custom command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_custom_command - Bad Token: {token}") logger.warning("delete_custom_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -613,8 +604,6 @@ async def delete_custom_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error deleting custom command {command_id}: {e}") logger.error(f"Error deleting custom command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# Creator endpoints # Creator endpoints
@ -684,8 +673,6 @@ async def get_creators(
except Exception as e: except Exception as e:
logger.error(f"Error getting creators: {e}") logger.error(f"Error getting creators: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/creators", include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/creators", include_in_schema=PRIVATE_IN_SCHEMA)
@ -695,7 +682,7 @@ async def create_creator_endpoint(
): ):
"""Create a new command creator""" """Create a new command creator"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"create_creator - Bad Token: {token}") logger.warning("create_creator - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -729,8 +716,6 @@ async def create_creator_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error creating creator: {e}") logger.error(f"Error creating creator: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/stats") @router.get("/stats")
@ -781,7 +766,7 @@ async def get_custom_command_stats():
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
command_dict["creator"] = { command_dict["creator"] = {
"discord_id": command_dict.pop("creator_discord_id"), "discord_id": command_dict.pop("creator_discord_id"),
@ -855,8 +840,6 @@ async def get_custom_command_stats():
except Exception as e: except Exception as e:
logger.error(f"Error getting custom command stats: {e}") logger.error(f"Error getting custom command stats: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# Special endpoints for Discord bot integration # Special endpoints for Discord bot integration
@ -881,7 +864,7 @@ async def get_custom_command_by_name_endpoint(command_name: str):
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
# Add creator info - get full creator record # Add creator info - get full creator record
@ -922,8 +905,6 @@ async def get_custom_command_by_name_endpoint(command_name: str):
except Exception as e: except Exception as e:
logger.error(f"Error getting custom command by name '{command_name}': {e}") logger.error(f"Error getting custom command by name '{command_name}': {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.patch("/by_name/{command_name}/execute", include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/by_name/{command_name}/execute", include_in_schema=PRIVATE_IN_SCHEMA)
@ -933,7 +914,7 @@ async def execute_custom_command(
): ):
"""Execute a custom command and update usage statistics""" """Execute a custom command and update usage statistics"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"execute_custom_command - Bad Token: {token}") logger.warning("execute_custom_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -966,7 +947,7 @@ async def execute_custom_command(
if updated_dict.get("tags"): if updated_dict.get("tags"):
try: try:
updated_dict["tags"] = json.loads(updated_dict["tags"]) updated_dict["tags"] = json.loads(updated_dict["tags"])
except: except Exception:
updated_dict["tags"] = [] updated_dict["tags"] = []
# Build creator object from the fields returned by get_custom_command_by_id # Build creator object from the fields returned by get_custom_command_by_id
@ -991,8 +972,6 @@ async def execute_custom_command(
except Exception as e: except Exception as e:
logger.error(f"Error executing custom command '{command_name}': {e}") logger.error(f"Error executing custom command '{command_name}': {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/autocomplete") @router.get("/autocomplete")
@ -1028,8 +1007,6 @@ async def get_command_names_for_autocomplete(
except Exception as e: except Exception as e:
logger.error(f"Error getting command names for autocomplete: {e}") logger.error(f"Error getting command names for autocomplete: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/{command_id}") @router.get("/{command_id}")
@ -1053,7 +1030,7 @@ async def get_custom_command(command_id: int):
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -1078,5 +1055,3 @@ async def get_custom_command(command_id: int):
except Exception as e: except Exception as e:
logger.error(f"Error getting custom command {command_id}: {e}") logger.error(f"Error getting custom command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()

View File

@ -19,6 +19,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -73,7 +75,7 @@ async def get_decisions(
irunners_scored: list = Query(default=None), irunners_scored: list = Query(default=None),
game_id: list = Query(default=None), game_id: list = Query(default=None),
player_id: list = Query(default=None), player_id: list = Query(default=None),
limit: Optional[int] = None, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
): ):
all_dec = Decision.select().order_by( all_dec = Decision.select().order_by(
@ -135,16 +137,12 @@ async def get_decisions(
if irunners_scored is not None: if irunners_scored is not None:
all_dec = all_dec.where(Decision.irunners_scored << irunners_scored) all_dec = all_dec.where(Decision.irunners_scored << irunners_scored)
if limit is not None:
if limit < 1:
limit = 1
all_dec = all_dec.limit(limit) all_dec = all_dec.limit(limit)
return_dec = { return_dec = {
"count": all_dec.count(), "count": all_dec.count(),
"decisions": [model_to_dict(x, recurse=not short_output) for x in all_dec], "decisions": [model_to_dict(x, recurse=not short_output) for x in all_dec],
} }
db.close()
return return_dec return return_dec
@ -164,12 +162,11 @@ async def patch_decision(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_decision - Bad Token: {token}") logger.warning("patch_decision - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_dec = Decision.get_or_none(Decision.id == decision_id) this_dec = Decision.get_or_none(Decision.id == decision_id)
if this_dec is None: if this_dec is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Decision ID {decision_id} not found" status_code=404, detail=f"Decision ID {decision_id} not found"
) )
@ -195,10 +192,8 @@ async def patch_decision(
if this_dec.save() == 1: if this_dec.save() == 1:
d_result = model_to_dict(this_dec) d_result = model_to_dict(this_dec)
db.close()
return d_result return d_result
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch decision {decision_id}" status_code=500, detail=f"Unable to patch decision {decision_id}"
) )
@ -208,7 +203,7 @@ async def patch_decision(
@handle_db_errors @handle_db_errors
async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_scheme)): async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_decisions - Bad Token: {token}") logger.warning("post_decisions - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_dec = [] new_dec = []
@ -222,12 +217,11 @@ async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_sch
status_code=404, detail=f"Player ID {x.pitcher_id} not found" status_code=404, detail=f"Player ID {x.pitcher_id} not found"
) )
new_dec.append(x.dict()) new_dec.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_dec, 10): for batch in chunked(new_dec, 10):
Decision.insert_many(batch).on_conflict_ignore().execute() Decision.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_dec)} decisions" return f"Inserted {len(new_dec)} decisions"
@ -236,18 +230,16 @@ async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_sch
@handle_db_errors @handle_db_errors
async def delete_decision(decision_id: int, token: str = Depends(oauth2_scheme)): async def delete_decision(decision_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_decision - Bad Token: {token}") logger.warning("delete_decision - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_dec = Decision.get_or_none(Decision.id == decision_id) this_dec = Decision.get_or_none(Decision.id == decision_id)
if this_dec is None: if this_dec is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Decision ID {decision_id} not found" status_code=404, detail=f"Decision ID {decision_id} not found"
) )
count = this_dec.delete_instance() count = this_dec.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Decision {decision_id} has been deleted" return f"Decision {decision_id} has been deleted"
@ -261,16 +253,14 @@ async def delete_decision(decision_id: int, token: str = Depends(oauth2_scheme))
@handle_db_errors @handle_db_errors
async def delete_decisions_game(game_id: int, token: str = Depends(oauth2_scheme)): async def delete_decisions_game(game_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_decisions_game - Bad Token: {token}") logger.warning("delete_decisions_game - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"Game ID {game_id} not found") raise HTTPException(status_code=404, detail=f"Game ID {game_id} not found")
count = Decision.delete().where(Decision.game == this_game).execute() count = Decision.delete().where(Decision.game == this_game).execute()
db.close()
if count > 0: if count > 0:
return f"Deleted {count} decisions matching Game ID {game_id}" return f"Deleted {count} decisions matching Game ID {game_id}"

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -32,6 +34,8 @@ async def get_divisions(
div_abbrev: Optional[str] = None, div_abbrev: Optional[str] = None,
lg_name: Optional[str] = None, lg_name: Optional[str] = None,
lg_abbrev: Optional[str] = None, lg_abbrev: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_divisions = Division.select().where(Division.season == season) all_divisions = Division.select().where(Division.season == season)
@ -44,11 +48,13 @@ async def get_divisions(
if lg_abbrev is not None: if lg_abbrev is not None:
all_divisions = all_divisions.where(Division.league_abbrev == lg_abbrev) all_divisions = all_divisions.where(Division.league_abbrev == lg_abbrev)
total_count = all_divisions.count()
all_divisions = all_divisions.offset(offset).limit(limit)
return_div = { return_div = {
"count": all_divisions.count(), "count": total_count,
"divisions": [model_to_dict(x) for x in all_divisions], "divisions": [model_to_dict(x) for x in all_divisions],
} }
db.close()
return return_div return return_div
@ -57,13 +63,11 @@ async def get_divisions(
async def get_one_division(division_id: int): async def get_one_division(division_id: int):
this_div = Division.get_or_none(Division.id == division_id) this_div = Division.get_or_none(Division.id == division_id)
if this_div is None: if this_div is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Division ID {division_id} not found" status_code=404, detail=f"Division ID {division_id} not found"
) )
r_div = model_to_dict(this_div) r_div = model_to_dict(this_div)
db.close()
return r_div return r_div
@ -78,12 +82,11 @@ async def patch_division(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_division - Bad Token: {token}") logger.warning("patch_division - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_div = Division.get_or_none(Division.id == division_id) this_div = Division.get_or_none(Division.id == division_id)
if this_div is None: if this_div is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Division ID {division_id} not found" status_code=404, detail=f"Division ID {division_id} not found"
) )
@ -99,10 +102,8 @@ async def patch_division(
if this_div.save() == 1: if this_div.save() == 1:
r_division = model_to_dict(this_div) r_division = model_to_dict(this_div)
db.close()
return r_division return r_division
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch division {division_id}" status_code=500, detail=f"Unable to patch division {division_id}"
) )
@ -114,17 +115,15 @@ async def post_division(
new_division: DivisionModel, token: str = Depends(oauth2_scheme) new_division: DivisionModel, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_division - Bad Token: {token}") logger.warning("post_division - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_division = Division(**new_division.dict()) this_division = Division(**new_division.model_dump())
if this_division.save() == 1: if this_division.save() == 1:
r_division = model_to_dict(this_division) r_division = model_to_dict(this_division)
db.close()
return r_division return r_division
else: else:
db.close()
raise HTTPException(status_code=500, detail=f"Unable to post division") raise HTTPException(status_code=500, detail=f"Unable to post division")
@ -132,18 +131,16 @@ async def post_division(
@handle_db_errors @handle_db_errors
async def delete_division(division_id: int, token: str = Depends(oauth2_scheme)): async def delete_division(division_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_division - Bad Token: {token}") logger.warning("delete_division - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_div = Division.get_or_none(Division.id == division_id) this_div = Division.get_or_none(Division.id == division_id)
if this_div is None: if this_div is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Division ID {division_id} not found" status_code=404, detail=f"Division ID {division_id} not found"
) )
count = this_div.delete_instance() count = this_div.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Division {division_id} has been deleted" return f"Division {division_id} has been deleted"

View File

@ -32,7 +32,6 @@ async def get_draftdata():
if draft_data is not None: if draft_data is not None:
r_data = model_to_dict(draft_data) r_data = model_to_dict(draft_data)
db.close()
return r_data return r_data
raise HTTPException(status_code=404, detail=f'No draft data found') raise HTTPException(status_code=404, detail=f'No draft data found')
@ -45,12 +44,11 @@ async def patch_draftdata(
pick_deadline: Optional[datetime.datetime] = None, result_channel: Optional[int] = None, pick_deadline: Optional[datetime.datetime] = None, result_channel: Optional[int] = None,
ping_channel: Optional[int] = None, pick_minutes: Optional[int] = None, token: str = Depends(oauth2_scheme)): ping_channel: Optional[int] = None, pick_minutes: Optional[int] = None, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'patch_draftdata - Bad Token: {token}') logger.warning('patch_draftdata - Bad Token')
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail='Unauthorized')
draft_data = DraftData.get_or_none(DraftData.id == data_id) draft_data = DraftData.get_or_none(DraftData.id == data_id)
if draft_data is None: if draft_data is None:
db.close()
raise HTTPException(status_code=404, detail=f'No draft data found') raise HTTPException(status_code=404, detail=f'No draft data found')
if currentpick is not None: if currentpick is not None:
@ -68,7 +66,6 @@ async def patch_draftdata(
saved = draft_data.save() saved = draft_data.save()
r_data = model_to_dict(draft_data) r_data = model_to_dict(draft_data)
db.close()
if saved == 1: if saved == 1:
return r_data return r_data

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -34,9 +36,11 @@ async def get_draftlist(
season: Optional[int], season: Optional[int],
team_id: list = Query(default=None), team_id: list = Query(default=None),
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"get_draftlist - Bad Token: {token}") logger.warning("get_draftlist - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
all_list = DraftList.select() all_list = DraftList.select()
@ -46,9 +50,11 @@ async def get_draftlist(
if team_id is not None: if team_id is not None:
all_list = all_list.where(DraftList.team_id << team_id) all_list = all_list.where(DraftList.team_id << team_id)
r_list = {"count": all_list.count(), "picks": [model_to_dict(x) for x in all_list]} total_count = all_list.count()
all_list = all_list.offset(offset).limit(limit)
r_list = {"count": total_count, "picks": [model_to_dict(x) for x in all_list]}
db.close()
return r_list return r_list
@ -56,7 +62,7 @@ async def get_draftlist(
@handle_db_errors @handle_db_errors
async def get_team_draftlist(team_id: int, token: str = Depends(oauth2_scheme)): async def get_team_draftlist(team_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_draftlist - Bad Token: {token}") logger.warning("post_draftlist - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_team = Team.get_or_none(Team.id == team_id) this_team = Team.get_or_none(Team.id == team_id)
@ -69,7 +75,6 @@ async def get_team_draftlist(team_id: int, token: str = Depends(oauth2_scheme)):
"picks": [model_to_dict(x) for x in this_list], "picks": [model_to_dict(x) for x in this_list],
} }
db.close()
return r_list return r_list
@ -79,7 +84,7 @@ async def post_draftlist(
draft_list: DraftListList, token: str = Depends(oauth2_scheme) draft_list: DraftListList, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_draftlist - Bad Token: {token}") logger.warning("post_draftlist - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_list = [] new_list = []
@ -93,13 +98,12 @@ async def post_draftlist(
DraftList.delete().where(DraftList.team == this_team).execute() DraftList.delete().where(DraftList.team == this_team).execute()
for x in draft_list.draft_list: for x in draft_list.draft_list:
new_list.append(x.dict()) new_list.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_list, 15): for batch in chunked(new_list, 15):
DraftList.insert_many(batch).on_conflict_ignore().execute() DraftList.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_list)} list values" return f"Inserted {len(new_list)} list values"
@ -107,9 +111,8 @@ async def post_draftlist(
@handle_db_errors @handle_db_errors
async def delete_draftlist(team_id: int, token: str = Depends(oauth2_scheme)): async def delete_draftlist(team_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_draftlist - Bad Token: {token}") logger.warning("delete_draftlist - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
count = DraftList.delete().where(DraftList.team_id == team_id).execute() count = DraftList.delete().where(DraftList.team_id == team_id).execute()
db.close()
return f"Deleted {count} list values" return f"Deleted {count} list values"

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -50,7 +52,7 @@ async def get_picks(
overall_end: Optional[int] = None, overall_end: Optional[int] = None,
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
sort: Optional[str] = None, sort: Optional[str] = None,
limit: Optional[int] = None, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
player_id: list = Query(default=None), player_id: list = Query(default=None),
player_taken: Optional[bool] = None, player_taken: Optional[bool] = None,
): ):
@ -110,7 +112,6 @@ async def get_picks(
all_picks = all_picks.where(DraftPick.overall <= overall_end) all_picks = all_picks.where(DraftPick.overall <= overall_end)
if player_taken is not None: if player_taken is not None:
all_picks = all_picks.where(DraftPick.player.is_null(not player_taken)) all_picks = all_picks.where(DraftPick.player.is_null(not player_taken))
if limit is not None:
all_picks = all_picks.limit(limit) all_picks = all_picks.limit(limit)
if sort is not None: if sort is not None:
@ -123,7 +124,6 @@ async def get_picks(
for line in all_picks: for line in all_picks:
return_picks["picks"].append(model_to_dict(line, recurse=not short_output)) return_picks["picks"].append(model_to_dict(line, recurse=not short_output))
db.close()
return return_picks return return_picks
@ -135,7 +135,6 @@ async def get_one_pick(pick_id: int, short_output: Optional[bool] = False):
r_pick = model_to_dict(this_pick, recurse=not short_output) r_pick = model_to_dict(this_pick, recurse=not short_output)
else: else:
raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found") raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found")
db.close()
return r_pick return r_pick
@ -145,15 +144,14 @@ async def patch_pick(
pick_id: int, new_pick: DraftPickModel, token: str = Depends(oauth2_scheme) pick_id: int, new_pick: DraftPickModel, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_pick - Bad Token: {token}") logger.warning("patch_pick - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
if DraftPick.get_or_none(DraftPick.id == pick_id) is None: if DraftPick.get_or_none(DraftPick.id == pick_id) is None:
raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found") raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found")
DraftPick.update(**new_pick.dict()).where(DraftPick.id == pick_id).execute() DraftPick.update(**new_pick.model_dump()).where(DraftPick.id == pick_id).execute()
r_pick = model_to_dict(DraftPick.get_by_id(pick_id)) r_pick = model_to_dict(DraftPick.get_by_id(pick_id))
db.close()
return r_pick return r_pick
@ -161,7 +159,7 @@ async def patch_pick(
@handle_db_errors @handle_db_errors
async def post_picks(p_list: DraftPickList, token: str = Depends(oauth2_scheme)): async def post_picks(p_list: DraftPickList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_picks - Bad Token: {token}") logger.warning("post_picks - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_picks = [] new_picks = []
@ -170,18 +168,16 @@ async def post_picks(p_list: DraftPickList, token: str = Depends(oauth2_scheme))
DraftPick.season == pick.season, DraftPick.overall == pick.overall DraftPick.season == pick.season, DraftPick.overall == pick.overall
) )
if dupe: if dupe:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Pick # {pick.overall} already exists for season {pick.season}", detail=f"Pick # {pick.overall} already exists for season {pick.season}",
) )
new_picks.append(pick.dict()) new_picks.append(pick.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_picks, 15): for batch in chunked(new_picks, 15):
DraftPick.insert_many(batch).on_conflict_ignore().execute() DraftPick.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_picks)} picks" return f"Inserted {len(new_picks)} picks"
@ -190,7 +186,7 @@ async def post_picks(p_list: DraftPickList, token: str = Depends(oauth2_scheme))
@handle_db_errors @handle_db_errors
async def delete_pick(pick_id: int, token: str = Depends(oauth2_scheme)): async def delete_pick(pick_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_pick - Bad Token: {token}") logger.warning("delete_pick - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_pick = DraftPick.get_or_none(DraftPick.id == pick_id) this_pick = DraftPick.get_or_none(DraftPick.id == pick_id)
@ -198,7 +194,6 @@ async def delete_pick(pick_id: int, token: str = Depends(oauth2_scheme)):
raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found") raise HTTPException(status_code=404, detail=f"Pick ID {pick_id} not found")
count = this_pick.delete_instance() count = this_pick.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Draft pick {pick_id} has been deleted" return f"Draft pick {pick_id} has been deleted"

View File

@ -3,40 +3,58 @@ from typing import List, Optional, Literal
import logging import logging
import pydantic import pydantic
from ..db_engine import db, BattingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks from ..db_engine import (
from ..dependencies import oauth2_scheme, valid_token, handle_db_errors db,
BattingStat,
logger = logging.getLogger('discord_app') Team,
Player,
router = APIRouter( Current,
prefix='/api/v3/fieldingstats', model_to_dict,
tags=['fieldingstats'] chunked,
fn,
per_season_weeks,
)
from ..dependencies import (
oauth2_scheme,
valid_token,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app")
@router.get('') router = APIRouter(prefix="/api/v3/fieldingstats", tags=["fieldingstats"])
@router.get("")
@handle_db_errors @handle_db_errors
async def get_fieldingstats( async def get_fieldingstats(
season: int, s_type: Optional[str] = 'regular', team_abbrev: list = Query(default=None), season: int,
player_name: list = Query(default=None), player_id: list = Query(default=None), s_type: Optional[str] = "regular",
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), team_abbrev: list = Query(default=None),
position: list = Query(default=None), limit: Optional[int] = None, sort: Optional[str] = None, player_name: list = Query(default=None),
short_output: Optional[bool] = True): player_id: list = Query(default=None),
if 'post' in s_type.lower(): week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
sort: Optional[str] = None,
short_output: Optional[bool] = True,
):
if "post" in s_type.lower():
all_stats = BattingStat.post_season(season) all_stats = BattingStat.post_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() return {"count": 0, "stats": []}
return {'count': 0, 'stats': []} elif s_type.lower() in ["combined", "total", "all"]:
elif s_type.lower() in ['combined', 'total', 'all']:
all_stats = BattingStat.combined_season(season) all_stats = BattingStat.combined_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() return {"count": 0, "stats": []}
return {'count': 0, 'stats': []}
else: else:
all_stats = BattingStat.regular_season(season) all_stats = BattingStat.regular_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() return {"count": 0, "stats": []}
return {'count': 0, 'stats': []}
all_stats = all_stats.where( all_stats = all_stats.where(
(BattingStat.xch > 0) | (BattingStat.pb > 0) | (BattingStat.sbc > 0) (BattingStat.xch > 0) | (BattingStat.pb > 0) | (BattingStat.sbc > 0)
@ -51,7 +69,9 @@ async def get_fieldingstats(
if player_id: if player_id:
all_stats = all_stats.where(BattingStat.player_id << player_id) all_stats = all_stats.where(BattingStat.player_id << player_id)
else: else:
p_query = Player.select_season(season).where(fn.Lower(Player.name) << [x.lower() for x in player_name]) p_query = Player.select_season(season).where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
all_stats = all_stats.where(BattingStat.player << p_query) all_stats = all_stats.where(BattingStat.player << p_query)
if game_num: if game_num:
all_stats = all_stats.where(BattingStat.game == game_num) all_stats = all_stats.where(BattingStat.game == game_num)
@ -63,73 +83,91 @@ async def get_fieldingstats(
if week_end is not None: if week_end is not None:
end = min(week_end, end) end = min(week_end, end)
if start > end: if start > end:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f'Start week {start} is after end week {end} - cannot pull stats' detail=f"Start week {start} is after end week {end} - cannot pull stats",
)
all_stats = all_stats.where(
(BattingStat.week >= start) & (BattingStat.week <= end)
) )
all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
if limit: total_count = all_stats.count()
all_stats = all_stats.limit(limit) all_stats = all_stats.limit(limit)
if sort: if sort:
if sort == 'newest': if sort == "newest":
all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game) all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game)
return_stats = { return_stats = {
'count': all_stats.count(), "count": total_count,
'stats': [{ "stats": [
'player': x.player_id if short_output else model_to_dict(x.player, recurse=False), {
'team': x.team_id if short_output else model_to_dict(x.team, recurse=False), "player": x.player_id
'pos': x.pos, if short_output
'xch': x.xch, else model_to_dict(x.player, recurse=False),
'xhit': x.xhit, "team": x.team_id
'error': x.error, if short_output
'pb': x.pb, else model_to_dict(x.team, recurse=False),
'sbc': x.sbc, "pos": x.pos,
'csc': x.csc, "xch": x.xch,
'week': x.week, "xhit": x.xhit,
'game': x.game, "error": x.error,
'season': x.season "pb": x.pb,
} for x in all_stats] "sbc": x.sbc,
"csc": x.csc,
"week": x.week,
"game": x.game,
"season": x.season,
}
for x in all_stats
],
} }
db.close()
return return_stats return return_stats
@router.get('/totals') @router.get("/totals")
@handle_db_errors @handle_db_errors
async def get_totalstats( async def get_totalstats(
season: int, s_type: Literal['regular', 'post', 'total', None] = None, team_abbrev: list = Query(default=None), season: int,
team_id: list = Query(default=None), player_name: list = Query(default=None), s_type: Literal["regular", "post", "total", None] = None,
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), team_abbrev: list = Query(default=None),
position: list = Query(default=None), sort: Optional[str] = None, player_id: list = Query(default=None), team_id: list = Query(default=None),
group_by: Literal['team', 'player', 'playerteam'] = 'player', short_output: Optional[bool] = False, player_name: list = Query(default=None),
min_ch: Optional[int] = 1, week: list = Query(default=None)): week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
sort: Optional[str] = None,
player_id: list = Query(default=None),
group_by: Literal["team", "player", "playerteam"] = "player",
short_output: Optional[bool] = False,
min_ch: Optional[int] = 1,
week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
# Build SELECT fields conditionally based on group_by to match GROUP BY exactly # Build SELECT fields conditionally based on group_by to match GROUP BY exactly
select_fields = [] select_fields = []
if group_by == 'player': if group_by == "player":
select_fields = [BattingStat.player, BattingStat.pos] select_fields = [BattingStat.player, BattingStat.pos]
elif group_by == 'team': elif group_by == "team":
select_fields = [BattingStat.team, BattingStat.pos] select_fields = [BattingStat.team, BattingStat.pos]
elif group_by == 'playerteam': elif group_by == "playerteam":
select_fields = [BattingStat.player, BattingStat.team, BattingStat.pos] select_fields = [BattingStat.player, BattingStat.team, BattingStat.pos]
else: else:
# Default case # Default case
select_fields = [BattingStat.player, BattingStat.pos] select_fields = [BattingStat.player, BattingStat.pos]
all_stats = ( all_stats = (
BattingStat BattingStat.select(
.select(*select_fields, *select_fields,
fn.SUM(BattingStat.xch).alias('sum_xch'), fn.SUM(BattingStat.xch).alias("sum_xch"),
fn.SUM(BattingStat.xhit).alias('sum_xhit'), fn.SUM(BattingStat.error).alias('sum_error'), fn.SUM(BattingStat.xhit).alias("sum_xhit"),
fn.SUM(BattingStat.pb).alias('sum_pb'), fn.SUM(BattingStat.sbc).alias('sum_sbc'), fn.SUM(BattingStat.error).alias("sum_error"),
fn.SUM(BattingStat.csc).alias('sum_csc')) fn.SUM(BattingStat.pb).alias("sum_pb"),
fn.SUM(BattingStat.sbc).alias("sum_sbc"),
fn.SUM(BattingStat.csc).alias("sum_csc"),
)
.where(BattingStat.season == season) .where(BattingStat.season == season)
.having(fn.SUM(BattingStat.xch) >= min_ch) .having(fn.SUM(BattingStat.xch) >= min_ch)
) )
@ -141,16 +179,20 @@ async def get_totalstats(
elif week_start is not None or week_end is not None: elif week_start is not None or week_end is not None:
if week_start is None or week_end is None: if week_start is None or week_end is None:
raise HTTPException( raise HTTPException(
status_code=400, detail='Both week_start and week_end must be included if either is used.' status_code=400,
detail="Both week_start and week_end must be included if either is used.",
)
weeks["start"] = week_start
if week_end < weeks["start"]:
raise HTTPException(
status_code=400,
detail="week_end must be greater than or equal to week_start",
) )
weeks['start'] = week_start
if week_end < weeks['start']:
raise HTTPException(status_code=400, detail='week_end must be greater than or equal to week_start')
else: else:
weeks['end'] = week_end weeks["end"] = week_end
all_stats = all_stats.where( all_stats = all_stats.where(
(BattingStat.week >= weeks['start']) & (BattingStat.week <= weeks['end']) (BattingStat.week >= weeks["start"]) & (BattingStat.week <= weeks["end"])
) )
elif week is not None: elif week is not None:
@ -161,14 +203,20 @@ async def get_totalstats(
if position is not None: if position is not None:
p_list = [x.upper() for x in position] p_list = [x.upper() for x in position]
all_players = Player.select().where( all_players = Player.select().where(
(Player.pos_1 << p_list) | (Player.pos_2 << p_list) | (Player.pos_3 << p_list) | (Player.pos_4 << p_list) | (Player.pos_1 << p_list)
(Player.pos_5 << p_list) | (Player.pos_6 << p_list) | (Player.pos_7 << p_list) | (Player.pos_8 << p_list) | (Player.pos_2 << p_list)
| (Player.pos_3 << p_list)
| (Player.pos_4 << p_list)
| (Player.pos_5 << p_list)
| (Player.pos_6 << p_list)
| (Player.pos_7 << p_list)
| (Player.pos_8 << p_list)
) )
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
if sort is not None: if sort is not None:
if sort == 'player': if sort == "player":
all_stats = all_stats.order_by(BattingStat.player) all_stats = all_stats.order_by(BattingStat.player)
elif sort == 'team': elif sort == "team":
all_stats = all_stats.order_by(BattingStat.team) all_stats = all_stats.order_by(BattingStat.team)
if group_by is not None: if group_by is not None:
# Use the same fields for GROUP BY as we used for SELECT # Use the same fields for GROUP BY as we used for SELECT
@ -177,47 +225,55 @@ async def get_totalstats(
all_teams = Team.select().where(Team.id << team_id) all_teams = Team.select().where(Team.id << team_id)
all_stats = all_stats.where(BattingStat.team << all_teams) all_stats = all_stats.where(BattingStat.team << all_teams)
elif team_abbrev is not None: elif team_abbrev is not None:
all_teams = Team.select().where(fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]) all_teams = Team.select().where(
fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]
)
all_stats = all_stats.where(BattingStat.team << all_teams) all_stats = all_stats.where(BattingStat.team << all_teams)
if player_name is not None: if player_name is not None:
all_players = Player.select().where(fn.Lower(Player.name) << [x.lower() for x in player_name]) all_players = Player.select().where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
elif player_id is not None: elif player_id is not None:
all_players = Player.select().where(Player.id << player_id) all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
return_stats = { total_count = all_stats.count()
'count': 0, all_stats = all_stats.offset(offset).limit(limit)
'stats': []
} return_stats = {"count": total_count, "stats": []}
for x in all_stats: for x in all_stats:
if x.sum_xch + x.sum_sbc <= 0: if x.sum_xch + x.sum_sbc <= 0:
continue continue
# Handle player field based on grouping with safe access # Handle player field based on grouping with safe access
this_player = 'TOT' this_player = "TOT"
if 'player' in group_by and hasattr(x, 'player'): if "player" in group_by and hasattr(x, "player"):
this_player = x.player_id if short_output else model_to_dict(x.player, recurse=False) this_player = (
x.player_id if short_output else model_to_dict(x.player, recurse=False)
)
# Handle team field based on grouping with safe access # Handle team field based on grouping with safe access
this_team = 'TOT' this_team = "TOT"
if 'team' in group_by and hasattr(x, 'team'): if "team" in group_by and hasattr(x, "team"):
this_team = x.team_id if short_output else model_to_dict(x.team, recurse=False) this_team = (
x.team_id if short_output else model_to_dict(x.team, recurse=False)
)
return_stats['stats'].append({ return_stats["stats"].append(
'player': this_player, {
'team': this_team, "player": this_player,
'pos': x.pos, "team": this_team,
'xch': x.sum_xch, "pos": x.pos,
'xhit': x.sum_xhit, "xch": x.sum_xch,
'error': x.sum_error, "xhit": x.sum_xhit,
'pb': x.sum_pb, "error": x.sum_error,
'sbc': x.sum_sbc, "pb": x.sum_pb,
'csc': x.sum_csc "sbc": x.sum_sbc,
}) "csc": x.sum_csc,
}
)
return_stats['count'] = len(return_stats['stats'])
db.close()
return return_stats return return_stats

View File

@ -138,8 +138,6 @@ async def get_help_commands(
except Exception as e: except Exception as e:
logger.error(f"Error getting help commands: {e}") logger.error(f"Error getting help commands: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/", include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/", include_in_schema=PRIVATE_IN_SCHEMA)
@ -149,7 +147,7 @@ async def create_help_command_endpoint(
): ):
"""Create a new help command""" """Create a new help command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"create_help_command - Bad Token: {token}") logger.warning("create_help_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -187,8 +185,6 @@ async def create_help_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error creating help command: {e}") logger.error(f"Error creating help command: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.put("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA) @router.put("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@ -198,7 +194,7 @@ async def update_help_command_endpoint(
): ):
"""Update an existing help command""" """Update an existing help command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"update_help_command - Bad Token: {token}") logger.warning("update_help_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -238,8 +234,6 @@ async def update_help_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error updating help command {command_id}: {e}") logger.error(f"Error updating help command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.patch("/{command_id}/restore", include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/{command_id}/restore", include_in_schema=PRIVATE_IN_SCHEMA)
@ -249,7 +243,7 @@ async def restore_help_command_endpoint(
): ):
"""Restore a soft-deleted help command""" """Restore a soft-deleted help command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"restore_help_command - Bad Token: {token}") logger.warning("restore_help_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -277,8 +271,6 @@ async def restore_help_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error restoring help command {command_id}: {e}") logger.error(f"Error restoring help command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.delete("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA) @router.delete("/{command_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@ -288,7 +280,7 @@ async def delete_help_command_endpoint(
): ):
"""Soft delete a help command""" """Soft delete a help command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_help_command - Bad Token: {token}") logger.warning("delete_help_command - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -309,8 +301,6 @@ async def delete_help_command_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error deleting help command {command_id}: {e}") logger.error(f"Error deleting help command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/stats") @router.get("/stats")
@ -368,8 +358,6 @@ async def get_help_command_stats():
except Exception as e: except Exception as e:
logger.error(f"Error getting help command stats: {e}") logger.error(f"Error getting help command stats: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# Special endpoints for Discord bot integration # Special endpoints for Discord bot integration
@ -402,8 +390,6 @@ async def get_help_command_by_name_endpoint(
except Exception as e: except Exception as e:
logger.error(f"Error getting help command by name '{command_name}': {e}") logger.error(f"Error getting help command by name '{command_name}': {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.patch("/by_name/{command_name}/view", include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/by_name/{command_name}/view", include_in_schema=PRIVATE_IN_SCHEMA)
@ -411,7 +397,7 @@ async def get_help_command_by_name_endpoint(
async def increment_view_count(command_name: str, token: str = Depends(oauth2_scheme)): async def increment_view_count(command_name: str, token: str = Depends(oauth2_scheme)):
"""Increment view count for a help command""" """Increment view count for a help command"""
if not valid_token(token): if not valid_token(token):
logger.warning(f"increment_view_count - Bad Token: {token}") logger.warning("increment_view_count - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
@ -439,8 +425,6 @@ async def increment_view_count(command_name: str, token: str = Depends(oauth2_sc
except Exception as e: except Exception as e:
logger.error(f"Error incrementing view count for '{command_name}': {e}") logger.error(f"Error incrementing view count for '{command_name}': {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/autocomplete") @router.get("/autocomplete")
@ -470,8 +454,6 @@ async def get_help_names_for_autocomplete(
except Exception as e: except Exception as e:
logger.error(f"Error getting help names for autocomplete: {e}") logger.error(f"Error getting help names for autocomplete: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/{command_id}") @router.get("/{command_id}")
@ -499,5 +481,3 @@ async def get_help_command(command_id: int):
except Exception as e: except Exception as e:
logger.error(f"Error getting help command {command_id}: {e}") logger.error(f"Error getting help command {command_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -38,6 +40,8 @@ async def get_injuries(
is_active: bool = None, is_active: bool = None,
short_output: bool = False, short_output: bool = False,
sort: Optional[str] = "start-asc", sort: Optional[str] = "start-asc",
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_injuries = Injury.select() all_injuries = Injury.select()
@ -64,11 +68,13 @@ async def get_injuries(
elif sort == "start-desc": elif sort == "start-desc":
all_injuries = all_injuries.order_by(-Injury.start_week, -Injury.start_game) all_injuries = all_injuries.order_by(-Injury.start_week, -Injury.start_game)
total_count = all_injuries.count()
all_injuries = all_injuries.offset(offset).limit(limit)
return_injuries = { return_injuries = {
"count": all_injuries.count(), "count": total_count,
"injuries": [model_to_dict(x, recurse=not short_output) for x in all_injuries], "injuries": [model_to_dict(x, recurse=not short_output) for x in all_injuries],
} }
db.close()
return return_injuries return return_injuries
@ -80,12 +86,11 @@ async def patch_injury(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_injury - Bad Token: {token}") logger.warning("patch_injury - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_injury = Injury.get_or_none(Injury.id == injury_id) this_injury = Injury.get_or_none(Injury.id == injury_id)
if this_injury is None: if this_injury is None:
db.close()
raise HTTPException(status_code=404, detail=f"Injury ID {injury_id} not found") raise HTTPException(status_code=404, detail=f"Injury ID {injury_id} not found")
if is_active is not None: if is_active is not None:
@ -93,10 +98,8 @@ async def patch_injury(
if this_injury.save() == 1: if this_injury.save() == 1:
r_injury = model_to_dict(this_injury) r_injury = model_to_dict(this_injury)
db.close()
return r_injury return r_injury
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch injury {injury_id}" status_code=500, detail=f"Unable to patch injury {injury_id}"
) )
@ -106,17 +109,15 @@ async def patch_injury(
@handle_db_errors @handle_db_errors
async def post_injury(new_injury: InjuryModel, token: str = Depends(oauth2_scheme)): async def post_injury(new_injury: InjuryModel, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_injury - Bad Token: {token}") logger.warning("post_injury - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_injury = Injury(**new_injury.dict()) this_injury = Injury(**new_injury.model_dump())
if this_injury.save(): if this_injury.save():
r_injury = model_to_dict(this_injury) r_injury = model_to_dict(this_injury)
db.close()
return r_injury return r_injury
else: else:
db.close()
raise HTTPException(status_code=500, detail=f"Unable to post injury") raise HTTPException(status_code=500, detail=f"Unable to post injury")
@ -124,16 +125,14 @@ async def post_injury(new_injury: InjuryModel, token: str = Depends(oauth2_schem
@handle_db_errors @handle_db_errors
async def delete_injury(injury_id: int, token: str = Depends(oauth2_scheme)): async def delete_injury(injury_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_injury - Bad Token: {token}") logger.warning("delete_injury - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_injury = Injury.get_or_none(Injury.id == injury_id) this_injury = Injury.get_or_none(Injury.id == injury_id)
if this_injury is None: if this_injury is None:
db.close()
raise HTTPException(status_code=404, detail=f"Injury ID {injury_id} not found") raise HTTPException(status_code=404, detail=f"Injury ID {injury_id} not found")
count = this_injury.delete_instance() count = this_injury.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Injury {injury_id} has been deleted" return f"Injury {injury_id} has been deleted"

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -34,6 +36,8 @@ async def get_keepers(
team_id: list = Query(default=None), team_id: list = Query(default=None),
player_id: list = Query(default=None), player_id: list = Query(default=None),
short_output: bool = False, short_output: bool = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_keepers = Keeper.select() all_keepers = Keeper.select()
@ -44,11 +48,13 @@ async def get_keepers(
if player_id is not None: if player_id is not None:
all_keepers = all_keepers.where(Keeper.player_id << player_id) all_keepers = all_keepers.where(Keeper.player_id << player_id)
total_count = all_keepers.count()
all_keepers = all_keepers.offset(offset).limit(limit)
return_keepers = { return_keepers = {
"count": all_keepers.count(), "count": total_count,
"keepers": [model_to_dict(x, recurse=not short_output) for x in all_keepers], "keepers": [model_to_dict(x, recurse=not short_output) for x in all_keepers],
} }
db.close()
return return_keepers return return_keepers
@ -62,7 +68,7 @@ async def patch_keeper(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_keeper - Bad Token: {token}") logger.warning("patch_keeper - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_keeper = Keeper.get_or_none(Keeper.id == keeper_id) this_keeper = Keeper.get_or_none(Keeper.id == keeper_id)
@ -78,10 +84,8 @@ async def patch_keeper(
if this_keeper.save(): if this_keeper.save():
r_keeper = model_to_dict(this_keeper) r_keeper = model_to_dict(this_keeper)
db.close()
return r_keeper return r_keeper
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch keeper {keeper_id}" status_code=500, detail=f"Unable to patch keeper {keeper_id}"
) )
@ -91,17 +95,16 @@ async def patch_keeper(
@handle_db_errors @handle_db_errors
async def post_keepers(k_list: KeeperList, token: str = Depends(oauth2_scheme)): async def post_keepers(k_list: KeeperList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_keepers - Bad Token: {token}") logger.warning("post_keepers - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_keepers = [] new_keepers = []
for keeper in k_list.keepers: for keeper in k_list.keepers:
new_keepers.append(keeper.dict()) new_keepers.append(keeper.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_keepers, 14): for batch in chunked(new_keepers, 14):
Keeper.insert_many(batch).on_conflict_ignore().execute() Keeper.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_keepers)} keepers" return f"Inserted {len(new_keepers)} keepers"
@ -110,7 +113,7 @@ async def post_keepers(k_list: KeeperList, token: str = Depends(oauth2_scheme)):
@handle_db_errors @handle_db_errors
async def delete_keeper(keeper_id: int, token: str = Depends(oauth2_scheme)): async def delete_keeper(keeper_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_keeper - Bad Token: {token}") logger.warning("delete_keeper - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_keeper = Keeper.get_or_none(Keeper.id == keeper_id) this_keeper = Keeper.get_or_none(Keeper.id == keeper_id)
@ -118,7 +121,6 @@ async def delete_keeper(keeper_id: int, token: str = Depends(oauth2_scheme)):
raise HTTPException(status_code=404, detail=f"Keeper ID {keeper_id} not found") raise HTTPException(status_code=404, detail=f"Keeper ID {keeper_id} not found")
count = this_keeper.delete_instance() count = this_keeper.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Keeper ID {keeper_id} has been deleted" return f"Keeper ID {keeper_id} has been deleted"

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -29,6 +31,8 @@ async def get_managers(
name: list = Query(default=None), name: list = Query(default=None),
active: Optional[bool] = None, active: Optional[bool] = None,
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
if active is not None: if active is not None:
current = Current.latest() current = Current.latest()
@ -61,7 +65,9 @@ async def get_managers(
i_mgr.append(z) i_mgr.append(z)
final_mgrs = [model_to_dict(y, recurse=not short_output) for y in i_mgr] final_mgrs = [model_to_dict(y, recurse=not short_output) for y in i_mgr]
return_managers = {"count": len(final_mgrs), "managers": final_mgrs} total_count = len(final_mgrs)
final_mgrs = final_mgrs[offset : offset + limit]
return_managers = {"count": total_count, "managers": final_mgrs}
else: else:
all_managers = Manager.select() all_managers = Manager.select()
@ -69,14 +75,15 @@ async def get_managers(
name_list = [x.lower() for x in name] name_list = [x.lower() for x in name]
all_managers = all_managers.where(fn.Lower(Manager.name) << name_list) all_managers = all_managers.where(fn.Lower(Manager.name) << name_list)
total_count = all_managers.count()
all_managers = all_managers.offset(offset).limit(limit)
return_managers = { return_managers = {
"count": all_managers.count(), "count": total_count,
"managers": [ "managers": [
model_to_dict(x, recurse=not short_output) for x in all_managers model_to_dict(x, recurse=not short_output) for x in all_managers
], ],
} }
db.close()
return return_managers return return_managers
@ -86,7 +93,6 @@ async def get_one_manager(manager_id: int, short_output: Optional[bool] = False)
this_manager = Manager.get_or_none(Manager.id == manager_id) this_manager = Manager.get_or_none(Manager.id == manager_id)
if this_manager is not None: if this_manager is not None:
r_manager = model_to_dict(this_manager, recurse=not short_output) r_manager = model_to_dict(this_manager, recurse=not short_output)
db.close()
return r_manager return r_manager
else: else:
raise HTTPException(status_code=404, detail=f"Manager {manager_id} not found") raise HTTPException(status_code=404, detail=f"Manager {manager_id} not found")
@ -103,12 +109,11 @@ async def patch_manager(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_manager - Bad Token: {token}") logger.warning("patch_manager - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_manager = Manager.get_or_none(Manager.id == manager_id) this_manager = Manager.get_or_none(Manager.id == manager_id)
if this_manager is None: if this_manager is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Manager ID {manager_id} not found" status_code=404, detail=f"Manager ID {manager_id} not found"
) )
@ -124,10 +129,8 @@ async def patch_manager(
if this_manager.save() == 1: if this_manager.save() == 1:
r_manager = model_to_dict(this_manager) r_manager = model_to_dict(this_manager)
db.close()
return r_manager return r_manager
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch manager {this_manager}" status_code=500, detail=f"Unable to patch manager {this_manager}"
) )
@ -137,17 +140,15 @@ async def patch_manager(
@handle_db_errors @handle_db_errors
async def post_manager(new_manager: ManagerModel, token: str = Depends(oauth2_scheme)): async def post_manager(new_manager: ManagerModel, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_manager - Bad Token: {token}") logger.warning("post_manager - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_manager = Manager(**new_manager.dict()) this_manager = Manager(**new_manager.model_dump())
if this_manager.save(): if this_manager.save():
r_manager = model_to_dict(this_manager) r_manager = model_to_dict(this_manager)
db.close()
return r_manager return r_manager
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to post manager {this_manager.name}" status_code=500, detail=f"Unable to post manager {this_manager.name}"
) )
@ -157,18 +158,16 @@ async def post_manager(new_manager: ManagerModel, token: str = Depends(oauth2_sc
@handle_db_errors @handle_db_errors
async def delete_manager(manager_id: int, token: str = Depends(oauth2_scheme)): async def delete_manager(manager_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_manager - Bad Token: {token}") logger.warning("delete_manager - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_manager = Manager.get_or_none(Manager.id == manager_id) this_manager = Manager.get_or_none(Manager.id == manager_id)
if this_manager is None: if this_manager is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Manager ID {manager_id} not found" status_code=404, detail=f"Manager ID {manager_id} not found"
) )
count = this_manager.delete_instance() count = this_manager.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Manager {manager_id} has been deleted" return f"Manager {manager_id} has been deleted"

View File

@ -1,6 +1,3 @@
import datetime
import os
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional, Literal from typing import List, Optional, Literal
import logging import logging
@ -22,6 +19,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -71,7 +70,7 @@ async def get_pitstats(
week_start: Optional[int] = None, week_start: Optional[int] = None,
week_end: Optional[int] = None, week_end: Optional[int] = None,
game_num: list = Query(default=None), game_num: list = Query(default=None),
limit: Optional[int] = None, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
ip_min: Optional[float] = None, ip_min: Optional[float] = None,
sort: Optional[str] = None, sort: Optional[str] = None,
short_output: Optional[bool] = True, short_output: Optional[bool] = True,
@ -79,17 +78,14 @@ async def get_pitstats(
if "post" in s_type.lower(): if "post" in s_type.lower():
all_stats = PitchingStat.post_season(season) all_stats = PitchingStat.post_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
elif s_type.lower() in ["combined", "total", "all"]: elif s_type.lower() in ["combined", "total", "all"]:
all_stats = PitchingStat.combined_season(season) all_stats = PitchingStat.combined_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
else: else:
all_stats = PitchingStat.regular_season(season) all_stats = PitchingStat.regular_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close()
return {"count": 0, "stats": []} return {"count": 0, "stats": []}
if team_abbrev is not None: if team_abbrev is not None:
@ -115,7 +111,6 @@ async def get_pitstats(
if week_end is not None: if week_end is not None:
end = min(week_end, end) end = min(week_end, end)
if start > end: if start > end:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Start week {start} is after end week {end} - cannot pull stats", detail=f"Start week {start} is after end week {end} - cannot pull stats",
@ -124,7 +119,6 @@ async def get_pitstats(
(PitchingStat.week >= start) & (PitchingStat.week <= end) (PitchingStat.week >= start) & (PitchingStat.week <= end)
) )
if limit:
all_stats = all_stats.limit(limit) all_stats = all_stats.limit(limit)
if sort: if sort:
if sort == "newest": if sort == "newest":
@ -135,7 +129,6 @@ async def get_pitstats(
"stats": [model_to_dict(x, recurse=not short_output) for x in all_stats], "stats": [model_to_dict(x, recurse=not short_output) for x in all_stats],
} }
db.close()
return return_stats return return_stats
@ -157,6 +150,8 @@ async def get_totalstats(
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
group_by: Literal["team", "player", "playerteam"] = "player", group_by: Literal["team", "player", "playerteam"] = "player",
week: list = Query(default=None), week: list = Query(default=None),
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
raise HTTPException( raise HTTPException(
@ -262,7 +257,10 @@ async def get_totalstats(
all_players = Player.select().where(Player.id << player_id) all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(PitchingStat.player << all_players) all_stats = all_stats.where(PitchingStat.player << all_players)
return_stats = {"count": all_stats.count(), "stats": []} total_count = all_stats.count()
all_stats = all_stats.offset(offset).limit(limit)
return_stats = {"count": total_count, "stats": []}
for x in all_stats: for x in all_stats:
# Handle player field based on grouping with safe access # Handle player field based on grouping with safe access
@ -304,7 +302,6 @@ async def get_totalstats(
"bsv": x.sum_bsv, "bsv": x.sum_bsv,
} }
) )
db.close()
return return_stats return return_stats
@ -314,15 +311,16 @@ async def patch_pitstats(
stat_id: int, new_stats: PitStatModel, token: str = Depends(oauth2_scheme) stat_id: int, new_stats: PitStatModel, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_pitstats - Bad Token: {token}") logger.warning("patch_pitstats - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
if PitchingStat.get_or_none(PitchingStat.id == stat_id) is None: if PitchingStat.get_or_none(PitchingStat.id == stat_id) is None:
raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found") raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found")
PitchingStat.update(**new_stats.dict()).where(PitchingStat.id == stat_id).execute() PitchingStat.update(**new_stats.model_dump()).where(
PitchingStat.id == stat_id
).execute()
r_stat = model_to_dict(PitchingStat.get_by_id(stat_id)) r_stat = model_to_dict(PitchingStat.get_by_id(stat_id))
db.close()
return r_stat return r_stat
@ -330,7 +328,7 @@ async def patch_pitstats(
@handle_db_errors @handle_db_errors
async def post_pitstats(s_list: PitStatList, token: str = Depends(oauth2_scheme)): async def post_pitstats(s_list: PitStatList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_pitstats - Bad Token: {token}") logger.warning("post_pitstats - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
all_stats = [] all_stats = []
@ -347,11 +345,10 @@ async def post_pitstats(s_list: PitStatList, token: str = Depends(oauth2_scheme)
status_code=404, detail=f"Player ID {x.player_id} not found" status_code=404, detail=f"Player ID {x.player_id} not found"
) )
all_stats.append(PitchingStat(**x.dict())) all_stats.append(PitchingStat(**x.model_dump()))
with db.atomic(): with db.atomic():
for batch in chunked(all_stats, 15): for batch in chunked(all_stats, 15):
PitchingStat.insert_many(batch).on_conflict_ignore().execute() PitchingStat.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Added {len(all_stats)} batting lines" return f"Added {len(all_stats)} batting lines"

View File

@ -4,9 +4,13 @@ Thin HTTP layer using PlayerService for business logic.
""" """
from fastapi import APIRouter, Query, Response, Depends from fastapi import APIRouter, Query, Response, Depends
from typing import Optional, List from typing import Literal, Optional, List
from ..dependencies import oauth2_scheme, cache_result, handle_db_errors from ..dependencies import (
oauth2_scheme,
cache_result,
handle_db_errors,
)
from ..services.base import BaseService from ..services.base import BaseService
from ..services.player_service import PlayerService from ..services.player_service import PlayerService
@ -23,7 +27,7 @@ async def get_players(
pos: list = Query(default=None), pos: list = Query(default=None),
strat_code: list = Query(default=None), strat_code: list = Query(default=None),
is_injured: Optional[bool] = None, is_injured: Optional[bool] = None,
sort: Optional[str] = None, sort: Optional[Literal["cost-asc", "cost-desc", "name-asc", "name-desc"]] = None,
limit: Optional[int] = Query( limit: Optional[int] = Query(
default=None, ge=1, description="Maximum number of results to return" default=None, ge=1, description="Maximum number of results to return"
), ),

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -42,6 +44,8 @@ async def get_results(
away_abbrev: list = Query(default=None), away_abbrev: list = Query(default=None),
home_abbrev: list = Query(default=None), home_abbrev: list = Query(default=None),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_results = Result.select_season(season) all_results = Result.select_season(season)
@ -74,11 +78,13 @@ async def get_results(
if week_end is not None: if week_end is not None:
all_results = all_results.where(Result.week <= week_end) all_results = all_results.where(Result.week <= week_end)
total_count = all_results.count()
all_results = all_results.offset(offset).limit(limit)
return_results = { return_results = {
"count": all_results.count(), "count": total_count,
"results": [model_to_dict(x, recurse=not short_output) for x in all_results], "results": [model_to_dict(x, recurse=not short_output) for x in all_results],
} }
db.close()
return return_results return return_results
@ -90,7 +96,6 @@ async def get_one_result(result_id: int, short_output: Optional[bool] = False):
r_result = model_to_dict(this_result, recurse=not short_output) r_result = model_to_dict(this_result, recurse=not short_output)
else: else:
r_result = None r_result = None
db.close()
return r_result return r_result
@ -109,7 +114,7 @@ async def patch_result(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_player - Bad Token: {token}") logger.warning("patch_player - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_result = Result.get_or_none(Result.id == result_id) this_result = Result.get_or_none(Result.id == result_id)
@ -142,10 +147,8 @@ async def patch_result(
if this_result.save() == 1: if this_result.save() == 1:
r_result = model_to_dict(this_result) r_result = model_to_dict(this_result)
db.close()
return r_result return r_result
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch result {result_id}" status_code=500, detail=f"Unable to patch result {result_id}"
) )
@ -155,26 +158,36 @@ async def patch_result(
@handle_db_errors @handle_db_errors
async def post_results(result_list: ResultList, token: str = Depends(oauth2_scheme)): async def post_results(result_list: ResultList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_player - Bad Token: {token}") logger.warning("patch_player - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_results = [] new_results = []
all_team_ids = list(
set(x.awayteam_id for x in result_list.results)
| set(x.hometeam_id for x in result_list.results)
)
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
for x in result_list.results: for x in result_list.results:
if Team.get_or_none(Team.id == x.awayteam_id) is None: if x.awayteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.awayteam_id} not found" status_code=404, detail=f"Team ID {x.awayteam_id} not found"
) )
if Team.get_or_none(Team.id == x.hometeam_id) is None: if x.hometeam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.hometeam_id} not found" status_code=404, detail=f"Team ID {x.hometeam_id} not found"
) )
new_results.append(x.dict()) new_results.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_results, 15): for batch in chunked(new_results, 15):
Result.insert_many(batch).on_conflict_ignore().execute() Result.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_results)} results" return f"Inserted {len(new_results)} results"
@ -183,16 +196,14 @@ async def post_results(result_list: ResultList, token: str = Depends(oauth2_sche
@handle_db_errors @handle_db_errors
async def delete_result(result_id: int, token: str = Depends(oauth2_scheme)): async def delete_result(result_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_result - Bad Token: {token}") logger.warning("delete_result - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_result = Result.get_or_none(Result.id == result_id) this_result = Result.get_or_none(Result.id == result_id)
if not this_result: if not this_result:
db.close()
raise HTTPException(status_code=404, detail=f"Result ID {result_id} not found") raise HTTPException(status_code=404, detail=f"Result ID {result_id} not found")
count = this_result.delete_instance() count = this_result.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Result {result_id} has been deleted" return f"Result {result_id} has been deleted"

View File

@ -12,6 +12,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -44,6 +46,8 @@ async def get_players(
key_mlbam: list = Query(default=None), key_mlbam: list = Query(default=None),
sort: Optional[str] = None, sort: Optional[str] = None,
csv: Optional[bool] = False, csv: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_players = SbaPlayer.select() all_players = SbaPlayer.select()
@ -98,14 +102,15 @@ async def get_players(
if csv: if csv:
return_val = query_to_csv(all_players) return_val = query_to_csv(all_players)
db.close()
return Response(content=return_val, media_type="text/csv") return Response(content=return_val, media_type="text/csv")
total_count = all_players.count()
all_players = all_players.offset(offset).limit(limit)
return_val = { return_val = {
"count": all_players.count(), "count": total_count,
"players": [model_to_dict(x) for x in all_players], "players": [model_to_dict(x) for x in all_players],
} }
db.close()
return return_val return return_val
@ -114,13 +119,11 @@ async def get_players(
async def get_one_player(player_id: int): async def get_one_player(player_id: int):
this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id) this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id)
if this_player is None: if this_player is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"SbaPlayer id {player_id} not found" status_code=404, detail=f"SbaPlayer id {player_id} not found"
) )
r_data = model_to_dict(this_player) r_data = model_to_dict(this_player)
db.close()
return r_data return r_data
@ -137,8 +140,7 @@ async def patch_player(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logging.warning(f"Bad Token: {token}") logging.warning("Bad Token")
db.close()
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="You are not authorized to patch mlb players. This event has been logged.", detail="You are not authorized to patch mlb players. This event has been logged.",
@ -146,7 +148,6 @@ async def patch_player(
this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id) this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id)
if this_player is None: if this_player is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"SbaPlayer id {player_id} not found" status_code=404, detail=f"SbaPlayer id {player_id} not found"
) )
@ -166,10 +167,8 @@ async def patch_player(
if this_player.save() == 1: if this_player.save() == 1:
return_val = model_to_dict(this_player) return_val = model_to_dict(this_player)
db.close()
return return_val return return_val
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=418, status_code=418,
detail="Well slap my ass and call me a teapot; I could not save that player", detail="Well slap my ass and call me a teapot; I could not save that player",
@ -180,8 +179,7 @@ async def patch_player(
@handle_db_errors @handle_db_errors
async def post_players(players: PlayerList, token: str = Depends(oauth2_scheme)): async def post_players(players: PlayerList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logging.warning(f"Bad Token: {token}") logging.warning("Bad Token")
db.close()
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="You are not authorized to post mlb players. This event has been logged.", detail="You are not authorized to post mlb players. This event has been logged.",
@ -200,7 +198,6 @@ async def post_players(players: PlayerList, token: str = Depends(oauth2_scheme))
) )
if dupes.count() > 0: if dupes.count() > 0:
logger.error(f"Found a dupe for {x}") logger.error(f"Found a dupe for {x}")
db.close()
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"{x.first_name} {x.last_name} has a key already in the database", detail=f"{x.first_name} {x.last_name} has a key already in the database",
@ -211,7 +208,6 @@ async def post_players(players: PlayerList, token: str = Depends(oauth2_scheme))
with db.atomic(): with db.atomic():
for batch in chunked(new_players, 15): for batch in chunked(new_players, 15):
SbaPlayer.insert_many(batch).on_conflict_ignore().execute() SbaPlayer.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_players)} new MLB players" return f"Inserted {len(new_players)} new MLB players"
@ -220,8 +216,7 @@ async def post_players(players: PlayerList, token: str = Depends(oauth2_scheme))
@handle_db_errors @handle_db_errors
async def post_one_player(player: SbaPlayerModel, token: str = Depends(oauth2_scheme)): async def post_one_player(player: SbaPlayerModel, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logging.warning(f"Bad Token: {token}") logging.warning("Bad Token")
db.close()
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="You are not authorized to post mlb players. This event has been logged.", detail="You are not authorized to post mlb players. This event has been logged.",
@ -236,20 +231,17 @@ async def post_one_player(player: SbaPlayerModel, token: str = Depends(oauth2_sc
logging.info(f"POST /SbaPlayers/one - dupes found:") logging.info(f"POST /SbaPlayers/one - dupes found:")
for x in dupes: for x in dupes:
logging.info(f"{x}") logging.info(f"{x}")
db.close()
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"{player.first_name} {player.last_name} has a key already in the database", detail=f"{player.first_name} {player.last_name} has a key already in the database",
) )
new_player = SbaPlayer(**player.dict()) new_player = SbaPlayer(**player.model_dump())
saved = new_player.save() saved = new_player.save()
if saved == 1: if saved == 1:
return_val = model_to_dict(new_player) return_val = model_to_dict(new_player)
db.close()
return return_val return return_val
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=418, status_code=418,
detail="Well slap my ass and call me a teapot; I could not save that player", detail="Well slap my ass and call me a teapot; I could not save that player",
@ -260,8 +252,7 @@ async def post_one_player(player: SbaPlayerModel, token: str = Depends(oauth2_sc
@handle_db_errors @handle_db_errors
async def delete_player(player_id: int, token: str = Depends(oauth2_scheme)): async def delete_player(player_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logging.warning(f"Bad Token: {token}") logging.warning("Bad Token")
db.close()
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="You are not authorized to delete mlb players. This event has been logged.", detail="You are not authorized to delete mlb players. This event has been logged.",
@ -269,13 +260,11 @@ async def delete_player(player_id: int, token: str = Depends(oauth2_scheme)):
this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id) this_player = SbaPlayer.get_or_none(SbaPlayer.id == player_id)
if this_player is None: if this_player is None:
db.close()
raise HTTPException( raise HTTPException(
status_code=404, detail=f"SbaPlayer id {player_id} not found" status_code=404, detail=f"SbaPlayer id {player_id} not found"
) )
count = this_player.delete_instance() count = this_player.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Player {player_id} has been deleted" return f"Player {player_id} has been deleted"

View File

@ -9,6 +9,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -38,6 +40,8 @@ async def get_schedules(
week_start: Optional[int] = None, week_start: Optional[int] = None,
week_end: Optional[int] = None, week_end: Optional[int] = None,
short_output: Optional[bool] = True, short_output: Optional[bool] = True,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
all_sched = Schedule.select_season(season) all_sched = Schedule.select_season(season)
@ -69,11 +73,13 @@ async def get_schedules(
all_sched = all_sched.order_by(Schedule.id) all_sched = all_sched.order_by(Schedule.id)
total_count = all_sched.count()
all_sched = all_sched.offset(offset).limit(limit)
return_sched = { return_sched = {
"count": all_sched.count(), "count": total_count,
"schedules": [model_to_dict(x, recurse=not short_output) for x in all_sched], "schedules": [model_to_dict(x, recurse=not short_output) for x in all_sched],
} }
db.close()
return return_sched return return_sched
@ -85,7 +91,6 @@ async def get_one_schedule(schedule_id: int):
r_sched = model_to_dict(this_sched) r_sched = model_to_dict(this_sched)
else: else:
r_sched = None r_sched = None
db.close()
return r_sched return r_sched
@ -101,7 +106,7 @@ async def patch_schedule(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_schedule - Bad Token: {token}") logger.warning("patch_schedule - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_sched = Schedule.get_or_none(Schedule.id == schedule_id) this_sched = Schedule.get_or_none(Schedule.id == schedule_id)
@ -127,10 +132,8 @@ async def patch_schedule(
if this_sched.save() == 1: if this_sched.save() == 1:
r_sched = model_to_dict(this_sched) r_sched = model_to_dict(this_sched)
db.close()
return r_sched return r_sched
else: else:
db.close()
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unable to patch schedule {schedule_id}" status_code=500, detail=f"Unable to patch schedule {schedule_id}"
) )
@ -140,26 +143,36 @@ async def patch_schedule(
@handle_db_errors @handle_db_errors
async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_scheme)): async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_schedules - Bad Token: {token}") logger.warning("post_schedules - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_sched = [] new_sched = []
all_team_ids = list(
set(x.awayteam_id for x in sched_list.schedules)
| set(x.hometeam_id for x in sched_list.schedules)
)
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
for x in sched_list.schedules: for x in sched_list.schedules:
if Team.get_or_none(Team.id == x.awayteam_id) is None: if x.awayteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.awayteam_id} not found" status_code=404, detail=f"Team ID {x.awayteam_id} not found"
) )
if Team.get_or_none(Team.id == x.hometeam_id) is None: if x.hometeam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.hometeam_id} not found" status_code=404, detail=f"Team ID {x.hometeam_id} not found"
) )
new_sched.append(x.dict()) new_sched.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_sched, 15): for batch in chunked(new_sched, 15):
Schedule.insert_many(batch).on_conflict_ignore().execute() Schedule.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_sched)} schedules" return f"Inserted {len(new_sched)} schedules"
@ -168,7 +181,7 @@ async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_s
@handle_db_errors @handle_db_errors
async def delete_schedule(schedule_id: int, token: str = Depends(oauth2_scheme)): async def delete_schedule(schedule_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_schedule - Bad Token: {token}") logger.warning("delete_schedule - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_sched = Schedule.get_or_none(Schedule.id == schedule_id) this_sched = Schedule.get_or_none(Schedule.id == schedule_id)
@ -178,7 +191,6 @@ async def delete_schedule(schedule_id: int, token: str = Depends(oauth2_scheme))
) )
count = this_sched.delete_instance() count = this_sched.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Schedule {this_sched} has been deleted" return f"Schedule {this_sched} has been deleted"

View File

@ -1,24 +1,33 @@
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional from typing import List, Optional
import logging import logging
import pydantic
from ..db_engine import db, Standings, Team, Division, model_to_dict, chunked, fn from ..db_engine import db, Standings, Team, Division, model_to_dict, chunked, fn
from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors from ..dependencies import (
oauth2_scheme,
logger = logging.getLogger('discord_app') valid_token,
PRIVATE_IN_SCHEMA,
router = APIRouter( handle_db_errors,
prefix='/api/v3/standings', MAX_LIMIT,
tags=['standings'] DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app")
@router.get('') router = APIRouter(prefix="/api/v3/standings", tags=["standings"])
@router.get("")
@handle_db_errors @handle_db_errors
async def get_standings( async def get_standings(
season: int, team_id: list = Query(default=None), league_abbrev: Optional[str] = None, season: int,
division_abbrev: Optional[str] = None, short_output: Optional[bool] = False): team_id: list = Query(default=None),
league_abbrev: Optional[str] = None,
division_abbrev: Optional[str] = None,
short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
):
standings = Standings.select_season(season) standings = Standings.select_season(season)
# if standings.count() == 0: # if standings.count() == 0:
@ -30,55 +39,67 @@ async def get_standings(
standings = standings.where(Standings.team << t_query) standings = standings.where(Standings.team << t_query)
if league_abbrev is not None: if league_abbrev is not None:
l_query = Division.select().where(fn.Lower(Division.league_abbrev) == league_abbrev.lower()) l_query = Division.select().where(
fn.Lower(Division.league_abbrev) == league_abbrev.lower()
)
standings = standings.where(Standings.team.division << l_query) standings = standings.where(Standings.team.division << l_query)
if division_abbrev is not None: if division_abbrev is not None:
d_query = Division.select().where(fn.Lower(Division.division_abbrev) == division_abbrev.lower()) d_query = Division.select().where(
fn.Lower(Division.division_abbrev) == division_abbrev.lower()
)
standings = standings.where(Standings.team.division << d_query) standings = standings.where(Standings.team.division << d_query)
def win_pct(this_team_stan): def win_pct(this_team_stan):
if this_team_stan.wins + this_team_stan.losses == 0: if this_team_stan.wins + this_team_stan.losses == 0:
return 0 return 0
else: else:
return (this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses)) + \ return (
(this_team_stan.run_diff * .000001) this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses)
) + (this_team_stan.run_diff * 0.000001)
div_teams = [x for x in standings] div_teams = [x for x in standings]
div_teams.sort(key=lambda team: win_pct(team), reverse=True) div_teams.sort(key=lambda team: win_pct(team), reverse=True)
total_count = len(div_teams)
div_teams = div_teams[offset : offset + limit]
return_standings = { return_standings = {
'count': len(div_teams), "count": total_count,
'standings': [model_to_dict(x, recurse=not short_output) for x in div_teams] "standings": [model_to_dict(x, recurse=not short_output) for x in div_teams],
} }
db.close()
return return_standings return return_standings
@router.get('/team/{team_id}') @router.get("/team/{team_id}")
@handle_db_errors @handle_db_errors
async def get_team_standings(team_id: int): async def get_team_standings(team_id: int):
this_stan = Standings.get_or_none(Standings.team_id == team_id) this_stan = Standings.get_or_none(Standings.team_id == team_id)
if this_stan is None: if this_stan is None:
raise HTTPException(status_code=404, detail=f'No standings found for team id {team_id}') raise HTTPException(
status_code=404, detail=f"No standings found for team id {team_id}"
)
return model_to_dict(this_stan) return model_to_dict(this_stan)
@router.patch('/{stan_id}', include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/{stan_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def patch_standings( async def patch_standings(
stan_id, wins: Optional[int] = None, losses: Optional[int] = None, token: str = Depends(oauth2_scheme)): stan_id: int,
wins: Optional[int] = None,
losses: Optional[int] = None,
token: str = Depends(oauth2_scheme),
):
if not valid_token(token): if not valid_token(token):
logger.warning(f'patch_standings - Bad Token: {token}') logger.warning("patch_standings - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
this_stan = Standings.get_by_id(stan_id) this_stan = Standings.get_by_id(stan_id)
except Exception as e: except Exception as e:
db.close() raise HTTPException(status_code=404, detail=f"No team found with id {stan_id}")
raise HTTPException(status_code=404, detail=f'No team found with id {stan_id}')
if wins: if wins:
this_stan.wins = wins this_stan.wins = wins
@ -86,40 +107,37 @@ async def patch_standings(
this_stan.losses = losses this_stan.losses = losses
this_stan.save() this_stan.save()
db.close()
return model_to_dict(this_stan) return model_to_dict(this_stan)
@router.post('/s{season}/new', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/s{season}/new", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def post_standings(season: int, token: str = Depends(oauth2_scheme)): async def post_standings(season: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'post_standings - Bad Token: {token}') logger.warning("post_standings - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
new_teams = [] new_teams = []
all_teams = Team.select().where(Team.season == season) all_teams = Team.select().where(Team.season == season)
for x in all_teams: for x in all_teams:
new_teams.append(Standings({'team_id': x.id})) new_teams.append(Standings({"team_id": x.id}))
with db.atomic(): with db.atomic():
for batch in chunked(new_teams, 16): for batch in chunked(new_teams, 16):
Standings.insert_many(batch).on_conflict_ignore().execute() Standings.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f'Inserted {len(new_teams)} standings' return f"Inserted {len(new_teams)} standings"
@router.post('/s{season}/recalculate', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/s{season}/recalculate", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def recalculate_standings(season: int, token: str = Depends(oauth2_scheme)): async def recalculate_standings(season: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'recalculate_standings - Bad Token: {token}') logger.warning("recalculate_standings - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
code = Standings.recalculate(season) code = Standings.recalculate(season)
db.close()
if code == 69: if code == 69:
raise HTTPException(status_code=500, detail=f'Error recreating Standings rows') raise HTTPException(status_code=500, detail=f"Error recreating Standings rows")
return f'Just recalculated standings for season {season}' return f"Just recalculated standings for season {season}"

View File

@ -13,6 +13,7 @@ from ..dependencies import (
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
update_season_batting_stats, update_season_batting_stats,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -59,6 +60,8 @@ async def get_games(
division_id: Optional[int] = None, division_id: Optional[int] = None,
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
sort: Optional[str] = None, sort: Optional[str] = None,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
) -> Any: ) -> Any:
all_games = StratGame.select() all_games = StratGame.select()
@ -119,11 +122,13 @@ async def get_games(
StratGame.season, StratGame.week, StratGame.game_num StratGame.season, StratGame.week, StratGame.game_num
) )
total_count = all_games.count()
all_games = all_games.offset(offset).limit(limit)
return_games = { return_games = {
"count": all_games.count(), "count": total_count,
"games": [model_to_dict(x, recurse=not short_output) for x in all_games], "games": [model_to_dict(x, recurse=not short_output) for x in all_games],
} }
db.close()
return return_games return return_games
@ -132,11 +137,9 @@ async def get_games(
async def get_one_game(game_id: int) -> Any: async def get_one_game(game_id: int) -> Any:
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found") raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found")
g_result = model_to_dict(this_game) g_result = model_to_dict(this_game)
db.close()
return g_result return g_result
@ -153,12 +156,11 @@ async def patch_game(
scorecard_url: Optional[str] = None, scorecard_url: Optional[str] = None,
) -> Any: ) -> Any:
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_game - Bad Token: {token}") logger.warning("patch_game - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found") raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found")
if game_num is not None: if game_num is not None:
@ -234,7 +236,7 @@ async def patch_game(
@handle_db_errors @handle_db_errors
async def post_games(game_list: GameList, token: str = Depends(oauth2_scheme)) -> Any: async def post_games(game_list: GameList, token: str = Depends(oauth2_scheme)) -> Any:
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_games - Bad Token: {token}") logger.warning("post_games - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_games = [] new_games = []
@ -248,12 +250,11 @@ async def post_games(game_list: GameList, token: str = Depends(oauth2_scheme)) -
status_code=404, detail=f"Team ID {x.home_team_id} not found" status_code=404, detail=f"Team ID {x.home_team_id} not found"
) )
new_games.append(x.dict()) new_games.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_games, 16): for batch in chunked(new_games, 16):
StratGame.insert_many(batch).on_conflict_ignore().execute() StratGame.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_games)} games" return f"Inserted {len(new_games)} games"
@ -262,12 +263,11 @@ async def post_games(game_list: GameList, token: str = Depends(oauth2_scheme)) -
@handle_db_errors @handle_db_errors
async def wipe_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any: async def wipe_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any:
if not valid_token(token): if not valid_token(token):
logger.warning(f"wipe_game - Bad Token: {token}") logger.warning("wipe_game - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found") raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found")
this_game.away_score = None this_game.away_score = None
@ -278,10 +278,8 @@ async def wipe_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any:
if this_game.save() == 1: if this_game.save() == 1:
g_result = model_to_dict(this_game) g_result = model_to_dict(this_game)
db.close()
return g_result return g_result
else: else:
db.close()
raise HTTPException(status_code=500, detail=f"Unable to wipe game {game_id}") raise HTTPException(status_code=500, detail=f"Unable to wipe game {game_id}")
@ -289,16 +287,14 @@ async def wipe_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any:
@handle_db_errors @handle_db_errors
async def delete_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any: async def delete_game(game_id: int, token: str = Depends(oauth2_scheme)) -> Any:
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_game - Bad Token: {token}") logger.warning("delete_game - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found") raise HTTPException(status_code=404, detail=f"StratGame ID {game_id} not found")
count = this_game.delete_instance() count = this_game.delete_instance()
db.close()
if count == 1: if count == 1:
return f"StratGame {game_id} has been deleted" return f"StratGame {game_id} has been deleted"

View File

@ -13,7 +13,13 @@ from ...db_engine import (
fn, fn,
model_to_dict, model_to_dict,
) )
from ...dependencies import add_cache_headers, cache_result, handle_db_errors from ...dependencies import (
add_cache_headers,
cache_result,
handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from .common import build_season_games from .common import build_season_games
router = APIRouter() router = APIRouter()
@ -52,7 +58,7 @@ async def get_batting_totals(
risp: Optional[bool] = None, risp: Optional[bool] = None,
inning: list = Query(default=None), inning: list = Query(default=None),
sort: Optional[str] = None, sort: Optional[str] = None,
limit: Optional[int] = 200, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
page_num: Optional[int] = 1, page_num: Optional[int] = 1,
week_start: Optional[int] = None, week_start: Optional[int] = None,
@ -423,8 +429,6 @@ async def get_batting_totals(
run_plays = run_plays.order_by(StratPlay.game.asc()) run_plays = run_plays.order_by(StratPlay.game.asc())
# For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY # For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY
if limit < 1:
limit = 1
bat_plays = bat_plays.paginate(page_num, limit) bat_plays = bat_plays.paginate(page_num, limit)
logger.info(f"bat_plays query: {bat_plays}") logger.info(f"bat_plays query: {bat_plays}")
@ -594,5 +598,4 @@ async def get_batting_totals(
} }
) )
db.close()
return return_stats return return_stats

View File

@ -20,10 +20,8 @@ logger = logging.getLogger("discord_app")
@handle_db_errors @handle_db_errors
async def get_one_play(play_id: int): async def get_one_play(play_id: int):
if StratPlay.get_or_none(StratPlay.id == play_id) is None: if StratPlay.get_or_none(StratPlay.id == play_id) is None:
db.close()
raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found") raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found")
r_play = model_to_dict(StratPlay.get_by_id(play_id)) r_play = model_to_dict(StratPlay.get_by_id(play_id))
db.close()
return r_play return r_play
@ -33,16 +31,14 @@ async def patch_play(
play_id: int, new_play: PlayModel, token: str = Depends(oauth2_scheme) play_id: int, new_play: PlayModel, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_play - Bad Token: {token}") logger.warning("patch_play - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
if StratPlay.get_or_none(StratPlay.id == play_id) is None: if StratPlay.get_or_none(StratPlay.id == play_id) is None:
db.close()
raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found") raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found")
StratPlay.update(**new_play.dict()).where(StratPlay.id == play_id).execute() StratPlay.update(**new_play.model_dump()).where(StratPlay.id == play_id).execute()
r_play = model_to_dict(StratPlay.get_by_id(play_id)) r_play = model_to_dict(StratPlay.get_by_id(play_id))
db.close()
return r_play return r_play
@ -50,7 +46,7 @@ async def patch_play(
@handle_db_errors @handle_db_errors
async def post_plays(p_list: PlayList, token: str = Depends(oauth2_scheme)): async def post_plays(p_list: PlayList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_plays - Bad Token: {token}") logger.warning("post_plays - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_plays = [] new_plays = []
@ -88,12 +84,11 @@ async def post_plays(p_list: PlayList, token: str = Depends(oauth2_scheme)):
if this_play.pa == 0: if this_play.pa == 0:
this_play.batter_final = None this_play.batter_final = None
new_plays.append(this_play.dict()) new_plays.append(this_play.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(new_plays, 20): for batch in chunked(new_plays, 20):
StratPlay.insert_many(batch).on_conflict_ignore().execute() StratPlay.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"Inserted {len(new_plays)} plays" return f"Inserted {len(new_plays)} plays"
@ -102,16 +97,14 @@ async def post_plays(p_list: PlayList, token: str = Depends(oauth2_scheme)):
@handle_db_errors @handle_db_errors
async def delete_play(play_id: int, token: str = Depends(oauth2_scheme)): async def delete_play(play_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_play - Bad Token: {token}") logger.warning("delete_play - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_play = StratPlay.get_or_none(StratPlay.id == play_id) this_play = StratPlay.get_or_none(StratPlay.id == play_id)
if not this_play: if not this_play:
db.close()
raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found") raise HTTPException(status_code=404, detail=f"Play ID {play_id} not found")
count = this_play.delete_instance() count = this_play.delete_instance()
db.close()
if count == 1: if count == 1:
return f"Play {play_id} has been deleted" return f"Play {play_id} has been deleted"
@ -125,16 +118,14 @@ async def delete_play(play_id: int, token: str = Depends(oauth2_scheme)):
@handle_db_errors @handle_db_errors
async def delete_plays_game(game_id: int, token: str = Depends(oauth2_scheme)): async def delete_plays_game(game_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_plays_game - Bad Token: {token}") logger.warning("delete_plays_game - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close()
raise HTTPException(status_code=404, detail=f"Game ID {game_id} not found") raise HTTPException(status_code=404, detail=f"Game ID {game_id} not found")
count = StratPlay.delete().where(StratPlay.game == this_game).execute() count = StratPlay.delete().where(StratPlay.game == this_game).execute()
db.close()
if count > 0: if count > 0:
return f"Deleted {count} plays matching Game ID {game_id}" return f"Deleted {count} plays matching Game ID {game_id}"
@ -148,12 +139,11 @@ async def delete_plays_game(game_id: int, token: str = Depends(oauth2_scheme)):
@handle_db_errors @handle_db_errors
async def post_erun_check(token: str = Depends(oauth2_scheme)): async def post_erun_check(token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_erun_check - Bad Token: {token}") logger.warning("post_erun_check - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
all_plays = StratPlay.update(run=1).where( all_plays = StratPlay.update(run=1).where(
(StratPlay.e_run == 1) & (StratPlay.run == 0) (StratPlay.e_run == 1) & (StratPlay.run == 0)
) )
count = all_plays.execute() count = all_plays.execute()
db.close()
return count return count

View File

@ -13,7 +13,13 @@ from ...db_engine import (
fn, fn,
SQL, SQL,
) )
from ...dependencies import handle_db_errors, add_cache_headers, cache_result from ...dependencies import (
handle_db_errors,
add_cache_headers,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from .common import build_season_games from .common import build_season_games
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -51,7 +57,7 @@ async def get_fielding_totals(
team_id: list = Query(default=None), team_id: list = Query(default=None),
manager_id: list = Query(default=None), manager_id: list = Query(default=None),
sort: Optional[str] = None, sort: Optional[str] = None,
limit: Optional[int] = 200, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
page_num: Optional[int] = 1, page_num: Optional[int] = 1,
): ):
@ -237,8 +243,6 @@ async def get_fielding_totals(
def_plays = def_plays.order_by(StratPlay.game.asc()) def_plays = def_plays.order_by(StratPlay.game.asc())
# For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY # For other group_by values, skip game_id/play_num sorting since they're not in GROUP BY
if limit < 1:
limit = 1
def_plays = def_plays.paginate(page_num, limit) def_plays = def_plays.paginate(page_num, limit)
logger.info(f"def_plays query: {def_plays}") logger.info(f"def_plays query: {def_plays}")
@ -361,5 +365,4 @@ async def get_fielding_totals(
"week": this_week, "week": this_week,
} }
) )
db.close()
return return_stats return return_stats

View File

@ -16,7 +16,13 @@ from ...db_engine import (
SQL, SQL,
complex_data_to_csv, complex_data_to_csv,
) )
from ...dependencies import handle_db_errors, add_cache_headers, cache_result from ...dependencies import (
handle_db_errors,
add_cache_headers,
cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
)
from .common import build_season_games from .common import build_season_games
router = APIRouter() router = APIRouter()
@ -51,7 +57,7 @@ async def get_pitching_totals(
risp: Optional[bool] = None, risp: Optional[bool] = None,
inning: list = Query(default=None), inning: list = Query(default=None),
sort: Optional[str] = None, sort: Optional[str] = None,
limit: Optional[int] = 200, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
csv: Optional[bool] = False, csv: Optional[bool] = False,
page_num: Optional[int] = 1, page_num: Optional[int] = 1,
@ -164,8 +170,6 @@ async def get_pitching_totals(
if group_by in ["playergame", "teamgame"]: if group_by in ["playergame", "teamgame"]:
pitch_plays = pitch_plays.order_by(StratPlay.game.asc()) pitch_plays = pitch_plays.order_by(StratPlay.game.asc())
if limit < 1:
limit = 1
pitch_plays = pitch_plays.paginate(page_num, limit) pitch_plays = pitch_plays.paginate(page_num, limit)
# Execute the Peewee query # Execute the Peewee query
@ -348,7 +352,6 @@ async def get_pitching_totals(
) )
return_stats["count"] = len(return_stats["stats"]) return_stats["count"] = len(return_stats["stats"])
db.close()
if csv: if csv:
return Response( return Response(
content=complex_data_to_csv(return_stats["stats"]), media_type="text/csv" content=complex_data_to_csv(return_stats["stats"]), media_type="text/csv"

View File

@ -16,6 +16,8 @@ from ...dependencies import (
handle_db_errors, handle_db_errors,
add_cache_headers, add_cache_headers,
cache_result, cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -70,7 +72,7 @@ async def get_plays(
pitcher_team_id: list = Query(default=None), pitcher_team_id: list = Query(default=None),
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
sort: Optional[str] = None, sort: Optional[str] = None,
limit: Optional[int] = 200, limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
page_num: Optional[int] = 1, page_num: Optional[int] = 1,
s_type: Literal["regular", "post", "total", None] = None, s_type: Literal["regular", "post", "total", None] = None,
): ):
@ -185,8 +187,6 @@ async def get_plays(
season_games = season_games.where(StratGame.week > 18) season_games = season_games.where(StratGame.week > 18)
all_plays = all_plays.where(StratPlay.game << season_games) all_plays = all_plays.where(StratPlay.game << season_games)
if limit < 1:
limit = 1
bat_plays = all_plays.paginate(page_num, limit) bat_plays = all_plays.paginate(page_num, limit)
if sort == "wpa-desc": if sort == "wpa-desc":
@ -210,5 +210,4 @@ async def get_plays(
"count": all_plays.count(), "count": all_plays.count(),
"plays": [model_to_dict(x, recurse=not short_output) for x in all_plays], "plays": [model_to_dict(x, recurse=not short_output) for x in all_plays],
} }
db.close()
return return_plays return return_plays

View File

@ -11,6 +11,8 @@ from ..dependencies import (
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
cache_result, cache_result,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
from ..services.base import BaseService from ..services.base import BaseService
from ..services.team_service import TeamService from ..services.team_service import TeamService

View File

@ -10,6 +10,8 @@ from ..dependencies import (
valid_token, valid_token,
PRIVATE_IN_SCHEMA, PRIVATE_IN_SCHEMA,
handle_db_errors, handle_db_errors,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
logger = logging.getLogger("discord_app") logger = logging.getLogger("discord_app")
@ -36,7 +38,7 @@ class TransactionList(pydantic.BaseModel):
@router.get("") @router.get("")
@handle_db_errors @handle_db_errors
async def get_transactions( async def get_transactions(
season, season: int,
team_abbrev: list = Query(default=None), team_abbrev: list = Query(default=None),
week_start: Optional[int] = 0, week_start: Optional[int] = 0,
week_end: Optional[int] = None, week_end: Optional[int] = None,
@ -45,8 +47,9 @@ async def get_transactions(
player_name: list = Query(default=None), player_name: list = Query(default=None),
player_id: list = Query(default=None), player_id: list = Query(default=None),
move_id: Optional[str] = None, move_id: Optional[str] = None,
is_trade: Optional[bool] = None,
short_output: Optional[bool] = False, short_output: Optional[bool] = False,
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = Query(default=0, ge=0),
): ):
if season: if season:
transactions = Transaction.select_season(season) transactions = Transaction.select_season(season)
@ -75,30 +78,29 @@ async def get_transactions(
transactions = transactions.where(Transaction.player << these_players) transactions = transactions.where(Transaction.player << these_players)
if cancelled: if cancelled:
transactions = transactions.where(Transaction.cancelled == 1) transactions = transactions.where(Transaction.cancelled == True)
else: else:
transactions = transactions.where(Transaction.cancelled == 0) transactions = transactions.where(Transaction.cancelled == False)
if frozen: if frozen:
transactions = transactions.where(Transaction.frozen == 1) transactions = transactions.where(Transaction.frozen == True)
else: else:
transactions = transactions.where(Transaction.frozen == 0) transactions = transactions.where(Transaction.frozen == False)
if is_trade is not None:
raise HTTPException(
status_code=501, detail="The is_trade parameter is not implemented, yet"
)
transactions = transactions.order_by(-Transaction.week, Transaction.moveid) transactions = transactions.order_by(-Transaction.week, Transaction.moveid)
total_count = transactions.count()
transactions = transactions.offset(offset).limit(limit)
return_trans = { return_trans = {
"count": transactions.count(), "count": total_count,
"limit": limit,
"offset": offset,
"transactions": [ "transactions": [
model_to_dict(x, recurse=not short_output) for x in transactions model_to_dict(x, recurse=not short_output) for x in transactions
], ],
} }
db.close()
return return_trans return return_trans
@ -111,12 +113,11 @@ async def patch_transactions(
cancelled: Optional[bool] = None, cancelled: Optional[bool] = None,
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"patch_transactions - Bad Token: {token}") logger.warning("patch_transactions - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
these_moves = Transaction.select().where(Transaction.moveid == move_id) these_moves = Transaction.select().where(Transaction.moveid == move_id)
if these_moves.count() == 0: if these_moves.count() == 0:
db.close()
raise HTTPException(status_code=404, detail=f"Move ID {move_id} not found") raise HTTPException(status_code=404, detail=f"Move ID {move_id} not found")
if frozen is not None: if frozen is not None:
@ -128,7 +129,6 @@ async def patch_transactions(
x.cancelled = cancelled x.cancelled = cancelled
x.save() x.save()
db.close()
return f"Updated {these_moves.count()} transactions" return f"Updated {these_moves.count()} transactions"
@ -138,32 +138,46 @@ async def post_transactions(
moves: TransactionList, token: str = Depends(oauth2_scheme) moves: TransactionList, token: str = Depends(oauth2_scheme)
): ):
if not valid_token(token): if not valid_token(token):
logger.warning(f"post_transactions - Bad Token: {token}") logger.warning("post_transactions - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
all_moves = [] all_moves = []
all_team_ids = list(
set(x.oldteam_id for x in moves.moves) | set(x.newteam_id for x in moves.moves)
)
all_player_ids = list(set(x.player_id for x in moves.moves))
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
found_player_ids = (
set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids))
if all_player_ids
else set()
)
for x in moves.moves: for x in moves.moves:
if Team.get_or_none(Team.id == x.oldteam_id) is None: if x.oldteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.oldteam_id} not found" status_code=404, detail=f"Team ID {x.oldteam_id} not found"
) )
if Team.get_or_none(Team.id == x.newteam_id) is None: if x.newteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.newteam_id} not found" status_code=404, detail=f"Team ID {x.newteam_id} not found"
) )
if Player.get_or_none(Player.id == x.player_id) is None: if x.player_id not in found_player_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Player ID {x.player_id} not found" status_code=404, detail=f"Player ID {x.player_id} not found"
) )
all_moves.append(x.dict()) all_moves.append(x.model_dump())
with db.atomic(): with db.atomic():
for batch in chunked(all_moves, 15): for batch in chunked(all_moves, 15):
Transaction.insert_many(batch).on_conflict_ignore().execute() Transaction.insert_many(batch).on_conflict_ignore().execute()
db.close()
return f"{len(all_moves)} transactions have been added" return f"{len(all_moves)} transactions have been added"
@ -171,13 +185,12 @@ async def post_transactions(
@handle_db_errors @handle_db_errors
async def delete_transactions(move_id, token: str = Depends(oauth2_scheme)): async def delete_transactions(move_id, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f"delete_transactions - Bad Token: {token}") logger.warning("delete_transactions - Bad Token")
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
delete_query = Transaction.delete().where(Transaction.moveid == move_id) delete_query = Transaction.delete().where(Transaction.moveid == move_id)
count = delete_query.execute() count = delete_query.execute()
db.close()
if count > 0: if count > 0:
return f"Removed {count} transactions" return f"Removed {count} transactions"
else: else:

View File

@ -3,187 +3,282 @@ from typing import List, Literal, Optional
import logging import logging
import pydantic import pydantic
from ..db_engine import SeasonBattingStats, SeasonPitchingStats, db, Manager, Team, Current, model_to_dict, fn, query_to_csv, StratPlay, StratGame from ..db_engine import (
from ..dependencies import add_cache_headers, cache_result, oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors, update_season_batting_stats, update_season_pitching_stats, get_cache_stats SeasonBattingStats,
SeasonPitchingStats,
logger = logging.getLogger('discord_app') db,
Manager,
router = APIRouter( Team,
prefix='/api/v3/views', Current,
tags=['views'] model_to_dict,
fn,
query_to_csv,
StratPlay,
StratGame,
)
from ..dependencies import (
add_cache_headers,
cache_result,
oauth2_scheme,
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
update_season_batting_stats,
update_season_pitching_stats,
get_cache_stats,
MAX_LIMIT,
DEFAULT_LIMIT,
) )
@router.get('/season-stats/batting') logger = logging.getLogger("discord_app")
router = APIRouter(prefix="/api/v3/views", tags=["views"])
@router.get("/season-stats/batting")
@handle_db_errors @handle_db_errors
@add_cache_headers(max_age=10 * 60) @add_cache_headers(max_age=10 * 60)
@cache_result(ttl=5*60, key_prefix='season-batting') @cache_result(ttl=5 * 60, key_prefix="season-batting")
async def get_season_batting_stats( async def get_season_batting_stats(
season: Optional[int] = None, season: Optional[int] = None,
team_id: Optional[int] = None, team_id: Optional[int] = None,
player_id: Optional[int] = None, player_id: Optional[int] = None,
sbaplayer_id: Optional[int] = None, sbaplayer_id: Optional[int] = None,
min_pa: Optional[int] = None, # Minimum plate appearances min_pa: Optional[int] = None, # Minimum plate appearances
sort_by: str = "woba", # Default sort field sort_by: Literal[
sort_order: Literal['asc', 'desc'] = 'desc', # asc or desc "pa",
limit: Optional[int] = 200, "ab",
"run",
"hit",
"double",
"triple",
"homerun",
"rbi",
"bb",
"so",
"bphr",
"bpfo",
"bp1b",
"bplo",
"gidp",
"hbp",
"sac",
"ibb",
"avg",
"obp",
"slg",
"ops",
"woba",
"k_pct",
"sb",
"cs",
] = "woba", # Sort field
sort_order: Literal["asc", "desc"] = "desc", # asc or desc
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = 0, offset: int = 0,
csv: Optional[bool] = False csv: Optional[bool] = False,
): ):
logger.info(f'Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') logger.info(
f"Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}"
)
# Use the enhanced get_top_hitters method # Use the enhanced get_top_hitters method
query = SeasonBattingStats.get_top_hitters( query = SeasonBattingStats.get_top_hitters(
season=season, season=season,
stat=sort_by, stat=sort_by,
limit=limit if limit != 0 else None, limit=limit if limit != 0 else None,
desc=(sort_order.lower() == 'desc'), desc=(sort_order.lower() == "desc"),
team_id=team_id, team_id=team_id,
player_id=player_id, player_id=player_id,
sbaplayer_id=sbaplayer_id, sbaplayer_id=sbaplayer_id,
min_pa=min_pa, min_pa=min_pa,
offset=offset offset=offset,
) )
# Build applied filters for response # Build applied filters for response
applied_filters = {} applied_filters = {}
if season is not None: if season is not None:
applied_filters['season'] = season applied_filters["season"] = season
if team_id is not None: if team_id is not None:
applied_filters['team_id'] = team_id applied_filters["team_id"] = team_id
if player_id is not None: if player_id is not None:
applied_filters['player_id'] = player_id applied_filters["player_id"] = player_id
if min_pa is not None: if min_pa is not None:
applied_filters['min_pa'] = min_pa applied_filters["min_pa"] = min_pa
if csv: if csv:
return_val = query_to_csv(query) return_val = query_to_csv(query)
return Response(content=return_val, media_type='text/csv') return Response(content=return_val, media_type="text/csv")
else: else:
stat_list = [model_to_dict(stat) for stat in query] stat_list = [model_to_dict(stat) for stat in query]
return { return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list}
'count': len(stat_list),
'filters': applied_filters,
'stats': stat_list
}
@router.post('/season-stats/batting/refresh', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/season-stats/batting/refresh", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def refresh_season_batting_stats( async def refresh_season_batting_stats(
season: int, season: int, token: str = Depends(oauth2_scheme)
token: str = Depends(oauth2_scheme)
) -> dict: ) -> dict:
""" """
Refresh batting stats for all players in a specific season. Refresh batting stats for all players in a specific season.
Useful for full season updates. Useful for full season updates.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') logger.warning("refresh_season_batting_stats - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info(f'Refreshing all batting stats for season {season}') logger.info(f"Refreshing all batting stats for season {season}")
try: try:
# Get all player IDs who have stratplay records in this season # Get all player IDs who have stratplay records in this season
batter_ids = [row.batter_id for row in batter_ids = [
StratPlay.select(StratPlay.batter_id.distinct()) row.batter_id
.join(StratGame).where(StratGame.season == season)] for row in StratPlay.select(StratPlay.batter_id.distinct())
.join(StratGame)
.where(StratGame.season == season)
]
if batter_ids: if batter_ids:
update_season_batting_stats(batter_ids, season, db) update_season_batting_stats(batter_ids, season, db)
logger.info(f'Successfully refreshed {len(batter_ids)} players for season {season}') logger.info(
f"Successfully refreshed {len(batter_ids)} players for season {season}"
)
return { return {
'message': f'Season {season} batting stats refreshed', "message": f"Season {season} batting stats refreshed",
'players_updated': len(batter_ids) "players_updated": len(batter_ids),
} }
else: else:
logger.warning(f'No batting data found for season {season}') logger.warning(f"No batting data found for season {season}")
return { return {
'message': f'No batting data found for season {season}', "message": f"No batting data found for season {season}",
'players_updated': 0 "players_updated": 0,
} }
except Exception as e: except Exception as e:
logger.error(f'Error refreshing season {season}: {e}') logger.error(f"Error refreshing season {season}: {e}")
raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}")
@router.get('/season-stats/pitching') @router.get("/season-stats/pitching")
@handle_db_errors @handle_db_errors
@add_cache_headers(max_age=10 * 60) @add_cache_headers(max_age=10 * 60)
@cache_result(ttl=5*60, key_prefix='season-pitching') @cache_result(ttl=5 * 60, key_prefix="season-pitching")
async def get_season_pitching_stats( async def get_season_pitching_stats(
season: Optional[int] = None, season: Optional[int] = None,
team_id: Optional[int] = None, team_id: Optional[int] = None,
player_id: Optional[int] = None, player_id: Optional[int] = None,
sbaplayer_id: Optional[int] = None, sbaplayer_id: Optional[int] = None,
min_outs: Optional[int] = None, # Minimum outs pitched min_outs: Optional[int] = None, # Minimum outs pitched
sort_by: str = "era", # Default sort field sort_by: Literal[
sort_order: Literal['asc', 'desc'] = 'asc', # asc or desc (asc default for ERA) "tbf",
limit: Optional[int] = 200, "outs",
"games",
"gs",
"win",
"loss",
"hold",
"saves",
"bsave",
"ir",
"irs",
"ab",
"run",
"e_run",
"hits",
"double",
"triple",
"homerun",
"bb",
"so",
"hbp",
"sac",
"ibb",
"gidp",
"sb",
"cs",
"bphr",
"bpfo",
"bp1b",
"bplo",
"wp",
"balk",
"wpa",
"era",
"whip",
"avg",
"obp",
"slg",
"ops",
"woba",
"hper9",
"kper9",
"bbper9",
"kperbb",
"lob_2outs",
"rbipercent",
"re24",
] = "era", # Sort field
sort_order: Literal["asc", "desc"] = "asc", # asc or desc (asc default for ERA)
limit: int = Query(default=DEFAULT_LIMIT, ge=1, le=MAX_LIMIT),
offset: int = 0, offset: int = 0,
csv: Optional[bool] = False csv: Optional[bool] = False,
): ):
logger.info(f'Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') logger.info(
f"Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}"
)
# Use the get_top_pitchers method # Use the get_top_pitchers method
query = SeasonPitchingStats.get_top_pitchers( query = SeasonPitchingStats.get_top_pitchers(
season=season, season=season,
stat=sort_by, stat=sort_by,
limit=limit if limit != 0 else None, limit=limit if limit != 0 else None,
desc=(sort_order.lower() == 'desc'), desc=(sort_order.lower() == "desc"),
team_id=team_id, team_id=team_id,
player_id=player_id, player_id=player_id,
sbaplayer_id=sbaplayer_id, sbaplayer_id=sbaplayer_id,
min_outs=min_outs, min_outs=min_outs,
offset=offset offset=offset,
) )
# Build applied filters for response # Build applied filters for response
applied_filters = {} applied_filters = {}
if season is not None: if season is not None:
applied_filters['season'] = season applied_filters["season"] = season
if team_id is not None: if team_id is not None:
applied_filters['team_id'] = team_id applied_filters["team_id"] = team_id
if player_id is not None: if player_id is not None:
applied_filters['player_id'] = player_id applied_filters["player_id"] = player_id
if min_outs is not None: if min_outs is not None:
applied_filters['min_outs'] = min_outs applied_filters["min_outs"] = min_outs
if csv: if csv:
return_val = query_to_csv(query) return_val = query_to_csv(query)
return Response(content=return_val, media_type='text/csv') return Response(content=return_val, media_type="text/csv")
else: else:
stat_list = [model_to_dict(stat) for stat in query] stat_list = [model_to_dict(stat) for stat in query]
return { return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list}
'count': len(stat_list),
'filters': applied_filters,
'stats': stat_list
}
@router.post('/season-stats/pitching/refresh', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/season-stats/pitching/refresh", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def refresh_season_pitching_stats( async def refresh_season_pitching_stats(
season: int, season: int, token: str = Depends(oauth2_scheme)
token: str = Depends(oauth2_scheme)
) -> dict: ) -> dict:
""" """
Refresh pitching statistics for a specific season by aggregating from individual games. Refresh pitching statistics for a specific season by aggregating from individual games.
Private endpoint - not included in public API documentation. Private endpoint - not included in public API documentation.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') logger.warning("refresh_season_batting_stats - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info(f'Refreshing season {season} pitching stats') logger.info(f"Refreshing season {season} pitching stats")
try: try:
# Get all pitcher IDs for this season # Get all pitcher IDs for this season
pitcher_query = ( pitcher_query = (
StratPlay StratPlay.select(StratPlay.pitcher_id)
.select(StratPlay.pitcher_id)
.join(StratGame, on=(StratPlay.game_id == StratGame.id)) .join(StratGame, on=(StratPlay.game_id == StratGame.id))
.where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False))) .where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False)))
.distinct() .distinct()
@ -191,51 +286,50 @@ async def refresh_season_pitching_stats(
pitcher_ids = [row.pitcher_id for row in pitcher_query] pitcher_ids = [row.pitcher_id for row in pitcher_query]
if not pitcher_ids: if not pitcher_ids:
logger.warning(f'No pitchers found for season {season}') logger.warning(f"No pitchers found for season {season}")
return { return {
'status': 'success', "status": "success",
'message': f'No pitchers found for season {season}', "message": f"No pitchers found for season {season}",
'players_updated': 0 "players_updated": 0,
} }
# Use the dependency function to update pitching stats # Use the dependency function to update pitching stats
update_season_pitching_stats(pitcher_ids, season, db) update_season_pitching_stats(pitcher_ids, season, db)
logger.info(f'Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated') logger.info(
f"Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated"
)
return { return {
'status': 'success', "status": "success",
'message': f'Season {season} pitching stats refreshed', "message": f"Season {season} pitching stats refreshed",
'players_updated': len(pitcher_ids) "players_updated": len(pitcher_ids),
} }
except Exception as e: except Exception as e:
logger.error(f'Error refreshing season {season} pitching stats: {e}') logger.error(f"Error refreshing season {season} pitching stats: {e}")
raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}")
@router.get('/admin/cache', include_in_schema=PRIVATE_IN_SCHEMA) @router.get("/admin/cache", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def get_admin_cache_stats( async def get_admin_cache_stats(token: str = Depends(oauth2_scheme)) -> dict:
token: str = Depends(oauth2_scheme)
) -> dict:
""" """
Get Redis cache statistics and status. Get Redis cache statistics and status.
Private endpoint - requires authentication. Private endpoint - requires authentication.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'get_admin_cache_stats - Bad Token: {token}') logger.warning("get_admin_cache_stats - Bad Token")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info('Getting cache statistics') logger.info("Getting cache statistics")
try: try:
cache_stats = get_cache_stats() cache_stats = get_cache_stats()
logger.info(f'Cache stats retrieved: {cache_stats}') logger.info(f"Cache stats retrieved: {cache_stats}")
return { return {"status": "success", "cache_info": cache_stats}
'status': 'success',
'cache_info': cache_stats
}
except Exception as e: except Exception as e:
logger.error(f'Error getting cache stats: {e}') logger.error(f"Error getting cache stats: {e}")
raise HTTPException(status_code=500, detail=f'Failed to get cache stats: {str(e)}') raise HTTPException(
status_code=500, detail=f"Failed to get cache stats: {str(e)}"
)

View File

@ -39,7 +39,7 @@ class PlayerService(BaseService):
cache_patterns = ["players*", "players-search*", "player*", "team-roster*"] cache_patterns = ["players*", "players-search*", "player*", "team-roster*"]
# Deprecated fields to exclude from player responses # Deprecated fields to exclude from player responses
EXCLUDED_FIELDS = ['pitcher_injury'] EXCLUDED_FIELDS = ["pitcher_injury"]
# Class-level repository for dependency injection # Class-level repository for dependency injection
_injected_repo: Optional[AbstractPlayerRepository] = None _injected_repo: Optional[AbstractPlayerRepository] = None
@ -135,17 +135,21 @@ class PlayerService(BaseService):
# Apply sorting # Apply sorting
query = cls._apply_player_sort(query, sort) query = cls._apply_player_sort(query, sort)
# Convert to list of dicts # Apply pagination at DB level for real queries, Python level for mocks
if isinstance(query, InMemoryQueryResult):
total_count = len(query)
players_data = cls._query_to_player_dicts(query, short_output) players_data = cls._query_to_player_dicts(query, short_output)
# Store total count before pagination
total_count = len(players_data)
# Apply pagination (offset and limit)
if offset is not None: if offset is not None:
players_data = players_data[offset:] players_data = players_data[offset:]
if limit is not None: if limit is not None:
players_data = players_data[:limit] players_data = players_data[:limit]
else:
total_count = query.count()
if offset is not None:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
players_data = cls._query_to_player_dicts(query, short_output)
# Return format # Return format
if as_csv: if as_csv:
@ -154,7 +158,7 @@ class PlayerService(BaseService):
return { return {
"count": len(players_data), "count": len(players_data),
"total": total_count, "total": total_count,
"players": players_data "players": players_data,
} }
except Exception as e: except Exception as e:
@ -204,9 +208,9 @@ class PlayerService(BaseService):
p_list = [x.upper() for x in pos] p_list = [x.upper() for x in pos]
# Expand generic "P" to match all pitcher positions # Expand generic "P" to match all pitcher positions
pitcher_positions = ['SP', 'RP', 'CP'] pitcher_positions = ["SP", "RP", "CP"]
if 'P' in p_list: if "P" in p_list:
p_list.remove('P') p_list.remove("P")
p_list.extend(pitcher_positions) p_list.extend(pitcher_positions)
pos_conditions = ( pos_conditions = (
@ -245,9 +249,9 @@ class PlayerService(BaseService):
p_list = [p.upper() for p in pos] p_list = [p.upper() for p in pos]
# Expand generic "P" to match all pitcher positions # Expand generic "P" to match all pitcher positions
pitcher_positions = ['SP', 'RP', 'CP'] pitcher_positions = ["SP", "RP", "CP"]
if 'P' in p_list: if "P" in p_list:
p_list.remove('P') p_list.remove("P")
p_list.extend(pitcher_positions) p_list.extend(pitcher_positions)
player_pos = [ player_pos = [
@ -385,19 +389,23 @@ class PlayerService(BaseService):
# This filters at the database level instead of loading all players # This filters at the database level instead of loading all players
if search_all_seasons: if search_all_seasons:
# Search all seasons, order by season DESC (newest first) # Search all seasons, order by season DESC (newest first)
query = (Player.select() query = (
Player.select()
.where(fn.Lower(Player.name).contains(query_lower)) .where(fn.Lower(Player.name).contains(query_lower))
.order_by(Player.season.desc(), Player.name) .order_by(Player.season.desc(), Player.name)
.limit(limit * 2)) # Get extra for exact match sorting .limit(limit * 2)
) # Get extra for exact match sorting
else: else:
# Search specific season # Search specific season
query = (Player.select() query = (
Player.select()
.where( .where(
(Player.season == season) & (Player.season == season)
(fn.Lower(Player.name).contains(query_lower)) & (fn.Lower(Player.name).contains(query_lower))
) )
.order_by(Player.name) .order_by(Player.name)
.limit(limit * 2)) # Get extra for exact match sorting .limit(limit * 2)
) # Get extra for exact match sorting
# Execute query and convert limited results to dicts # Execute query and convert limited results to dicts
players = list(query) players = list(query)
@ -468,19 +476,29 @@ class PlayerService(BaseService):
# Use backrefs=False to avoid circular reference issues # Use backrefs=False to avoid circular reference issues
player_dict = model_to_dict(player, recurse=recurse, backrefs=False) player_dict = model_to_dict(player, recurse=recurse, backrefs=False)
# Filter out excluded fields # Filter out excluded fields
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
except (ImportError, AttributeError, TypeError) as e: except (ImportError, AttributeError, TypeError) as e:
# Log the error and fall back to non-recursive serialization # Log the error and fall back to non-recursive serialization
logger.warning(f"Error in recursive player serialization: {e}, falling back to non-recursive") logger.warning(
f"Error in recursive player serialization: {e}, falling back to non-recursive"
)
try: try:
# Fallback to non-recursive serialization # Fallback to non-recursive serialization
player_dict = model_to_dict(player, recurse=False) player_dict = model_to_dict(player, recurse=False)
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
except Exception as fallback_error: except Exception as fallback_error:
# Final fallback to basic dict conversion # Final fallback to basic dict conversion
logger.error(f"Error in non-recursive serialization: {fallback_error}, using basic dict") logger.error(
f"Error in non-recursive serialization: {fallback_error}, using basic dict"
)
player_dict = dict(player) player_dict = dict(player)
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
@classmethod @classmethod
def update_player( def update_player(
@ -508,6 +526,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error updating player {player_id}: {str(e)}" status_code=500, detail=f"Error updating player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def patch_player( def patch_player(
@ -535,6 +555,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error patching player {player_id}: {str(e)}" status_code=500, detail=f"Error patching player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def create_players( def create_players(
@ -567,6 +589,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error creating players: {str(e)}" status_code=500, detail=f"Error creating players: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def delete_player(cls, player_id: int, token: str) -> Dict[str, str]: def delete_player(cls, player_id: int, token: str) -> Dict[str, str]:
@ -590,6 +614,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error deleting player {player_id}: {str(e)}" status_code=500, detail=f"Error deleting player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def _format_player_csv(cls, players: List[Dict]) -> str: def _format_player_csv(cls, players: List[Dict]) -> str:
@ -603,12 +629,12 @@ class PlayerService(BaseService):
flat_player = player.copy() flat_player = player.copy()
# Flatten team object to just abbreviation # Flatten team object to just abbreviation
if isinstance(flat_player.get('team'), dict): if isinstance(flat_player.get("team"), dict):
flat_player['team'] = flat_player['team'].get('abbrev', '') flat_player["team"] = flat_player["team"].get("abbrev", "")
# Flatten sbaplayer object to just ID # Flatten sbaplayer object to just ID
if isinstance(flat_player.get('sbaplayer'), dict): if isinstance(flat_player.get("sbaplayer"), dict):
flat_player['sbaplayer'] = flat_player['sbaplayer'].get('id', '') flat_player["sbaplayer"] = flat_player["sbaplayer"].get("id", "")
flattened_players.append(flat_player) flattened_players.append(flat_player)

View File

@ -34,6 +34,7 @@ services:
- REDIS_HOST=sba_redis - REDIS_HOST=sba_redis
- REDIS_PORT=6379 - REDIS_PORT=6379
- REDIS_DB=0 - REDIS_DB=0
- DISCORD_WEBHOOK_URL=${DISCORD_WEBHOOK_URL}
depends_on: depends_on:
- postgres - postgres
- redis - redis

88
migrations.py Normal file
View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""Apply pending SQL migrations and record them in schema_versions.
Usage:
python migrations.py
Connects to PostgreSQL using the same environment variables as the API:
POSTGRES_DB (default: sba_master)
POSTGRES_USER (default: sba_admin)
POSTGRES_PASSWORD (required)
POSTGRES_HOST (default: sba_postgres)
POSTGRES_PORT (default: 5432)
On first run against an existing database, all migrations will be applied.
All migration files use IF NOT EXISTS guards so re-applying is safe.
"""
import os
import sys
from pathlib import Path
import psycopg2
MIGRATIONS_DIR = Path(__file__).parent / "migrations"
_CREATE_SCHEMA_VERSIONS = """
CREATE TABLE IF NOT EXISTS schema_versions (
filename VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
);
"""
def _get_connection():
password = os.environ.get("POSTGRES_PASSWORD")
if password is None:
raise RuntimeError("POSTGRES_PASSWORD environment variable is not set")
return psycopg2.connect(
dbname=os.environ.get("POSTGRES_DB", "sba_master"),
user=os.environ.get("POSTGRES_USER", "sba_admin"),
password=password,
host=os.environ.get("POSTGRES_HOST", "sba_postgres"),
port=int(os.environ.get("POSTGRES_PORT", "5432")),
)
def main():
conn = _get_connection()
try:
with conn:
with conn.cursor() as cur:
cur.execute(_CREATE_SCHEMA_VERSIONS)
with conn.cursor() as cur:
cur.execute("SELECT filename FROM schema_versions")
applied = {row[0] for row in cur.fetchall()}
migration_files = sorted(MIGRATIONS_DIR.glob("*.sql"))
pending = [f for f in migration_files if f.name not in applied]
if not pending:
print("No pending migrations.")
return
for migration_file in pending:
print(f"Applying {migration_file.name} ...", end=" ", flush=True)
sql = migration_file.read_text()
with conn:
with conn.cursor() as cur:
cur.execute(sql)
cur.execute(
"INSERT INTO schema_versions (filename) VALUES (%s)",
(migration_file.name,),
)
print("done")
print(f"\nApplied {len(pending)} migration(s).")
finally:
conn.close()
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@ -0,0 +1,9 @@
-- Migration: Add schema_versions table for migration tracking
-- Date: 2026-03-27
-- Description: Creates a table to record which SQL migrations have been applied,
-- preventing double-application and missed migrations across environments.
CREATE TABLE IF NOT EXISTS schema_versions (
filename VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT NOW()
);

View File

@ -0,0 +1,24 @@
-- Migration: Add missing indexes on foreign key columns in stratplay and stratgame
-- Created: 2026-03-27
--
-- PostgreSQL does not auto-index foreign key columns. These tables are the
-- highest-volume tables in the schema and are filtered/joined on these columns
-- in batting, pitching, and running stats aggregation and standings recalculation.
-- stratplay: FK join column
CREATE INDEX IF NOT EXISTS idx_stratplay_game_id ON stratplay(game_id);
-- stratplay: filtered in batting stats aggregation
CREATE INDEX IF NOT EXISTS idx_stratplay_batter_id ON stratplay(batter_id);
-- stratplay: filtered in pitching stats aggregation
CREATE INDEX IF NOT EXISTS idx_stratplay_pitcher_id ON stratplay(pitcher_id);
-- stratplay: filtered in running stats
CREATE INDEX IF NOT EXISTS idx_stratplay_runner_id ON stratplay(runner_id);
-- stratgame: heavily filtered by season
CREATE INDEX IF NOT EXISTS idx_stratgame_season ON stratgame(season);
-- stratgame: standings recalculation query ordering
CREATE INDEX IF NOT EXISTS idx_stratgame_season_week_game_num ON stratgame(season, week, game_num);

View File

@ -81,9 +81,9 @@ class TestRouteRegistration:
for route, methods in EXPECTED_PLAY_ROUTES.items(): for route, methods in EXPECTED_PLAY_ROUTES.items():
assert route in paths, f"Route {route} missing from OpenAPI schema" assert route in paths, f"Route {route} missing from OpenAPI schema"
for method in methods: for method in methods:
assert ( assert method in paths[route], (
method in paths[route] f"Method {method.upper()} missing for {route}"
), f"Method {method.upper()} missing for {route}" )
def test_play_routes_have_plays_tag(self, api): def test_play_routes_have_plays_tag(self, api):
"""All play routes should be tagged with 'plays'.""" """All play routes should be tagged with 'plays'."""
@ -96,9 +96,9 @@ class TestRouteRegistration:
for method, spec in paths[route].items(): for method, spec in paths[route].items():
if method in ("get", "post", "patch", "delete"): if method in ("get", "post", "patch", "delete"):
tags = spec.get("tags", []) tags = spec.get("tags", [])
assert ( assert "plays" in tags, (
"plays" in tags f"{method.upper()} {route} missing 'plays' tag, has {tags}"
), f"{method.upper()} {route} missing 'plays' tag, has {tags}" )
@pytest.mark.post_deploy @pytest.mark.post_deploy
@pytest.mark.skip( @pytest.mark.skip(
@ -124,9 +124,9 @@ class TestRouteRegistration:
]: ]:
params = paths[route]["get"].get("parameters", []) params = paths[route]["get"].get("parameters", [])
param_names = [p["name"] for p in params] param_names = [p["name"] for p in params]
assert ( assert "sbaplayer_id" in param_names, (
"sbaplayer_id" in param_names f"sbaplayer_id parameter missing from {route}"
), f"sbaplayer_id parameter missing from {route}" )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -493,10 +493,9 @@ class TestPlayCrud:
assert result["id"] == play_id assert result["id"] == play_id
def test_get_nonexistent_play(self, api): def test_get_nonexistent_play(self, api):
"""GET /plays/999999999 returns an error (wrapped by handle_db_errors).""" """GET /plays/999999999 returns 404 Not Found."""
r = requests.get(f"{api}/api/v3/plays/999999999", timeout=10) r = requests.get(f"{api}/api/v3/plays/999999999", timeout=10)
# handle_db_errors wraps HTTPException as 500 with detail message assert r.status_code == 404
assert r.status_code == 500
assert "not found" in r.json().get("detail", "").lower() assert "not found" in r.json().get("detail", "").lower()
@ -575,9 +574,9 @@ class TestGroupBySbaPlayer:
) )
assert r_seasons.status_code == 200 assert r_seasons.status_code == 200
season_pas = [s["pa"] for s in r_seasons.json()["stats"]] season_pas = [s["pa"] for s in r_seasons.json()["stats"]]
assert career_pa >= max( assert career_pa >= max(season_pas), (
season_pas f"Career PA ({career_pa}) should be >= max season PA ({max(season_pas)})"
), f"Career PA ({career_pa}) should be >= max season PA ({max(season_pas)})" )
@pytest.mark.post_deploy @pytest.mark.post_deploy
def test_batting_sbaplayer_short_output(self, api): def test_batting_sbaplayer_short_output(self, api):

View File

@ -7,21 +7,18 @@ import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import sys import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.player_service import PlayerService from app.services.player_service import PlayerService
from app.services.base import ServiceConfig from app.services.base import ServiceConfig
from app.services.mocks import ( from app.services.mocks import MockPlayerRepository, MockCacheService, EnhancedMockCache
MockPlayerRepository,
MockCacheService,
EnhancedMockCache
)
# ============================================================================ # ============================================================================
# FIXTURES # FIXTURES
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def cache(): def cache():
"""Create fresh cache for each test.""" """Create fresh cache for each test."""
@ -35,12 +32,65 @@ def repo(cache):
# Add test players # Add test players
players = [ players = [
{'id': 1, 'name': 'Mike Trout', 'wara': 5.2, 'team_id': 1, 'season': 10, 'pos_1': 'CF', 'pos_2': 'LF', 'strat_code': 'Elite', 'injury_rating': 'A'}, {
{'id': 2, 'name': 'Aaron Judge', 'wara': 4.8, 'team_id': 2, 'season': 10, 'pos_1': 'RF', 'strat_code': 'Power', 'injury_rating': 'B'}, "id": 1,
{'id': 3, 'name': 'Mookie Betts', 'wara': 5.5, 'team_id': 3, 'season': 10, 'pos_1': 'RF', 'pos_2': '2B', 'strat_code': 'Elite', 'injury_rating': 'A'}, "name": "Mike Trout",
{'id': 4, 'name': 'Injured Player', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': 'P', 'il_return': 'Week 5', 'injury_rating': 'C'}, "wara": 5.2,
{'id': 5, 'name': 'Old Player', 'wara': 1.0, 'team_id': 1, 'season': 5, 'pos_1': '1B'}, "team_id": 1,
{'id': 6, 'name': 'Juan Soto', 'wara': 4.5, 'team_id': 2, 'season': 10, 'pos_1': '1B', 'strat_code': 'Contact'}, "season": 10,
"pos_1": "CF",
"pos_2": "LF",
"strat_code": "Elite",
"injury_rating": "A",
},
{
"id": 2,
"name": "Aaron Judge",
"wara": 4.8,
"team_id": 2,
"season": 10,
"pos_1": "RF",
"strat_code": "Power",
"injury_rating": "B",
},
{
"id": 3,
"name": "Mookie Betts",
"wara": 5.5,
"team_id": 3,
"season": 10,
"pos_1": "RF",
"pos_2": "2B",
"strat_code": "Elite",
"injury_rating": "A",
},
{
"id": 4,
"name": "Injured Player",
"wara": 2.0,
"team_id": 1,
"season": 10,
"pos_1": "P",
"il_return": "Week 5",
"injury_rating": "C",
},
{
"id": 5,
"name": "Old Player",
"wara": 1.0,
"team_id": 1,
"season": 5,
"pos_1": "1B",
},
{
"id": 6,
"name": "Juan Soto",
"wara": 4.5,
"team_id": 2,
"season": 10,
"pos_1": "1B",
"strat_code": "Contact",
},
] ]
for player in players: for player in players:
@ -60,6 +110,7 @@ def service(repo, cache):
# TEST CLASSES # TEST CLASSES
# ============================================================================ # ============================================================================
class TestPlayerServiceGetPlayers: class TestPlayerServiceGetPlayers:
"""Tests for get_players method - 50+ lines covered.""" """Tests for get_players method - 50+ lines covered."""
@ -67,71 +118,73 @@ class TestPlayerServiceGetPlayers:
"""Get all players for a season.""" """Get all players for a season."""
result = service.get_players(season=10) result = service.get_players(season=10)
assert result['count'] >= 5 # We have 5 season 10 players assert result["count"] >= 5 # We have 5 season 10 players
assert len(result['players']) >= 5 assert len(result["players"]) >= 5
assert all(p.get('season') == 10 for p in result['players']) assert all(p.get("season") == 10 for p in result["players"])
def test_filter_by_single_team(self, service): def test_filter_by_single_team(self, service):
"""Filter by single team ID.""" """Filter by single team ID."""
result = service.get_players(season=10, team_id=[1]) result = service.get_players(season=10, team_id=[1])
assert result['count'] >= 1 assert result["count"] >= 1
assert all(p.get('team_id') == 1 for p in result['players']) assert all(p.get("team_id") == 1 for p in result["players"])
def test_filter_by_multiple_teams(self, service): def test_filter_by_multiple_teams(self, service):
"""Filter by multiple team IDs.""" """Filter by multiple team IDs."""
result = service.get_players(season=10, team_id=[1, 2]) result = service.get_players(season=10, team_id=[1, 2])
assert result['count'] >= 2 assert result["count"] >= 2
assert all(p.get('team_id') in [1, 2] for p in result['players']) assert all(p.get("team_id") in [1, 2] for p in result["players"])
def test_filter_by_position(self, service): def test_filter_by_position(self, service):
"""Filter by position.""" """Filter by position."""
result = service.get_players(season=10, pos=['CF']) result = service.get_players(season=10, pos=["CF"])
assert result['count'] >= 1 assert result["count"] >= 1
assert any(p.get('pos_1') == 'CF' or p.get('pos_2') == 'CF' for p in result['players']) assert any(
p.get("pos_1") == "CF" or p.get("pos_2") == "CF" for p in result["players"]
)
def test_filter_by_strat_code(self, service): def test_filter_by_strat_code(self, service):
"""Filter by strat code.""" """Filter by strat code."""
result = service.get_players(season=10, strat_code=['Elite']) result = service.get_players(season=10, strat_code=["Elite"])
assert result['count'] >= 2 # Trout and Betts assert result["count"] >= 2 # Trout and Betts
assert all('Elite' in str(p.get('strat_code', '')) for p in result['players']) assert all("Elite" in str(p.get("strat_code", "")) for p in result["players"])
def test_filter_injured_only(self, service): def test_filter_injured_only(self, service):
"""Filter injured players only.""" """Filter injured players only."""
result = service.get_players(season=10, is_injured=True) result = service.get_players(season=10, is_injured=True)
assert result['count'] >= 1 assert result["count"] >= 1
assert all(p.get('il_return') is not None for p in result['players']) assert all(p.get("il_return") is not None for p in result["players"])
def test_sort_cost_ascending(self, service): def test_sort_cost_ascending(self, service):
"""Sort by WARA ascending.""" """Sort by WARA ascending."""
result = service.get_players(season=10, sort='cost-asc') result = service.get_players(season=10, sort="cost-asc")
wara = [p.get('wara', 0) for p in result['players']] wara = [p.get("wara", 0) for p in result["players"]]
assert wara == sorted(wara) assert wara == sorted(wara)
def test_sort_cost_descending(self, service): def test_sort_cost_descending(self, service):
"""Sort by WARA descending.""" """Sort by WARA descending."""
result = service.get_players(season=10, sort='cost-desc') result = service.get_players(season=10, sort="cost-desc")
wara = [p.get('wara', 0) for p in result['players']] wara = [p.get("wara", 0) for p in result["players"]]
assert wara == sorted(wara, reverse=True) assert wara == sorted(wara, reverse=True)
def test_sort_name_ascending(self, service): def test_sort_name_ascending(self, service):
"""Sort by name ascending.""" """Sort by name ascending."""
result = service.get_players(season=10, sort='name-asc') result = service.get_players(season=10, sort="name-asc")
names = [p.get('name', '') for p in result['players']] names = [p.get("name", "") for p in result["players"]]
assert names == sorted(names) assert names == sorted(names)
def test_sort_name_descending(self, service): def test_sort_name_descending(self, service):
"""Sort by name descending.""" """Sort by name descending."""
result = service.get_players(season=10, sort='name-desc') result = service.get_players(season=10, sort="name-desc")
names = [p.get('name', '') for p in result['players']] names = [p.get("name", "") for p in result["players"]]
assert names == sorted(names, reverse=True) assert names == sorted(names, reverse=True)
@ -140,46 +193,46 @@ class TestPlayerServiceSearch:
def test_exact_name_match(self, service): def test_exact_name_match(self, service):
"""Search with exact name match.""" """Search with exact name match."""
result = service.search_players('Mike Trout', season=10) result = service.search_players("Mike Trout", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
names = [p.get('name') for p in result['players']] names = [p.get("name") for p in result["players"]]
assert 'Mike Trout' in names assert "Mike Trout" in names
def test_partial_name_match(self, service): def test_partial_name_match(self, service):
"""Search with partial name match.""" """Search with partial name match."""
result = service.search_players('Trout', season=10) result = service.search_players("Trout", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
assert any('Trout' in p.get('name', '') for p in result['players']) assert any("Trout" in p.get("name", "") for p in result["players"])
def test_case_insensitive_search(self, service): def test_case_insensitive_search(self, service):
"""Search is case insensitive.""" """Search is case insensitive."""
result1 = service.search_players('MIKE', season=10) result1 = service.search_players("MIKE", season=10)
result2 = service.search_players('mike', season=10) result2 = service.search_players("mike", season=10)
assert result1['count'] == result2['count'] assert result1["count"] == result2["count"]
def test_search_all_seasons(self, service): def test_search_all_seasons(self, service):
"""Search across all seasons.""" """Search across all seasons."""
result = service.search_players('Player', season=None) result = service.search_players("Player", season=None)
# Should find both current and old players # Should find both current and old players
assert result['all_seasons'] == True assert result["all_seasons"] == True
assert result['count'] >= 2 assert result["count"] >= 2
def test_search_limit(self, service): def test_search_limit(self, service):
"""Limit search results.""" """Limit search results."""
result = service.search_players('a', season=10, limit=2) result = service.search_players("a", season=10, limit=2)
assert result['count'] <= 2 assert result["count"] <= 2
def test_search_no_results(self, service): def test_search_no_results(self, service):
"""Search returns empty when no matches.""" """Search returns empty when no matches."""
result = service.search_players('XYZ123NotExist', season=10) result = service.search_players("XYZ123NotExist", season=10)
assert result['count'] == 0 assert result["count"] == 0
assert result['players'] == [] assert result["players"] == []
class TestPlayerServiceGetPlayer: class TestPlayerServiceGetPlayer:
@ -190,8 +243,8 @@ class TestPlayerServiceGetPlayer:
result = service.get_player(1) result = service.get_player(1)
assert result is not None assert result is not None
assert result.get('id') == 1 assert result.get("id") == 1
assert result.get('name') == 'Mike Trout' assert result.get("name") == "Mike Trout"
def test_get_nonexistent_player(self, service): def test_get_nonexistent_player(self, service):
"""Get player that doesn't exist.""" """Get player that doesn't exist."""
@ -204,8 +257,8 @@ class TestPlayerServiceGetPlayer:
result = service.get_player(1, short_output=True) result = service.get_player(1, short_output=True)
# Should still have basic fields # Should still have basic fields
assert result.get('id') == 1 assert result.get("id") == 1
assert result.get('name') == 'Mike Trout' assert result.get("name") == "Mike Trout"
class TestPlayerServiceCreate: class TestPlayerServiceCreate:
@ -216,24 +269,26 @@ class TestPlayerServiceCreate:
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
new_player = [{ new_player = [
'name': 'New Player', {
'wara': 3.0, "name": "New Player",
'team_id': 1, "wara": 3.0,
'season': 10, "team_id": 1,
'pos_1': 'SS' "season": 10,
}] "pos_1": "SS",
}
]
# Mock auth # Mock auth
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.create_players(new_player, 'valid_token') result = service.create_players(new_player, "valid_token")
assert 'Inserted' in str(result) assert "Inserted" in str(result)
# Verify player was added (ID 7 since fixture has players 1-6) # Verify player was added (ID 7 since fixture has players 1-6)
player = repo.get_by_id(7) # Next ID after fixture data player = repo.get_by_id(7) # Next ID after fixture data
assert player is not None assert player is not None
assert player['name'] == 'New Player' assert player["name"] == "New Player"
def test_create_multiple_players(self, repo, cache): def test_create_multiple_players(self, repo, cache):
"""Create multiple new players.""" """Create multiple new players."""
@ -241,37 +296,59 @@ class TestPlayerServiceCreate:
service = PlayerService(config=config) service = PlayerService(config=config)
new_players = [ new_players = [
{'name': 'Player A', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': '2B'}, {
{'name': 'Player B', 'wara': 2.5, 'team_id': 2, 'season': 10, 'pos_1': '3B'}, "name": "Player A",
"wara": 2.0,
"team_id": 1,
"season": 10,
"pos_1": "2B",
},
{
"name": "Player B",
"wara": 2.5,
"team_id": 2,
"season": 10,
"pos_1": "3B",
},
] ]
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.create_players(new_players, 'valid_token') result = service.create_players(new_players, "valid_token")
assert 'Inserted 2 players' in str(result) assert "Inserted 2 players" in str(result)
def test_create_duplicate_fails(self, repo, cache): def test_create_duplicate_fails(self, repo, cache):
"""Creating duplicate player should fail.""" """Creating duplicate player should fail."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
duplicate = [{'name': 'Mike Trout', 'wara': 5.0, 'team_id': 1, 'season': 10, 'pos_1': 'CF'}] duplicate = [
{
"name": "Mike Trout",
"wara": 5.0,
"team_id": 1,
"season": 10,
"pos_1": "CF",
}
]
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.create_players(duplicate, 'valid_token') service.create_players(duplicate, "valid_token")
assert 'already exists' in str(exc_info.value) assert "already exists" in str(exc_info.value)
def test_create_requires_auth(self, repo, cache): def test_create_requires_auth(self, repo, cache):
"""Creating players requires authentication.""" """Creating players requires authentication."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
new_player = [{'name': 'Test', 'wara': 1.0, 'team_id': 1, 'season': 10, 'pos_1': 'P'}] new_player = [
{"name": "Test", "wara": 1.0, "team_id": 1, "season": 10, "pos_1": "P"}
]
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.create_players(new_player, 'bad_token') service.create_players(new_player, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
@ -284,50 +361,46 @@ class TestPlayerServiceUpdate:
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.patch_player(1, {'name': 'New Name'}, 'valid_token') result = service.patch_player(1, {"name": "New Name"}, "valid_token")
assert result is not None assert result is not None
assert result.get('name') == 'New Name' assert result.get("name") == "New Name"
def test_patch_player_wara(self, repo, cache): def test_patch_player_wara(self, repo, cache):
"""Patch player's WARA.""" """Patch player's WARA."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.patch_player(1, {'wara': 6.0}, 'valid_token') result = service.patch_player(1, {"wara": 6.0}, "valid_token")
assert result.get('wara') == 6.0 assert result.get("wara") == 6.0
def test_patch_multiple_fields(self, repo, cache): def test_patch_multiple_fields(self, repo, cache):
"""Patch multiple fields at once.""" """Patch multiple fields at once."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
updates = { updates = {"name": "Updated Name", "wara": 7.0, "strat_code": "Super Elite"}
'name': 'Updated Name',
'wara': 7.0,
'strat_code': 'Super Elite'
}
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.patch_player(1, updates, 'valid_token') result = service.patch_player(1, updates, "valid_token")
assert result.get('name') == 'Updated Name' assert result.get("name") == "Updated Name"
assert result.get('wara') == 7.0 assert result.get("wara") == 7.0
assert result.get('strat_code') == 'Super Elite' assert result.get("strat_code") == "Super Elite"
def test_patch_nonexistent_player(self, repo, cache): def test_patch_nonexistent_player(self, repo, cache):
"""Patch fails for non-existent player.""" """Patch fails for non-existent player."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.patch_player(99999, {'name': 'Test'}, 'valid_token') service.patch_player(99999, {"name": "Test"}, "valid_token")
assert 'not found' in str(exc_info.value) assert "not found" in str(exc_info.value)
def test_patch_requires_auth(self, repo, cache): def test_patch_requires_auth(self, repo, cache):
"""Patching requires authentication.""" """Patching requires authentication."""
@ -335,7 +408,7 @@ class TestPlayerServiceUpdate:
service = PlayerService(config=config) service = PlayerService(config=config)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.patch_player(1, {'name': 'Test'}, 'bad_token') service.patch_player(1, {"name": "Test"}, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
@ -351,10 +424,10 @@ class TestPlayerServiceDelete:
# Verify player exists # Verify player exists
assert repo.get_by_id(1) is not None assert repo.get_by_id(1) is not None
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.delete_player(1, 'valid_token') result = service.delete_player(1, "valid_token")
assert 'deleted' in str(result) assert "deleted" in str(result)
# Verify player is gone # Verify player is gone
assert repo.get_by_id(1) is None assert repo.get_by_id(1) is None
@ -364,11 +437,11 @@ class TestPlayerServiceDelete:
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.delete_player(99999, 'valid_token') service.delete_player(99999, "valid_token")
assert 'not found' in str(exc_info.value) assert "not found" in str(exc_info.value)
def test_delete_requires_auth(self, repo, cache): def test_delete_requires_auth(self, repo, cache):
"""Deleting requires authentication.""" """Deleting requires authentication."""
@ -376,56 +449,11 @@ class TestPlayerServiceDelete:
service = PlayerService(config=config) service = PlayerService(config=config)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.delete_player(1, 'bad_token') service.delete_player(1, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
class TestPlayerServiceCache:
"""Tests for cache functionality."""
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_set_on_read(self, service, cache):
"""Cache is set on player read."""
service.get_players(season=10)
assert cache.was_called('set')
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_invalidation_on_update(self, repo, cache):
"""Cache is invalidated on player update."""
config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config)
# Read to set cache
service.get_players(season=10)
initial_calls = len(cache.get_calls('set'))
# Update should invalidate cache
with patch.object(service, 'require_auth', return_value=True):
service.patch_player(1, {'name': 'Test'}, 'valid_token')
# Should have more delete calls after update
delete_calls = [c for c in cache.get_calls() if c.get('method') == 'delete']
assert len(delete_calls) > 0
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_hit_rate(self, repo, cache):
"""Test cache hit rate tracking."""
config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config)
# First call - cache miss
service.get_players(season=10)
miss_count = cache._miss_count
# Second call - cache hit
service.get_players(season=10)
# Hit rate should have improved
assert cache.hit_rate > 0
class TestPlayerServiceValidation: class TestPlayerServiceValidation:
"""Tests for input validation and edge cases.""" """Tests for input validation and edge cases."""
@ -433,19 +461,19 @@ class TestPlayerServiceValidation:
"""Invalid season returns empty result.""" """Invalid season returns empty result."""
result = service.get_players(season=999) result = service.get_players(season=999)
assert result['count'] == 0 or result['players'] == [] assert result["count"] == 0 or result["players"] == []
def test_empty_search_returns_all(self, service): def test_empty_search_returns_all(self, service):
"""Empty search query returns all players.""" """Empty search query returns all players."""
result = service.search_players('', season=10) result = service.search_players("", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
def test_sort_with_no_results(self, service): def test_sort_with_no_results(self, service):
"""Sorting with no results doesn't error.""" """Sorting with no results doesn't error."""
result = service.get_players(season=999, sort='cost-desc') result = service.get_players(season=999, sort="cost-desc")
assert result['count'] == 0 or result['players'] == [] assert result["count"] == 0 or result["players"] == []
def test_cache_clear_on_create(self, repo, cache): def test_cache_clear_on_create(self, repo, cache):
"""Cache is cleared when new players are created.""" """Cache is cleared when new players are created."""
@ -453,16 +481,21 @@ class TestPlayerServiceValidation:
service = PlayerService(config=config) service = PlayerService(config=config)
# Set up some cache data # Set up some cache data
cache.set('test:key', 'value', 300) cache.set("test:key", "value", 300)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
service.create_players([{ service.create_players(
'name': 'New', [
'wara': 1.0, {
'team_id': 1, "name": "New",
'season': 10, "wara": 1.0,
'pos_1': 'P' "team_id": 1,
}], 'valid_token') "season": 10,
"pos_1": "P",
}
],
"valid_token",
)
# Should have invalidate calls # Should have invalidate calls
assert len(cache.get_calls()) > 0 assert len(cache.get_calls()) > 0
@ -477,30 +510,37 @@ class TestPlayerServiceIntegration:
service = PlayerService(config=config) service = PlayerService(config=config)
# CREATE # CREATE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
create_result = service.create_players([{ create_result = service.create_players(
'name': 'CRUD Test', [
'wara': 3.0, {
'team_id': 1, "name": "CRUD Test",
'season': 10, "wara": 3.0,
'pos_1': 'DH' "team_id": 1,
}], 'valid_token') "season": 10,
"pos_1": "DH",
}
],
"valid_token",
)
# READ # READ
search_result = service.search_players('CRUD', season=10) search_result = service.search_players("CRUD", season=10)
assert search_result['count'] >= 1 assert search_result["count"] >= 1
player_id = search_result['players'][0].get('id') player_id = search_result["players"][0].get("id")
# UPDATE # UPDATE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
update_result = service.patch_player(player_id, {'wara': 4.0}, 'valid_token') update_result = service.patch_player(
assert update_result.get('wara') == 4.0 player_id, {"wara": 4.0}, "valid_token"
)
assert update_result.get("wara") == 4.0
# DELETE # DELETE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
delete_result = service.delete_player(player_id, 'valid_token') delete_result = service.delete_player(player_id, "valid_token")
assert 'deleted' in str(delete_result) assert "deleted" in str(delete_result)
# VERIFY DELETED # VERIFY DELETED
get_result = service.get_player(player_id) get_result = service.get_player(player_id)
@ -510,13 +550,13 @@ class TestPlayerServiceIntegration:
"""Search and then filter operations.""" """Search and then filter operations."""
# First get all players # First get all players
all_result = service.get_players(season=10) all_result = service.get_players(season=10)
initial_count = all_result['count'] initial_count = all_result["count"]
# Then filter by team # Then filter by team
filtered = service.get_players(season=10, team_id=[1]) filtered = service.get_players(season=10, team_id=[1])
# Filtered should be <= all # Filtered should be <= all
assert filtered['count'] <= initial_count assert filtered["count"] <= initial_count
# ============================================================================ # ============================================================================

View File

@ -0,0 +1,154 @@
"""
Tests for query limit/offset parameter validation and middleware behavior.
Verifies that:
- FastAPI enforces MAX_LIMIT cap (returns 422 for limit > 500)
- FastAPI enforces ge=1 on limit (returns 422 for limit=0 or limit=-1)
- Transactions endpoint returns limit/offset keys in the response
- strip_empty_query_params middleware treats ?param= as absent
These tests exercise FastAPI parameter validation which fires before any
handler code runs, so most tests don't require a live DB connection.
The app imports redis and psycopg2 at module level, so we mock those
system-level packages before importing app.main.
"""
import sys
import pytest
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Stub out C-extension / system packages that aren't installed in the test
# environment before any app code is imported.
# ---------------------------------------------------------------------------
_redis_stub = MagicMock()
_redis_stub.Redis = MagicMock(return_value=MagicMock(ping=MagicMock(return_value=True)))
sys.modules.setdefault("redis", _redis_stub)
_psycopg2_stub = MagicMock()
sys.modules.setdefault("psycopg2", _psycopg2_stub)
_playhouse_pool_stub = MagicMock()
sys.modules.setdefault("playhouse.pool", _playhouse_pool_stub)
_playhouse_pool_stub.PooledPostgresqlDatabase = MagicMock()
_pandas_stub = MagicMock()
sys.modules.setdefault("pandas", _pandas_stub)
_pandas_stub.DataFrame = MagicMock()
@pytest.fixture(scope="module")
def client():
"""
TestClient with the Peewee db object mocked so the app can be imported
without a running PostgreSQL instance. FastAPI validates query params
before calling handler code, so 422 responses don't need a real DB.
"""
mock_db = MagicMock()
mock_db.is_closed.return_value = False
mock_db.connect.return_value = None
mock_db.close.return_value = None
with patch("app.db_engine.db", mock_db):
from fastapi.testclient import TestClient
from app.main import app
with TestClient(app, raise_server_exceptions=False) as c:
yield c
def test_limit_exceeds_max_returns_422(client):
"""
GET /api/v3/decisions with limit=1000 should return 422.
MAX_LIMIT is 500; the decisions endpoint declares
limit: int = Query(ge=1, le=MAX_LIMIT), so FastAPI rejects values > 500
before any handler code runs.
"""
response = client.get("/api/v3/decisions?limit=1000")
assert response.status_code == 422
def test_limit_zero_returns_422(client):
"""
GET /api/v3/decisions with limit=0 should return 422.
Query(ge=1) rejects zero values.
"""
response = client.get("/api/v3/decisions?limit=0")
assert response.status_code == 422
def test_limit_negative_returns_422(client):
"""
GET /api/v3/decisions with limit=-1 should return 422.
Query(ge=1) rejects negative values.
"""
response = client.get("/api/v3/decisions?limit=-1")
assert response.status_code == 422
def test_transactions_has_limit_in_response(client):
"""
GET /api/v3/transactions?season=12 should include 'limit' and 'offset'
keys in the JSON response body.
The transactions endpoint was updated to return pagination metadata
alongside results so callers know the applied page size.
"""
mock_qs = MagicMock()
mock_qs.count.return_value = 0
mock_qs.where.return_value = mock_qs
mock_qs.order_by.return_value = mock_qs
mock_qs.offset.return_value = mock_qs
mock_qs.limit.return_value = mock_qs
mock_qs.__iter__ = MagicMock(return_value=iter([]))
with (
patch("app.routers_v3.transactions.Transaction") as mock_txn,
patch("app.routers_v3.transactions.Team") as mock_team,
patch("app.routers_v3.transactions.Player") as mock_player,
):
mock_txn.select_season.return_value = mock_qs
mock_txn.select.return_value = mock_qs
mock_team.select.return_value = mock_qs
mock_player.select.return_value = mock_qs
response = client.get("/api/v3/transactions?season=12")
# If the mock is sufficient the response is 200 with pagination keys;
# if some DB path still fires we at least confirm limit param is accepted.
assert response.status_code != 422
if response.status_code == 200:
data = response.json()
assert "limit" in data, "Response missing 'limit' key"
assert "offset" in data, "Response missing 'offset' key"
def test_empty_string_param_stripped(client):
"""
Query params with an empty string value should be treated as absent.
The strip_empty_query_params middleware rewrites the query string before
FastAPI parses it, so ?league_abbrev= is removed entirely rather than
forwarded as an empty string to the handler.
Expected: the request is accepted (not 422) and the empty param is ignored.
"""
mock_qs = MagicMock()
mock_qs.count.return_value = 0
mock_qs.where.return_value = mock_qs
mock_qs.__iter__ = MagicMock(return_value=iter([]))
with patch("app.routers_v3.standings.Standings") as mock_standings:
mock_standings.select_season.return_value = mock_qs
# ?league_abbrev= should be stripped → treated as absent (None), not ""
response = client.get("/api/v3/standings?season=12&league_abbrev=")
assert response.status_code != 422, (
"Empty string query param caused a 422 — middleware may not be stripping it"
)