Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dd253dab09 | |||
|
|
0ea1c1d633 | ||
| c8cb80c5f3 | |||
|
|
6016afb999 | ||
| f95c857363 | |||
|
|
174ce4474d | ||
|
|
2091302b8a | ||
| 27a272b813 | |||
|
|
95010bfd5d | ||
| deb40476a4 | |||
|
|
65d3099a7c | ||
| 8e02889fd4 | |||
|
|
b872a05397 | ||
| 6889499fff | |||
|
|
3c453c89ce | ||
| be4213aab6 | |||
|
|
4e75656225 | ||
| c30e0ad321 | |||
|
|
b57f91833b | ||
| 04efc46382 | |||
|
|
7e7aa46a73 | ||
| 91b367af93 | |||
|
|
3c24e03a0c | ||
| fd24a41422 | |||
| daa3366b60 | |||
| ee2387a385 | |||
|
|
6f3339a42e | ||
| 498fcdfe51 | |||
|
|
a3e63f730f | ||
| f0934937cb | |||
| 4775c175c5 | |||
|
|
ce2c47ca0c | ||
| 0c041bce99 | |||
|
|
70c4555a74 | ||
|
|
8878ce85f7 | ||
|
|
008d6be86c | ||
| 18ab1393c0 | |||
| 8862850c59 | |||
| 8d97e1dd17 | |||
| 52fa56cb69 | |||
|
|
d4e7246166 | ||
|
|
0992acf718 | ||
|
|
b480120731 | ||
| 6d3c7305ce | |||
|
|
9df8d77fa0 | ||
|
|
df9e9bedbe | ||
|
|
c8ed4dee38 | ||
|
|
03dd449551 | ||
| 6c49233392 | |||
|
|
9a4ecda564 | ||
|
|
d295f27afe |
@ -1,22 +1,18 @@
|
||||
# Gitea Actions: Docker Build, Push, and Notify
|
||||
#
|
||||
# CI/CD pipeline for Major Domo Discord Bot:
|
||||
# - Builds Docker images on every push/PR
|
||||
# - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges
|
||||
# - Supports multi-channel releases: stable (main), rc (next-release), dev (PRs)
|
||||
# - Pushes to Docker Hub and creates git tag on main
|
||||
# - Triggered by pushing a CalVer tag (e.g., 2026.3.11)
|
||||
# - Builds Docker image and pushes to Docker Hub with version + production tags
|
||||
# - Sends Discord notifications on success/failure
|
||||
#
|
||||
# To release: git tag 2026.3.11 && git push --tags
|
||||
|
||||
name: Build Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- next-release
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- '20*' # matches CalVer tags like 2026.3.11
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@ -26,7 +22,16 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: https://github.com/actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Full history for tag counting
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Extract version from tag
|
||||
id: version
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/}
|
||||
SHA_SHORT=$(git rev-parse --short HEAD)
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "sha_short=$SHA_SHORT" >> $GITHUB_OUTPUT
|
||||
echo "timestamp=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: https://github.com/docker/setup-buildx-action@v3
|
||||
@ -37,67 +42,47 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Generate CalVer version
|
||||
id: calver
|
||||
uses: cal/gitea-actions/calver@main
|
||||
|
||||
- name: Resolve Docker tags
|
||||
id: tags
|
||||
uses: cal/gitea-actions/docker-tags@main
|
||||
with:
|
||||
image: manticorum67/major-domo-discordapp
|
||||
version: ${{ steps.calver.outputs.version }}
|
||||
sha_short: ${{ steps.calver.outputs.sha_short }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: https://github.com/docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.tags.outputs.tags }}
|
||||
tags: |
|
||||
manticorum67/major-domo-discordapp:${{ steps.version.outputs.version }}
|
||||
manticorum67/major-domo-discordapp:production
|
||||
cache-from: type=registry,ref=manticorum67/major-domo-discordapp:buildcache
|
||||
cache-to: type=registry,ref=manticorum67/major-domo-discordapp:buildcache,mode=max
|
||||
|
||||
- name: Tag release
|
||||
if: success() && github.ref == 'refs/heads/main'
|
||||
uses: cal/gitea-actions/gitea-tag@main
|
||||
with:
|
||||
version: ${{ steps.calver.outputs.version }}
|
||||
token: ${{ github.token }}
|
||||
|
||||
- name: Build Summary
|
||||
run: |
|
||||
echo "## Docker Build Successful" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Channel:** \`${{ steps.tags.outputs.channel }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Version:** \`${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Image Tags:**" >> $GITHUB_STEP_SUMMARY
|
||||
IFS=',' read -ra TAG_ARRAY <<< "${{ steps.tags.outputs.tags }}"
|
||||
for tag in "${TAG_ARRAY[@]}"; do
|
||||
echo "- \`${tag}\`" >> $GITHUB_STEP_SUMMARY
|
||||
done
|
||||
echo "- \`manticorum67/major-domo-discordapp:${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- \`manticorum67/major-domo-discordapp:production\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Build Details:**" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Branch: \`${{ steps.calver.outputs.branch }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Commit: \`${{ github.sha }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Timestamp: \`${{ steps.calver.outputs.timestamp }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Commit: \`${{ steps.version.outputs.sha_short }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Timestamp: \`${{ steps.version.outputs.timestamp }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Pull with: \`docker pull manticorum67/major-domo-discordapp:${{ steps.tags.outputs.primary_tag }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Pull with: \`docker pull manticorum67/major-domo-discordapp:${{ steps.version.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Discord Notification - Success
|
||||
if: success() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next-release')
|
||||
if: success()
|
||||
uses: cal/gitea-actions/discord-notify@main
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK }}
|
||||
title: "Major Domo Bot"
|
||||
status: success
|
||||
version: ${{ steps.calver.outputs.version }}
|
||||
image_tag: ${{ steps.tags.outputs.primary_tag }}
|
||||
commit_sha: ${{ steps.calver.outputs.sha_short }}
|
||||
timestamp: ${{ steps.calver.outputs.timestamp }}
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
image_tag: ${{ steps.version.outputs.version }}
|
||||
commit_sha: ${{ steps.version.outputs.sha_short }}
|
||||
timestamp: ${{ steps.version.outputs.timestamp }}
|
||||
|
||||
- name: Discord Notification - Failure
|
||||
if: failure() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next-release')
|
||||
if: failure()
|
||||
uses: cal/gitea-actions/discord-notify@main
|
||||
with:
|
||||
webhook_url: ${{ secrets.DISCORD_WEBHOOK }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -218,5 +218,6 @@ __marimo__/
|
||||
|
||||
# Project-specific
|
||||
data/
|
||||
storage/
|
||||
production_logs/
|
||||
*.json
|
||||
|
||||
@ -7,11 +7,11 @@
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SSH_CMD="ssh -i ~/.ssh/cloud_servers_rsa root@akamai"
|
||||
SSH_CMD="ssh akamai"
|
||||
REMOTE_DIR="/root/container-data/major-domo"
|
||||
SERVICE="discord-app"
|
||||
CONTAINER="major-domo-discord-app-1"
|
||||
IMAGE="manticorum67/major-domo-discordapp:latest"
|
||||
IMAGE="manticorum67/major-domo-discordapp:production"
|
||||
|
||||
SKIP_CONFIRM=false
|
||||
[[ "${1:-}" == "-y" ]] && SKIP_CONFIRM=true
|
||||
@ -19,9 +19,9 @@ SKIP_CONFIRM=false
|
||||
# --- Pre-deploy checks ---
|
||||
|
||||
if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then
|
||||
echo "WARNING: You have uncommitted changes."
|
||||
git status --short
|
||||
echo ""
|
||||
echo "WARNING: You have uncommitted changes."
|
||||
git status --short
|
||||
echo ""
|
||||
fi
|
||||
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||
@ -32,9 +32,12 @@ echo "Target: akamai (${IMAGE})"
|
||||
echo ""
|
||||
|
||||
if [[ "$SKIP_CONFIRM" != true ]]; then
|
||||
read -rp "Deploy to production? [y/N] " answer
|
||||
[[ "$answer" =~ ^[Yy]$ ]] || { echo "Aborted."; exit 0; }
|
||||
echo ""
|
||||
read -rp "Deploy to production? [y/N] " answer
|
||||
[[ "$answer" =~ ^[Yy]$ ]] || {
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
}
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# --- Save previous image for rollback ---
|
||||
@ -64,16 +67,16 @@ echo ""
|
||||
echo "==> Image digest: ${NEW_DIGEST}"
|
||||
|
||||
if [[ "$PREV_DIGEST" == "$NEW_DIGEST" ]]; then
|
||||
echo " (unchanged from previous deploy)"
|
||||
echo " (unchanged from previous deploy)"
|
||||
fi
|
||||
|
||||
# --- Rollback command ---
|
||||
|
||||
if [[ "$PREV_DIGEST" != "unknown" && "$PREV_DIGEST" != "$NEW_DIGEST" ]]; then
|
||||
echo ""
|
||||
echo "==> To rollback:"
|
||||
echo " ssh -i ~/.ssh/cloud_servers_rsa root@akamai \\"
|
||||
echo " \"cd ${REMOTE_DIR} && docker pull ${PREV_DIGEST} && docker tag ${PREV_DIGEST} ${IMAGE} && docker compose up -d ${SERVICE}\""
|
||||
echo ""
|
||||
echo "==> To rollback:"
|
||||
echo " ssh akamai \\"
|
||||
echo " \"cd ${REMOTE_DIR} && docker pull ${PREV_DIGEST} && docker tag ${PREV_DIGEST} ${IMAGE} && docker compose up -d ${SERVICE}\""
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
99
.scripts/release.sh
Executable file
99
.scripts/release.sh
Executable file
@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env bash
|
||||
# Create a CalVer release tag and push to trigger CI.
|
||||
#
|
||||
# Usage:
|
||||
# .scripts/release.sh # auto-generates next version (YYYY.M.BUILD)
|
||||
# .scripts/release.sh 2026.3.11 # explicit version
|
||||
# .scripts/release.sh -y # auto-generate + skip confirmation
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SKIP_CONFIRM=false
|
||||
VERSION=""
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
-y) SKIP_CONFIRM=true ;;
|
||||
*) VERSION="$arg" ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# --- Ensure we're on main and up to date ---
|
||||
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [[ "$BRANCH" != "main" ]]; then
|
||||
echo "ERROR: Must be on main branch (currently on ${BRANCH})"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "==> Fetching latest..."
|
||||
git fetch origin main --tags --quiet
|
||||
LOCAL=$(git rev-parse HEAD)
|
||||
REMOTE=$(git rev-parse origin/main)
|
||||
if [[ "$LOCAL" != "$REMOTE" ]]; then
|
||||
echo "ERROR: Local main is not up to date with origin. Run: git pull"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Determine version ---
|
||||
|
||||
YEAR=$(date +%Y)
|
||||
MONTH=$(date +%-m) # no leading zero
|
||||
|
||||
if [[ -z "$VERSION" ]]; then
|
||||
# Find the highest build number for this year.month
|
||||
LAST_BUILD=$(git tag --list "${YEAR}.${MONTH}.*" --sort=-v:refname | head -1 | awk -F. '{print $3}')
|
||||
NEXT_BUILD=$((${LAST_BUILD:-0} + 1))
|
||||
VERSION="${YEAR}.${MONTH}.${NEXT_BUILD}"
|
||||
fi
|
||||
|
||||
# Validate format
|
||||
if [[ ! "$VERSION" =~ ^20[0-9]{2}\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "ERROR: Invalid version format '${VERSION}'. Expected YYYY.M.BUILD (e.g., 2026.3.11)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check tag doesn't already exist
|
||||
if git rev-parse "refs/tags/${VERSION}" &>/dev/null; then
|
||||
echo "ERROR: Tag ${VERSION} already exists"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Show what's being released ---
|
||||
|
||||
LAST_TAG=$(git tag --sort=-v:refname | head -1)
|
||||
echo ""
|
||||
echo "Version: ${VERSION}"
|
||||
echo "Previous: ${LAST_TAG:-none}"
|
||||
echo "Commit: $(git log -1 --format='%h %s')"
|
||||
echo ""
|
||||
|
||||
if [[ -n "$LAST_TAG" ]]; then
|
||||
COMMIT_COUNT=$(git rev-list "${LAST_TAG}..HEAD" --count)
|
||||
echo "Changes since ${LAST_TAG} (${COMMIT_COUNT} commits):"
|
||||
git log "${LAST_TAG}..HEAD" --oneline --no-merges
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# --- Confirm ---
|
||||
|
||||
if [[ "$SKIP_CONFIRM" != true ]]; then
|
||||
read -rp "Create tag ${VERSION} and trigger release? [y/N] " answer
|
||||
[[ "$answer" =~ ^[Yy]$ ]] || {
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
}
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# --- Tag and push ---
|
||||
|
||||
git tag "$VERSION"
|
||||
git push origin tag "$VERSION"
|
||||
|
||||
echo ""
|
||||
echo "==> Tag ${VERSION} pushed. CI will build:"
|
||||
echo " - manticorum67/major-domo-discordapp:${VERSION}"
|
||||
echo " - manticorum67/major-domo-discordapp:production"
|
||||
echo ""
|
||||
echo "Deploy with: .scripts/deploy.sh"
|
||||
18
CLAUDE.md
18
CLAUDE.md
@ -16,15 +16,13 @@ manticorum67/major-domo-discordapp
|
||||
There is NO DASH between "discord" and "app". Not `discord-app`, not `discordapp-v2`.
|
||||
|
||||
### Git Workflow
|
||||
NEVER commit directly to `main` or `next-release`. Always use feature branches.
|
||||
NEVER commit directly to `main`. Always use feature branches.
|
||||
|
||||
**Branch from `next-release`** for normal work targeting the next release:
|
||||
```bash
|
||||
git checkout -b feature/name origin/next-release # or fix/name, refactor/name
|
||||
git checkout -b feature/name origin/main # or fix/name, refactor/name
|
||||
```
|
||||
**Branch from `main`** only for urgent hotfixes that bypass the release cycle.
|
||||
|
||||
PRs go to `next-release` (staging), then `next-release → main` when releasing.
|
||||
PRs go to `main`. CI builds the Docker image and creates a CalVer tag on merge.
|
||||
|
||||
### Double Emoji in Embeds
|
||||
`EmbedTemplate.success/error/warning/info/loading()` auto-add emoji prefixes.
|
||||
@ -63,13 +61,13 @@ class MyCog(commands.Cog):
|
||||
- **Container**: `major-domo-discord-app-1`
|
||||
- **Image**: `manticorum67/major-domo-discordapp` (no dash between discord and app)
|
||||
- **Health**: Process liveness only (no HTTP endpoint)
|
||||
- **CI/CD**: Gitea Actions on PR to `main` — builds Docker image, auto-generates CalVer version (`YYYY.MM.BUILD`) on merge
|
||||
- **CI/CD**: Gitea Actions — tag-triggered Docker builds (push a CalVer tag to release)
|
||||
|
||||
### Release Workflow
|
||||
1. Create feature/fix branches off `next-release` (e.g., `fix/scorebug-bugs`)
|
||||
2. When done, merge the branch into `next-release` — this is the staging branch where changes accumulate
|
||||
3. When ready to release, open a PR from `next-release` → `main`
|
||||
4. CI builds Docker image on PR; CalVer tag is created on merge
|
||||
1. Create feature/fix branches off `main` (e.g., `fix/scorebug-bugs`)
|
||||
2. Open a PR to `main` when ready — merging does NOT trigger a build
|
||||
3. When ready to release: `git tag YYYY.M.BUILD && git push --tags`
|
||||
4. CI builds Docker image, tags it with the version + `production`, notifies Discord
|
||||
5. Deploy the new image to production (see `/deploy` skill)
|
||||
- **Other services on same host**: `sba_db_api`, `sba_postgres`, `sba_redis`, `sba-website-sba-web-1`, `pd_api`
|
||||
|
||||
|
||||
65
bot.py
65
bot.py
@ -42,7 +42,9 @@ def setup_logging():
|
||||
|
||||
# JSON file handler - structured logging for monitoring and analysis
|
||||
json_handler = RotatingFileHandler(
|
||||
"logs/discord_bot_v2.json", maxBytes=5 * 1024 * 1024, backupCount=5 # 5MB
|
||||
"logs/discord_bot_v2.json",
|
||||
maxBytes=5 * 1024 * 1024,
|
||||
backupCount=5, # 5MB
|
||||
)
|
||||
json_handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(json_handler)
|
||||
@ -64,6 +66,44 @@ def setup_logging():
|
||||
return logger
|
||||
|
||||
|
||||
class MaintenanceAwareTree(discord.app_commands.CommandTree):
|
||||
"""
|
||||
CommandTree subclass that gates all interactions behind a maintenance mode check.
|
||||
|
||||
When bot.maintenance_mode is True, non-administrator users receive an ephemeral
|
||||
error message and the interaction is blocked. Administrators are always allowed
|
||||
through. When maintenance_mode is False the check is a no-op and every
|
||||
interaction proceeds normally.
|
||||
|
||||
This is the correct way to register a global interaction_check for app commands
|
||||
in discord.py — overriding the method on a CommandTree subclass passed via
|
||||
tree_cls rather than attempting to assign a decorator to self.tree inside
|
||||
setup_hook.
|
||||
"""
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||
"""Allow admins through; block everyone else when maintenance mode is active."""
|
||||
bot = interaction.client # type: ignore[assignment]
|
||||
|
||||
# If maintenance mode is off, always allow.
|
||||
if not getattr(bot, "maintenance_mode", False):
|
||||
return True
|
||||
|
||||
# Maintenance mode is on — let administrators through unconditionally.
|
||||
if (
|
||||
isinstance(interaction.user, discord.Member)
|
||||
and interaction.user.guild_permissions.administrator
|
||||
):
|
||||
return True
|
||||
|
||||
# Block non-admin users with an ephemeral notice.
|
||||
await interaction.response.send_message(
|
||||
"The bot is currently in maintenance mode. Please try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class SBABot(commands.Bot):
|
||||
"""Custom bot class for SBA league management."""
|
||||
|
||||
@ -77,31 +117,16 @@ class SBABot(commands.Bot):
|
||||
command_prefix="!", # Legacy prefix, primarily using slash commands
|
||||
intents=intents,
|
||||
description="Major Domo v2.0",
|
||||
tree_cls=MaintenanceAwareTree,
|
||||
)
|
||||
|
||||
self.logger = logging.getLogger("discord_bot_v2")
|
||||
self.maintenance_mode: bool = False
|
||||
self.logger = logging.getLogger("discord_bot_v2")
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Called when the bot is starting up."""
|
||||
self.logger.info("Setting up bot...")
|
||||
|
||||
@self.tree.interaction_check
|
||||
async def maintenance_check(interaction: discord.Interaction) -> bool:
|
||||
"""Block non-admin users when maintenance mode is enabled."""
|
||||
if not self.maintenance_mode:
|
||||
return True
|
||||
if (
|
||||
isinstance(interaction.user, discord.Member)
|
||||
and interaction.user.guild_permissions.administrator
|
||||
):
|
||||
return True
|
||||
await interaction.response.send_message(
|
||||
"🔧 The bot is currently in maintenance mode. Please try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return False
|
||||
|
||||
# Load command packages
|
||||
await self._load_command_packages()
|
||||
|
||||
@ -403,7 +428,9 @@ async def health_command(interaction: discord.Interaction):
|
||||
embed.add_field(name="Bot Status", value="✅ Online", inline=True)
|
||||
embed.add_field(name="API Status", value=api_status, inline=True)
|
||||
embed.add_field(name="Guilds", value=str(guild_count), inline=True)
|
||||
embed.add_field(name="Latency", value=f"{bot.latency*1000:.1f}ms", inline=True)
|
||||
embed.add_field(
|
||||
name="Latency", value=f"{bot.latency * 1000:.1f}ms", inline=True
|
||||
)
|
||||
|
||||
if bot.user:
|
||||
embed.set_footer(
|
||||
|
||||
@ -490,7 +490,7 @@ class AdminCommands(commands.Cog):
|
||||
await interaction.response.defer()
|
||||
|
||||
is_enabling = mode.lower() == "on"
|
||||
self.bot.maintenance_mode = is_enabling
|
||||
self.bot.maintenance_mode = is_enabling # type: ignore[attr-defined]
|
||||
self.logger.info(
|
||||
f"Maintenance mode {'enabled' if is_enabling else 'disabled'} by {interaction.user} (id={interaction.user.id})"
|
||||
)
|
||||
@ -568,14 +568,9 @@ class AdminCommands(commands.Cog):
|
||||
return
|
||||
|
||||
try:
|
||||
# Clear all messages from the channel
|
||||
deleted_count = 0
|
||||
async for message in live_scores_channel.history(limit=100):
|
||||
try:
|
||||
await message.delete()
|
||||
deleted_count += 1
|
||||
except discord.NotFound:
|
||||
pass # Message already deleted
|
||||
# Clear all messages from the channel using bulk delete
|
||||
deleted_messages = await live_scores_channel.purge(limit=100)
|
||||
deleted_count = len(deleted_messages)
|
||||
|
||||
self.logger.info(f"Cleared {deleted_count} messages from #live-sba-scores")
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Scorebug Commands
|
||||
Implements commands for publishing and displaying live game scorebugs from Google Sheets scorecards.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord import app_commands
|
||||
@ -73,12 +74,18 @@ class ScorebugCommands(commands.Cog):
|
||||
return
|
||||
|
||||
# Get team data for display
|
||||
away_team = None
|
||||
home_team = None
|
||||
if scorebug_data.away_team_id:
|
||||
away_team = await team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.home_team_id:
|
||||
home_team = await team_service.get_team(scorebug_data.home_team_id)
|
||||
away_team, home_team = await asyncio.gather(
|
||||
(
|
||||
team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.away_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
(
|
||||
team_service.get_team(scorebug_data.home_team_id)
|
||||
if scorebug_data.home_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
)
|
||||
|
||||
# Format scorecard link
|
||||
away_abbrev = away_team.abbrev if away_team else "AWAY"
|
||||
@ -86,7 +93,7 @@ class ScorebugCommands(commands.Cog):
|
||||
scorecard_link = f"[{away_abbrev} @ {home_abbrev}]({url})"
|
||||
|
||||
# Store the scorecard in the tracker
|
||||
self.scorecard_tracker.publish_scorecard(
|
||||
await self.scorecard_tracker.publish_scorecard(
|
||||
text_channel_id=interaction.channel_id, # type: ignore
|
||||
sheet_url=url,
|
||||
publisher_id=interaction.user.id,
|
||||
@ -157,7 +164,7 @@ class ScorebugCommands(commands.Cog):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Check if a scorecard is published in this channel
|
||||
sheet_url = self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
|
||||
sheet_url = await self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
|
||||
|
||||
if not sheet_url:
|
||||
embed = EmbedTemplate.error(
|
||||
@ -179,12 +186,18 @@ class ScorebugCommands(commands.Cog):
|
||||
)
|
||||
|
||||
# Get team data
|
||||
away_team = None
|
||||
home_team = None
|
||||
if scorebug_data.away_team_id:
|
||||
away_team = await team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.home_team_id:
|
||||
home_team = await team_service.get_team(scorebug_data.home_team_id)
|
||||
away_team, home_team = await asyncio.gather(
|
||||
(
|
||||
team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.away_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
(
|
||||
team_service.get_team(scorebug_data.home_team_id)
|
||||
if scorebug_data.home_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
)
|
||||
|
||||
# Create scorebug embed using shared utility
|
||||
embed = create_scorebug_embed(
|
||||
@ -194,7 +207,7 @@ class ScorebugCommands(commands.Cog):
|
||||
await interaction.edit_original_response(content=None, embed=embed)
|
||||
|
||||
# Update timestamp in tracker
|
||||
self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
|
||||
await self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
|
||||
|
||||
except SheetsException as e:
|
||||
embed = EmbedTemplate.error(
|
||||
|
||||
@ -24,7 +24,7 @@ class ScorecardTracker:
|
||||
- Timestamp tracking for monitoring
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/scorecards.json"):
|
||||
def __init__(self, data_file: str = "storage/scorecards.json"):
|
||||
"""
|
||||
Initialize the scorecard tracker.
|
||||
|
||||
@ -61,7 +61,7 @@ class ScorecardTracker:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save scorecard data: {e}")
|
||||
|
||||
def publish_scorecard(
|
||||
async def publish_scorecard(
|
||||
self, text_channel_id: int, sheet_url: str, publisher_id: int
|
||||
) -> None:
|
||||
"""
|
||||
@ -82,7 +82,7 @@ class ScorecardTracker:
|
||||
self.save_data()
|
||||
logger.info(f"Published scorecard to channel {text_channel_id}: {sheet_url}")
|
||||
|
||||
def unpublish_scorecard(self, text_channel_id: int) -> bool:
|
||||
async def unpublish_scorecard(self, text_channel_id: int) -> bool:
|
||||
"""
|
||||
Remove scorecard from a text channel.
|
||||
|
||||
@ -103,7 +103,7 @@ class ScorecardTracker:
|
||||
|
||||
return False
|
||||
|
||||
def get_scorecard(self, text_channel_id: int) -> Optional[str]:
|
||||
async def get_scorecard(self, text_channel_id: int) -> Optional[str]:
|
||||
"""
|
||||
Get scorecard URL for a text channel.
|
||||
|
||||
@ -118,7 +118,7 @@ class ScorecardTracker:
|
||||
scorecard_data = scorecards.get(str(text_channel_id))
|
||||
return scorecard_data["sheet_url"] if scorecard_data else None
|
||||
|
||||
def get_all_scorecards(self) -> List[Tuple[int, str]]:
|
||||
async def get_all_scorecards(self) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
Get all published scorecards.
|
||||
|
||||
@ -132,7 +132,7 @@ class ScorecardTracker:
|
||||
for channel_id, data in scorecards.items()
|
||||
]
|
||||
|
||||
def update_timestamp(self, text_channel_id: int) -> None:
|
||||
async def update_timestamp(self, text_channel_id: int) -> None:
|
||||
"""
|
||||
Update the last_updated timestamp for a scorecard.
|
||||
|
||||
@ -146,7 +146,7 @@ class ScorecardTracker:
|
||||
scorecards[channel_key]["last_updated"] = datetime.now(UTC).isoformat()
|
||||
self.save_data()
|
||||
|
||||
def cleanup_stale_entries(self, valid_channel_ids: List[int]) -> int:
|
||||
async def cleanup_stale_entries(self, valid_channel_ids: List[int]) -> int:
|
||||
"""
|
||||
Remove tracking entries for text channels that no longer exist.
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ The injury rating format (#p##) encodes both games played and rating:
|
||||
- Remaining: Injury rating (p70, p65, p60, p50, p40, p30, p20)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
import discord
|
||||
@ -114,16 +115,14 @@ class InjuryGroup(app_commands.Group):
|
||||
"""Roll for injury using 3d6 dice and injury tables."""
|
||||
await interaction.response.defer()
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
@ -530,16 +529,14 @@ class InjuryGroup(app_commands.Group):
|
||||
await interaction.followup.send(embed=embed, ephemeral=True)
|
||||
return
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
@ -717,16 +714,14 @@ class InjuryGroup(app_commands.Group):
|
||||
|
||||
await interaction.response.defer()
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
|
||||
@ -3,6 +3,7 @@ League Schedule Commands
|
||||
|
||||
Implements slash commands for displaying game schedules and results.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate
|
||||
|
||||
class ScheduleCommands(commands.Cog):
|
||||
"""League schedule command handlers."""
|
||||
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands')
|
||||
|
||||
@discord.app_commands.command(
|
||||
name="schedule",
|
||||
description="Display game schedule"
|
||||
)
|
||||
self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands")
|
||||
|
||||
@discord.app_commands.command(name="schedule", description="Display game schedule")
|
||||
@discord.app_commands.describe(
|
||||
season="Season to show schedule for (defaults to current season)",
|
||||
week="Week number to show (optional)",
|
||||
team="Team abbreviation to filter by (optional)"
|
||||
team="Team abbreviation to filter by (optional)",
|
||||
)
|
||||
@requires_team()
|
||||
@logged_command("/schedule")
|
||||
@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog):
|
||||
interaction: discord.Interaction,
|
||||
season: Optional[int] = None,
|
||||
week: Optional[int] = None,
|
||||
team: Optional[str] = None
|
||||
team: Optional[str] = None,
|
||||
):
|
||||
"""Display game schedule for a week or team."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
search_season = season or get_config().sba_season
|
||||
|
||||
|
||||
if team:
|
||||
# Show team schedule
|
||||
await self._show_team_schedule(interaction, search_season, team, week)
|
||||
@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog):
|
||||
else:
|
||||
# Show recent/upcoming games
|
||||
await self._show_current_schedule(interaction, search_season)
|
||||
|
||||
|
||||
# @discord.app_commands.command(
|
||||
# name="results",
|
||||
# description="Display recent game results"
|
||||
@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog):
|
||||
# ):
|
||||
# """Display recent game results."""
|
||||
# await interaction.response.defer()
|
||||
|
||||
|
||||
# search_season = season or get_config().sba_season
|
||||
|
||||
|
||||
# if week:
|
||||
# # Show specific week results
|
||||
# games = await schedule_service.get_week_schedule(search_season, week)
|
||||
# completed_games = [game for game in games if game.is_completed]
|
||||
|
||||
|
||||
# if not completed_games:
|
||||
# await interaction.followup.send(
|
||||
# f"❌ No completed games found for season {search_season}, week {week}.",
|
||||
# ephemeral=True
|
||||
# )
|
||||
# return
|
||||
|
||||
|
||||
# embed = await self._create_week_results_embed(completed_games, search_season, week)
|
||||
# await interaction.followup.send(embed=embed)
|
||||
# else:
|
||||
# # Show recent results
|
||||
# recent_games = await schedule_service.get_recent_games(search_season)
|
||||
|
||||
|
||||
# if not recent_games:
|
||||
# await interaction.followup.send(
|
||||
# f"❌ No recent games found for season {search_season}.",
|
||||
# ephemeral=True
|
||||
# )
|
||||
# return
|
||||
|
||||
|
||||
# embed = await self._create_recent_results_embed(recent_games, search_season)
|
||||
# await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int):
|
||||
|
||||
async def _show_week_schedule(
|
||||
self, interaction: discord.Interaction, season: int, week: int
|
||||
):
|
||||
"""Show schedule for a specific week."""
|
||||
self.logger.debug("Fetching week schedule", season=season, week=week)
|
||||
|
||||
|
||||
games = await schedule_service.get_week_schedule(season, week)
|
||||
|
||||
|
||||
if not games:
|
||||
await interaction.followup.send(
|
||||
f"❌ No games found for season {season}, week {week}.",
|
||||
ephemeral=True
|
||||
f"❌ No games found for season {season}, week {week}.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
embed = await self._create_week_schedule_embed(games, season, week)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]):
|
||||
|
||||
async def _show_team_schedule(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
season: int,
|
||||
team: str,
|
||||
week: Optional[int],
|
||||
):
|
||||
"""Show schedule for a specific team."""
|
||||
self.logger.debug("Fetching team schedule", season=season, team=team, week=week)
|
||||
|
||||
|
||||
if week:
|
||||
# Show team games for specific week
|
||||
week_games = await schedule_service.get_week_schedule(season, week)
|
||||
team_games = [
|
||||
game for game in week_games
|
||||
if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper()
|
||||
game
|
||||
for game in week_games
|
||||
if game.away_team.abbrev.upper() == team.upper()
|
||||
or game.home_team.abbrev.upper() == team.upper()
|
||||
]
|
||||
else:
|
||||
# Show team's recent/upcoming games (limited weeks)
|
||||
team_games = await schedule_service.get_team_schedule(season, team, weeks=4)
|
||||
|
||||
|
||||
if not team_games:
|
||||
week_text = f" for week {week}" if week else ""
|
||||
await interaction.followup.send(
|
||||
f"❌ No games found for team '{team}'{week_text} in season {season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
embed = await self._create_team_schedule_embed(team_games, season, team, week)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_current_schedule(self, interaction: discord.Interaction, season: int):
|
||||
|
||||
async def _show_current_schedule(
|
||||
self, interaction: discord.Interaction, season: int
|
||||
):
|
||||
"""Show current schedule overview with recent and upcoming games."""
|
||||
self.logger.debug("Fetching current schedule overview", season=season)
|
||||
|
||||
|
||||
# Get both recent and upcoming games
|
||||
recent_games, upcoming_games = await asyncio.gather(
|
||||
schedule_service.get_recent_games(season, weeks_back=1),
|
||||
schedule_service.get_upcoming_games(season, weeks_ahead=1)
|
||||
schedule_service.get_upcoming_games(season),
|
||||
)
|
||||
|
||||
|
||||
if not recent_games and not upcoming_games:
|
||||
await interaction.followup.send(
|
||||
f"❌ No recent or upcoming games found for season {season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season)
|
||||
|
||||
embed = await self._create_current_schedule_embed(
|
||||
recent_games, upcoming_games, season
|
||||
)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed:
|
||||
|
||||
async def _create_week_schedule_embed(
|
||||
self, games, season: int, week: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for a week's schedule."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 Week {week} Schedule - Season {season}",
|
||||
color=EmbedColors.PRIMARY
|
||||
color=EmbedColors.PRIMARY,
|
||||
)
|
||||
|
||||
|
||||
# Group games by series
|
||||
series_games = schedule_service.group_games_by_series(games)
|
||||
|
||||
|
||||
schedule_lines = []
|
||||
for (team1, team2), series in series_games.items():
|
||||
series_summary = await self._format_series_summary(series)
|
||||
schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}")
|
||||
|
||||
|
||||
if schedule_lines:
|
||||
embed.add_field(
|
||||
name="Games",
|
||||
value="\n\n".join(schedule_lines),
|
||||
inline=False
|
||||
name="Games", value="\n\n".join(schedule_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
# Add week summary
|
||||
completed = len([g for g in games if g.is_completed])
|
||||
total = len(games)
|
||||
embed.add_field(
|
||||
name="Week Progress",
|
||||
value=f"{completed}/{total} games completed",
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Week {week}")
|
||||
return embed
|
||||
|
||||
async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed:
|
||||
|
||||
async def _create_team_schedule_embed(
|
||||
self, games, season: int, team: str, week: Optional[int]
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for a team's schedule."""
|
||||
week_text = f" - Week {week}" if week else ""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 {team.upper()} Schedule{week_text} - Season {season}",
|
||||
color=EmbedColors.PRIMARY
|
||||
color=EmbedColors.PRIMARY,
|
||||
)
|
||||
|
||||
|
||||
# Separate completed and upcoming games
|
||||
completed_games = [g for g in games if g.is_completed]
|
||||
upcoming_games = [g for g in games if not g.is_completed]
|
||||
|
||||
|
||||
if completed_games:
|
||||
recent_lines = []
|
||||
for game in completed_games[-5:]: # Last 5 games
|
||||
result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L"
|
||||
result = (
|
||||
"W"
|
||||
if game.winner and game.winner.abbrev.upper() == team.upper()
|
||||
else "L"
|
||||
)
|
||||
if game.home_team.abbrev.upper() == team.upper():
|
||||
# Team was home
|
||||
recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})")
|
||||
recent_lines.append(
|
||||
f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})"
|
||||
)
|
||||
else:
|
||||
# Team was away
|
||||
recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})")
|
||||
|
||||
# Team was away
|
||||
recent_lines.append(
|
||||
f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})"
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Recent Results",
|
||||
value="\n".join(recent_lines) if recent_lines else "No recent games",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
if upcoming_games:
|
||||
upcoming_lines = []
|
||||
for game in upcoming_games[:5]: # Next 5 games
|
||||
if game.home_team.abbrev.upper() == team.upper():
|
||||
# Team is home
|
||||
upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}")
|
||||
upcoming_lines.append(
|
||||
f"Week {game.week}: vs {game.away_team.abbrev}"
|
||||
)
|
||||
else:
|
||||
# Team is away
|
||||
upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}")
|
||||
|
||||
upcoming_lines.append(
|
||||
f"Week {game.week}: @ {game.home_team.abbrev}"
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Upcoming Games",
|
||||
value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games",
|
||||
inline=False
|
||||
value=(
|
||||
"\n".join(upcoming_lines) if upcoming_lines else "No upcoming games"
|
||||
),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • {team.upper()}")
|
||||
return embed
|
||||
|
||||
async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed:
|
||||
|
||||
async def _create_week_results_embed(
|
||||
self, games, season: int, week: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for week results."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"🏆 Week {week} Results - Season {season}",
|
||||
color=EmbedColors.SUCCESS
|
||||
title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
# Group by series and show results
|
||||
series_games = schedule_service.group_games_by_series(games)
|
||||
|
||||
|
||||
results_lines = []
|
||||
for (team1, team2), series in series_games.items():
|
||||
# Count wins for each team
|
||||
team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1])
|
||||
team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2])
|
||||
|
||||
team1_wins = len(
|
||||
[g for g in series if g.winner and g.winner.abbrev == team1]
|
||||
)
|
||||
team2_wins = len(
|
||||
[g for g in series if g.winner and g.winner.abbrev == team2]
|
||||
)
|
||||
|
||||
# Series result
|
||||
series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**"
|
||||
|
||||
|
||||
# Individual games
|
||||
game_details = []
|
||||
for game in series:
|
||||
if game.series_game_display:
|
||||
game_details.append(f"{game.series_game_display}: {game.matchup_display}")
|
||||
|
||||
game_details.append(
|
||||
f"{game.series_game_display}: {game.matchup_display}"
|
||||
)
|
||||
|
||||
results_lines.append(f"{series_result}\n" + "\n".join(game_details))
|
||||
|
||||
|
||||
if results_lines:
|
||||
embed.add_field(
|
||||
name="Series Results",
|
||||
value="\n\n".join(results_lines),
|
||||
inline=False
|
||||
name="Series Results", value="\n\n".join(results_lines), inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Week {week} • {len(games)} games completed")
|
||||
|
||||
embed.set_footer(
|
||||
text=f"Season {season} • Week {week} • {len(games)} games completed"
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
async def _create_recent_results_embed(self, games, season: int) -> discord.Embed:
|
||||
"""Create an embed for recent results."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"🏆 Recent Results - Season {season}",
|
||||
color=EmbedColors.SUCCESS
|
||||
title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
# Show most recent games
|
||||
recent_lines = []
|
||||
for game in games[:10]: # Show last 10 games
|
||||
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
if recent_lines:
|
||||
embed.add_field(
|
||||
name="Latest Games",
|
||||
value="\n".join(recent_lines),
|
||||
inline=False
|
||||
name="Latest Games", value="\n".join(recent_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Last {len(games)} completed games")
|
||||
return embed
|
||||
|
||||
async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed:
|
||||
|
||||
async def _create_current_schedule_embed(
|
||||
self, recent_games, upcoming_games, season: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for current schedule overview."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 Current Schedule - Season {season}",
|
||||
color=EmbedColors.INFO
|
||||
title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO
|
||||
)
|
||||
|
||||
|
||||
if recent_games:
|
||||
recent_lines = []
|
||||
for game in recent_games[:5]:
|
||||
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name="Recent Results",
|
||||
value="\n".join(recent_lines),
|
||||
inline=False
|
||||
name="Recent Results", value="\n".join(recent_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
if upcoming_games:
|
||||
upcoming_lines = []
|
||||
for game in upcoming_games[:5]:
|
||||
upcoming_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name="Upcoming Games",
|
||||
value="\n".join(upcoming_lines),
|
||||
inline=False
|
||||
name="Upcoming Games", value="\n".join(upcoming_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season}")
|
||||
return embed
|
||||
|
||||
|
||||
async def _format_series_summary(self, series) -> str:
|
||||
"""Format a series summary."""
|
||||
lines = []
|
||||
for game in series:
|
||||
game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display
|
||||
game_display = (
|
||||
f"{game.series_game_display}: {game.matchup_display}"
|
||||
if game.series_game_display
|
||||
else game.matchup_display
|
||||
)
|
||||
lines.append(game_display)
|
||||
|
||||
|
||||
return "\n".join(lines) if lines else "No games"
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
"""Load the schedule commands cog."""
|
||||
await bot.add_cog(ScheduleCommands(bot))
|
||||
await bot.add_cog(ScheduleCommands(bot))
|
||||
|
||||
@ -5,6 +5,7 @@ Implements the /submit-scorecard command for submitting Google Sheets
|
||||
scorecards with play-by-play data, pitching decisions, and game results.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
import discord
|
||||
@ -107,11 +108,13 @@ class SubmitScorecardCommands(commands.Cog):
|
||||
content="🔍 Looking up teams and managers..."
|
||||
)
|
||||
|
||||
away_team = await team_service.get_team_by_abbrev(
|
||||
setup_data["away_team_abbrev"], current.season
|
||||
)
|
||||
home_team = await team_service.get_team_by_abbrev(
|
||||
setup_data["home_team_abbrev"], current.season
|
||||
away_team, home_team = await asyncio.gather(
|
||||
team_service.get_team_by_abbrev(
|
||||
setup_data["away_team_abbrev"], current.season
|
||||
),
|
||||
team_service.get_team_by_abbrev(
|
||||
setup_data["home_team_abbrev"], current.season
|
||||
),
|
||||
)
|
||||
|
||||
if not away_team or not home_team:
|
||||
@ -235,9 +238,13 @@ class SubmitScorecardCommands(commands.Cog):
|
||||
decision["game_num"] = setup_data["game_num"]
|
||||
|
||||
# Validate WP and LP exist and fetch Player objects
|
||||
wp, lp, sv, holders, _blown_saves = (
|
||||
await decision_service.find_winning_losing_pitchers(decisions_data)
|
||||
)
|
||||
(
|
||||
wp,
|
||||
lp,
|
||||
sv,
|
||||
holders,
|
||||
_blown_saves,
|
||||
) = await decision_service.find_winning_losing_pitchers(decisions_data)
|
||||
|
||||
if wp is None or lp is None:
|
||||
await interaction.edit_original_response(
|
||||
|
||||
@ -3,13 +3,14 @@ Soak Tracker
|
||||
|
||||
Provides persistent tracking of "soak" mentions using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.SoakTracker')
|
||||
logger = logging.getLogger(f"{__name__}.SoakTracker")
|
||||
|
||||
|
||||
class SoakTracker:
|
||||
@ -22,7 +23,7 @@ class SoakTracker:
|
||||
- Time-based calculations for disappointment tiers
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/soak_data.json"):
|
||||
def __init__(self, data_file: str = "storage/soak_data.json"):
|
||||
"""
|
||||
Initialize the soak tracker.
|
||||
|
||||
@ -38,28 +39,22 @@ class SoakTracker:
|
||||
"""Load soak data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded soak data: {self._data.get('total_count', 0)} total soaks")
|
||||
logger.debug(
|
||||
f"Loaded soak data: {self._data.get('total_count', 0)} total soaks"
|
||||
)
|
||||
else:
|
||||
self._data = {
|
||||
"last_soak": None,
|
||||
"total_count": 0,
|
||||
"history": []
|
||||
}
|
||||
self._data = {"last_soak": None, "total_count": 0, "history": []}
|
||||
logger.info("No existing soak data found, starting fresh")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load soak data: {e}")
|
||||
self._data = {
|
||||
"last_soak": None,
|
||||
"total_count": 0,
|
||||
"history": []
|
||||
}
|
||||
self._data = {"last_soak": None, "total_count": 0, "history": []}
|
||||
|
||||
def save_data(self) -> None:
|
||||
"""Save soak data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Soak data saved successfully")
|
||||
except Exception as e:
|
||||
@ -71,7 +66,7 @@ class SoakTracker:
|
||||
username: str,
|
||||
display_name: str,
|
||||
channel_id: int,
|
||||
message_id: int
|
||||
message_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Record a new soak mention.
|
||||
@ -89,7 +84,7 @@ class SoakTracker:
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"channel_id": str(channel_id),
|
||||
"message_id": str(message_id)
|
||||
"message_id": str(message_id),
|
||||
}
|
||||
|
||||
# Update last_soak
|
||||
@ -110,7 +105,9 @@ class SoakTracker:
|
||||
|
||||
self.save_data()
|
||||
|
||||
logger.info(f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}")
|
||||
logger.info(
|
||||
f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}"
|
||||
)
|
||||
|
||||
def get_last_soak(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
@ -135,10 +132,12 @@ class SoakTracker:
|
||||
try:
|
||||
# Parse ISO format timestamp
|
||||
last_timestamp_str = last_soak["timestamp"]
|
||||
if last_timestamp_str.endswith('Z'):
|
||||
last_timestamp_str = last_timestamp_str[:-1] + '+00:00'
|
||||
if last_timestamp_str.endswith("Z"):
|
||||
last_timestamp_str = last_timestamp_str[:-1] + "+00:00"
|
||||
|
||||
last_timestamp = datetime.fromisoformat(last_timestamp_str.replace('Z', '+00:00'))
|
||||
last_timestamp = datetime.fromisoformat(
|
||||
last_timestamp_str.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
# Ensure both times are timezone-aware
|
||||
if last_timestamp.tzinfo is None:
|
||||
|
||||
@ -3,6 +3,7 @@ Transaction Management Commands
|
||||
|
||||
Core transaction commands for roster management and transaction tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
@ -21,6 +22,7 @@ from views.base import PaginationView
|
||||
from services.transaction_service import transaction_service
|
||||
from services.roster_service import roster_service
|
||||
from services.team_service import team_service
|
||||
|
||||
# No longer need TransactionStatus enum
|
||||
|
||||
|
||||
@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView):
|
||||
all_transactions: list,
|
||||
user_id: int,
|
||||
timeout: float = 300.0,
|
||||
show_page_numbers: bool = True
|
||||
show_page_numbers: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
pages=pages,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
show_page_numbers=show_page_numbers
|
||||
show_page_numbers=show_page_numbers,
|
||||
)
|
||||
self.all_transactions = all_transactions
|
||||
|
||||
@discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1)
|
||||
async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
@discord.ui.button(
|
||||
label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1
|
||||
)
|
||||
async def show_move_ids(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Show all move IDs in an ephemeral message."""
|
||||
self.increment_interaction_count()
|
||||
|
||||
if not self.all_transactions:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to show.",
|
||||
ephemeral=True
|
||||
"No transactions to show.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView):
|
||||
# Send the messages
|
||||
if not messages:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to display.",
|
||||
ephemeral=True
|
||||
"No transactions to display.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView):
|
||||
|
||||
class TransactionCommands(commands.Cog):
|
||||
"""Transaction command handlers for roster management."""
|
||||
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.TransactionCommands')
|
||||
|
||||
self.logger = get_contextual_logger(f"{__name__}.TransactionCommands")
|
||||
|
||||
@app_commands.command(
|
||||
name="mymoves",
|
||||
description="View your pending and scheduled transactions"
|
||||
name="mymoves", description="View your pending and scheduled transactions"
|
||||
)
|
||||
@app_commands.describe(
|
||||
show_cancelled="Include cancelled transactions in the display (default: False)"
|
||||
@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog):
|
||||
@requires_team()
|
||||
@logged_command("/mymoves")
|
||||
async def my_moves(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
show_cancelled: bool = False
|
||||
self, interaction: discord.Interaction, show_cancelled: bool = False
|
||||
):
|
||||
"""Display user's transaction status and history."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get user's team
|
||||
team = await get_user_major_league_team(interaction.user.id, get_config().sba_season)
|
||||
|
||||
team = await get_user_major_league_team(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
|
||||
if not team:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team in the current season.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Get transactions in parallel
|
||||
pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season)
|
||||
frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season)
|
||||
processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season)
|
||||
|
||||
pending_transactions = await pending_task
|
||||
frozen_transactions = await frozen_task
|
||||
processed_transactions = await processed_task
|
||||
|
||||
(
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
) = await asyncio.gather(
|
||||
transaction_service.get_pending_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_frozen_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_processed_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
)
|
||||
|
||||
# Get cancelled if requested
|
||||
cancelled_transactions = []
|
||||
if show_cancelled:
|
||||
cancelled_transactions = await transaction_service.get_team_transactions(
|
||||
team.abbrev,
|
||||
get_config().sba_season,
|
||||
cancelled=True
|
||||
team.abbrev, get_config().sba_season, cancelled=True
|
||||
)
|
||||
|
||||
pages = self._create_my_moves_pages(
|
||||
@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog):
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
)
|
||||
|
||||
# Collect all transactions for the "Show Move IDs" button
|
||||
all_transactions = (
|
||||
pending_transactions +
|
||||
frozen_transactions +
|
||||
processed_transactions +
|
||||
cancelled_transactions
|
||||
pending_transactions
|
||||
+ frozen_transactions
|
||||
+ processed_transactions
|
||||
+ cancelled_transactions
|
||||
)
|
||||
|
||||
# If only one page and no transactions, send without any buttons
|
||||
@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog):
|
||||
all_transactions=all_transactions,
|
||||
user_id=interaction.user.id,
|
||||
timeout=300.0,
|
||||
show_page_numbers=True
|
||||
show_page_numbers=True,
|
||||
)
|
||||
await interaction.followup.send(embed=view.get_current_embed(), view=view)
|
||||
|
||||
|
||||
@app_commands.command(
|
||||
name="legal",
|
||||
description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(
|
||||
team="Team abbreviation to check (defaults to your team)"
|
||||
name="legal", description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(team="Team abbreviation to check (defaults to your team)")
|
||||
@requires_team()
|
||||
@logged_command("/legal")
|
||||
async def legal(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
team: Optional[str] = None
|
||||
):
|
||||
async def legal(self, interaction: discord.Interaction, team: Optional[str] = None):
|
||||
"""Check roster legality and display detailed validation results."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get target team
|
||||
if team:
|
||||
target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season)
|
||||
target_team = await team_service.get_team_by_abbrev(
|
||||
team.upper(), get_config().sba_season
|
||||
)
|
||||
if not target_team:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not find team '{team}' in season {get_config().sba_season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Get user's team
|
||||
user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season)
|
||||
user_teams = await team_service.get_teams_by_owner(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
if not user_teams:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team. Please specify a team abbreviation.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
target_team = user_teams[0]
|
||||
|
||||
|
||||
# Get rosters in parallel
|
||||
current_roster, next_roster = await asyncio.gather(
|
||||
roster_service.get_current_roster(target_team.id),
|
||||
roster_service.get_next_roster(target_team.id)
|
||||
roster_service.get_next_roster(target_team.id),
|
||||
)
|
||||
|
||||
|
||||
if not current_roster and not next_roster:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not retrieve roster data for {target_team.abbrev}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Validate rosters in parallel
|
||||
validation_tasks = []
|
||||
if current_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(current_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
if next_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(next_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
validation_results = await asyncio.gather(*validation_tasks)
|
||||
current_validation = validation_results[0] if current_roster else None
|
||||
next_validation = validation_results[1] if next_roster else None
|
||||
|
||||
|
||||
embed = await self._create_legal_embed(
|
||||
target_team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
next_validation,
|
||||
)
|
||||
|
||||
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
def _create_my_moves_pages(
|
||||
self,
|
||||
team,
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
) -> list[discord.Embed]:
|
||||
"""Create paginated embeds showing user's transaction status."""
|
||||
|
||||
@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog):
|
||||
# Page 1: Summary + Pending Transactions
|
||||
if pending_transactions:
|
||||
total_pending = len(pending_transactions)
|
||||
total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page
|
||||
total_pages = (
|
||||
total_pending + transactions_per_page - 1
|
||||
) // transactions_per_page
|
||||
|
||||
for page_num in range(total_pages):
|
||||
start_idx = page_num * transactions_per_page
|
||||
@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
# Pending transactions for this page
|
||||
@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"⏳ Pending Transactions ({total_pending} total)",
|
||||
value="\n".join(pending_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add summary only on first page
|
||||
@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text)
|
||||
if status_text
|
||||
else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
embed.add_field(
|
||||
name="⏳ Pending Transactions",
|
||||
value="No pending transactions",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
total_frozen = len(frozen_transactions)
|
||||
@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text) if status_text else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
frozen_lines = [format_transaction(tx) for tx in frozen_transactions]
|
||||
@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)",
|
||||
value="\n".join(frozen_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20
|
||||
processed_lines = [
|
||||
format_transaction(tx) for tx in processed_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)",
|
||||
value="\n".join(processed_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20
|
||||
cancelled_lines = [
|
||||
format_transaction(tx) for tx in cancelled_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)",
|
||||
value="\n".join(cancelled_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog):
|
||||
page.set_footer(text="Use /legal to check roster legality")
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
async def _create_legal_embed(
|
||||
self,
|
||||
team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
self, team, current_roster, next_roster, current_validation, next_validation
|
||||
) -> discord.Embed:
|
||||
"""Create embed showing roster legality check results."""
|
||||
|
||||
|
||||
# Determine overall status
|
||||
overall_legal = True
|
||||
if current_validation and not current_validation.is_legal:
|
||||
overall_legal = False
|
||||
if next_validation and not next_validation.is_legal:
|
||||
overall_legal = False
|
||||
|
||||
|
||||
status_emoji = "✅" if overall_legal else "❌"
|
||||
embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR
|
||||
|
||||
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"{status_emoji} Roster Check - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=embed_color
|
||||
color=embed_color,
|
||||
)
|
||||
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
|
||||
# Current week roster
|
||||
if current_roster and current_validation:
|
||||
current_lines = []
|
||||
current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL")
|
||||
current_lines.append(
|
||||
f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL"
|
||||
)
|
||||
current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if current_validation.errors:
|
||||
current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}")
|
||||
for error in current_validation.errors[:3]: # Show first 3 errors
|
||||
current_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if current_validation.warnings:
|
||||
current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}")
|
||||
current_lines.append(
|
||||
f"**⚠️ Warnings:** {len(current_validation.warnings)}"
|
||||
)
|
||||
for warning in current_validation.warnings[:2]: # Show first 2 warnings
|
||||
current_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{current_validation.status_emoji} Current Week",
|
||||
value="\n".join(current_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Current Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Current Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
# Next week roster
|
||||
|
||||
# Next week roster
|
||||
if next_roster and next_validation:
|
||||
next_lines = []
|
||||
next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL")
|
||||
next_lines.append(
|
||||
f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL"
|
||||
)
|
||||
next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if next_validation.errors:
|
||||
next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}")
|
||||
for error in next_validation.errors[:3]: # Show first 3 errors
|
||||
next_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if next_validation.warnings:
|
||||
next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}")
|
||||
for warning in next_validation.warnings[:2]: # Show first 2 warnings
|
||||
next_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{next_validation.status_emoji} Next Week",
|
||||
value="\n".join(next_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Next Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Next Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
|
||||
# Overall status
|
||||
if overall_legal:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
value="✅ All rosters are legal",
|
||||
inline=False
|
||||
name="Overall Status", value="✅ All rosters are legal", inline=False
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
name="Overall Status",
|
||||
value="❌ Roster violations found - please review and correct",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text="Roster validation based on current league rules")
|
||||
return embed
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
"""Load the transaction commands cog."""
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
|
||||
@ -26,6 +26,7 @@ from services.trade_builder import (
|
||||
clear_trade_builder,
|
||||
clear_trade_builder_by_team,
|
||||
)
|
||||
from services.league_service import league_service
|
||||
from services.player_service import player_service
|
||||
from services.team_service import team_service
|
||||
from models.team import RosterType
|
||||
@ -130,6 +131,22 @@ class TradeCommands(commands.Cog):
|
||||
)
|
||||
return
|
||||
|
||||
# Check trade deadline
|
||||
current = await league_service.get_current_state()
|
||||
if not current:
|
||||
await interaction.followup.send(
|
||||
"❌ Could not retrieve league state. Please try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
if current.is_past_trade_deadline:
|
||||
await interaction.followup.send(
|
||||
f"❌ **The trade deadline has passed.** The deadline was Week {current.trade_deadline} "
|
||||
f"and we are currently in Week {current.week}. No new trades can be initiated.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Clear any existing trade and create new one
|
||||
clear_trade_builder(interaction.user.id)
|
||||
trade_builder = get_trade_builder(interaction.user.id, user_team)
|
||||
|
||||
@ -3,6 +3,7 @@ Trade Channel Tracker
|
||||
|
||||
Provides persistent tracking of bot-created trade discussion channels using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
from pathlib import Path
|
||||
@ -12,7 +13,7 @@ import discord
|
||||
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
logger = get_contextual_logger(f'{__name__}.TradeChannelTracker')
|
||||
logger = get_contextual_logger(f"{__name__}.TradeChannelTracker")
|
||||
|
||||
|
||||
class TradeChannelTracker:
|
||||
@ -26,7 +27,7 @@ class TradeChannelTracker:
|
||||
- Automatic stale entry removal
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/trade_channels.json"):
|
||||
def __init__(self, data_file: str = "storage/trade_channels.json"):
|
||||
"""
|
||||
Initialize the trade channel tracker.
|
||||
|
||||
@ -42,9 +43,11 @@ class TradeChannelTracker:
|
||||
"""Load channel data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels")
|
||||
logger.debug(
|
||||
f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels"
|
||||
)
|
||||
else:
|
||||
self._data = {"trade_channels": {}}
|
||||
logger.info("No existing trade channel data found, starting fresh")
|
||||
@ -55,7 +58,7 @@ class TradeChannelTracker:
|
||||
def save_data(self) -> None:
|
||||
"""Save channel data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Trade channel data saved successfully")
|
||||
except Exception as e:
|
||||
@ -67,7 +70,7 @@ class TradeChannelTracker:
|
||||
trade_id: str,
|
||||
team1_abbrev: str,
|
||||
team2_abbrev: str,
|
||||
creator_id: int
|
||||
creator_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add a new trade channel to tracking.
|
||||
@ -87,10 +90,12 @@ class TradeChannelTracker:
|
||||
"team1_abbrev": team1_abbrev,
|
||||
"team2_abbrev": team2_abbrev,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"creator_id": str(creator_id)
|
||||
"creator_id": str(creator_id),
|
||||
}
|
||||
self.save_data()
|
||||
logger.info(f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
def remove_channel(self, channel_id: int) -> None:
|
||||
"""
|
||||
@ -108,7 +113,9 @@ class TradeChannelTracker:
|
||||
channel_name = channel_data["name"]
|
||||
del channels[channel_key]
|
||||
self.save_data()
|
||||
logger.info(f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
def get_channel_by_trade_id(self, trade_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
@ -175,7 +182,9 @@ class TradeChannelTracker:
|
||||
channel_name = channels[channel_id_str].get("name", "unknown")
|
||||
trade_id = channels[channel_id_str].get("trade_id", "unknown")
|
||||
del channels[channel_id_str]
|
||||
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
if stale_entries:
|
||||
self.save_data()
|
||||
|
||||
@ -3,6 +3,7 @@ Voice Channel Cleanup Service
|
||||
|
||||
Provides automatic cleanup of empty voice channels with restart resilience.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
@ -12,7 +13,7 @@ from .tracker import VoiceChannelTracker
|
||||
from commands.gameplay.scorecard_tracker import ScorecardTracker
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.VoiceChannelCleanupService')
|
||||
logger = logging.getLogger(f"{__name__}.VoiceChannelCleanupService")
|
||||
|
||||
|
||||
class VoiceChannelCleanupService:
|
||||
@ -27,7 +28,9 @@ class VoiceChannelCleanupService:
|
||||
- Automatic scorecard unpublishing when voice channel is cleaned up
|
||||
"""
|
||||
|
||||
def __init__(self, bot: commands.Bot, data_file: str = "data/voice_channels.json"):
|
||||
def __init__(
|
||||
self, bot: commands.Bot, data_file: str = "storage/voice_channels.json"
|
||||
):
|
||||
"""
|
||||
Initialize the cleanup service.
|
||||
|
||||
@ -36,10 +39,10 @@ class VoiceChannelCleanupService:
|
||||
data_file: Path to the JSON data file for persistence
|
||||
"""
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.VoiceChannelCleanupService')
|
||||
self.logger = get_contextual_logger(f"{__name__}.VoiceChannelCleanupService")
|
||||
self.tracker = VoiceChannelTracker(data_file)
|
||||
self.scorecard_tracker = ScorecardTracker()
|
||||
self.empty_threshold = 5 # Delete after 5 minutes empty
|
||||
self.empty_threshold = 5 # Delete after 5 minutes empty
|
||||
|
||||
# Start the cleanup task - @before_loop will wait for bot readiness
|
||||
self.cleanup_loop.start()
|
||||
@ -90,13 +93,17 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.warning(f"Guild {guild_id} not found, removing channel {channel_data['name']}")
|
||||
self.logger.warning(
|
||||
f"Guild {guild_id} not found, removing channel {channel_data['name']}"
|
||||
)
|
||||
channels_to_remove.append(channel_id)
|
||||
continue
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.warning(f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists"
|
||||
)
|
||||
channels_to_remove.append(channel_id)
|
||||
continue
|
||||
|
||||
@ -121,18 +128,26 @@ class VoiceChannelCleanupService:
|
||||
if channel_data and channel_data.get("text_channel_id"):
|
||||
try:
|
||||
text_channel_id_int = int(channel_data["text_channel_id"])
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = await self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in stale voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in stale voice channel data: {e}"
|
||||
)
|
||||
|
||||
# Also clean up any additional stale entries
|
||||
stale_removed = self.tracker.cleanup_stale_entries(valid_channel_ids)
|
||||
total_removed = len(channels_to_remove) + stale_removed
|
||||
|
||||
if total_removed > 0:
|
||||
self.logger.info(f"Cleaned up {total_removed} stale channel tracking entries")
|
||||
self.logger.info(
|
||||
f"Cleaned up {total_removed} stale channel tracking entries"
|
||||
)
|
||||
|
||||
self.logger.info(f"Verified {len(valid_channel_ids)} valid tracked channels")
|
||||
|
||||
@ -149,10 +164,14 @@ class VoiceChannelCleanupService:
|
||||
await self.update_all_channel_statuses(bot)
|
||||
|
||||
# Get channels ready for cleanup
|
||||
channels_for_cleanup = self.tracker.get_channels_for_cleanup(self.empty_threshold)
|
||||
channels_for_cleanup = self.tracker.get_channels_for_cleanup(
|
||||
self.empty_threshold
|
||||
)
|
||||
|
||||
if channels_for_cleanup:
|
||||
self.logger.info(f"Found {len(channels_for_cleanup)} channels ready for cleanup")
|
||||
self.logger.info(
|
||||
f"Found {len(channels_for_cleanup)} channels ready for cleanup"
|
||||
)
|
||||
|
||||
# Delete empty channels
|
||||
for channel_data in channels_for_cleanup:
|
||||
@ -182,12 +201,16 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.debug(f"Guild {guild_id} not found for channel {channel_data['name']}")
|
||||
self.logger.debug(
|
||||
f"Guild {guild_id} not found for channel {channel_data['name']}"
|
||||
)
|
||||
return
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.debug(f"Channel {channel_data['name']} no longer exists, removing from tracking")
|
||||
self.logger.debug(
|
||||
f"Channel {channel_data['name']} no longer exists, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
@ -195,17 +218,27 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = (
|
||||
await self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in manually deleted voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in manually deleted voice channel data: {e}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Ensure it's a voice channel before checking members
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
self.logger.warning(f"Channel {channel_data['name']} is not a voice channel, removing from tracking")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_data['name']} is not a voice channel, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
@ -213,11 +246,19 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = (
|
||||
await self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in wrong channel type data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in wrong channel type data: {e}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@ -225,11 +266,15 @@ class VoiceChannelCleanupService:
|
||||
is_empty = len(channel.members) == 0
|
||||
self.tracker.update_channel_status(channel_id, is_empty)
|
||||
|
||||
self.logger.debug(f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
|
||||
f"({len(channel.members)} members)")
|
||||
self.logger.debug(
|
||||
f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
|
||||
f"({len(channel.members)} members)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}")
|
||||
self.logger.error(
|
||||
f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
async def cleanup_channel(self, bot: commands.Bot, channel_data: dict) -> None:
|
||||
"""
|
||||
@ -246,25 +291,33 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.info(f"Guild {guild_id} not found, removing tracking for {channel_name}")
|
||||
self.logger.info(
|
||||
f"Guild {guild_id} not found, removing tracking for {channel_name}"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.info(f"Channel {channel_name} already deleted, removing from tracking")
|
||||
self.logger.info(
|
||||
f"Channel {channel_name} already deleted, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
# Ensure it's a voice channel before checking members
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
self.logger.warning(f"Channel {channel_name} is not a voice channel, removing from tracking")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_name} is not a voice channel, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
# Final check: make sure channel is still empty before deleting
|
||||
if len(channel.members) > 0:
|
||||
self.logger.info(f"Channel {channel_name} is no longer empty, skipping cleanup")
|
||||
self.logger.info(
|
||||
f"Channel {channel_name} is no longer empty, skipping cleanup"
|
||||
)
|
||||
self.tracker.update_channel_status(channel_id, False)
|
||||
return
|
||||
|
||||
@ -272,24 +325,36 @@ class VoiceChannelCleanupService:
|
||||
await channel.delete(reason="Automatic cleanup - empty for 5+ minutes")
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
self.logger.info(f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})")
|
||||
self.logger.info(
|
||||
f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})"
|
||||
)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
text_channel_id = channel_data.get("text_channel_id")
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = await self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"No scorecard found for text channel {text_channel_id_int}")
|
||||
self.logger.debug(
|
||||
f"No scorecard found for text channel {text_channel_id_int}"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in voice channel data: {e}"
|
||||
)
|
||||
|
||||
except discord.NotFound:
|
||||
# Channel was already deleted
|
||||
self.logger.info(f"Channel {channel_data.get('name', 'unknown')} was already deleted")
|
||||
self.logger.info(
|
||||
f"Channel {channel_data.get('name', 'unknown')} was already deleted"
|
||||
)
|
||||
self.tracker.remove_channel(int(channel_data["channel_id"]))
|
||||
|
||||
# Still try to unpublish associated scorecard
|
||||
@ -297,15 +362,25 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = await self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in voice channel data: {e}"
|
||||
)
|
||||
except discord.Forbidden:
|
||||
self.logger.error(f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}")
|
||||
self.logger.error(
|
||||
f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}")
|
||||
self.logger.error(
|
||||
f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
def get_tracker(self) -> VoiceChannelTracker:
|
||||
"""
|
||||
@ -330,7 +405,7 @@ class VoiceChannelCleanupService:
|
||||
"running": self.cleanup_loop.is_running(),
|
||||
"total_tracked": len(all_channels),
|
||||
"empty_channels": len(empty_channels),
|
||||
"empty_threshold": self.empty_threshold
|
||||
"empty_threshold": self.empty_threshold,
|
||||
}
|
||||
|
||||
|
||||
@ -344,4 +419,4 @@ def setup_voice_cleanup(bot: commands.Bot) -> VoiceChannelCleanupService:
|
||||
Returns:
|
||||
VoiceChannelCleanupService instance
|
||||
"""
|
||||
return VoiceChannelCleanupService(bot)
|
||||
return VoiceChannelCleanupService(bot)
|
||||
|
||||
@ -3,6 +3,7 @@ Voice Channel Tracker
|
||||
|
||||
Provides persistent tracking of bot-created voice channels using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, UTC
|
||||
@ -11,7 +12,7 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.VoiceChannelTracker')
|
||||
logger = logging.getLogger(f"{__name__}.VoiceChannelTracker")
|
||||
|
||||
|
||||
class VoiceChannelTracker:
|
||||
@ -25,7 +26,7 @@ class VoiceChannelTracker:
|
||||
- Automatic stale entry removal
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/voice_channels.json"):
|
||||
def __init__(self, data_file: str = "storage/voice_channels.json"):
|
||||
"""
|
||||
Initialize the voice channel tracker.
|
||||
|
||||
@ -41,9 +42,11 @@ class VoiceChannelTracker:
|
||||
"""Load channel data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels")
|
||||
logger.debug(
|
||||
f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels"
|
||||
)
|
||||
else:
|
||||
self._data = {"voice_channels": {}}
|
||||
logger.info("No existing voice channel data found, starting fresh")
|
||||
@ -54,7 +57,7 @@ class VoiceChannelTracker:
|
||||
def save_data(self) -> None:
|
||||
"""Save channel data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Voice channel data saved successfully")
|
||||
except Exception as e:
|
||||
@ -65,7 +68,7 @@ class VoiceChannelTracker:
|
||||
channel: discord.VoiceChannel,
|
||||
channel_type: str,
|
||||
creator_id: int,
|
||||
text_channel_id: Optional[int] = None
|
||||
text_channel_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a new channel to tracking.
|
||||
@ -85,7 +88,7 @@ class VoiceChannelTracker:
|
||||
"last_checked": datetime.now(UTC).isoformat(),
|
||||
"empty_since": None,
|
||||
"creator_id": str(creator_id),
|
||||
"text_channel_id": str(text_channel_id) if text_channel_id else None
|
||||
"text_channel_id": str(text_channel_id) if text_channel_id else None,
|
||||
}
|
||||
self.save_data()
|
||||
logger.info(f"Added channel to tracking: {channel.name} (ID: {channel.id})")
|
||||
@ -130,9 +133,13 @@ class VoiceChannelTracker:
|
||||
channel_name = channels[channel_key]["name"]
|
||||
del channels[channel_key]
|
||||
self.save_data()
|
||||
logger.info(f"Removed channel from tracking: {channel_name} (ID: {channel_id})")
|
||||
logger.info(
|
||||
f"Removed channel from tracking: {channel_name} (ID: {channel_id})"
|
||||
)
|
||||
|
||||
def get_channels_for_cleanup(self, empty_threshold_minutes: int = 15) -> List[Dict[str, Any]]:
|
||||
def get_channels_for_cleanup(
|
||||
self, empty_threshold_minutes: int = 15
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get channels that should be deleted based on empty duration.
|
||||
|
||||
@ -153,10 +160,12 @@ class VoiceChannelTracker:
|
||||
# Parse empty_since timestamp
|
||||
empty_since_str = channel_data["empty_since"]
|
||||
# Handle both with and without timezone info
|
||||
if empty_since_str.endswith('Z'):
|
||||
empty_since_str = empty_since_str[:-1] + '+00:00'
|
||||
if empty_since_str.endswith("Z"):
|
||||
empty_since_str = empty_since_str[:-1] + "+00:00"
|
||||
|
||||
empty_since = datetime.fromisoformat(empty_since_str.replace('Z', '+00:00'))
|
||||
empty_since = datetime.fromisoformat(
|
||||
empty_since_str.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
# Remove timezone info for comparison (both times are UTC)
|
||||
if empty_since.tzinfo:
|
||||
@ -164,10 +173,14 @@ class VoiceChannelTracker:
|
||||
|
||||
if empty_since <= cutoff_time:
|
||||
cleanup_candidates.append(channel_data)
|
||||
logger.debug(f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})")
|
||||
logger.debug(
|
||||
f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})"
|
||||
)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}")
|
||||
logger.warning(
|
||||
f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
return cleanup_candidates
|
||||
|
||||
@ -242,9 +255,11 @@ class VoiceChannelTracker:
|
||||
for channel_id_str in stale_entries:
|
||||
channel_name = channels[channel_id_str].get("name", "unknown")
|
||||
del channels[channel_id_str]
|
||||
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})")
|
||||
logger.info(
|
||||
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})"
|
||||
)
|
||||
|
||||
if stale_entries:
|
||||
self.save_data()
|
||||
|
||||
return len(stale_entries)
|
||||
return len(stale_entries)
|
||||
|
||||
@ -36,8 +36,11 @@ services:
|
||||
|
||||
# Volume mounts
|
||||
volumes:
|
||||
# Google Sheets credentials (required)
|
||||
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data}:/app/data:ro
|
||||
# Google Sheets credentials (read-only, file mount)
|
||||
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data/major-domo-service-creds.json}:/app/data/major-domo-service-creds.json:ro
|
||||
|
||||
# Runtime state files (writable) - scorecards, voice channels, trade channels, soak data
|
||||
- ${STATE_HOST_PATH:-./storage}:/app/storage:rw
|
||||
|
||||
# Logs directory (persistent) - mounted to /app/logs where the application expects it
|
||||
- ${LOGS_HOST_PATH:-./logs}:/app/logs:rw
|
||||
|
||||
@ -3,6 +3,7 @@ Current league state model
|
||||
|
||||
Represents the current state of the league including week, season, and settings.
|
||||
"""
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from models.base import SBABaseModel
|
||||
@ -10,38 +11,45 @@ from models.base import SBABaseModel
|
||||
|
||||
class Current(SBABaseModel):
|
||||
"""Model representing current league state and settings."""
|
||||
|
||||
|
||||
week: int = Field(69, description="Current week number")
|
||||
season: int = Field(69, description="Current season number")
|
||||
freeze: bool = Field(True, description="Whether league is frozen")
|
||||
bet_week: str = Field('sheets', description="Betting week identifier")
|
||||
bet_week: str = Field("sheets", description="Betting week identifier")
|
||||
trade_deadline: int = Field(1, description="Trade deadline week")
|
||||
pick_trade_start: int = Field(69, description="Draft pick trading start week")
|
||||
pick_trade_end: int = Field(420, description="Draft pick trading end week")
|
||||
playoffs_begin: int = Field(420, description="Week when playoffs begin")
|
||||
|
||||
|
||||
@field_validator("bet_week", mode="before")
|
||||
@classmethod
|
||||
def cast_bet_week_to_string(cls, v):
|
||||
"""Ensure bet_week is always a string."""
|
||||
return str(v) if v is not None else 'sheets'
|
||||
|
||||
return str(v) if v is not None else "sheets"
|
||||
|
||||
@property
|
||||
def is_offseason(self) -> bool:
|
||||
"""Check if league is currently in offseason."""
|
||||
return self.week > 18
|
||||
|
||||
|
||||
@property
|
||||
def is_playoffs(self) -> bool:
|
||||
"""Check if league is currently in playoffs."""
|
||||
return self.week >= self.playoffs_begin
|
||||
|
||||
|
||||
@property
|
||||
def can_trade_picks(self) -> bool:
|
||||
"""Check if draft pick trading is currently allowed."""
|
||||
return self.pick_trade_start <= self.week <= self.pick_trade_end
|
||||
|
||||
|
||||
@property
|
||||
def ever_trade_picks(self) -> bool:
|
||||
"""Check if draft pick trading is allowed this season at all"""
|
||||
return self.pick_trade_start <= self.playoffs_begin + 4
|
||||
return self.pick_trade_start <= self.playoffs_begin + 4
|
||||
|
||||
@property
|
||||
def is_past_trade_deadline(self) -> bool:
|
||||
"""Check if the trade deadline has passed."""
|
||||
if self.is_offseason:
|
||||
return False
|
||||
return self.week > self.trade_deadline
|
||||
|
||||
@ -4,6 +4,7 @@ Chart Service for managing gameplay charts and infographics.
|
||||
This service handles loading, saving, and managing chart definitions
|
||||
from the JSON configuration file.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -18,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class Chart:
|
||||
"""Represents a gameplay chart or infographic."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
category: str
|
||||
@ -27,17 +29,17 @@ class Chart:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert chart to dictionary (excluding key)."""
|
||||
return {
|
||||
'name': self.name,
|
||||
'category': self.category,
|
||||
'description': self.description,
|
||||
'urls': self.urls
|
||||
"name": self.name,
|
||||
"category": self.category,
|
||||
"description": self.description,
|
||||
"urls": self.urls,
|
||||
}
|
||||
|
||||
|
||||
class ChartService:
|
||||
"""Service for managing gameplay charts and infographics."""
|
||||
|
||||
CHARTS_FILE = Path(__file__).parent.parent / 'data' / 'charts.json'
|
||||
CHARTS_FILE = Path(__file__).parent.parent / "storage" / "charts.json"
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the chart service."""
|
||||
@ -54,21 +56,21 @@ class ChartService:
|
||||
self._categories = {}
|
||||
return
|
||||
|
||||
with open(self.CHARTS_FILE, 'r') as f:
|
||||
with open(self.CHARTS_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load categories
|
||||
self._categories = data.get('categories', {})
|
||||
self._categories = data.get("categories", {})
|
||||
|
||||
# Load charts
|
||||
charts_data = data.get('charts', {})
|
||||
charts_data = data.get("charts", {})
|
||||
for key, chart_data in charts_data.items():
|
||||
self._charts[key] = Chart(
|
||||
key=key,
|
||||
name=chart_data['name'],
|
||||
category=chart_data['category'],
|
||||
description=chart_data.get('description', ''),
|
||||
urls=chart_data.get('urls', [])
|
||||
name=chart_data["name"],
|
||||
category=chart_data["category"],
|
||||
description=chart_data.get("description", ""),
|
||||
urls=chart_data.get("urls", []),
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(self._charts)} charts from {self.CHARTS_FILE}")
|
||||
@ -81,20 +83,17 @@ class ChartService:
|
||||
def _save_charts(self) -> None:
|
||||
"""Save charts to JSON file."""
|
||||
try:
|
||||
# Ensure data directory exists
|
||||
# Ensure storage directory exists
|
||||
self.CHARTS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build data structure
|
||||
data = {
|
||||
'charts': {
|
||||
key: chart.to_dict()
|
||||
for key, chart in self._charts.items()
|
||||
},
|
||||
'categories': self._categories
|
||||
"charts": {key: chart.to_dict() for key, chart in self._charts.items()},
|
||||
"categories": self._categories,
|
||||
}
|
||||
|
||||
# Write to file
|
||||
with open(self.CHARTS_FILE, 'w') as f:
|
||||
with open(self.CHARTS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(self._charts)} charts to {self.CHARTS_FILE}")
|
||||
@ -134,10 +133,7 @@ class ChartService:
|
||||
Returns:
|
||||
List of charts in the specified category
|
||||
"""
|
||||
return [
|
||||
chart for chart in self._charts.values()
|
||||
if chart.category == category
|
||||
]
|
||||
return [chart for chart in self._charts.values() if chart.category == category]
|
||||
|
||||
def get_chart_keys(self) -> List[str]:
|
||||
"""
|
||||
@ -157,8 +153,9 @@ class ChartService:
|
||||
"""
|
||||
return self._categories.copy()
|
||||
|
||||
def add_chart(self, key: str, name: str, category: str,
|
||||
urls: List[str], description: str = "") -> None:
|
||||
def add_chart(
|
||||
self, key: str, name: str, category: str, urls: List[str], description: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Add a new chart.
|
||||
|
||||
@ -176,18 +173,19 @@ class ChartService:
|
||||
raise BotException(f"Chart '{key}' already exists")
|
||||
|
||||
self._charts[key] = Chart(
|
||||
key=key,
|
||||
name=name,
|
||||
category=category,
|
||||
description=description,
|
||||
urls=urls
|
||||
key=key, name=name, category=category, description=description, urls=urls
|
||||
)
|
||||
self._save_charts()
|
||||
logger.info(f"Added chart: {key}")
|
||||
|
||||
def update_chart(self, key: str, name: Optional[str] = None,
|
||||
category: Optional[str] = None, urls: Optional[List[str]] = None,
|
||||
description: Optional[str] = None) -> None:
|
||||
def update_chart(
|
||||
self,
|
||||
key: str,
|
||||
name: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
urls: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update an existing chart.
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Custom Commands Service for Discord Bot v2.0
|
||||
Modern async service layer for managing custom commands with full type safety.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional, List, Any, Tuple
|
||||
@ -119,8 +120,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
content_length=len(content),
|
||||
)
|
||||
|
||||
# Return full command with creator info
|
||||
return await self.get_command_by_name(name)
|
||||
# Return command with creator info (use POST response directly)
|
||||
return result.model_copy(update={"creator": creator})
|
||||
|
||||
async def get_command_by_name(self, name: str) -> CustomCommand:
|
||||
"""
|
||||
@ -217,7 +218,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
new_content_length=len(new_content),
|
||||
)
|
||||
|
||||
return await self.get_command_by_name(name)
|
||||
# Return updated command with creator info (use PUT response directly)
|
||||
return result.model_copy(update={"creator": command.creator})
|
||||
|
||||
async def delete_command(
|
||||
self, name: str, deleter_discord_id: int, force: bool = False
|
||||
@ -466,21 +468,28 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
|
||||
commands_data = await self.get_items_with_params(params)
|
||||
|
||||
creators = await asyncio.gather(
|
||||
*[
|
||||
self.get_creator_by_id(cmd_data.creator_id)
|
||||
for cmd_data in commands_data
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
commands = []
|
||||
for cmd_data in commands_data:
|
||||
try:
|
||||
creator = await self.get_creator_by_id(cmd_data.creator_id)
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
except BotException as e:
|
||||
# Handle missing creator gracefully
|
||||
for cmd_data, creator in zip(commands_data, creators):
|
||||
if isinstance(creator, BotException):
|
||||
self.logger.warning(
|
||||
"Skipping popular command with missing creator",
|
||||
command_id=cmd_data.id,
|
||||
command_name=cmd_data.name,
|
||||
creator_id=cmd_data.creator_id,
|
||||
error=str(e),
|
||||
error=str(creator),
|
||||
)
|
||||
continue
|
||||
if isinstance(creator, BaseException):
|
||||
raise creator
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
|
||||
return commands
|
||||
|
||||
@ -536,7 +545,9 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
# Update username if it changed
|
||||
if creator.username != username or creator.display_name != display_name:
|
||||
await self._update_creator_info(creator.id, username, display_name)
|
||||
creator = await self.get_creator_by_discord_id(discord_id)
|
||||
creator = creator.model_copy(
|
||||
update={"username": username, "display_name": display_name}
|
||||
)
|
||||
return creator
|
||||
except BotException:
|
||||
# Creator doesn't exist, create new one
|
||||
@ -557,7 +568,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
if not result:
|
||||
raise BotException("Failed to create command creator")
|
||||
|
||||
return await self.get_creator_by_discord_id(discord_id)
|
||||
# Return created creator directly from POST response
|
||||
return CustomCommandCreator(**result)
|
||||
|
||||
async def get_creator_by_discord_id(self, discord_id: int) -> CustomCommandCreator:
|
||||
"""Get creator by Discord ID.
|
||||
@ -610,31 +622,34 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
|
||||
async def get_statistics(self) -> CustomCommandStats:
|
||||
"""Get comprehensive statistics about custom commands."""
|
||||
# Get basic counts
|
||||
total_commands = await self._get_search_count([])
|
||||
active_commands = await self._get_search_count([("is_active", True)])
|
||||
total_creators = await self._get_creator_count()
|
||||
|
||||
# Get total uses
|
||||
all_commands = await self.get_items_with_params([("is_active", True)])
|
||||
total_uses = sum(cmd.use_count for cmd in all_commands)
|
||||
|
||||
# Get most popular command
|
||||
popular_commands = await self.get_popular_commands(limit=1)
|
||||
most_popular = popular_commands[0] if popular_commands else None
|
||||
|
||||
# Get most active creator
|
||||
most_active_creator = await self._get_most_active_creator()
|
||||
|
||||
# Get recent commands count
|
||||
week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
recent_count = await self._get_search_count(
|
||||
[("created_at__gte", week_ago.isoformat()), ("is_active", True)]
|
||||
|
||||
(
|
||||
total_commands,
|
||||
active_commands,
|
||||
total_creators,
|
||||
all_commands,
|
||||
popular_commands,
|
||||
most_active_creator,
|
||||
recent_count,
|
||||
warning_count,
|
||||
deletion_count,
|
||||
) = await asyncio.gather(
|
||||
self._get_search_count([]),
|
||||
self._get_search_count([("is_active", True)]),
|
||||
self._get_creator_count(),
|
||||
self.get_items_with_params([("is_active", True)]),
|
||||
self.get_popular_commands(limit=1),
|
||||
self._get_most_active_creator(),
|
||||
self._get_search_count(
|
||||
[("created_at__gte", week_ago.isoformat()), ("is_active", True)]
|
||||
),
|
||||
self._get_commands_needing_warning_count(),
|
||||
self._get_commands_eligible_for_deletion_count(),
|
||||
)
|
||||
|
||||
# Get cleanup statistics
|
||||
warning_count = await self._get_commands_needing_warning_count()
|
||||
deletion_count = await self._get_commands_eligible_for_deletion_count()
|
||||
total_uses = sum(cmd.use_count for cmd in all_commands)
|
||||
most_popular = popular_commands[0] if popular_commands else None
|
||||
|
||||
return CustomCommandStats(
|
||||
total_commands=total_commands,
|
||||
@ -662,21 +677,28 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
|
||||
commands_data = await self.get_items_with_params(params)
|
||||
|
||||
creators = await asyncio.gather(
|
||||
*[
|
||||
self.get_creator_by_id(cmd_data.creator_id)
|
||||
for cmd_data in commands_data
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
commands = []
|
||||
for cmd_data in commands_data:
|
||||
try:
|
||||
creator = await self.get_creator_by_id(cmd_data.creator_id)
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
except BotException as e:
|
||||
# Handle missing creator gracefully
|
||||
for cmd_data, creator in zip(commands_data, creators):
|
||||
if isinstance(creator, BotException):
|
||||
self.logger.warning(
|
||||
"Skipping command with missing creator",
|
||||
command_id=cmd_data.id,
|
||||
command_name=cmd_data.name,
|
||||
creator_id=cmd_data.creator_id,
|
||||
error=str(e),
|
||||
error=str(creator),
|
||||
)
|
||||
continue
|
||||
if isinstance(creator, BaseException):
|
||||
raise creator
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
|
||||
return commands
|
||||
|
||||
@ -688,21 +710,28 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
||||
|
||||
commands_data = await self.get_items_with_params(params)
|
||||
|
||||
creators = await asyncio.gather(
|
||||
*[
|
||||
self.get_creator_by_id(cmd_data.creator_id)
|
||||
for cmd_data in commands_data
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
commands = []
|
||||
for cmd_data in commands_data:
|
||||
try:
|
||||
creator = await self.get_creator_by_id(cmd_data.creator_id)
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
except BotException as e:
|
||||
# Handle missing creator gracefully
|
||||
for cmd_data, creator in zip(commands_data, creators):
|
||||
if isinstance(creator, BotException):
|
||||
self.logger.warning(
|
||||
"Skipping command with missing creator",
|
||||
command_id=cmd_data.id,
|
||||
command_name=cmd_data.name,
|
||||
creator_id=cmd_data.creator_id,
|
||||
error=str(e),
|
||||
error=str(creator),
|
||||
)
|
||||
continue
|
||||
if isinstance(creator, BaseException):
|
||||
raise creator
|
||||
commands.append(CustomCommand(**cmd_data.model_dump(), creator=creator))
|
||||
|
||||
return commands
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Decision Service
|
||||
Manages pitching decision operations for game submission.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
from utils.logging import get_contextual_logger
|
||||
@ -124,22 +125,19 @@ class DecisionService:
|
||||
if int(decision.get("b_save", 0)) == 1:
|
||||
bsv_ids.append(pitcher_id)
|
||||
|
||||
# Second pass: Fetch Player objects
|
||||
wp = await player_service.get_player(wp_id) if wp_id else None
|
||||
lp = await player_service.get_player(lp_id) if lp_id else None
|
||||
sv = await player_service.get_player(sv_id) if sv_id else None
|
||||
# Second pass: Fetch all Player objects in parallel
|
||||
# Order: [wp_id, lp_id, sv_id, *hold_ids, *bsv_ids]; None IDs resolve immediately
|
||||
ordered_ids = [wp_id, lp_id, sv_id] + hold_ids + bsv_ids
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
player_service.get_player(pid) if pid else asyncio.sleep(0, result=None)
|
||||
for pid in ordered_ids
|
||||
]
|
||||
)
|
||||
|
||||
holders = []
|
||||
for hold_id in hold_ids:
|
||||
holder = await player_service.get_player(hold_id)
|
||||
if holder:
|
||||
holders.append(holder)
|
||||
|
||||
blown_saves = []
|
||||
for bsv_id in bsv_ids:
|
||||
bsv = await player_service.get_player(bsv_id)
|
||||
if bsv:
|
||||
blown_saves.append(bsv)
|
||||
wp, lp, sv = results[0], results[1], results[2]
|
||||
holders = [p for p in results[3 : 3 + len(hold_ids)] if p]
|
||||
blown_saves = [p for p in results[3 + len(hold_ids) :] if p]
|
||||
|
||||
return wp, lp, sv, holders, blown_saves
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ Modern async service layer for managing help commands with full type safety.
|
||||
Allows admins and help editors to create custom help topics for league documentation,
|
||||
resources, FAQs, links, and guides.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
@ -12,7 +13,7 @@ from models.help_command import (
|
||||
HelpCommand,
|
||||
HelpCommandSearchFilters,
|
||||
HelpCommandSearchResult,
|
||||
HelpCommandStats
|
||||
HelpCommandStats,
|
||||
)
|
||||
from services.base_service import BaseService
|
||||
from exceptions import BotException
|
||||
@ -20,16 +21,19 @@ from exceptions import BotException
|
||||
|
||||
class HelpCommandNotFoundError(BotException):
|
||||
"""Raised when a help command is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HelpCommandExistsError(BotException):
|
||||
"""Raised when trying to create a help command that already exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HelpCommandPermissionError(BotException):
|
||||
"""Raised when user lacks permission for help command operation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -37,8 +41,8 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
"""Service for managing help commands."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(HelpCommand, 'help_commands')
|
||||
self.logger = get_contextual_logger(f'{__name__}.HelpCommandsService')
|
||||
super().__init__(HelpCommand, "help_commands")
|
||||
self.logger = get_contextual_logger(f"{__name__}.HelpCommandsService")
|
||||
self.logger.info("HelpCommandsService initialized")
|
||||
|
||||
# === Command CRUD Operations ===
|
||||
@ -50,7 +54,7 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
content: str,
|
||||
creator_discord_id: str,
|
||||
category: Optional[str] = None,
|
||||
display_order: int = 0
|
||||
display_order: int = 0,
|
||||
) -> HelpCommand:
|
||||
"""
|
||||
Create a new help command.
|
||||
@ -80,14 +84,16 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
|
||||
# Create help command data
|
||||
help_data = {
|
||||
'name': name.lower().strip(),
|
||||
'title': title.strip(),
|
||||
'content': content.strip(),
|
||||
'category': category.lower().strip() if category else None,
|
||||
'created_by_discord_id': str(creator_discord_id), # Convert to string for safe storage
|
||||
'display_order': display_order,
|
||||
'is_active': True,
|
||||
'view_count': 0
|
||||
"name": name.lower().strip(),
|
||||
"title": title.strip(),
|
||||
"content": content.strip(),
|
||||
"category": category.lower().strip() if category else None,
|
||||
"created_by_discord_id": str(
|
||||
creator_discord_id
|
||||
), # Convert to string for safe storage
|
||||
"display_order": display_order,
|
||||
"is_active": True,
|
||||
"view_count": 0,
|
||||
}
|
||||
|
||||
# Create via API
|
||||
@ -95,18 +101,18 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
if not result:
|
||||
raise BotException("Failed to create help command")
|
||||
|
||||
self.logger.info("Help command created",
|
||||
help_name=name,
|
||||
creator_id=creator_discord_id,
|
||||
category=category)
|
||||
self.logger.info(
|
||||
"Help command created",
|
||||
help_name=name,
|
||||
creator_id=creator_discord_id,
|
||||
category=category,
|
||||
)
|
||||
|
||||
# Return full help command
|
||||
return await self.get_help_by_name(name)
|
||||
# Return help command directly from POST response
|
||||
return result
|
||||
|
||||
async def get_help_by_name(
|
||||
self,
|
||||
name: str,
|
||||
include_inactive: bool = False
|
||||
self, name: str, include_inactive: bool = False
|
||||
) -> HelpCommand:
|
||||
"""
|
||||
Get a help command by name.
|
||||
@ -126,8 +132,12 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
try:
|
||||
# Use the dedicated by_name endpoint for exact lookup
|
||||
client = await self.get_client()
|
||||
params = [('include_inactive', include_inactive)] if include_inactive else []
|
||||
data = await client.get(f'help_commands/by_name/{normalized_name}', params=params)
|
||||
params = (
|
||||
[("include_inactive", include_inactive)] if include_inactive else []
|
||||
)
|
||||
data = await client.get(
|
||||
f"help_commands/by_name/{normalized_name}", params=params
|
||||
)
|
||||
|
||||
if not data:
|
||||
raise HelpCommandNotFoundError(f"Help topic '{name}' not found")
|
||||
@ -139,9 +149,9 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
if "404" in str(e) or "not found" in str(e).lower():
|
||||
raise HelpCommandNotFoundError(f"Help topic '{name}' not found")
|
||||
else:
|
||||
self.logger.error("Failed to get help command by name",
|
||||
help_name=name,
|
||||
error=e)
|
||||
self.logger.error(
|
||||
"Failed to get help command by name", help_name=name, error=e
|
||||
)
|
||||
raise BotException(f"Failed to retrieve help topic '{name}': {e}")
|
||||
|
||||
async def update_help(
|
||||
@ -151,7 +161,7 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
new_content: Optional[str] = None,
|
||||
updater_discord_id: Optional[str] = None,
|
||||
new_category: Optional[str] = None,
|
||||
new_display_order: Optional[int] = None
|
||||
new_display_order: Optional[int] = None,
|
||||
) -> HelpCommand:
|
||||
"""
|
||||
Update an existing help command.
|
||||
@ -176,35 +186,42 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
update_data = {}
|
||||
|
||||
if new_title is not None:
|
||||
update_data['title'] = new_title.strip()
|
||||
update_data["title"] = new_title.strip()
|
||||
|
||||
if new_content is not None:
|
||||
update_data['content'] = new_content.strip()
|
||||
update_data["content"] = new_content.strip()
|
||||
|
||||
if new_category is not None:
|
||||
update_data['category'] = new_category.lower().strip() if new_category else None
|
||||
update_data["category"] = (
|
||||
new_category.lower().strip() if new_category else None
|
||||
)
|
||||
|
||||
if new_display_order is not None:
|
||||
update_data['display_order'] = new_display_order
|
||||
update_data["display_order"] = new_display_order
|
||||
|
||||
if updater_discord_id is not None:
|
||||
update_data['last_modified_by'] = str(updater_discord_id) # Convert to string for safe storage
|
||||
update_data["last_modified_by"] = str(
|
||||
updater_discord_id
|
||||
) # Convert to string for safe storage
|
||||
|
||||
if not update_data:
|
||||
raise BotException("No fields to update")
|
||||
|
||||
# Update via API
|
||||
client = await self.get_client()
|
||||
result = await client.put(f'help_commands/{help_cmd.id}', update_data)
|
||||
result = await client.put(f"help_commands/{help_cmd.id}", update_data)
|
||||
if not result:
|
||||
raise BotException("Failed to update help command")
|
||||
|
||||
self.logger.info("Help command updated",
|
||||
help_name=name,
|
||||
updater_id=updater_discord_id,
|
||||
fields_updated=list(update_data.keys()))
|
||||
self.logger.info(
|
||||
"Help command updated",
|
||||
help_name=name,
|
||||
updater_id=updater_discord_id,
|
||||
fields_updated=list(update_data.keys()),
|
||||
)
|
||||
|
||||
return await self.get_help_by_name(name)
|
||||
# Return updated help command directly from PUT response
|
||||
return self.model_class.from_api_data(result)
|
||||
|
||||
async def delete_help(self, name: str) -> bool:
|
||||
"""
|
||||
@ -223,11 +240,11 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
|
||||
# Soft delete via API
|
||||
client = await self.get_client()
|
||||
await client.delete(f'help_commands/{help_cmd.id}')
|
||||
await client.delete(f"help_commands/{help_cmd.id}")
|
||||
|
||||
self.logger.info("Help command soft deleted",
|
||||
help_name=name,
|
||||
help_id=help_cmd.id)
|
||||
self.logger.info(
|
||||
"Help command soft deleted", help_name=name, help_id=help_cmd.id
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@ -252,13 +269,11 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
|
||||
# Restore via API
|
||||
client = await self.get_client()
|
||||
result = await client.patch(f'help_commands/{help_cmd.id}/restore')
|
||||
result = await client.patch(f"help_commands/{help_cmd.id}/restore")
|
||||
if not result:
|
||||
raise BotException("Failed to restore help command")
|
||||
|
||||
self.logger.info("Help command restored",
|
||||
help_name=name,
|
||||
help_id=help_cmd.id)
|
||||
self.logger.info("Help command restored", help_name=name, help_id=help_cmd.id)
|
||||
|
||||
return self.model_class.from_api_data(result)
|
||||
|
||||
@ -279,10 +294,9 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
|
||||
try:
|
||||
client = await self.get_client()
|
||||
await client.patch(f'help_commands/by_name/{normalized_name}/view')
|
||||
await client.patch(f"help_commands/by_name/{normalized_name}/view")
|
||||
|
||||
self.logger.debug("Help command view count incremented",
|
||||
help_name=name)
|
||||
self.logger.debug("Help command view count incremented", help_name=name)
|
||||
|
||||
# Return updated command
|
||||
return await self.get_help_by_name(name)
|
||||
@ -291,16 +305,15 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
if "404" in str(e) or "not found" in str(e).lower():
|
||||
raise HelpCommandNotFoundError(f"Help topic '{name}' not found")
|
||||
else:
|
||||
self.logger.error("Failed to increment view count",
|
||||
help_name=name,
|
||||
error=e)
|
||||
self.logger.error(
|
||||
"Failed to increment view count", help_name=name, error=e
|
||||
)
|
||||
raise BotException(f"Failed to increment view count for '{name}': {e}")
|
||||
|
||||
# === Search and Listing ===
|
||||
|
||||
async def search_help_commands(
|
||||
self,
|
||||
filters: HelpCommandSearchFilters
|
||||
self, filters: HelpCommandSearchFilters
|
||||
) -> HelpCommandSearchResult:
|
||||
"""
|
||||
Search for help commands with filtering and pagination.
|
||||
@ -316,23 +329,23 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
|
||||
# Apply filters
|
||||
if filters.name_contains:
|
||||
params.append(('name', filters.name_contains)) # API will do ILIKE search
|
||||
params.append(("name", filters.name_contains)) # API will do ILIKE search
|
||||
|
||||
if filters.category:
|
||||
params.append(('category', filters.category))
|
||||
params.append(("category", filters.category))
|
||||
|
||||
params.append(('is_active', filters.is_active))
|
||||
params.append(("is_active", filters.is_active))
|
||||
|
||||
# Add sorting
|
||||
params.append(('sort', filters.sort_by))
|
||||
params.append(("sort", filters.sort_by))
|
||||
|
||||
# Add pagination
|
||||
params.append(('page', filters.page))
|
||||
params.append(('page_size', filters.page_size))
|
||||
params.append(("page", filters.page))
|
||||
params.append(("page_size", filters.page_size))
|
||||
|
||||
# Execute search via API
|
||||
client = await self.get_client()
|
||||
data = await client.get('help_commands', params=params)
|
||||
data = await client.get("help_commands", params=params)
|
||||
|
||||
if not data:
|
||||
return HelpCommandSearchResult(
|
||||
@ -341,14 +354,14 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
page=filters.page,
|
||||
page_size=filters.page_size,
|
||||
total_pages=0,
|
||||
has_more=False
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
# Extract response data
|
||||
help_commands_data = data.get('help_commands', [])
|
||||
total_count = data.get('total_count', 0)
|
||||
total_pages = data.get('total_pages', 0)
|
||||
has_more = data.get('has_more', False)
|
||||
help_commands_data = data.get("help_commands", [])
|
||||
total_count = data.get("total_count", 0)
|
||||
total_pages = data.get("total_pages", 0)
|
||||
has_more = data.get("has_more", False)
|
||||
|
||||
# Convert to HelpCommand objects
|
||||
help_commands = []
|
||||
@ -356,15 +369,21 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
try:
|
||||
help_commands.append(self.model_class.from_api_data(cmd_data))
|
||||
except Exception as e:
|
||||
self.logger.warning("Failed to create HelpCommand from API data",
|
||||
help_id=cmd_data.get('id'),
|
||||
error=e)
|
||||
self.logger.warning(
|
||||
"Failed to create HelpCommand from API data",
|
||||
help_id=cmd_data.get("id"),
|
||||
error=e,
|
||||
)
|
||||
continue
|
||||
|
||||
self.logger.debug("Help commands search completed",
|
||||
total_results=total_count,
|
||||
page=filters.page,
|
||||
filters_applied=len([p for p in params if p[0] not in ['sort', 'page', 'page_size']]))
|
||||
self.logger.debug(
|
||||
"Help commands search completed",
|
||||
total_results=total_count,
|
||||
page=filters.page,
|
||||
filters_applied=len(
|
||||
[p for p in params if p[0] not in ["sort", "page", "page_size"]]
|
||||
),
|
||||
)
|
||||
|
||||
return HelpCommandSearchResult(
|
||||
help_commands=help_commands,
|
||||
@ -372,13 +391,11 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
page=filters.page,
|
||||
page_size=filters.page_size,
|
||||
total_pages=total_pages,
|
||||
has_more=has_more
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def get_all_help_topics(
|
||||
self,
|
||||
category: Optional[str] = None,
|
||||
include_inactive: bool = False
|
||||
self, category: Optional[str] = None, include_inactive: bool = False
|
||||
) -> List[HelpCommand]:
|
||||
"""
|
||||
Get all help topics, optionally filtered by category.
|
||||
@ -393,37 +410,36 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
params = []
|
||||
|
||||
if category:
|
||||
params.append(('category', category))
|
||||
params.append(("category", category))
|
||||
|
||||
params.append(('is_active', not include_inactive))
|
||||
params.append(('sort', 'display_order'))
|
||||
params.append(('page_size', 100)) # Get all
|
||||
params.append(("is_active", not include_inactive))
|
||||
params.append(("sort", "display_order"))
|
||||
params.append(("page_size", 100)) # Get all
|
||||
|
||||
client = await self.get_client()
|
||||
data = await client.get('help_commands', params=params)
|
||||
data = await client.get("help_commands", params=params)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
help_commands_data = data.get('help_commands', [])
|
||||
help_commands_data = data.get("help_commands", [])
|
||||
|
||||
help_commands = []
|
||||
for cmd_data in help_commands_data:
|
||||
try:
|
||||
help_commands.append(self.model_class.from_api_data(cmd_data))
|
||||
except Exception as e:
|
||||
self.logger.warning("Failed to create HelpCommand from API data",
|
||||
help_id=cmd_data.get('id'),
|
||||
error=e)
|
||||
self.logger.warning(
|
||||
"Failed to create HelpCommand from API data",
|
||||
help_id=cmd_data.get("id"),
|
||||
error=e,
|
||||
)
|
||||
continue
|
||||
|
||||
return help_commands
|
||||
|
||||
async def get_help_names_for_autocomplete(
|
||||
self,
|
||||
partial_name: str = "",
|
||||
limit: int = 25,
|
||||
include_inactive: bool = False
|
||||
self, partial_name: str = "", limit: int = 25, include_inactive: bool = False
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get help command names for Discord autocomplete.
|
||||
@ -439,25 +455,28 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
try:
|
||||
# Use the dedicated autocomplete endpoint
|
||||
client = await self.get_client()
|
||||
params = [('limit', limit)]
|
||||
params = [("limit", limit)]
|
||||
|
||||
if partial_name:
|
||||
params.append(('q', partial_name.lower()))
|
||||
params.append(("q", partial_name.lower()))
|
||||
|
||||
result = await client.get('help_commands/autocomplete', params=params)
|
||||
result = await client.get("help_commands/autocomplete", params=params)
|
||||
|
||||
# The autocomplete endpoint returns results with name, title, category
|
||||
if isinstance(result, dict) and 'results' in result:
|
||||
return [item['name'] for item in result['results']]
|
||||
if isinstance(result, dict) and "results" in result:
|
||||
return [item["name"] for item in result["results"]]
|
||||
else:
|
||||
self.logger.warning("Unexpected autocomplete response format",
|
||||
response=result)
|
||||
self.logger.warning(
|
||||
"Unexpected autocomplete response format", response=result
|
||||
)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to get help names for autocomplete",
|
||||
partial_name=partial_name,
|
||||
error=e)
|
||||
self.logger.error(
|
||||
"Failed to get help names for autocomplete",
|
||||
partial_name=partial_name,
|
||||
error=e,
|
||||
)
|
||||
# Return empty list on error to not break Discord autocomplete
|
||||
return []
|
||||
|
||||
@ -467,7 +486,7 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
"""Get comprehensive statistics about help commands."""
|
||||
try:
|
||||
client = await self.get_client()
|
||||
data = await client.get('help_commands/stats')
|
||||
data = await client.get("help_commands/stats")
|
||||
|
||||
if not data:
|
||||
return HelpCommandStats(
|
||||
@ -475,23 +494,25 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
active_commands=0,
|
||||
total_views=0,
|
||||
most_viewed_command=None,
|
||||
recent_commands_count=0
|
||||
recent_commands_count=0,
|
||||
)
|
||||
|
||||
# Convert most_viewed_command if present
|
||||
most_viewed = None
|
||||
if data.get('most_viewed_command'):
|
||||
if data.get("most_viewed_command"):
|
||||
try:
|
||||
most_viewed = self.model_class.from_api_data(data['most_viewed_command'])
|
||||
most_viewed = self.model_class.from_api_data(
|
||||
data["most_viewed_command"]
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning("Failed to parse most viewed command", error=e)
|
||||
|
||||
return HelpCommandStats(
|
||||
total_commands=data.get('total_commands', 0),
|
||||
active_commands=data.get('active_commands', 0),
|
||||
total_views=data.get('total_views', 0),
|
||||
total_commands=data.get("total_commands", 0),
|
||||
active_commands=data.get("active_commands", 0),
|
||||
total_views=data.get("total_views", 0),
|
||||
most_viewed_command=most_viewed,
|
||||
recent_commands_count=data.get('recent_commands_count', 0)
|
||||
recent_commands_count=data.get("recent_commands_count", 0),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -502,7 +523,7 @@ class HelpCommandsService(BaseService[HelpCommand]):
|
||||
active_commands=0,
|
||||
total_views=0,
|
||||
most_viewed_command=None,
|
||||
recent_commands_count=0
|
||||
recent_commands_count=0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0
|
||||
Handles game schedule and results retrieval and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
@ -102,10 +103,10 @@ class ScheduleService:
|
||||
# If weeks not specified, try a reasonable range (18 weeks typical)
|
||||
week_range = range(1, (weeks + 1) if weeks else 19)
|
||||
|
||||
for week in week_range:
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Filter games involving this team
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in week_range]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
for game in week_games:
|
||||
if (
|
||||
game.away_team.abbrev.upper() == team_abbrev_upper
|
||||
@ -135,15 +136,13 @@ class ScheduleService:
|
||||
recent_games = []
|
||||
|
||||
# Get games from recent weeks
|
||||
for week_offset in range(weeks_back):
|
||||
# This is simplified - in production you'd want to determine current week
|
||||
week = 10 - week_offset # Assuming we're around week 10
|
||||
if week <= 0:
|
||||
break
|
||||
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Only include completed games
|
||||
weeks_to_fetch = [
|
||||
(10 - offset) for offset in range(weeks_back) if (10 - offset) > 0
|
||||
]
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in weeks_to_fetch]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
completed_games = [game for game in week_games if game.is_completed]
|
||||
recent_games.extend(completed_games)
|
||||
|
||||
@ -157,13 +156,12 @@ class ScheduleService:
|
||||
logger.error(f"Error getting recent games: {e}")
|
||||
return []
|
||||
|
||||
async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]:
|
||||
async def get_upcoming_games(self, season: int) -> List[Game]:
|
||||
"""
|
||||
Get upcoming scheduled games by scanning multiple weeks.
|
||||
Get upcoming scheduled games by scanning all weeks.
|
||||
|
||||
Args:
|
||||
season: Season number
|
||||
weeks_ahead: Number of weeks to scan ahead (default 6)
|
||||
|
||||
Returns:
|
||||
List of upcoming Game instances
|
||||
@ -171,20 +169,16 @@ class ScheduleService:
|
||||
try:
|
||||
upcoming_games = []
|
||||
|
||||
# Scan through weeks to find games without scores
|
||||
for week in range(1, 19): # Standard season length
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Find games without scores (not yet played)
|
||||
# Fetch all weeks in parallel and filter for incomplete games
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in range(1, 19)]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
upcoming_games_week = [
|
||||
game for game in week_games if not game.is_completed
|
||||
]
|
||||
upcoming_games.extend(upcoming_games_week)
|
||||
|
||||
# If we found upcoming games, we can limit how many more weeks to check
|
||||
if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit
|
||||
break
|
||||
|
||||
# Sort by week, then game number
|
||||
upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0
|
||||
Handles batting and pitching statistics retrieval and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -144,11 +145,10 @@ class StatsService:
|
||||
"""
|
||||
try:
|
||||
# Get both types of stats concurrently
|
||||
batting_task = self.get_batting_stats(player_id, season)
|
||||
pitching_task = self.get_pitching_stats(player_id, season)
|
||||
|
||||
batting_stats = await batting_task
|
||||
pitching_stats = await pitching_task
|
||||
batting_stats, pitching_stats = await asyncio.gather(
|
||||
self.get_batting_stats(player_id, season),
|
||||
self.get_pitching_stats(player_id, season),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved stats for player {player_id}: "
|
||||
|
||||
@ -4,6 +4,7 @@ Trade Builder Service
|
||||
Extends the TransactionBuilder to support multi-team trades and player exchanges.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime, timezone
|
||||
@ -524,14 +525,22 @@ class TradeBuilder:
|
||||
|
||||
# Validate each team's roster after the trade
|
||||
for participant in self.trade.participants:
|
||||
team_id = participant.team.id
|
||||
result.team_abbrevs[team_id] = participant.team.abbrev
|
||||
if team_id in self._team_builders:
|
||||
builder = self._team_builders[team_id]
|
||||
roster_validation = await builder.validate_transaction(next_week)
|
||||
result.team_abbrevs[participant.team.id] = participant.team.abbrev
|
||||
|
||||
team_ids_to_validate = [
|
||||
participant.team.id
|
||||
for participant in self.trade.participants
|
||||
if participant.team.id in self._team_builders
|
||||
]
|
||||
if team_ids_to_validate:
|
||||
validations = await asyncio.gather(
|
||||
*[
|
||||
self._team_builders[tid].validate_transaction(next_week)
|
||||
for tid in team_ids_to_validate
|
||||
]
|
||||
)
|
||||
for team_id, roster_validation in zip(team_ids_to_validate, validations):
|
||||
result.participant_validations[team_id] = roster_validation
|
||||
|
||||
if not roster_validation.is_legal:
|
||||
result.is_legal = False
|
||||
|
||||
|
||||
@ -277,6 +277,35 @@ class TransactionBuilder:
|
||||
Returns:
|
||||
Tuple of (success: bool, error_message: str). If success is True, error_message is empty.
|
||||
"""
|
||||
# Fetch current state once if needed by FA lock or pending-transaction check
|
||||
is_fa_pickup = (
|
||||
move.from_roster == RosterType.FREE_AGENCY
|
||||
and move.to_roster != RosterType.FREE_AGENCY
|
||||
)
|
||||
needs_current_state = is_fa_pickup or (
|
||||
check_pending_transactions and next_week is None
|
||||
)
|
||||
|
||||
current_week: Optional[int] = None
|
||||
if needs_current_state:
|
||||
try:
|
||||
current_state = await league_service.get_current_state()
|
||||
current_week = current_state.week if current_state else 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get current week: {e}")
|
||||
current_week = 1
|
||||
|
||||
# Block adding players FROM Free Agency after the FA lock deadline
|
||||
if is_fa_pickup and current_week is not None:
|
||||
config = get_config()
|
||||
if current_week >= config.fa_lock_week:
|
||||
error_msg = (
|
||||
f"Free agency is closed (week {current_week}, deadline was week {config.fa_lock_week}). "
|
||||
f"Cannot add {move.player.name} from FA."
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# Check if player is already in a move in this transaction builder
|
||||
existing_move = self.get_move_for_player(move.player.id)
|
||||
if existing_move:
|
||||
@ -299,23 +328,15 @@ class TransactionBuilder:
|
||||
return False, error_msg
|
||||
|
||||
# Check if player is already in another team's pending transaction for next week
|
||||
# This prevents duplicate claims that would need to be resolved at freeze time
|
||||
# Only applies to /dropadd (scheduled moves), not /ilmove (immediate moves)
|
||||
if check_pending_transactions:
|
||||
if next_week is None:
|
||||
try:
|
||||
current_state = await league_service.get_current_state()
|
||||
next_week = (current_state.week + 1) if current_state else 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not get current week for pending transaction check: {e}"
|
||||
)
|
||||
next_week = 1
|
||||
next_week = (current_week + 1) if current_week else 1
|
||||
|
||||
is_pending, claiming_team = (
|
||||
await transaction_service.is_player_in_pending_transaction(
|
||||
player_id=move.player.id, week=next_week, season=self.season
|
||||
)
|
||||
(
|
||||
is_pending,
|
||||
claiming_team,
|
||||
) = await transaction_service.is_player_in_pending_transaction(
|
||||
player_id=move.player.id, week=next_week, season=self.season
|
||||
)
|
||||
|
||||
if is_pending:
|
||||
|
||||
@ -95,7 +95,7 @@ class LiveScorebugTracker:
|
||||
# Don't return - still update voice channels
|
||||
else:
|
||||
# Get all published scorecards
|
||||
all_scorecards = self.scorecard_tracker.get_all_scorecards()
|
||||
all_scorecards = await self.scorecard_tracker.get_all_scorecards()
|
||||
|
||||
if not all_scorecards:
|
||||
# No active scorebugs - clear the channel and hide it
|
||||
@ -112,17 +112,16 @@ class LiveScorebugTracker:
|
||||
for text_channel_id, sheet_url in all_scorecards:
|
||||
try:
|
||||
scorebug_data = await self.scorebug_service.read_scorebug_data(
|
||||
sheet_url, full_length=False # Compact view for live channel
|
||||
sheet_url,
|
||||
full_length=False, # Compact view for live channel
|
||||
)
|
||||
|
||||
# Only include active (non-final) games
|
||||
if scorebug_data.is_active:
|
||||
# Get team data
|
||||
away_team = await team_service.get_team(
|
||||
scorebug_data.away_team_id
|
||||
)
|
||||
home_team = await team_service.get_team(
|
||||
scorebug_data.home_team_id
|
||||
away_team, home_team = await asyncio.gather(
|
||||
team_service.get_team(scorebug_data.away_team_id),
|
||||
team_service.get_team(scorebug_data.home_team_id),
|
||||
)
|
||||
|
||||
if away_team is None or home_team is None:
|
||||
@ -188,9 +187,8 @@ class LiveScorebugTracker:
|
||||
embeds: List of scorebug embeds
|
||||
"""
|
||||
try:
|
||||
# Clear old messages
|
||||
async for message in channel.history(limit=25):
|
||||
await message.delete()
|
||||
# Clear old messages using bulk delete
|
||||
await channel.purge(limit=25)
|
||||
|
||||
# Post new scorebugs (Discord allows up to 10 embeds per message)
|
||||
if len(embeds) <= 10:
|
||||
@ -216,9 +214,8 @@ class LiveScorebugTracker:
|
||||
channel: Discord text channel
|
||||
"""
|
||||
try:
|
||||
# Clear all messages
|
||||
async for message in channel.history(limit=25):
|
||||
await message.delete()
|
||||
# Clear all messages using bulk delete
|
||||
await channel.purge(limit=25)
|
||||
|
||||
self.logger.info("Cleared live-sba-scores channel (no active games)")
|
||||
|
||||
|
||||
282
tests/test_bot_maintenance_tree.py
Normal file
282
tests/test_bot_maintenance_tree.py
Normal file
@ -0,0 +1,282 @@
|
||||
"""
|
||||
Tests for MaintenanceAwareTree and the maintenance_mode attribute on SBABot.
|
||||
|
||||
What:
|
||||
Verifies that the CommandTree subclass correctly gates interactions behind
|
||||
bot.maintenance_mode. When maintenance mode is off every interaction is
|
||||
allowed through unconditionally. When maintenance mode is on, non-admin
|
||||
users receive an ephemeral error message and the check returns False, while
|
||||
administrators are always allowed through.
|
||||
|
||||
Why:
|
||||
The original code attempted to register an interaction_check via a decorator
|
||||
on self.tree inside setup_hook. That is not a valid pattern in discord.py —
|
||||
interaction_check is an overridable async method on CommandTree, not a
|
||||
decorator. The broken assignment caused a RuntimeWarning (unawaited
|
||||
coroutine) and silently made maintenance mode a no-op. These tests confirm
|
||||
the correct subclass-based implementation behaves as specified.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import discord
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bot(maintenance_mode: bool = False) -> MagicMock:
|
||||
"""Return a minimal mock bot with a maintenance_mode attribute."""
|
||||
bot = MagicMock()
|
||||
bot.maintenance_mode = maintenance_mode
|
||||
return bot
|
||||
|
||||
|
||||
def _make_interaction(is_admin: bool, bot: MagicMock) -> AsyncMock:
|
||||
"""
|
||||
Build a mock discord.Interaction.
|
||||
|
||||
The interaction's .client is set to *bot* so that MaintenanceAwareTree
|
||||
can read bot.maintenance_mode via interaction.client, mirroring how
|
||||
discord.py wires things at runtime.
|
||||
"""
|
||||
interaction = AsyncMock(spec=discord.Interaction)
|
||||
interaction.client = bot
|
||||
|
||||
# Mock the user as a guild Member so that guild_permissions is accessible.
|
||||
user = MagicMock(spec=discord.Member)
|
||||
user.guild_permissions = MagicMock()
|
||||
user.guild_permissions.administrator = is_admin
|
||||
interaction.user = user
|
||||
|
||||
# response.send_message must be awaitable.
|
||||
interaction.response = AsyncMock()
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
return interaction
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import the class under test after mocks are available.
|
||||
# We import here (not at module level) so that the conftest env-vars are set
|
||||
# before any discord_bot_v2 modules are touched.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_discord_app_commands(monkeypatch):
|
||||
"""
|
||||
Prevent MaintenanceAwareTree.__init__ from calling discord internals that
|
||||
need a real event loop / Discord connection. We test only the logic of
|
||||
interaction_check, so we stub out the parent __init__.
|
||||
"""
|
||||
# Nothing extra to patch for the interaction_check itself; the parent
|
||||
# CommandTree.__init__ is only called when constructing SBABot, which we
|
||||
# don't do in these unit tests.
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for MaintenanceAwareTree.interaction_check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaintenanceAwareTree:
|
||||
"""Unit tests for MaintenanceAwareTree.interaction_check."""
|
||||
|
||||
@pytest.fixture
|
||||
def tree_cls(self):
|
||||
"""Import and return the MaintenanceAwareTree class."""
|
||||
from bot import MaintenanceAwareTree
|
||||
|
||||
return MaintenanceAwareTree
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Maintenance OFF
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_off_allows_non_admin(self, tree_cls):
|
||||
"""
|
||||
When maintenance_mode is False, non-admin users are always allowed.
|
||||
The check must return True without sending any message.
|
||||
"""
|
||||
bot = _make_bot(maintenance_mode=False)
|
||||
interaction = _make_interaction(is_admin=False, bot=bot)
|
||||
|
||||
# Instantiate tree without calling parent __init__ by testing the method
|
||||
# directly on an unbound basis.
|
||||
result = await tree_cls.interaction_check(
|
||||
MagicMock(), # placeholder 'self' for the tree instance
|
||||
interaction,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
interaction.response.send_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_off_allows_admin(self, tree_cls):
|
||||
"""
|
||||
When maintenance_mode is False, admin users are also always allowed.
|
||||
"""
|
||||
bot = _make_bot(maintenance_mode=False)
|
||||
interaction = _make_interaction(is_admin=True, bot=bot)
|
||||
|
||||
result = await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
assert result is True
|
||||
interaction.response.send_message.assert_not_called()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Maintenance ON — non-admin
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_on_blocks_non_admin(self, tree_cls):
|
||||
"""
|
||||
When maintenance_mode is True, non-admin users must be blocked.
|
||||
The check must return False and send an ephemeral message.
|
||||
"""
|
||||
bot = _make_bot(maintenance_mode=True)
|
||||
interaction = _make_interaction(is_admin=False, bot=bot)
|
||||
|
||||
result = await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
assert result is False
|
||||
interaction.response.send_message.assert_called_once()
|
||||
|
||||
# Confirm the call used ephemeral=True
|
||||
_, kwargs = interaction.response.send_message.call_args
|
||||
assert kwargs.get("ephemeral") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_on_message_has_no_emoji(self, tree_cls):
|
||||
"""
|
||||
The maintenance block message must not contain emoji characters.
|
||||
The project style deliberately strips emoji from user-facing strings.
|
||||
"""
|
||||
import unicodedata
|
||||
|
||||
bot = _make_bot(maintenance_mode=True)
|
||||
interaction = _make_interaction(is_admin=False, bot=bot)
|
||||
|
||||
await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
args, _ = interaction.response.send_message.call_args
|
||||
message_text = args[0] if args else ""
|
||||
|
||||
for ch in message_text:
|
||||
category = unicodedata.category(ch)
|
||||
assert category != "So", (
|
||||
f"Unexpected emoji/symbol character {ch!r} (category {category!r}) "
|
||||
f"found in maintenance message: {message_text!r}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Maintenance ON — admin
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_on_allows_admin(self, tree_cls):
|
||||
"""
|
||||
When maintenance_mode is True, administrator users must still be
|
||||
allowed through. Admins should never be locked out of bot commands.
|
||||
"""
|
||||
bot = _make_bot(maintenance_mode=True)
|
||||
interaction = _make_interaction(is_admin=True, bot=bot)
|
||||
|
||||
result = await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
assert result is True
|
||||
interaction.response.send_message.assert_not_called()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Edge case: non-Member user during maintenance
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_on_blocks_non_member_user(self, tree_cls):
|
||||
"""
|
||||
When maintenance_mode is True and the user is not a guild Member
|
||||
(e.g. interaction from a DM context), the check must still block them
|
||||
because we cannot verify administrator status.
|
||||
"""
|
||||
bot = _make_bot(maintenance_mode=True)
|
||||
interaction = AsyncMock(spec=discord.Interaction)
|
||||
interaction.client = bot
|
||||
|
||||
# Simulate a non-Member user (e.g. discord.User from DM context)
|
||||
user = MagicMock(spec=discord.User)
|
||||
# discord.User has no guild_permissions attribute
|
||||
interaction.user = user
|
||||
interaction.response = AsyncMock()
|
||||
interaction.response.send_message = AsyncMock()
|
||||
|
||||
result = await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
assert result is False
|
||||
interaction.response.send_message.assert_called_once()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Missing attribute safety: bot without maintenance_mode attr
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_maintenance_mode_attr_defaults_to_allowed(self, tree_cls):
|
||||
"""
|
||||
If the bot object does not have a maintenance_mode attribute (e.g.
|
||||
during testing or unusual startup), getattr fallback must treat it as
|
||||
False and allow the interaction.
|
||||
"""
|
||||
bot = MagicMock()
|
||||
# Deliberately do NOT set bot.maintenance_mode
|
||||
del bot.maintenance_mode
|
||||
|
||||
interaction = _make_interaction(is_admin=False, bot=bot)
|
||||
|
||||
result = await tree_cls.interaction_check(MagicMock(), interaction)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for SBABot.maintenance_mode attribute
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSBABotMaintenanceModeAttribute:
|
||||
"""
|
||||
Confirms that SBABot.__init__ always sets maintenance_mode = False.
|
||||
|
||||
We avoid constructing a real SBABot (which requires a Discord event loop
|
||||
and valid token infrastructure) by patching the entire commands.Bot.__init__
|
||||
and then calling SBABot.__init__ directly on a bare instance so that only
|
||||
the SBABot-specific attribute assignments execute.
|
||||
"""
|
||||
|
||||
def test_maintenance_mode_default_is_false(self, monkeypatch):
|
||||
"""
|
||||
SBABot.__init__ must set self.maintenance_mode = False so that the
|
||||
MaintenanceAwareTree has the attribute available from the very first
|
||||
interaction, even before /admin-maintenance is ever called.
|
||||
|
||||
Strategy: patch commands.Bot.__init__ to be a no-op so super().__init__
|
||||
succeeds without a real Discord connection, then call SBABot.__init__
|
||||
and assert the attribute is present with the correct default value.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from discord.ext import commands
|
||||
from bot import SBABot
|
||||
|
||||
with patch.object(commands.Bot, "__init__", return_value=None):
|
||||
bot = SBABot.__new__(SBABot)
|
||||
SBABot.__init__(bot)
|
||||
|
||||
assert hasattr(
|
||||
bot, "maintenance_mode"
|
||||
), "SBABot must define self.maintenance_mode in __init__"
|
||||
assert (
|
||||
bot.maintenance_mode is False
|
||||
), "SBABot.maintenance_mode must default to False"
|
||||
143
tests/test_commands_trade_deadline.py
Normal file
143
tests/test_commands_trade_deadline.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""
|
||||
Tests for trade deadline enforcement in /trade commands.
|
||||
|
||||
Validates that trades are blocked after the trade deadline and allowed during/before it.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from tests.factories import CurrentFactory, TeamFactory
|
||||
|
||||
|
||||
class TestTradeInitiateDeadlineGuard:
|
||||
"""Test trade deadline enforcement in /trade initiate command."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interaction(self):
|
||||
"""Create mock Discord interaction with deferred response."""
|
||||
interaction = AsyncMock()
|
||||
interaction.user = MagicMock()
|
||||
interaction.user.id = 258104532423147520
|
||||
interaction.response = AsyncMock()
|
||||
interaction.followup = AsyncMock()
|
||||
interaction.guild = MagicMock()
|
||||
interaction.guild.id = 669356687294988350
|
||||
return interaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_initiate_blocked_past_deadline(self, mock_interaction):
|
||||
"""After the trade deadline, /trade initiate should return a friendly error."""
|
||||
user_team = TeamFactory.west_virginia()
|
||||
other_team = TeamFactory.new_york()
|
||||
past_deadline = CurrentFactory.create(week=15, trade_deadline=14)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"commands.transactions.trade.validate_user_has_team",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user_team,
|
||||
),
|
||||
patch(
|
||||
"commands.transactions.trade.get_team_by_abbrev_with_validation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=other_team,
|
||||
),
|
||||
patch("commands.transactions.trade.league_service") as mock_league,
|
||||
):
|
||||
mock_league.get_current_state = AsyncMock(return_value=past_deadline)
|
||||
|
||||
from commands.transactions.trade import TradeCommands
|
||||
|
||||
bot = MagicMock()
|
||||
cog = TradeCommands(bot)
|
||||
await cog.trade_initiate.callback(cog, mock_interaction, "NY")
|
||||
|
||||
mock_interaction.followup.send.assert_called_once()
|
||||
call_kwargs = mock_interaction.followup.send.call_args
|
||||
msg = (
|
||||
call_kwargs[0][0]
|
||||
if call_kwargs[0]
|
||||
else call_kwargs[1].get("content", "")
|
||||
)
|
||||
assert "trade deadline has passed" in msg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_initiate_allowed_at_deadline_week(self, mock_interaction):
|
||||
"""During the deadline week itself, /trade initiate should proceed."""
|
||||
user_team = TeamFactory.west_virginia()
|
||||
other_team = TeamFactory.new_york()
|
||||
at_deadline = CurrentFactory.create(week=14, trade_deadline=14)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"commands.transactions.trade.validate_user_has_team",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user_team,
|
||||
),
|
||||
patch(
|
||||
"commands.transactions.trade.get_team_by_abbrev_with_validation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=other_team,
|
||||
),
|
||||
patch("commands.transactions.trade.league_service") as mock_league,
|
||||
patch("commands.transactions.trade.clear_trade_builder") as mock_clear,
|
||||
patch("commands.transactions.trade.get_trade_builder") as mock_get_builder,
|
||||
patch(
|
||||
"commands.transactions.trade.create_trade_embed",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_league.get_current_state = AsyncMock(return_value=at_deadline)
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.add_team = AsyncMock(return_value=(True, None))
|
||||
mock_builder.trade_id = "test-123"
|
||||
mock_get_builder.return_value = mock_builder
|
||||
|
||||
from commands.transactions.trade import TradeCommands
|
||||
|
||||
bot = MagicMock()
|
||||
cog = TradeCommands(bot)
|
||||
cog.channel_manager = MagicMock()
|
||||
cog.channel_manager.create_trade_channel = AsyncMock(return_value=None)
|
||||
await cog.trade_initiate.callback(cog, mock_interaction, "NY")
|
||||
|
||||
# Should have proceeded past deadline check to clear/create trade
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_initiate_blocked_when_current_none(self, mock_interaction):
|
||||
"""When league state can't be fetched, /trade initiate should fail closed."""
|
||||
user_team = TeamFactory.west_virginia()
|
||||
other_team = TeamFactory.new_york()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"commands.transactions.trade.validate_user_has_team",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user_team,
|
||||
),
|
||||
patch(
|
||||
"commands.transactions.trade.get_team_by_abbrev_with_validation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=other_team,
|
||||
),
|
||||
patch("commands.transactions.trade.league_service") as mock_league,
|
||||
):
|
||||
mock_league.get_current_state = AsyncMock(return_value=None)
|
||||
|
||||
from commands.transactions.trade import TradeCommands
|
||||
|
||||
bot = MagicMock()
|
||||
cog = TradeCommands(bot)
|
||||
await cog.trade_initiate.callback(cog, mock_interaction, "NY")
|
||||
|
||||
mock_interaction.followup.send.assert_called_once()
|
||||
call_kwargs = mock_interaction.followup.send.call_args
|
||||
msg = (
|
||||
call_kwargs[0][0]
|
||||
if call_kwargs[0]
|
||||
else call_kwargs[1].get("content", "")
|
||||
)
|
||||
assert "could not retrieve league state" in msg.lower()
|
||||
@ -3,6 +3,7 @@ Tests for SBA data models
|
||||
|
||||
Validates model creation, validation, and business logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Team, Player, Current, DraftPick, DraftData, DraftList
|
||||
@ -10,94 +11,102 @@ from models import Team, Player, Current, DraftPick, DraftData, DraftList
|
||||
|
||||
class TestSBABaseModel:
|
||||
"""Test base model functionality."""
|
||||
|
||||
|
||||
def test_model_creation_with_api_data(self):
|
||||
"""Test creating models from API data."""
|
||||
team_data = {
|
||||
'id': 1,
|
||||
'abbrev': 'NYY',
|
||||
'sname': 'Yankees',
|
||||
'lname': 'New York Yankees',
|
||||
'season': 12
|
||||
"id": 1,
|
||||
"abbrev": "NYY",
|
||||
"sname": "Yankees",
|
||||
"lname": "New York Yankees",
|
||||
"season": 12,
|
||||
}
|
||||
|
||||
|
||||
team = Team.from_api_data(team_data)
|
||||
assert team.id == 1
|
||||
assert team.abbrev == 'NYY'
|
||||
assert team.lname == 'New York Yankees'
|
||||
|
||||
assert team.abbrev == "NYY"
|
||||
assert team.lname == "New York Yankees"
|
||||
|
||||
def test_to_dict_functionality(self):
|
||||
"""Test model to dictionary conversion."""
|
||||
team = Team(id=1, abbrev='LAA', sname='Angels', lname='Los Angeles Angels', season=12)
|
||||
|
||||
team = Team(
|
||||
id=1, abbrev="LAA", sname="Angels", lname="Los Angeles Angels", season=12
|
||||
)
|
||||
|
||||
team_dict = team.to_dict()
|
||||
assert 'abbrev' in team_dict
|
||||
assert team_dict['abbrev'] == 'LAA'
|
||||
assert team_dict['lname'] == 'Los Angeles Angels'
|
||||
|
||||
assert "abbrev" in team_dict
|
||||
assert team_dict["abbrev"] == "LAA"
|
||||
assert team_dict["lname"] == "Los Angeles Angels"
|
||||
|
||||
def test_model_repr(self):
|
||||
"""Test model string representation."""
|
||||
team = Team(id=2, abbrev='BOS', sname='Red Sox', lname='Boston Red Sox', season=12)
|
||||
team = Team(
|
||||
id=2, abbrev="BOS", sname="Red Sox", lname="Boston Red Sox", season=12
|
||||
)
|
||||
repr_str = repr(team)
|
||||
assert 'Team(' in repr_str
|
||||
assert 'abbrev=BOS' in repr_str
|
||||
assert "Team(" in repr_str
|
||||
assert "abbrev=BOS" in repr_str
|
||||
|
||||
|
||||
class TestTeamModel:
|
||||
"""Test Team model functionality."""
|
||||
|
||||
|
||||
def test_team_creation_minimal(self):
|
||||
"""Test team creation with minimal required fields."""
|
||||
team = Team(
|
||||
id=4,
|
||||
abbrev='HOU',
|
||||
sname='Astros',
|
||||
lname='Houston Astros',
|
||||
season=12
|
||||
id=4, abbrev="HOU", sname="Astros", lname="Houston Astros", season=12
|
||||
)
|
||||
|
||||
assert team.abbrev == 'HOU'
|
||||
assert team.sname == 'Astros'
|
||||
assert team.lname == 'Houston Astros'
|
||||
|
||||
assert team.abbrev == "HOU"
|
||||
assert team.sname == "Astros"
|
||||
assert team.lname == "Houston Astros"
|
||||
assert team.season == 12
|
||||
|
||||
|
||||
def test_team_creation_with_optional_fields(self):
|
||||
"""Test team creation with optional fields."""
|
||||
team = Team(
|
||||
id=5,
|
||||
abbrev='SF',
|
||||
sname='Giants',
|
||||
lname='San Francisco Giants',
|
||||
abbrev="SF",
|
||||
sname="Giants",
|
||||
lname="San Francisco Giants",
|
||||
season=12,
|
||||
gmid=100,
|
||||
division_id=1,
|
||||
stadium='Oracle Park',
|
||||
color='FF8C00'
|
||||
stadium="Oracle Park",
|
||||
color="FF8C00",
|
||||
)
|
||||
|
||||
|
||||
assert team.gmid == 100
|
||||
assert team.division_id == 1
|
||||
assert team.stadium == 'Oracle Park'
|
||||
assert team.color == 'FF8C00'
|
||||
|
||||
assert team.stadium == "Oracle Park"
|
||||
assert team.color == "FF8C00"
|
||||
|
||||
def test_team_str_representation(self):
|
||||
"""Test team string representation."""
|
||||
team = Team(id=3, abbrev='SD', sname='Padres', lname='San Diego Padres', season=12)
|
||||
assert str(team) == 'SD - San Diego Padres'
|
||||
team = Team(
|
||||
id=3, abbrev="SD", sname="Padres", lname="San Diego Padres", season=12
|
||||
)
|
||||
assert str(team) == "SD - San Diego Padres"
|
||||
|
||||
def test_team_roster_type_major_league(self):
|
||||
"""Test roster type detection for Major League teams."""
|
||||
from models.team import RosterType
|
||||
|
||||
# 3 chars or less → Major League
|
||||
team = Team(id=1, abbrev='NYY', sname='Yankees', lname='New York Yankees', season=12)
|
||||
team = Team(
|
||||
id=1, abbrev="NYY", sname="Yankees", lname="New York Yankees", season=12
|
||||
)
|
||||
assert team.roster_type() == RosterType.MAJOR_LEAGUE
|
||||
|
||||
team = Team(id=2, abbrev='BOS', sname='Red Sox', lname='Boston Red Sox', season=12)
|
||||
team = Team(
|
||||
id=2, abbrev="BOS", sname="Red Sox", lname="Boston Red Sox", season=12
|
||||
)
|
||||
assert team.roster_type() == RosterType.MAJOR_LEAGUE
|
||||
|
||||
# Even "BHM" (ends in M) should be Major League
|
||||
team = Team(id=3, abbrev='BHM', sname='Iron', lname='Birmingham Iron', season=12)
|
||||
team = Team(
|
||||
id=3, abbrev="BHM", sname="Iron", lname="Birmingham Iron", season=12
|
||||
)
|
||||
assert team.roster_type() == RosterType.MAJOR_LEAGUE
|
||||
|
||||
def test_team_roster_type_minor_league(self):
|
||||
@ -105,14 +114,28 @@ class TestTeamModel:
|
||||
from models.team import RosterType
|
||||
|
||||
# Standard Minor League: [Team] + "MIL"
|
||||
team = Team(id=4, abbrev='NYYMIL', sname='RailRiders', lname='Staten Island RailRiders', season=12)
|
||||
team = Team(
|
||||
id=4,
|
||||
abbrev="NYYMIL",
|
||||
sname="RailRiders",
|
||||
lname="Staten Island RailRiders",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
team = Team(id=5, abbrev='PORMIL', sname='Portland MiL', lname='Portland Minor League', season=12)
|
||||
team = Team(
|
||||
id=5,
|
||||
abbrev="PORMIL",
|
||||
sname="Portland MiL",
|
||||
lname="Portland Minor League",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
# Case insensitive
|
||||
team = Team(id=6, abbrev='LAAmil', sname='Bees', lname='Salt Lake Bees', season=12)
|
||||
team = Team(
|
||||
id=6, abbrev="LAAmil", sname="Bees", lname="Salt Lake Bees", season=12
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
def test_team_roster_type_injured_list(self):
|
||||
@ -120,14 +143,32 @@ class TestTeamModel:
|
||||
from models.team import RosterType
|
||||
|
||||
# Standard Injured List: [Team] + "IL"
|
||||
team = Team(id=7, abbrev='NYYIL', sname='Yankees IL', lname='New York Yankees IL', season=12)
|
||||
team = Team(
|
||||
id=7,
|
||||
abbrev="NYYIL",
|
||||
sname="Yankees IL",
|
||||
lname="New York Yankees IL",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
team = Team(id=8, abbrev='PORIL', sname='Loggers IL', lname='Portland Loggers IL', season=12)
|
||||
team = Team(
|
||||
id=8,
|
||||
abbrev="PORIL",
|
||||
sname="Loggers IL",
|
||||
lname="Portland Loggers IL",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
# Case insensitive
|
||||
team = Team(id=9, abbrev='LAAil', sname='Angels IL', lname='Los Angeles Angels IL', season=12)
|
||||
team = Team(
|
||||
id=9,
|
||||
abbrev="LAAil",
|
||||
sname="Angels IL",
|
||||
lname="Los Angeles Angels IL",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
def test_team_roster_type_edge_case_bhmil(self):
|
||||
@ -143,16 +184,30 @@ class TestTeamModel:
|
||||
from models.team import RosterType
|
||||
|
||||
# "BHMIL" = "BHM" + "IL" → sname contains "IL" → INJURED_LIST
|
||||
team = Team(id=10, abbrev='BHMIL', sname='Iron IL', lname='Birmingham Iron IL', season=12)
|
||||
team = Team(
|
||||
id=10,
|
||||
abbrev="BHMIL",
|
||||
sname="Iron IL",
|
||||
lname="Birmingham Iron IL",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
# Compare with a real Minor League team that has "Island" in name
|
||||
# "NYYMIL" = "NYY" + "MIL", even though sname has "Island" → MINOR_LEAGUE
|
||||
team = Team(id=11, abbrev='NYYMIL', sname='Staten Island RailRiders', lname='Staten Island RailRiders', season=12)
|
||||
team = Team(
|
||||
id=11,
|
||||
abbrev="NYYMIL",
|
||||
sname="Staten Island RailRiders",
|
||||
lname="Staten Island RailRiders",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
# Another IL edge case with sname containing "IL" at word boundary
|
||||
team = Team(id=12, abbrev='WVMIL', sname='WV IL', lname='West Virginia IL', season=12)
|
||||
team = Team(
|
||||
id=12, abbrev="WVMIL", sname="WV IL", lname="West Virginia IL", season=12
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
def test_team_roster_type_sname_disambiguation(self):
|
||||
@ -160,221 +215,231 @@ class TestTeamModel:
|
||||
from models.team import RosterType
|
||||
|
||||
# MiL team - sname does NOT have "IL" as a word
|
||||
team = Team(id=13, abbrev='WVMIL', sname='Miners', lname='West Virginia Miners', season=12)
|
||||
team = Team(
|
||||
id=13,
|
||||
abbrev="WVMIL",
|
||||
sname="Miners",
|
||||
lname="West Virginia Miners",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
# IL team - sname has "IL" at word boundary
|
||||
team = Team(id=14, abbrev='WVMIL', sname='Miners IL', lname='West Virginia Miners IL', season=12)
|
||||
team = Team(
|
||||
id=14,
|
||||
abbrev="WVMIL",
|
||||
sname="Miners IL",
|
||||
lname="West Virginia Miners IL",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.INJURED_LIST
|
||||
|
||||
# MiL team - sname has "IL" but only in "Island" (substring, not word boundary)
|
||||
team = Team(id=15, abbrev='CHIMIL', sname='Island Hoppers', lname='Chicago Island Hoppers', season=12)
|
||||
team = Team(
|
||||
id=15,
|
||||
abbrev="CHIMIL",
|
||||
sname="Island Hoppers",
|
||||
lname="Chicago Island Hoppers",
|
||||
season=12,
|
||||
)
|
||||
assert team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
|
||||
class TestPlayerModel:
|
||||
"""Test Player model functionality."""
|
||||
|
||||
|
||||
def test_player_creation(self):
|
||||
"""Test player creation with required fields."""
|
||||
player = Player(
|
||||
id=101,
|
||||
name='Mike Trout',
|
||||
name="Mike Trout",
|
||||
wara=8.5,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='trout.jpg',
|
||||
pos_1='CF'
|
||||
image="trout.jpg",
|
||||
pos_1="CF",
|
||||
)
|
||||
|
||||
assert player.name == 'Mike Trout'
|
||||
|
||||
assert player.name == "Mike Trout"
|
||||
assert player.wara == 8.5
|
||||
assert player.team_id == 1
|
||||
assert player.pos_1 == 'CF'
|
||||
|
||||
assert player.pos_1 == "CF"
|
||||
|
||||
def test_player_positions_property(self):
|
||||
"""Test player positions property."""
|
||||
player = Player(
|
||||
id=102,
|
||||
name='Shohei Ohtani',
|
||||
name="Shohei Ohtani",
|
||||
wara=9.0,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='ohtani.jpg',
|
||||
pos_1='SP',
|
||||
pos_2='DH',
|
||||
pos_3='RF'
|
||||
image="ohtani.jpg",
|
||||
pos_1="SP",
|
||||
pos_2="DH",
|
||||
pos_3="RF",
|
||||
)
|
||||
|
||||
|
||||
positions = player.positions
|
||||
assert len(positions) == 3
|
||||
assert 'SP' in positions
|
||||
assert 'DH' in positions
|
||||
assert 'RF' in positions
|
||||
|
||||
assert "SP" in positions
|
||||
assert "DH" in positions
|
||||
assert "RF" in positions
|
||||
|
||||
def test_player_primary_position(self):
|
||||
"""Test primary position property."""
|
||||
player = Player(
|
||||
id=103,
|
||||
name='Mookie Betts',
|
||||
name="Mookie Betts",
|
||||
wara=7.2,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='betts.jpg',
|
||||
pos_1='RF',
|
||||
pos_2='2B'
|
||||
image="betts.jpg",
|
||||
pos_1="RF",
|
||||
pos_2="2B",
|
||||
)
|
||||
|
||||
assert player.primary_position == 'RF'
|
||||
|
||||
|
||||
assert player.primary_position == "RF"
|
||||
|
||||
def test_player_is_pitcher(self):
|
||||
"""Test is_pitcher property."""
|
||||
pitcher = Player(
|
||||
id=104,
|
||||
name='Gerrit Cole',
|
||||
name="Gerrit Cole",
|
||||
wara=6.8,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='cole.jpg',
|
||||
pos_1='SP'
|
||||
image="cole.jpg",
|
||||
pos_1="SP",
|
||||
)
|
||||
|
||||
|
||||
position_player = Player(
|
||||
id=105,
|
||||
name='Aaron Judge',
|
||||
name="Aaron Judge",
|
||||
wara=8.1,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='judge.jpg',
|
||||
pos_1='RF'
|
||||
image="judge.jpg",
|
||||
pos_1="RF",
|
||||
)
|
||||
|
||||
|
||||
assert pitcher.is_pitcher is True
|
||||
assert position_player.is_pitcher is False
|
||||
|
||||
|
||||
def test_player_str_representation(self):
|
||||
"""Test player string representation."""
|
||||
player = Player(
|
||||
id=106,
|
||||
name='Ronald Acuna Jr.',
|
||||
name="Ronald Acuna Jr.",
|
||||
wara=8.8,
|
||||
season=12,
|
||||
team_id=1,
|
||||
image='acuna.jpg',
|
||||
pos_1='OF'
|
||||
image="acuna.jpg",
|
||||
pos_1="OF",
|
||||
)
|
||||
|
||||
assert str(player) == 'Ronald Acuna Jr. (OF)'
|
||||
|
||||
assert str(player) == "Ronald Acuna Jr. (OF)"
|
||||
|
||||
|
||||
class TestCurrentModel:
|
||||
"""Test Current league state model."""
|
||||
|
||||
|
||||
def test_current_default_values(self):
|
||||
"""Test current model with default values."""
|
||||
current = Current()
|
||||
|
||||
|
||||
assert current.week == 69
|
||||
assert current.season == 69
|
||||
assert current.freeze is True
|
||||
assert current.bet_week == 'sheets'
|
||||
|
||||
assert current.bet_week == "sheets"
|
||||
|
||||
def test_current_with_custom_values(self):
|
||||
"""Test current model with custom values."""
|
||||
current = Current(
|
||||
week=15,
|
||||
season=12,
|
||||
freeze=False,
|
||||
trade_deadline=14,
|
||||
playoffs_begin=19
|
||||
week=15, season=12, freeze=False, trade_deadline=14, playoffs_begin=19
|
||||
)
|
||||
|
||||
|
||||
assert current.week == 15
|
||||
assert current.season == 12
|
||||
assert current.freeze is False
|
||||
|
||||
|
||||
def test_current_properties(self):
|
||||
"""Test current model properties."""
|
||||
# Regular season
|
||||
current = Current(week=10, playoffs_begin=19)
|
||||
assert current.is_offseason is False
|
||||
assert current.is_playoffs is False
|
||||
|
||||
|
||||
# Playoffs
|
||||
current = Current(week=20, playoffs_begin=19)
|
||||
assert current.is_offseason is True
|
||||
assert current.is_playoffs is True
|
||||
|
||||
|
||||
# Pick trading
|
||||
current = Current(week=15, pick_trade_start=10, pick_trade_end=20)
|
||||
assert current.can_trade_picks is True
|
||||
|
||||
def test_is_past_trade_deadline(self):
|
||||
"""Test trade deadline property — trades allowed during deadline week, blocked after."""
|
||||
# Before deadline
|
||||
current = Current(week=10, trade_deadline=14)
|
||||
assert current.is_past_trade_deadline is False
|
||||
|
||||
# At deadline week (still allowed)
|
||||
current = Current(week=14, trade_deadline=14)
|
||||
assert current.is_past_trade_deadline is False
|
||||
|
||||
# One week past deadline
|
||||
current = Current(week=15, trade_deadline=14)
|
||||
assert current.is_past_trade_deadline is True
|
||||
|
||||
# Offseason bypasses deadline (week > 18)
|
||||
current = Current(week=20, trade_deadline=14)
|
||||
assert current.is_offseason is True
|
||||
assert current.is_past_trade_deadline is False
|
||||
|
||||
|
||||
class TestDraftPickModel:
|
||||
"""Test DraftPick model functionality."""
|
||||
|
||||
|
||||
def test_draft_pick_creation(self):
|
||||
"""Test draft pick creation."""
|
||||
pick = DraftPick(
|
||||
season=12,
|
||||
overall=1,
|
||||
round=1,
|
||||
origowner_id=1,
|
||||
owner_id=1
|
||||
)
|
||||
|
||||
pick = DraftPick(season=12, overall=1, round=1, origowner_id=1, owner_id=1)
|
||||
|
||||
assert pick.season == 12
|
||||
assert pick.overall == 1
|
||||
assert pick.origowner_id == 1
|
||||
assert pick.owner_id == 1
|
||||
|
||||
|
||||
def test_draft_pick_properties(self):
|
||||
"""Test draft pick properties."""
|
||||
# Not traded, not selected
|
||||
pick = DraftPick(
|
||||
season=12,
|
||||
overall=5,
|
||||
round=1,
|
||||
origowner_id=1,
|
||||
owner_id=1
|
||||
)
|
||||
|
||||
pick = DraftPick(season=12, overall=5, round=1, origowner_id=1, owner_id=1)
|
||||
|
||||
assert pick.is_traded is False
|
||||
assert pick.is_selected is False
|
||||
|
||||
|
||||
# Traded pick
|
||||
traded_pick = DraftPick(
|
||||
season=12,
|
||||
overall=10,
|
||||
round=1,
|
||||
origowner_id=1,
|
||||
owner_id=2
|
||||
season=12, overall=10, round=1, origowner_id=1, owner_id=2
|
||||
)
|
||||
|
||||
|
||||
assert traded_pick.is_traded is True
|
||||
|
||||
|
||||
# Selected pick
|
||||
selected_pick = DraftPick(
|
||||
season=12,
|
||||
overall=15,
|
||||
round=1,
|
||||
origowner_id=1,
|
||||
owner_id=1,
|
||||
player_id=100
|
||||
season=12, overall=15, round=1, origowner_id=1, owner_id=1, player_id=100
|
||||
)
|
||||
|
||||
|
||||
assert selected_pick.is_selected is True
|
||||
|
||||
|
||||
class TestDraftDataModel:
|
||||
"""Test DraftData model functionality."""
|
||||
|
||||
|
||||
def test_draft_data_creation(self):
|
||||
"""Test draft data creation."""
|
||||
draft_data = DraftData(
|
||||
result_channel=123456789,
|
||||
ping_channel=987654321,
|
||||
pick_minutes=10
|
||||
result_channel=123456789, ping_channel=987654321, pick_minutes=10
|
||||
)
|
||||
|
||||
assert draft_data.result_channel == 123456789
|
||||
@ -384,20 +449,12 @@ class TestDraftDataModel:
|
||||
def test_draft_data_properties(self):
|
||||
"""Test draft data properties."""
|
||||
# Inactive draft
|
||||
draft_data = DraftData(
|
||||
result_channel=123,
|
||||
ping_channel=456,
|
||||
timer=False
|
||||
)
|
||||
draft_data = DraftData(result_channel=123, ping_channel=456, timer=False)
|
||||
|
||||
assert draft_data.is_draft_active is False
|
||||
|
||||
# Active draft
|
||||
active_draft = DraftData(
|
||||
result_channel=123,
|
||||
ping_channel=456,
|
||||
timer=True
|
||||
)
|
||||
active_draft = DraftData(result_channel=123, ping_channel=456, timer=True)
|
||||
|
||||
assert active_draft.is_draft_active is True
|
||||
|
||||
@ -409,17 +466,13 @@ class TestDraftListModel:
|
||||
not just IDs. The API returns these objects populated.
|
||||
"""
|
||||
|
||||
def _create_mock_team(self, team_id: int = 1) -> 'Team':
|
||||
def _create_mock_team(self, team_id: int = 1) -> "Team":
|
||||
"""Create a mock team for testing."""
|
||||
return Team(
|
||||
id=team_id,
|
||||
abbrev="TST",
|
||||
sname="Test",
|
||||
lname="Test Team",
|
||||
season=12
|
||||
id=team_id, abbrev="TST", sname="Test", lname="Test Team", season=12
|
||||
)
|
||||
|
||||
def _create_mock_player(self, player_id: int = 100) -> 'Player':
|
||||
def _create_mock_player(self, player_id: int = 100) -> "Player":
|
||||
"""Create a mock player for testing."""
|
||||
return Player(
|
||||
id=player_id,
|
||||
@ -430,7 +483,7 @@ class TestDraftListModel:
|
||||
team_id=1,
|
||||
season=12,
|
||||
wara=2.5,
|
||||
image="https://example.com/test.jpg"
|
||||
image="https://example.com/test.jpg",
|
||||
)
|
||||
|
||||
def test_draft_list_creation(self):
|
||||
@ -438,12 +491,7 @@ class TestDraftListModel:
|
||||
mock_team = self._create_mock_team(team_id=1)
|
||||
mock_player = self._create_mock_player(player_id=100)
|
||||
|
||||
draft_entry = DraftList(
|
||||
season=12,
|
||||
team=mock_team,
|
||||
rank=1,
|
||||
player=mock_player
|
||||
)
|
||||
draft_entry = DraftList(season=12, team=mock_team, rank=1, player=mock_player)
|
||||
|
||||
assert draft_entry.season == 12
|
||||
assert draft_entry.team_id == 1
|
||||
@ -456,18 +504,10 @@ class TestDraftListModel:
|
||||
mock_player_top = self._create_mock_player(player_id=100)
|
||||
mock_player_lower = self._create_mock_player(player_id=200)
|
||||
|
||||
top_pick = DraftList(
|
||||
season=12,
|
||||
team=mock_team,
|
||||
rank=1,
|
||||
player=mock_player_top
|
||||
)
|
||||
top_pick = DraftList(season=12, team=mock_team, rank=1, player=mock_player_top)
|
||||
|
||||
lower_pick = DraftList(
|
||||
season=12,
|
||||
team=mock_team,
|
||||
rank=5,
|
||||
player=mock_player_lower
|
||||
season=12, team=mock_team, rank=5, player=mock_player_lower
|
||||
)
|
||||
|
||||
assert top_pick.is_top_ranked is True
|
||||
@ -486,32 +526,32 @@ class TestDraftListModel:
|
||||
"""
|
||||
# Simulate API response format - nested objects, NOT flat IDs
|
||||
api_response = {
|
||||
'id': 303,
|
||||
'season': 13,
|
||||
'rank': 1,
|
||||
'team': {
|
||||
'id': 548,
|
||||
'abbrev': 'WV',
|
||||
'sname': 'Black Bears',
|
||||
'lname': 'West Virginia Black Bears',
|
||||
'season': 13
|
||||
"id": 303,
|
||||
"season": 13,
|
||||
"rank": 1,
|
||||
"team": {
|
||||
"id": 548,
|
||||
"abbrev": "WV",
|
||||
"sname": "Black Bears",
|
||||
"lname": "West Virginia Black Bears",
|
||||
"season": 13,
|
||||
},
|
||||
'player': {
|
||||
'id': 12843,
|
||||
'name': 'George Springer',
|
||||
'wara': 0.31,
|
||||
'image': 'https://example.com/springer.png',
|
||||
'season': 13,
|
||||
'pos_1': 'CF',
|
||||
"player": {
|
||||
"id": 12843,
|
||||
"name": "George Springer",
|
||||
"wara": 0.31,
|
||||
"image": "https://example.com/springer.png",
|
||||
"season": 13,
|
||||
"pos_1": "CF",
|
||||
# Note: NO flat team_id here - it's nested in 'team' below
|
||||
'team': {
|
||||
'id': 547, # Free Agent team
|
||||
'abbrev': 'FA',
|
||||
'sname': 'Free Agents',
|
||||
'lname': 'Free Agents',
|
||||
'season': 13
|
||||
}
|
||||
}
|
||||
"team": {
|
||||
"id": 547, # Free Agent team
|
||||
"abbrev": "FA",
|
||||
"sname": "Free Agents",
|
||||
"lname": "Free Agents",
|
||||
"season": 13,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create DraftList using from_api_data (what BaseService calls)
|
||||
@ -522,87 +562,94 @@ class TestDraftListModel:
|
||||
assert draft_entry.player is not None
|
||||
|
||||
# CRITICAL: player.team_id must be extracted from nested team object
|
||||
assert draft_entry.player.team_id == 547, \
|
||||
assert draft_entry.player.team_id == 547, (
|
||||
f"player.team_id should be 547 (FA), got {draft_entry.player.team_id}"
|
||||
)
|
||||
|
||||
# Verify the nested team object is also populated
|
||||
assert draft_entry.player.team is not None
|
||||
assert draft_entry.player.team.id == 547
|
||||
assert draft_entry.player.team.abbrev == 'FA'
|
||||
assert draft_entry.player.team.abbrev == "FA"
|
||||
|
||||
# Verify DraftList's own team data
|
||||
assert draft_entry.team.id == 548
|
||||
assert draft_entry.team.abbrev == 'WV'
|
||||
assert draft_entry.team.abbrev == "WV"
|
||||
assert draft_entry.team_id == 548 # Property from nested team
|
||||
|
||||
|
||||
class TestModelCoverageExtras:
|
||||
"""Additional model coverage tests."""
|
||||
|
||||
|
||||
def test_base_model_from_api_data_validation(self):
|
||||
"""Test from_api_data with various edge cases."""
|
||||
from models.base import SBABaseModel
|
||||
|
||||
|
||||
# Test with empty data raises ValueError
|
||||
with pytest.raises(ValueError, match="Cannot create SBABaseModel from empty data"):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot create SBABaseModel from empty data"
|
||||
):
|
||||
SBABaseModel.from_api_data({})
|
||||
|
||||
|
||||
# Test with None raises ValueError
|
||||
with pytest.raises(ValueError, match="Cannot create SBABaseModel from empty data"):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot create SBABaseModel from empty data"
|
||||
):
|
||||
SBABaseModel.from_api_data(None)
|
||||
|
||||
|
||||
def test_player_positions_comprehensive(self):
|
||||
"""Test player positions property with all position variations."""
|
||||
player_data = {
|
||||
'id': 201,
|
||||
'name': 'Multi-Position Player',
|
||||
'wara': 3.0,
|
||||
'season': 12,
|
||||
'team_id': 5,
|
||||
'image': 'https://example.com/player.jpg',
|
||||
'pos_1': 'C',
|
||||
'pos_2': '1B',
|
||||
'pos_3': '3B',
|
||||
'pos_4': None, # Test None handling
|
||||
'pos_5': 'DH',
|
||||
'pos_6': 'OF',
|
||||
'pos_7': None, # Another None
|
||||
'pos_8': 'SS'
|
||||
"id": 201,
|
||||
"name": "Multi-Position Player",
|
||||
"wara": 3.0,
|
||||
"season": 12,
|
||||
"team_id": 5,
|
||||
"image": "https://example.com/player.jpg",
|
||||
"pos_1": "C",
|
||||
"pos_2": "1B",
|
||||
"pos_3": "3B",
|
||||
"pos_4": None, # Test None handling
|
||||
"pos_5": "DH",
|
||||
"pos_6": "OF",
|
||||
"pos_7": None, # Another None
|
||||
"pos_8": "SS",
|
||||
}
|
||||
player = Player.from_api_data(player_data)
|
||||
|
||||
|
||||
positions = player.positions
|
||||
assert 'C' in positions
|
||||
assert '1B' in positions
|
||||
assert '3B' in positions
|
||||
assert 'DH' in positions
|
||||
assert 'OF' in positions
|
||||
assert 'SS' in positions
|
||||
assert "C" in positions
|
||||
assert "1B" in positions
|
||||
assert "3B" in positions
|
||||
assert "DH" in positions
|
||||
assert "OF" in positions
|
||||
assert "SS" in positions
|
||||
assert len(positions) == 6 # Should exclude None values
|
||||
assert None not in positions
|
||||
|
||||
|
||||
def test_player_is_pitcher_variations(self):
|
||||
"""Test is_pitcher property with different positions."""
|
||||
test_cases = [
|
||||
('SP', True), # Starting pitcher
|
||||
('RP', True), # Relief pitcher
|
||||
('P', True), # Generic pitcher
|
||||
('C', False), # Catcher
|
||||
('1B', False), # First base
|
||||
('OF', False), # Outfield
|
||||
('DH', False), # Designated hitter
|
||||
("SP", True), # Starting pitcher
|
||||
("RP", True), # Relief pitcher
|
||||
("P", True), # Generic pitcher
|
||||
("C", False), # Catcher
|
||||
("1B", False), # First base
|
||||
("OF", False), # Outfield
|
||||
("DH", False), # Designated hitter
|
||||
]
|
||||
|
||||
|
||||
for position, expected in test_cases:
|
||||
player_data = {
|
||||
'id': 300 + ord(position[0]), # Generate unique IDs based on position
|
||||
'name': f'Test {position}',
|
||||
'wara': 2.0,
|
||||
'season': 12,
|
||||
'team_id': 5,
|
||||
'image': 'https://example.com/player.jpg',
|
||||
'pos_1': position,
|
||||
"id": 300 + ord(position[0]), # Generate unique IDs based on position
|
||||
"name": f"Test {position}",
|
||||
"wara": 2.0,
|
||||
"season": 12,
|
||||
"team_id": 5,
|
||||
"image": "https://example.com/player.jpg",
|
||||
"pos_1": position,
|
||||
}
|
||||
player = Player.from_api_data(player_data)
|
||||
assert player.is_pitcher == expected, f"Position {position} should return {expected}"
|
||||
assert player.primary_position == position
|
||||
assert player.is_pitcher == expected, (
|
||||
f"Position {position} should return {expected}"
|
||||
)
|
||||
assert player.primary_position == position
|
||||
|
||||
@ -24,7 +24,8 @@ from utils.scorebug_helpers import create_scorebug_embed, create_team_progress_b
|
||||
class TestScorecardTrackerFreshReads:
|
||||
"""Tests that ScorecardTracker reads fresh data from disk (fix for #40)."""
|
||||
|
||||
def test_get_all_scorecards_reads_fresh_data(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_scorecards_reads_fresh_data(self, tmp_path):
|
||||
"""get_all_scorecards() should pick up scorecards written by another process.
|
||||
|
||||
Simulates the background task having a stale tracker instance while
|
||||
@ -34,7 +35,7 @@ class TestScorecardTrackerFreshReads:
|
||||
data_file.write_text(json.dumps({"scorecards": {}}))
|
||||
|
||||
tracker = ScorecardTracker(data_file=str(data_file))
|
||||
assert tracker.get_all_scorecards() == []
|
||||
assert await tracker.get_all_scorecards() == []
|
||||
|
||||
# Another process writes a scorecard to the same file
|
||||
new_data = {
|
||||
@ -51,17 +52,18 @@ class TestScorecardTrackerFreshReads:
|
||||
data_file.write_text(json.dumps(new_data))
|
||||
|
||||
# Should see the new scorecard without restart
|
||||
result = tracker.get_all_scorecards()
|
||||
result = await tracker.get_all_scorecards()
|
||||
assert len(result) == 1
|
||||
assert result[0] == (111, "https://docs.google.com/spreadsheets/d/abc123")
|
||||
|
||||
def test_get_scorecard_reads_fresh_data(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_scorecard_reads_fresh_data(self, tmp_path):
|
||||
"""get_scorecard() should pick up a scorecard written by another process."""
|
||||
data_file = tmp_path / "scorecards.json"
|
||||
data_file.write_text(json.dumps({"scorecards": {}}))
|
||||
|
||||
tracker = ScorecardTracker(data_file=str(data_file))
|
||||
assert tracker.get_scorecard(222) is None
|
||||
assert await tracker.get_scorecard(222) is None
|
||||
|
||||
# Another process writes a scorecard
|
||||
new_data = {
|
||||
@ -79,7 +81,7 @@ class TestScorecardTrackerFreshReads:
|
||||
|
||||
# Should see the new scorecard
|
||||
assert (
|
||||
tracker.get_scorecard(222)
|
||||
await tracker.get_scorecard(222)
|
||||
== "https://docs.google.com/spreadsheets/d/xyz789"
|
||||
)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ Tests for Help Commands Service in Discord Bot v2.0
|
||||
|
||||
Comprehensive tests for help commands CRUD operations and business logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
@ -10,13 +11,13 @@ from unittest.mock import AsyncMock
|
||||
from services.help_commands_service import (
|
||||
HelpCommandsService,
|
||||
HelpCommandNotFoundError,
|
||||
HelpCommandExistsError
|
||||
HelpCommandExistsError,
|
||||
)
|
||||
from models.help_command import (
|
||||
HelpCommand,
|
||||
HelpCommandSearchFilters,
|
||||
HelpCommandSearchResult,
|
||||
HelpCommandStats
|
||||
HelpCommandStats,
|
||||
)
|
||||
|
||||
|
||||
@ -26,17 +27,17 @@ def sample_help_command() -> HelpCommand:
|
||||
now = datetime.now(timezone.utc)
|
||||
return HelpCommand(
|
||||
id=1,
|
||||
name='trading-rules',
|
||||
title='Trading Rules & Guidelines',
|
||||
content='Complete trading rules for the league...',
|
||||
category='rules',
|
||||
created_by_discord_id='123456789',
|
||||
name="trading-rules",
|
||||
title="Trading Rules & Guidelines",
|
||||
content="Complete trading rules for the league...",
|
||||
category="rules",
|
||||
created_by_discord_id="123456789",
|
||||
created_at=now,
|
||||
updated_at=None,
|
||||
last_modified_by=None,
|
||||
is_active=True,
|
||||
view_count=100,
|
||||
display_order=10
|
||||
display_order=10,
|
||||
)
|
||||
|
||||
|
||||
@ -64,6 +65,7 @@ class TestHelpCommandsServiceInit:
|
||||
|
||||
# Multiple imports should return the same instance
|
||||
from services.help_commands_service import help_commands_service as service2
|
||||
|
||||
assert help_commands_service is service2
|
||||
|
||||
def test_service_has_required_methods(self):
|
||||
@ -71,22 +73,22 @@ class TestHelpCommandsServiceInit:
|
||||
from services.help_commands_service import help_commands_service
|
||||
|
||||
# Core CRUD operations
|
||||
assert hasattr(help_commands_service, 'create_help')
|
||||
assert hasattr(help_commands_service, 'get_help_by_name')
|
||||
assert hasattr(help_commands_service, 'update_help')
|
||||
assert hasattr(help_commands_service, 'delete_help')
|
||||
assert hasattr(help_commands_service, 'restore_help')
|
||||
assert hasattr(help_commands_service, "create_help")
|
||||
assert hasattr(help_commands_service, "get_help_by_name")
|
||||
assert hasattr(help_commands_service, "update_help")
|
||||
assert hasattr(help_commands_service, "delete_help")
|
||||
assert hasattr(help_commands_service, "restore_help")
|
||||
|
||||
# Search and listing
|
||||
assert hasattr(help_commands_service, 'search_help_commands')
|
||||
assert hasattr(help_commands_service, 'get_all_help_topics')
|
||||
assert hasattr(help_commands_service, 'get_help_names_for_autocomplete')
|
||||
assert hasattr(help_commands_service, "search_help_commands")
|
||||
assert hasattr(help_commands_service, "get_all_help_topics")
|
||||
assert hasattr(help_commands_service, "get_help_names_for_autocomplete")
|
||||
|
||||
# View tracking
|
||||
assert hasattr(help_commands_service, 'increment_view_count')
|
||||
assert hasattr(help_commands_service, "increment_view_count")
|
||||
|
||||
# Statistics
|
||||
assert hasattr(help_commands_service, 'get_statistics')
|
||||
assert hasattr(help_commands_service, "get_statistics")
|
||||
|
||||
|
||||
class TestHelpCommandsServiceCRUD:
|
||||
@ -118,7 +120,7 @@ class TestHelpCommandsServiceCRUD:
|
||||
last_modified_by=None,
|
||||
is_active=True,
|
||||
view_count=0,
|
||||
display_order=data.get("display_order", 0)
|
||||
display_order=data.get("display_order", 0),
|
||||
)
|
||||
return created_help
|
||||
|
||||
@ -130,8 +132,8 @@ class TestHelpCommandsServiceCRUD:
|
||||
name="test-topic",
|
||||
title="Test Topic",
|
||||
content="This is test content for the help topic.",
|
||||
creator_discord_id='123456789',
|
||||
category="info"
|
||||
creator_discord_id="123456789",
|
||||
category="info",
|
||||
)
|
||||
|
||||
assert isinstance(result, HelpCommand)
|
||||
@ -141,39 +143,48 @@ class TestHelpCommandsServiceCRUD:
|
||||
assert result.view_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_help_already_exists(self, help_commands_service_instance, sample_help_command):
|
||||
async def test_create_help_already_exists(
|
||||
self, help_commands_service_instance, sample_help_command
|
||||
):
|
||||
"""Test help command creation when topic already exists."""
|
||||
|
||||
# Mock topic already exists
|
||||
async def mock_get_help_by_name(*args, **kwargs):
|
||||
return sample_help_command
|
||||
|
||||
help_commands_service_instance.get_help_by_name = mock_get_help_by_name
|
||||
|
||||
with pytest.raises(HelpCommandExistsError, match="Help topic 'trading-rules' already exists"):
|
||||
with pytest.raises(
|
||||
HelpCommandExistsError, match="Help topic 'trading-rules' already exists"
|
||||
):
|
||||
await help_commands_service_instance.create_help(
|
||||
name="trading-rules",
|
||||
title="Trading Rules",
|
||||
content="Rules content",
|
||||
creator_discord_id='123456789'
|
||||
creator_discord_id="123456789",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_help_by_name_success(self, help_commands_service_instance, sample_help_command):
|
||||
async def test_get_help_by_name_success(
|
||||
self, help_commands_service_instance, sample_help_command
|
||||
):
|
||||
"""Test successful help command retrieval."""
|
||||
# Mock the API client to return proper data structure
|
||||
help_data = {
|
||||
'id': sample_help_command.id,
|
||||
'name': sample_help_command.name,
|
||||
'title': sample_help_command.title,
|
||||
'content': sample_help_command.content,
|
||||
'category': sample_help_command.category,
|
||||
'created_by_discord_id': sample_help_command.created_by_discord_id,
|
||||
'created_at': sample_help_command.created_at.isoformat(),
|
||||
'updated_at': sample_help_command.updated_at.isoformat() if sample_help_command.updated_at else None,
|
||||
'last_modified_by': sample_help_command.last_modified_by,
|
||||
'is_active': sample_help_command.is_active,
|
||||
'view_count': sample_help_command.view_count,
|
||||
'display_order': sample_help_command.display_order
|
||||
"id": sample_help_command.id,
|
||||
"name": sample_help_command.name,
|
||||
"title": sample_help_command.title,
|
||||
"content": sample_help_command.content,
|
||||
"category": sample_help_command.category,
|
||||
"created_by_discord_id": sample_help_command.created_by_discord_id,
|
||||
"created_at": sample_help_command.created_at.isoformat(),
|
||||
"updated_at": sample_help_command.updated_at.isoformat()
|
||||
if sample_help_command.updated_at
|
||||
else None,
|
||||
"last_modified_by": sample_help_command.last_modified_by,
|
||||
"is_active": sample_help_command.is_active,
|
||||
"view_count": sample_help_command.view_count,
|
||||
"display_order": sample_help_command.display_order,
|
||||
}
|
||||
|
||||
help_commands_service_instance._client.get.return_value = help_data
|
||||
@ -191,66 +202,61 @@ class TestHelpCommandsServiceCRUD:
|
||||
# Mock the API client to return None (not found)
|
||||
help_commands_service_instance._client.get.return_value = None
|
||||
|
||||
with pytest.raises(HelpCommandNotFoundError, match="Help topic 'nonexistent' not found"):
|
||||
with pytest.raises(
|
||||
HelpCommandNotFoundError, match="Help topic 'nonexistent' not found"
|
||||
):
|
||||
await help_commands_service_instance.get_help_by_name("nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_help_success(self, help_commands_service_instance, sample_help_command):
|
||||
async def test_update_help_success(
|
||||
self, help_commands_service_instance, sample_help_command
|
||||
):
|
||||
"""Test successful help command update."""
|
||||
|
||||
# Mock getting the existing help command
|
||||
async def mock_get_help_by_name(name, include_inactive=False):
|
||||
if name == "trading-rules":
|
||||
return sample_help_command
|
||||
raise HelpCommandNotFoundError(f"Help topic '{name}' not found")
|
||||
|
||||
# Mock the API update call
|
||||
# Mock the API update call returning the updated help command data directly
|
||||
updated_data = {
|
||||
"id": sample_help_command.id,
|
||||
"name": sample_help_command.name,
|
||||
"title": "Updated Trading Rules",
|
||||
"content": "Updated content",
|
||||
"category": sample_help_command.category,
|
||||
"created_by_discord_id": sample_help_command.created_by_discord_id,
|
||||
"created_at": sample_help_command.created_at.isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_modified_by": "987654321",
|
||||
"is_active": sample_help_command.is_active,
|
||||
"view_count": sample_help_command.view_count,
|
||||
"display_order": sample_help_command.display_order,
|
||||
}
|
||||
|
||||
async def mock_put(*args, **kwargs):
|
||||
return True
|
||||
return updated_data
|
||||
|
||||
help_commands_service_instance.get_help_by_name = mock_get_help_by_name
|
||||
help_commands_service_instance._client.put = mock_put
|
||||
|
||||
# Update should call get_help_by_name again at the end, so mock it to return updated version
|
||||
updated_help = HelpCommand(
|
||||
id=sample_help_command.id,
|
||||
name=sample_help_command.name,
|
||||
title="Updated Trading Rules",
|
||||
content="Updated content",
|
||||
category=sample_help_command.category,
|
||||
created_by_discord_id=sample_help_command.created_by_discord_id,
|
||||
created_at=sample_help_command.created_at,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
last_modified_by='987654321',
|
||||
is_active=sample_help_command.is_active,
|
||||
view_count=sample_help_command.view_count,
|
||||
display_order=sample_help_command.display_order
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_with_counter(name, include_inactive=False):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return sample_help_command
|
||||
else:
|
||||
return updated_help
|
||||
|
||||
help_commands_service_instance.get_help_by_name = mock_get_with_counter
|
||||
|
||||
result = await help_commands_service_instance.update_help(
|
||||
name="trading-rules",
|
||||
new_title="Updated Trading Rules",
|
||||
new_content="Updated content",
|
||||
updater_discord_id='987654321'
|
||||
updater_discord_id="987654321",
|
||||
)
|
||||
|
||||
assert isinstance(result, HelpCommand)
|
||||
assert result.title == "Updated Trading Rules"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_help_success(self, help_commands_service_instance, sample_help_command):
|
||||
async def test_delete_help_success(
|
||||
self, help_commands_service_instance, sample_help_command
|
||||
):
|
||||
"""Test successful help command deletion (soft delete)."""
|
||||
|
||||
# Mock getting the help command
|
||||
async def mock_get_help_by_name(name, include_inactive=False):
|
||||
return sample_help_command
|
||||
@ -272,12 +278,12 @@ class TestHelpCommandsServiceCRUD:
|
||||
# Mock getting a deleted help command
|
||||
deleted_help = HelpCommand(
|
||||
id=1,
|
||||
name='deleted-topic',
|
||||
title='Deleted Topic',
|
||||
content='Content',
|
||||
created_by_discord_id='123456789',
|
||||
name="deleted-topic",
|
||||
title="Deleted Topic",
|
||||
content="Content",
|
||||
created_by_discord_id="123456789",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
is_active=False
|
||||
is_active=False,
|
||||
)
|
||||
|
||||
async def mock_get_help_by_name(name, include_inactive=False):
|
||||
@ -285,15 +291,15 @@ class TestHelpCommandsServiceCRUD:
|
||||
|
||||
# Mock the API restore call
|
||||
restored_data = {
|
||||
'id': deleted_help.id,
|
||||
'name': deleted_help.name,
|
||||
'title': deleted_help.title,
|
||||
'content': deleted_help.content,
|
||||
'created_by_discord_id': deleted_help.created_by_discord_id,
|
||||
'created_at': deleted_help.created_at.isoformat(),
|
||||
'is_active': True,
|
||||
'view_count': 0,
|
||||
'display_order': 0
|
||||
"id": deleted_help.id,
|
||||
"name": deleted_help.name,
|
||||
"title": deleted_help.title,
|
||||
"content": deleted_help.content,
|
||||
"created_by_discord_id": deleted_help.created_by_discord_id,
|
||||
"created_at": deleted_help.created_at.isoformat(),
|
||||
"is_active": True,
|
||||
"view_count": 0,
|
||||
"display_order": 0,
|
||||
}
|
||||
|
||||
help_commands_service_instance.get_help_by_name = mock_get_help_by_name
|
||||
@ -312,33 +318,30 @@ class TestHelpCommandsServiceSearch:
|
||||
async def test_search_help_commands(self, help_commands_service_instance):
|
||||
"""Test searching for help commands with filters."""
|
||||
filters = HelpCommandSearchFilters(
|
||||
name_contains='trading',
|
||||
category='rules',
|
||||
page=1,
|
||||
page_size=10
|
||||
name_contains="trading", category="rules", page=1, page_size=10
|
||||
)
|
||||
|
||||
# Mock API response
|
||||
api_response = {
|
||||
'help_commands': [
|
||||
"help_commands": [
|
||||
{
|
||||
'id': 1,
|
||||
'name': 'trading-rules',
|
||||
'title': 'Trading Rules',
|
||||
'content': 'Content',
|
||||
'category': 'rules',
|
||||
'created_by_discord_id': '123',
|
||||
'created_at': datetime.now(timezone.utc).isoformat(),
|
||||
'is_active': True,
|
||||
'view_count': 100,
|
||||
'display_order': 0
|
||||
"id": 1,
|
||||
"name": "trading-rules",
|
||||
"title": "Trading Rules",
|
||||
"content": "Content",
|
||||
"category": "rules",
|
||||
"created_by_discord_id": "123",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"is_active": True,
|
||||
"view_count": 100,
|
||||
"display_order": 0,
|
||||
}
|
||||
],
|
||||
'total_count': 1,
|
||||
'page': 1,
|
||||
'page_size': 10,
|
||||
'total_pages': 1,
|
||||
'has_more': False
|
||||
"total_count": 1,
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
"total_pages": 1,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
help_commands_service_instance._client.get.return_value = api_response
|
||||
@ -348,33 +351,33 @@ class TestHelpCommandsServiceSearch:
|
||||
assert isinstance(result, HelpCommandSearchResult)
|
||||
assert len(result.help_commands) == 1
|
||||
assert result.total_count == 1
|
||||
assert result.help_commands[0].name == 'trading-rules'
|
||||
assert result.help_commands[0].name == "trading-rules"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_help_topics(self, help_commands_service_instance):
|
||||
"""Test getting all help topics."""
|
||||
# Mock API response
|
||||
api_response = {
|
||||
'help_commands': [
|
||||
"help_commands": [
|
||||
{
|
||||
'id': i,
|
||||
'name': f'topic-{i}',
|
||||
'title': f'Topic {i}',
|
||||
'content': f'Content {i}',
|
||||
'category': 'rules' if i % 2 == 0 else 'guides',
|
||||
'created_by_discord_id': '123',
|
||||
'created_at': datetime.now(timezone.utc).isoformat(),
|
||||
'is_active': True,
|
||||
'view_count': i * 10,
|
||||
'display_order': i
|
||||
"id": i,
|
||||
"name": f"topic-{i}",
|
||||
"title": f"Topic {i}",
|
||||
"content": f"Content {i}",
|
||||
"category": "rules" if i % 2 == 0 else "guides",
|
||||
"created_by_discord_id": "123",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"is_active": True,
|
||||
"view_count": i * 10,
|
||||
"display_order": i,
|
||||
}
|
||||
for i in range(1, 6)
|
||||
],
|
||||
'total_count': 5,
|
||||
'page': 1,
|
||||
'page_size': 100,
|
||||
'total_pages': 1,
|
||||
'has_more': False
|
||||
"total_count": 5,
|
||||
"page": 1,
|
||||
"page_size": 100,
|
||||
"total_pages": 1,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
help_commands_service_instance._client.get.return_value = api_response
|
||||
@ -386,42 +389,45 @@ class TestHelpCommandsServiceSearch:
|
||||
assert all(isinstance(cmd, HelpCommand) for cmd in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_help_names_for_autocomplete(self, help_commands_service_instance):
|
||||
async def test_get_help_names_for_autocomplete(
|
||||
self, help_commands_service_instance
|
||||
):
|
||||
"""Test getting help names for autocomplete."""
|
||||
# Mock API response
|
||||
api_response = {
|
||||
'results': [
|
||||
"results": [
|
||||
{
|
||||
'name': 'trading-rules',
|
||||
'title': 'Trading Rules',
|
||||
'category': 'rules'
|
||||
"name": "trading-rules",
|
||||
"title": "Trading Rules",
|
||||
"category": "rules",
|
||||
},
|
||||
{
|
||||
'name': 'trading-deadline',
|
||||
'title': 'Trading Deadline',
|
||||
'category': 'info'
|
||||
}
|
||||
"name": "trading-deadline",
|
||||
"title": "Trading Deadline",
|
||||
"category": "info",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
help_commands_service_instance._client.get.return_value = api_response
|
||||
|
||||
result = await help_commands_service_instance.get_help_names_for_autocomplete(
|
||||
partial_name='trading',
|
||||
limit=25
|
||||
partial_name="trading", limit=25
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert 'trading-rules' in result
|
||||
assert 'trading-deadline' in result
|
||||
assert "trading-rules" in result
|
||||
assert "trading-deadline" in result
|
||||
|
||||
|
||||
class TestHelpCommandsServiceViewTracking:
|
||||
"""Test view count tracking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_view_count(self, help_commands_service_instance, sample_help_command):
|
||||
async def test_increment_view_count(
|
||||
self, help_commands_service_instance, sample_help_command
|
||||
):
|
||||
"""Test incrementing view count."""
|
||||
# Mock the API patch call
|
||||
help_commands_service_instance._client.patch = AsyncMock()
|
||||
@ -437,7 +443,7 @@ class TestHelpCommandsServiceViewTracking:
|
||||
created_at=sample_help_command.created_at,
|
||||
is_active=sample_help_command.is_active,
|
||||
view_count=sample_help_command.view_count + 1,
|
||||
display_order=sample_help_command.display_order
|
||||
display_order=sample_help_command.display_order,
|
||||
)
|
||||
|
||||
async def mock_get_help_by_name(name, include_inactive=False):
|
||||
@ -445,7 +451,9 @@ class TestHelpCommandsServiceViewTracking:
|
||||
|
||||
help_commands_service_instance.get_help_by_name = mock_get_help_by_name
|
||||
|
||||
result = await help_commands_service_instance.increment_view_count("trading-rules")
|
||||
result = await help_commands_service_instance.increment_view_count(
|
||||
"trading-rules"
|
||||
)
|
||||
|
||||
assert isinstance(result, HelpCommand)
|
||||
assert result.view_count == 101
|
||||
@ -459,21 +467,21 @@ class TestHelpCommandsServiceStatistics:
|
||||
"""Test getting help command statistics."""
|
||||
# Mock API response
|
||||
api_response = {
|
||||
'total_commands': 50,
|
||||
'active_commands': 45,
|
||||
'total_views': 5000,
|
||||
'most_viewed_command': {
|
||||
'id': 1,
|
||||
'name': 'popular-topic',
|
||||
'title': 'Popular Topic',
|
||||
'content': 'Content',
|
||||
'created_by_discord_id': '123',
|
||||
'created_at': datetime.now(timezone.utc).isoformat(),
|
||||
'is_active': True,
|
||||
'view_count': 500,
|
||||
'display_order': 0
|
||||
"total_commands": 50,
|
||||
"active_commands": 45,
|
||||
"total_views": 5000,
|
||||
"most_viewed_command": {
|
||||
"id": 1,
|
||||
"name": "popular-topic",
|
||||
"title": "Popular Topic",
|
||||
"content": "Content",
|
||||
"created_by_discord_id": "123",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"is_active": True,
|
||||
"view_count": 500,
|
||||
"display_order": 0,
|
||||
},
|
||||
'recent_commands_count': 5
|
||||
"recent_commands_count": 5,
|
||||
}
|
||||
|
||||
help_commands_service_instance._client.get.return_value = api_response
|
||||
@ -485,7 +493,7 @@ class TestHelpCommandsServiceStatistics:
|
||||
assert result.active_commands == 45
|
||||
assert result.total_views == 5000
|
||||
assert result.most_viewed_command is not None
|
||||
assert result.most_viewed_command.name == 'popular-topic'
|
||||
assert result.most_viewed_command.name == "popular-topic"
|
||||
assert result.recent_commands_count == 5
|
||||
|
||||
|
||||
@ -498,7 +506,9 @@ class TestHelpCommandsServiceErrorHandling:
|
||||
from exceptions import APIException, BotException
|
||||
|
||||
# Mock the API client to raise an APIException
|
||||
help_commands_service_instance._client.get.side_effect = APIException("Connection error")
|
||||
help_commands_service_instance._client.get.side_effect = APIException(
|
||||
"Connection error"
|
||||
)
|
||||
|
||||
with pytest.raises(BotException, match="Failed to retrieve help topic 'test'"):
|
||||
await help_commands_service_instance.get_help_by_name("test")
|
||||
|
||||
284
tests/test_services_schedule.py
Normal file
284
tests/test_services_schedule.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""
|
||||
Tests for schedule service functionality.
|
||||
|
||||
Covers get_week_schedule, get_team_schedule, get_recent_games,
|
||||
get_upcoming_games, and group_games_by_series — verifying the
|
||||
asyncio.gather parallelization and post-fetch filtering logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from services.schedule_service import ScheduleService
|
||||
from tests.factories import GameFactory, TeamFactory
|
||||
|
||||
|
||||
def _game(game_id, week, away_abbrev, home_abbrev, **kwargs):
|
||||
"""Create a Game with distinct team IDs per matchup."""
|
||||
return GameFactory.create(
|
||||
id=game_id,
|
||||
week=week,
|
||||
away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev),
|
||||
home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TestGetWeekSchedule:
|
||||
"""Tests for ScheduleService.get_week_schedule — the HTTP layer."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
svc = ScheduleService()
|
||||
svc.get_client = AsyncMock()
|
||||
return svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success(self, service):
|
||||
"""get_week_schedule returns parsed Game objects on a normal response."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {
|
||||
"games": [
|
||||
{
|
||||
"id": 1,
|
||||
"season": 12,
|
||||
"week": 5,
|
||||
"game_num": 1,
|
||||
"season_type": "regular",
|
||||
"away_team": {
|
||||
"id": 10,
|
||||
"abbrev": "NYY",
|
||||
"sname": "NYY",
|
||||
"lname": "New York",
|
||||
"season": 12,
|
||||
},
|
||||
"home_team": {
|
||||
"id": 11,
|
||||
"abbrev": "BOS",
|
||||
"sname": "BOS",
|
||||
"lname": "Boston",
|
||||
"season": 12,
|
||||
},
|
||||
"away_score": 4,
|
||||
"home_score": 2,
|
||||
}
|
||||
]
|
||||
}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 5)
|
||||
|
||||
assert len(games) == 1
|
||||
assert games[0].away_team.abbrev == "NYY"
|
||||
assert games[0].home_team.abbrev == "BOS"
|
||||
assert games[0].is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response(self, service):
|
||||
"""get_week_schedule returns [] when the API has no games."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {"games": []}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 99)
|
||||
assert games == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_returns_empty(self, service):
|
||||
"""get_week_schedule returns [] on API error (no exception raised)."""
|
||||
service.get_client.side_effect = Exception("connection refused")
|
||||
|
||||
games = await service.get_week_schedule(12, 1)
|
||||
assert games == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_games_key(self, service):
|
||||
"""get_week_schedule returns [] when response lacks 'games' key."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {"status": "ok"}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 1)
|
||||
assert games == []
|
||||
|
||||
|
||||
class TestGetTeamSchedule:
|
||||
"""Tests for get_team_schedule — gather + team-abbrev filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_team_case_insensitive(self, service):
|
||||
"""get_team_schedule returns only games involving the requested team,
|
||||
regardless of abbreviation casing."""
|
||||
week1 = [
|
||||
_game(1, 1, "NYY", "BOS", away_score=3, home_score=1),
|
||||
_game(2, 1, "LAD", "CHC", away_score=5, home_score=2),
|
||||
]
|
||||
week2 = [
|
||||
_game(3, 2, "BOS", "NYY", away_score=2, home_score=4),
|
||||
]
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.side_effect = [week1, week2]
|
||||
result = await service.get_team_schedule(12, "nyy", weeks=2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(
|
||||
g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_season_fetches_18_weeks(self, service):
|
||||
"""When weeks is None, all 18 weeks are fetched via gather."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_team_schedule(12, "NYY")
|
||||
|
||||
assert mock.call_count == 18
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limited_weeks(self, service):
|
||||
"""When weeks=5, only 5 weeks are fetched."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_team_schedule(12, "NYY", weeks=5)
|
||||
|
||||
assert mock.call_count == 5
|
||||
|
||||
|
||||
class TestGetRecentGames:
|
||||
"""Tests for get_recent_games — gather + completed-only filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_completed_games(self, service):
|
||||
"""get_recent_games filters out games without scores."""
|
||||
completed = GameFactory.completed(id=1, week=10)
|
||||
incomplete = GameFactory.upcoming(id=2, week=10)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = [completed, incomplete]
|
||||
result = await service.get_recent_games(12, weeks_back=1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sorted_descending_by_week_and_game_num(self, service):
|
||||
"""Recent games are sorted most-recent first."""
|
||||
game_w10 = GameFactory.completed(id=1, week=10, game_num=2)
|
||||
game_w9 = GameFactory.completed(id=2, week=9, game_num=1)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.side_effect = [[game_w10], [game_w9]]
|
||||
result = await service.get_recent_games(12, weeks_back=2)
|
||||
|
||||
assert result[0].week == 10
|
||||
assert result[1].week == 9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_negative_weeks(self, service):
|
||||
"""Weeks that would be <= 0 are excluded from fetch."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_recent_games(12, weeks_back=15)
|
||||
|
||||
# weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks
|
||||
assert mock.call_count == 10
|
||||
|
||||
|
||||
class TestGetUpcomingGames:
|
||||
"""Tests for get_upcoming_games — gather all 18 weeks + incomplete filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_incomplete_games(self, service):
|
||||
"""get_upcoming_games filters out completed games."""
|
||||
completed = GameFactory.completed(id=1, week=5)
|
||||
upcoming = GameFactory.upcoming(id=2, week=5)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = [completed, upcoming]
|
||||
result = await service.get_upcoming_games(12)
|
||||
|
||||
assert len(result) == 18 # 1 incomplete game per week × 18 weeks
|
||||
assert all(not g.is_completed for g in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sorted_ascending_by_week_and_game_num(self, service):
|
||||
"""Upcoming games are sorted earliest first."""
|
||||
game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1)
|
||||
game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
|
||||
def side_effect(season, week):
|
||||
if week == 1:
|
||||
return [game_w1]
|
||||
if week == 3:
|
||||
return [game_w3]
|
||||
return []
|
||||
|
||||
mock.side_effect = side_effect
|
||||
result = await service.get_upcoming_games(12)
|
||||
|
||||
assert result[0].week == 1
|
||||
assert result[1].week == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_all_18_weeks(self, service):
|
||||
"""All 18 weeks are fetched in parallel (no early exit)."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_upcoming_games(12)
|
||||
|
||||
assert mock.call_count == 18
|
||||
|
||||
|
||||
class TestGroupGamesBySeries:
|
||||
"""Tests for group_games_by_series — synchronous grouping logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
def test_groups_by_alphabetical_pairing(self, service):
|
||||
"""Games between the same two teams are grouped under one key,
|
||||
with the alphabetically-first team first in the tuple."""
|
||||
games = [
|
||||
_game(1, 1, "NYY", "BOS", game_num=1),
|
||||
_game(2, 1, "BOS", "NYY", game_num=2),
|
||||
_game(3, 1, "LAD", "CHC", game_num=1),
|
||||
]
|
||||
|
||||
result = service.group_games_by_series(games)
|
||||
|
||||
assert ("BOS", "NYY") in result
|
||||
assert len(result[("BOS", "NYY")]) == 2
|
||||
assert ("CHC", "LAD") in result
|
||||
assert len(result[("CHC", "LAD")]) == 1
|
||||
|
||||
def test_sorted_by_game_num_within_series(self, service):
|
||||
"""Games within each series are sorted by game_num."""
|
||||
games = [
|
||||
_game(1, 1, "NYY", "BOS", game_num=3),
|
||||
_game(2, 1, "NYY", "BOS", game_num=1),
|
||||
_game(3, 1, "NYY", "BOS", game_num=2),
|
||||
]
|
||||
|
||||
result = service.group_games_by_series(games)
|
||||
series = result[("BOS", "NYY")]
|
||||
assert [g.game_num for g in series] == [1, 2, 3]
|
||||
|
||||
def test_empty_input(self, service):
|
||||
"""Empty games list returns empty dict."""
|
||||
assert service.group_games_by_series([]) == {}
|
||||
111
tests/test_services_stats.py
Normal file
111
tests/test_services_stats.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""
|
||||
Tests for StatsService
|
||||
|
||||
Validates stats service functionality including concurrent stat retrieval
|
||||
and error handling in get_player_stats().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from services.stats_service import StatsService
|
||||
|
||||
|
||||
class TestStatsServiceGetPlayerStats:
|
||||
"""Test StatsService.get_player_stats() concurrent retrieval."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh StatsService instance for testing."""
|
||||
return StatsService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_batting_stats(self):
|
||||
"""Create a mock BattingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.avg = 0.300
|
||||
return stats
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pitching_stats(self):
|
||||
"""Create a mock PitchingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.era = 3.50
|
||||
return stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_stats_returned(
|
||||
self, service, mock_batting_stats, mock_pitching_stats
|
||||
):
|
||||
"""When both batting and pitching stats exist, both are returned.
|
||||
|
||||
Verifies that get_player_stats returns a tuple of (batting, pitching)
|
||||
when both stat types are available for the player.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=100, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is mock_pitching_stats
|
||||
service.get_batting_stats.assert_called_once_with(100, 12)
|
||||
service.get_pitching_stats.assert_called_once_with(100, 12)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batting_only(self, service, mock_batting_stats):
|
||||
"""When only batting stats exist, pitching is None.
|
||||
|
||||
Covers the case of a position player with no pitching record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=200, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pitching_only(self, service, mock_pitching_stats):
|
||||
"""When only pitching stats exist, batting is None.
|
||||
|
||||
Covers the case of a pitcher with no batting record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=300, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is mock_pitching_stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stats_found(self, service):
|
||||
"""When no stats exist for the player, both are None.
|
||||
|
||||
Covers the case where a player has no stats for the given season
|
||||
(e.g., didn't play).
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=400, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_none_tuple(self, service):
|
||||
"""When an exception occurs, (None, None) is returned.
|
||||
|
||||
The get_player_stats method wraps both calls in a try/except and
|
||||
returns (None, None) on any error, ensuring callers always get a tuple.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down"))
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=500, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
@ -115,6 +115,13 @@ class TestTransactionBuilder:
|
||||
svc.get_current_roster.return_value = mock_roster
|
||||
return svc
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_league_service(self):
|
||||
"""Patch league_service for all tests so FA lock check uses week 10 (before deadline)."""
|
||||
with patch("services.transaction_builder.league_service") as mock_ls:
|
||||
mock_ls.get_current_state = AsyncMock(return_value=MagicMock(week=10))
|
||||
yield mock_ls
|
||||
|
||||
@pytest.fixture
|
||||
def builder(self, mock_team, mock_roster_service):
|
||||
"""Create a TransactionBuilder for testing with injected roster service."""
|
||||
@ -152,6 +159,50 @@ class TestTransactionBuilder:
|
||||
assert builder.is_empty is False
|
||||
assert move in builder.moves
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_move_from_fa_blocked_after_deadline(self, builder, mock_player):
|
||||
"""Test that adding a player FROM Free Agency is blocked after fa_lock_week."""
|
||||
move = TransactionMove(
|
||||
player=mock_player,
|
||||
from_roster=RosterType.FREE_AGENCY,
|
||||
to_roster=RosterType.MAJOR_LEAGUE,
|
||||
to_team=builder.team,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"services.transaction_builder.league_service"
|
||||
) as mock_league_service:
|
||||
mock_league_service.get_current_state = AsyncMock(
|
||||
return_value=MagicMock(week=15)
|
||||
)
|
||||
|
||||
success, error_message = await builder.add_move(
|
||||
move, check_pending_transactions=False
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert "Free agency is closed" in error_message
|
||||
assert builder.move_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_to_fa_allowed_after_deadline(self, builder, mock_player):
|
||||
"""Test that dropping a player TO Free Agency is still allowed after fa_lock_week."""
|
||||
move = TransactionMove(
|
||||
player=mock_player,
|
||||
from_roster=RosterType.MAJOR_LEAGUE,
|
||||
to_roster=RosterType.FREE_AGENCY,
|
||||
from_team=builder.team,
|
||||
)
|
||||
|
||||
# Drop to FA doesn't trigger the FA lock check (autouse fixture provides week 10)
|
||||
success, error_message = await builder.add_move(
|
||||
move, check_pending_transactions=False
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error_message == ""
|
||||
assert builder.move_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_duplicate_move_fails(self, builder, mock_player):
|
||||
"""Test that adding duplicate moves for same player fails."""
|
||||
@ -809,6 +860,13 @@ class TestPendingTransactionValidation:
|
||||
"""Create a mock player for testing."""
|
||||
return Player(id=12472, name="Test Player", wara=2.5, season=12, pos_1="OF")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_league_service(self):
|
||||
"""Patch league_service so FA lock check and week resolution use week 10."""
|
||||
with patch("services.transaction_builder.league_service") as mock_ls:
|
||||
mock_ls.get_current_state = AsyncMock(return_value=MagicMock(week=10))
|
||||
yield mock_ls
|
||||
|
||||
@pytest.fixture
|
||||
def builder(self, mock_team):
|
||||
"""Create a TransactionBuilder for testing."""
|
||||
|
||||
@ -3,10 +3,16 @@ Tests for shared autocomplete utility functions.
|
||||
|
||||
Validates the shared autocomplete functions used across multiple command modules.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from utils.autocomplete import player_autocomplete, team_autocomplete, major_league_team_autocomplete
|
||||
import utils.autocomplete
|
||||
from utils.autocomplete import (
|
||||
player_autocomplete,
|
||||
team_autocomplete,
|
||||
major_league_team_autocomplete,
|
||||
)
|
||||
from tests.factories import PlayerFactory, TeamFactory
|
||||
from models.team import RosterType
|
||||
|
||||
@ -14,6 +20,13 @@ from models.team import RosterType
|
||||
class TestPlayerAutocomplete:
|
||||
"""Test player autocomplete functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_user_team_cache(self):
|
||||
"""Clear the module-level user team cache before each test to prevent interference."""
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
yield
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interaction(self):
|
||||
"""Create a mock Discord interaction."""
|
||||
@ -26,41 +39,43 @@ class TestPlayerAutocomplete:
|
||||
"""Test successful player autocomplete."""
|
||||
mock_players = [
|
||||
PlayerFactory.mike_trout(id=1),
|
||||
PlayerFactory.ronald_acuna(id=2)
|
||||
PlayerFactory.ronald_acuna(id=2),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players = AsyncMock(return_value=mock_players)
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
|
||||
assert len(choices) == 2
|
||||
assert choices[0].name == 'Mike Trout (CF)'
|
||||
assert choices[0].value == 'Mike Trout'
|
||||
assert choices[1].name == 'Ronald Acuna Jr. (OF)'
|
||||
assert choices[1].value == 'Ronald Acuna Jr.'
|
||||
assert choices[0].name == "Mike Trout (CF)"
|
||||
assert choices[0].value == "Mike Trout"
|
||||
assert choices[1].name == "Ronald Acuna Jr. (OF)"
|
||||
assert choices[1].value == "Ronald Acuna Jr."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_with_team_info(self, mock_interaction):
|
||||
"""Test player autocomplete with team information."""
|
||||
mock_team = TeamFactory.create(id=499, abbrev='LAA', sname='Angels', lname='Los Angeles Angels')
|
||||
mock_team = TeamFactory.create(
|
||||
id=499, abbrev="LAA", sname="Angels", lname="Los Angeles Angels"
|
||||
)
|
||||
mock_player = PlayerFactory.mike_trout(id=1)
|
||||
mock_player.team = mock_team
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players = AsyncMock(return_value=[mock_player])
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
|
||||
assert len(choices) == 1
|
||||
assert choices[0].name == 'Mike Trout (CF - LAA)'
|
||||
assert choices[0].value == 'Mike Trout'
|
||||
assert choices[0].name == "Mike Trout (CF - LAA)"
|
||||
assert choices[0].value == "Mike Trout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_prioritizes_user_team(self, mock_interaction):
|
||||
"""Test that user's team players are prioritized in autocomplete."""
|
||||
user_team = TeamFactory.create(id=1, abbrev='POR', sname='Loggers')
|
||||
other_team = TeamFactory.create(id=2, abbrev='LAA', sname='Angels')
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
other_team = TeamFactory.create(id=2, abbrev="LAA", sname="Angels")
|
||||
|
||||
# Create players - one from user's team, one from other team
|
||||
user_player = PlayerFactory.mike_trout(id=1)
|
||||
@ -71,32 +86,35 @@ class TestPlayerAutocomplete:
|
||||
other_player.team = other_team
|
||||
other_player.team_id = other_team.id
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service, \
|
||||
patch('utils.autocomplete.get_user_major_league_team') as mock_get_team:
|
||||
|
||||
mock_service.search_players = AsyncMock(return_value=[other_player, user_player])
|
||||
with (
|
||||
patch("utils.autocomplete.player_service") as mock_service,
|
||||
patch("utils.autocomplete.get_user_major_league_team") as mock_get_team,
|
||||
):
|
||||
mock_service.search_players = AsyncMock(
|
||||
return_value=[other_player, user_player]
|
||||
)
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'player')
|
||||
choices = await player_autocomplete(mock_interaction, "player")
|
||||
|
||||
assert len(choices) == 2
|
||||
# User's team player should be first
|
||||
assert choices[0].name == 'Mike Trout (CF - POR)'
|
||||
assert choices[1].name == 'Ronald Acuna Jr. (OF - LAA)'
|
||||
assert choices[0].name == "Mike Trout (CF - POR)"
|
||||
assert choices[1].name == "Ronald Acuna Jr. (OF - LAA)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test player autocomplete with short input returns empty."""
|
||||
choices = await player_autocomplete(mock_interaction, 'T')
|
||||
choices = await player_autocomplete(mock_interaction, "T")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_error_handling(self, mock_interaction):
|
||||
"""Test player autocomplete error handling."""
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players.side_effect = Exception("API Error")
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
assert len(choices) == 0
|
||||
|
||||
|
||||
@ -114,35 +132,35 @@ class TestTeamAutocomplete:
|
||||
async def test_team_autocomplete_success(self, mock_interaction):
|
||||
"""Test successful team autocomplete."""
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
|
||||
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'),
|
||||
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'),
|
||||
TeamFactory.create(id=4, abbrev='POR', sname='Loggers')
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
|
||||
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"),
|
||||
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"),
|
||||
TeamFactory.create(id=4, abbrev="POR", sname="Loggers"),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await team_autocomplete(mock_interaction, 'la')
|
||||
choices = await team_autocomplete(mock_interaction, "la")
|
||||
|
||||
assert len(choices) == 3 # All teams with 'la' in abbrev or sname
|
||||
assert any('LAA' in choice.name for choice in choices)
|
||||
assert any('LAAMIL' in choice.name for choice in choices)
|
||||
assert any('LAAAIL' in choice.name for choice in choices)
|
||||
assert any("LAA" in choice.name for choice in choices)
|
||||
assert any("LAAMIL" in choice.name for choice in choices)
|
||||
assert any("LAAAIL" in choice.name for choice in choices)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test team autocomplete with very short input."""
|
||||
choices = await team_autocomplete(mock_interaction, '')
|
||||
choices = await team_autocomplete(mock_interaction, "")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_autocomplete_error_handling(self, mock_interaction):
|
||||
"""Test team autocomplete error handling."""
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season.side_effect = Exception("API Error")
|
||||
|
||||
choices = await team_autocomplete(mock_interaction, 'LAA')
|
||||
choices = await team_autocomplete(mock_interaction, "LAA")
|
||||
assert len(choices) == 0
|
||||
|
||||
|
||||
@ -157,101 +175,197 @@ class TestMajorLeagueTeamAutocomplete:
|
||||
return interaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_filters_correctly(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_filters_correctly(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test that only major league teams are returned."""
|
||||
# Create teams with different roster types
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), # ML
|
||||
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'), # MiL
|
||||
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'), # IL
|
||||
TeamFactory.create(id=4, abbrev='FA', sname='Free Agents'), # FA
|
||||
TeamFactory.create(id=5, abbrev='POR', sname='Loggers'), # ML
|
||||
TeamFactory.create(id=6, abbrev='PORMIL', sname='Portland MiL'), # MiL
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), # ML
|
||||
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"), # MiL
|
||||
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"), # IL
|
||||
TeamFactory.create(id=4, abbrev="FA", sname="Free Agents"), # FA
|
||||
TeamFactory.create(id=5, abbrev="POR", sname="Loggers"), # ML
|
||||
TeamFactory.create(id=6, abbrev="PORMIL", sname="Portland MiL"), # MiL
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'l')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "l")
|
||||
|
||||
# Should only return major league teams that match 'l' (LAA, POR)
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'LAA' in choice_values
|
||||
assert 'POR' in choice_values
|
||||
assert "LAA" in choice_values
|
||||
assert "POR" in choice_values
|
||||
assert len(choice_values) == 2
|
||||
# Should NOT include MiL, IL, or FA teams
|
||||
assert 'LAAMIL' not in choice_values
|
||||
assert 'LAAAIL' not in choice_values
|
||||
assert 'FA' not in choice_values
|
||||
assert 'PORMIL' not in choice_values
|
||||
assert "LAAMIL" not in choice_values
|
||||
assert "LAAAIL" not in choice_values
|
||||
assert "FA" not in choice_values
|
||||
assert "PORMIL" not in choice_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_matching(self, mock_interaction):
|
||||
"""Test search matching on abbreviation and short name."""
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
|
||||
TeamFactory.create(id=2, abbrev='LAD', sname='Dodgers'),
|
||||
TeamFactory.create(id=3, abbrev='POR', sname='Loggers'),
|
||||
TeamFactory.create(id=4, abbrev='BOS', sname='Red Sox'),
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
|
||||
TeamFactory.create(id=2, abbrev="LAD", sname="Dodgers"),
|
||||
TeamFactory.create(id=3, abbrev="POR", sname="Loggers"),
|
||||
TeamFactory.create(id=4, abbrev="BOS", sname="Red Sox"),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
# Test abbreviation matching
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'la')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "la")
|
||||
assert len(choices) == 2 # LAA and LAD
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'LAA' in choice_values
|
||||
assert 'LAD' in choice_values
|
||||
assert "LAA" in choice_values
|
||||
assert "LAD" in choice_values
|
||||
|
||||
# Test short name matching
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'red')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "red")
|
||||
assert len(choices) == 1
|
||||
assert choices[0].value == 'BOS'
|
||||
assert choices[0].value == "BOS"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test major league team autocomplete with very short input."""
|
||||
choices = await major_league_team_autocomplete(mock_interaction, '')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_error_handling(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_error_handling(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test major league team autocomplete error handling."""
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season.side_effect = Exception("API Error")
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'LAA')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "LAA")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_roster_type_detection(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_roster_type_detection(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test that roster type detection works correctly for edge cases."""
|
||||
# Test edge cases like teams whose abbreviation ends in 'M' + 'IL'
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='BHM', sname='Iron'), # ML team ending in 'M'
|
||||
TeamFactory.create(id=2, abbrev='BHMIL', sname='Iron IL'), # IL team (BHM + IL)
|
||||
TeamFactory.create(id=3, abbrev='NYYMIL', sname='Staten Island RailRiders'), # MiL team (NYY + MIL)
|
||||
TeamFactory.create(id=4, abbrev='NYY', sname='Yankees'), # ML team
|
||||
TeamFactory.create(
|
||||
id=1, abbrev="BHM", sname="Iron"
|
||||
), # ML team ending in 'M'
|
||||
TeamFactory.create(
|
||||
id=2, abbrev="BHMIL", sname="Iron IL"
|
||||
), # IL team (BHM + IL)
|
||||
TeamFactory.create(
|
||||
id=3, abbrev="NYYMIL", sname="Staten Island RailRiders"
|
||||
), # MiL team (NYY + MIL)
|
||||
TeamFactory.create(id=4, abbrev="NYY", sname="Yankees"), # ML team
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'b')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "b")
|
||||
|
||||
# Should only return major league teams
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'BHM' in choice_values # Major league team
|
||||
assert 'BHMIL' not in choice_values # Should be detected as IL, not MiL
|
||||
assert 'NYYMIL' not in choice_values # Minor league team
|
||||
assert "BHM" in choice_values # Major league team
|
||||
assert "BHMIL" not in choice_values # Should be detected as IL, not MiL
|
||||
assert "NYYMIL" not in choice_values # Minor league team
|
||||
|
||||
# Verify the roster type detection is working
|
||||
bhm_team = next(t for t in mock_teams if t.abbrev == 'BHM')
|
||||
bhmil_team = next(t for t in mock_teams if t.abbrev == 'BHMIL')
|
||||
nyymil_team = next(t for t in mock_teams if t.abbrev == 'NYYMIL')
|
||||
bhm_team = next(t for t in mock_teams if t.abbrev == "BHM")
|
||||
bhmil_team = next(t for t in mock_teams if t.abbrev == "BHMIL")
|
||||
nyymil_team = next(t for t in mock_teams if t.abbrev == "NYYMIL")
|
||||
|
||||
assert bhm_team.roster_type() == RosterType.MAJOR_LEAGUE
|
||||
assert bhmil_team.roster_type() == RosterType.INJURED_LIST
|
||||
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
|
||||
class TestGetCachedUserTeam:
|
||||
"""Test the _get_cached_user_team caching helper.
|
||||
|
||||
Verifies that the cache avoids redundant get_user_major_league_team calls
|
||||
on repeated invocations within the TTL window, and that expired entries are
|
||||
re-fetched.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache(self):
|
||||
"""Isolate each test from cache state left by other tests."""
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
yield
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interaction(self):
|
||||
interaction = MagicMock()
|
||||
interaction.user.id = 99999
|
||||
return interaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_result_on_first_call(self, mock_interaction):
|
||||
"""First call populates the cache; API function called exactly once."""
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team
|
||||
|
||||
result1 = await _get_cached_user_team(mock_interaction)
|
||||
result2 = await _get_cached_user_team(mock_interaction)
|
||||
|
||||
assert result1 is user_team
|
||||
assert result2 is user_team
|
||||
# API called only once despite two invocations
|
||||
mock_get_team.assert_called_once_with(99999)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_fetches_after_ttl_expires(self, mock_interaction):
|
||||
"""Expired cache entries cause a fresh API call."""
|
||||
import time
|
||||
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team, _USER_TEAM_CACHE_TTL
|
||||
|
||||
# Seed the cache with a timestamp that is already expired
|
||||
utils.autocomplete._user_team_cache[99999] = (
|
||||
user_team,
|
||||
time.time() - _USER_TEAM_CACHE_TTL - 1,
|
||||
)
|
||||
|
||||
await _get_cached_user_team(mock_interaction)
|
||||
|
||||
# Should have called the API to refresh the stale entry
|
||||
mock_get_team.assert_called_once_with(99999)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_none_result(self, mock_interaction):
|
||||
"""None (user has no team) is cached to avoid repeated API calls."""
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = None
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team
|
||||
|
||||
result1 = await _get_cached_user_team(mock_interaction)
|
||||
result2 = await _get_cached_user_team(mock_interaction)
|
||||
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
mock_get_team.assert_called_once()
|
||||
|
||||
@ -4,16 +4,33 @@ Autocomplete Utilities
|
||||
Shared autocomplete functions for Discord slash commands.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from config import get_config
|
||||
from models.team import RosterType
|
||||
from models.team import RosterType, Team
|
||||
from services.player_service import player_service
|
||||
from services.team_service import team_service
|
||||
from utils.team_utils import get_user_major_league_team
|
||||
|
||||
# Cache for user team lookups: user_id -> (team, cached_at)
|
||||
_user_team_cache: Dict[int, Tuple[Optional[Team], float]] = {}
|
||||
_USER_TEAM_CACHE_TTL = 60 # seconds
|
||||
|
||||
|
||||
async def _get_cached_user_team(interaction: discord.Interaction) -> Optional[Team]:
|
||||
"""Return the user's major league team, cached for 60 seconds per user."""
|
||||
user_id = interaction.user.id
|
||||
if user_id in _user_team_cache:
|
||||
team, cached_at = _user_team_cache[user_id]
|
||||
if time.time() - cached_at < _USER_TEAM_CACHE_TTL:
|
||||
return team
|
||||
team = await get_user_major_league_team(user_id)
|
||||
_user_team_cache[user_id] = (team, time.time())
|
||||
return team
|
||||
|
||||
|
||||
async def player_autocomplete(
|
||||
interaction: discord.Interaction, current: str
|
||||
@ -34,12 +51,12 @@ async def player_autocomplete(
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get user's team for prioritization
|
||||
user_team = await get_user_major_league_team(interaction.user.id)
|
||||
# Get user's team for prioritization (cached per user, 60s TTL)
|
||||
user_team = await _get_cached_user_team(interaction)
|
||||
|
||||
# Search for players using the search endpoint
|
||||
players = await player_service.search_players(
|
||||
current, limit=50, season=get_config().sba_season
|
||||
current, limit=25, season=get_config().sba_season
|
||||
)
|
||||
|
||||
# Separate players by team (user's team vs others)
|
||||
|
||||
@ -188,9 +188,11 @@ class CacheManager:
|
||||
|
||||
try:
|
||||
pattern = f"{prefix}:*"
|
||||
keys = await client.keys(pattern)
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
keys_to_delete = []
|
||||
async for key in client.scan_iter(match=pattern):
|
||||
keys_to_delete.append(key)
|
||||
if keys_to_delete:
|
||||
deleted = await client.delete(*keys_to_delete)
|
||||
logger.info(f"Cleared {deleted} cache keys with prefix '{prefix}'")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
|
||||
@ -11,29 +11,29 @@ from functools import wraps
|
||||
from typing import List, Optional, Callable, Any
|
||||
from utils.logging import set_discord_context, get_contextual_logger
|
||||
|
||||
cache_logger = logging.getLogger(f'{__name__}.CacheDecorators')
|
||||
period_check_logger = logging.getLogger(f'{__name__}.PeriodCheckDecorators')
|
||||
cache_logger = logging.getLogger(f"{__name__}.CacheDecorators")
|
||||
period_check_logger = logging.getLogger(f"{__name__}.PeriodCheckDecorators")
|
||||
|
||||
|
||||
def logged_command(
|
||||
command_name: Optional[str] = None,
|
||||
command_name: Optional[str] = None,
|
||||
log_params: bool = True,
|
||||
exclude_params: Optional[List[str]] = None
|
||||
exclude_params: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Decorator for Discord commands that adds comprehensive logging.
|
||||
|
||||
|
||||
This decorator automatically handles:
|
||||
- Setting Discord context with interaction details
|
||||
- Starting/ending operation timing
|
||||
- Logging command start/completion/failure
|
||||
- Preserving function metadata and signature
|
||||
|
||||
|
||||
Args:
|
||||
command_name: Override command name (defaults to function name with slashes)
|
||||
log_params: Whether to log command parameters (default: True)
|
||||
exclude_params: List of parameter names to exclude from logging
|
||||
|
||||
|
||||
Example:
|
||||
@logged_command("/roster", exclude_params=["sensitive_data"])
|
||||
async def team_roster(self, interaction, team_name: str, season: int = None):
|
||||
@ -42,57 +42,65 @@ def logged_command(
|
||||
players = await team_service.get_roster(team.id, season)
|
||||
embed = create_roster_embed(team, players)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
Side Effects:
|
||||
- Automatically sets Discord context for all subsequent log entries
|
||||
- Creates trace_id for request correlation
|
||||
- Logs command execution timing and results
|
||||
- Re-raises all exceptions after logging (preserves original behavior)
|
||||
|
||||
|
||||
Requirements:
|
||||
- The decorated class must have a 'logger' attribute, or one will be created
|
||||
- Function must be an async method with (self, interaction, ...) signature
|
||||
- Preserves Discord.py command registration compatibility
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
|
||||
exclude_set = set(exclude_params or [])
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, interaction, *args, **kwargs):
|
||||
# Auto-detect command name if not provided
|
||||
cmd_name = command_name or f"/{func.__name__.replace('_', '-')}"
|
||||
|
||||
|
||||
# Build context with safe parameter logging
|
||||
context = {"command": cmd_name}
|
||||
if log_params:
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
|
||||
exclude_set = set(exclude_params or [])
|
||||
|
||||
for i, (name, value) in enumerate(zip(param_names, args)):
|
||||
if name not in exclude_set:
|
||||
context[f"param_{name}"] = value
|
||||
|
||||
|
||||
set_discord_context(interaction=interaction, **context)
|
||||
|
||||
|
||||
# Get logger from the class instance or create one
|
||||
logger = getattr(self, 'logger', get_contextual_logger(f'{self.__class__.__module__}.{self.__class__.__name__}'))
|
||||
logger = getattr(
|
||||
self,
|
||||
"logger",
|
||||
get_contextual_logger(
|
||||
f"{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
),
|
||||
)
|
||||
trace_id = logger.start_operation(f"{func.__name__}_command")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"{cmd_name} command started")
|
||||
result = await func(self, interaction, *args, **kwargs)
|
||||
logger.info(f"{cmd_name} command completed successfully")
|
||||
logger.end_operation(trace_id, "completed")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{cmd_name} command failed", error=e)
|
||||
logger.end_operation(trace_id, "failed")
|
||||
# Re-raise to maintain original exception handling behavior
|
||||
raise
|
||||
|
||||
|
||||
# Preserve signature for Discord.py command registration
|
||||
wrapper.__signature__ = inspect.signature(func) # type: ignore
|
||||
wrapper.__signature__ = sig # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -122,6 +130,7 @@ def requires_draft_period(func):
|
||||
- Should be placed before @logged_command decorator
|
||||
- league_service must be available via import
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, interaction, *args, **kwargs):
|
||||
# Import here to avoid circular imports
|
||||
@ -133,10 +142,12 @@ def requires_draft_period(func):
|
||||
current = await league_service.get_current_state()
|
||||
|
||||
if not current:
|
||||
period_check_logger.error("Could not retrieve league state for draft period check")
|
||||
period_check_logger.error(
|
||||
"Could not retrieve league state for draft period check"
|
||||
)
|
||||
embed = EmbedTemplate.error(
|
||||
"System Error",
|
||||
"Could not verify draft period status. Please try again later."
|
||||
"Could not verify draft period status. Please try again later.",
|
||||
)
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
return
|
||||
@ -148,12 +159,12 @@ def requires_draft_period(func):
|
||||
extra={
|
||||
"user_id": interaction.user.id,
|
||||
"command": func.__name__,
|
||||
"current_week": current.week
|
||||
}
|
||||
"current_week": current.week,
|
||||
},
|
||||
)
|
||||
embed = EmbedTemplate.error(
|
||||
"Not Available",
|
||||
"Draft commands are only available in the offseason."
|
||||
"Draft commands are only available in the offseason.",
|
||||
)
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
return
|
||||
@ -161,7 +172,7 @@ def requires_draft_period(func):
|
||||
# Week <= 0, allow command to proceed
|
||||
period_check_logger.debug(
|
||||
f"Draft period check passed - week {current.week}",
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__}
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__},
|
||||
)
|
||||
return await func(self, interaction, *args, **kwargs)
|
||||
|
||||
@ -169,7 +180,7 @@ def requires_draft_period(func):
|
||||
period_check_logger.error(
|
||||
f"Error in draft period check: {e}",
|
||||
exc_info=True,
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__}
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__},
|
||||
)
|
||||
# Re-raise to let error handling in logged_command handle it
|
||||
raise
|
||||
@ -182,110 +193,115 @@ def requires_draft_period(func):
|
||||
def cached_api_call(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
"""
|
||||
Decorator to add Redis caching to service methods that return List[T].
|
||||
|
||||
|
||||
This decorator will:
|
||||
1. Check cache for existing data using generated key
|
||||
2. Return cached data if found
|
||||
3. Execute original method if cache miss
|
||||
4. Cache the result for future calls
|
||||
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live override in seconds (uses service default if None)
|
||||
cache_key_suffix: Additional suffix for cache key differentiation
|
||||
|
||||
|
||||
Usage:
|
||||
@cached_api_call(ttl=600, cache_key_suffix="by_season")
|
||||
async def get_teams_by_season(self, season: int) -> List[Team]:
|
||||
# Original method implementation
|
||||
|
||||
|
||||
Requirements:
|
||||
- Method must be async
|
||||
- Method must return List[T] where T is a model
|
||||
- Class must have self.cache (CacheManager instance)
|
||||
- Class must have self._generate_cache_key, self._get_cached_items, self._cache_items methods
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs) -> List[Any]:
|
||||
# Check if caching is available (service has cache manager)
|
||||
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
|
||||
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
|
||||
# No caching available, execute original method
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Generate cache key from method name, args, and kwargs
|
||||
method_name = f"{func.__name__}{cache_key_suffix}"
|
||||
|
||||
|
||||
# Convert args and kwargs to params list for consistent cache key
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
|
||||
# Skip 'self' and convert to params format
|
||||
params = []
|
||||
for param_name, param_value in bound_args.arguments.items():
|
||||
if param_name != 'self' and param_value is not None:
|
||||
if param_name != "self" and param_value is not None:
|
||||
params.append((param_name, param_value))
|
||||
|
||||
|
||||
cache_key = self._generate_cache_key(method_name, params)
|
||||
|
||||
|
||||
# Try to get from cache
|
||||
if hasattr(self, '_get_cached_items'):
|
||||
if hasattr(self, "_get_cached_items"):
|
||||
cached_result = await self._get_cached_items(cache_key)
|
||||
if cached_result is not None:
|
||||
cache_logger.debug(f"Cache hit: {method_name}")
|
||||
return cached_result
|
||||
|
||||
|
||||
# Cache miss - execute original method
|
||||
cache_logger.debug(f"Cache miss: {method_name}")
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Cache the result if we have items and caching methods
|
||||
if result and hasattr(self, '_cache_items'):
|
||||
if result and hasattr(self, "_cache_items"):
|
||||
await self._cache_items(cache_key, result, ttl)
|
||||
cache_logger.debug(f"Cached {len(result)} items for {method_name}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
"""
|
||||
Decorator to add Redis caching to service methods that return Optional[T].
|
||||
|
||||
|
||||
Similar to cached_api_call but for methods returning a single model instance.
|
||||
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live override in seconds
|
||||
cache_key_suffix: Additional suffix for cache key differentiation
|
||||
|
||||
|
||||
Usage:
|
||||
@cached_single_item(ttl=300, cache_key_suffix="by_id")
|
||||
async def get_player(self, player_id: int) -> Optional[Player]:
|
||||
# Original method implementation
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs) -> Optional[Any]:
|
||||
# Check if caching is available
|
||||
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
|
||||
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Generate cache key
|
||||
method_name = f"{func.__name__}{cache_key_suffix}"
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
|
||||
params = []
|
||||
for param_name, param_value in bound_args.arguments.items():
|
||||
if param_name != 'self' and param_value is not None:
|
||||
if param_name != "self" and param_value is not None:
|
||||
params.append((param_name, param_value))
|
||||
|
||||
|
||||
cache_key = self._generate_cache_key(method_name, params)
|
||||
|
||||
|
||||
# Try cache first
|
||||
try:
|
||||
cached_data = await self.cache.get(cache_key)
|
||||
@ -293,12 +309,14 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
cache_logger.debug(f"Cache hit: {method_name}")
|
||||
return self.model_class.from_api_data(cached_data)
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error reading single item cache for {cache_key}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error reading single item cache for {cache_key}: {e}"
|
||||
)
|
||||
|
||||
# Cache miss - execute original method
|
||||
cache_logger.debug(f"Cache miss: {method_name}")
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Cache the single result
|
||||
if result:
|
||||
try:
|
||||
@ -306,43 +324,54 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
await self.cache.set(cache_key, cache_data, ttl)
|
||||
cache_logger.debug(f"Cached single item for {method_name}")
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error caching single item for {cache_key}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error caching single item for {cache_key}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_invalidate(*cache_patterns: str):
|
||||
"""
|
||||
Decorator to invalidate cache entries when data is modified.
|
||||
|
||||
|
||||
Args:
|
||||
cache_patterns: Cache key patterns to invalidate (supports prefix matching)
|
||||
|
||||
|
||||
Usage:
|
||||
@cache_invalidate("players_by_team", "teams_by_season")
|
||||
async def update_player(self, player_id: int, updates: dict) -> Optional[Player]:
|
||||
# Original method implementation
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
# Execute original method first
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Invalidate specified cache patterns
|
||||
if hasattr(self, 'cache'):
|
||||
if hasattr(self, "cache"):
|
||||
for pattern in cache_patterns:
|
||||
try:
|
||||
cleared = await self.cache.clear_prefix(f"sba:{self.endpoint}_{pattern}")
|
||||
cleared = await self.cache.clear_prefix(
|
||||
f"sba:{self.endpoint}_{pattern}"
|
||||
)
|
||||
if cleared > 0:
|
||||
cache_logger.info(f"Invalidated {cleared} cache entries for pattern: {pattern}")
|
||||
cache_logger.info(
|
||||
f"Invalidated {cleared} cache entries for pattern: {pattern}"
|
||||
)
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error invalidating cache pattern {pattern}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error invalidating cache pattern {pattern}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
||||
@ -24,6 +24,8 @@ JSONValue = Union[
|
||||
str, int, float, bool, None, dict[str, Any], list[Any] # nested object # arrays
|
||||
]
|
||||
|
||||
_SERIALIZABLE_TYPES = (str, int, float, bool, type(None))
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured file logging."""
|
||||
@ -93,11 +95,11 @@ class JSONFormatter(logging.Formatter):
|
||||
extra_data = {}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in excluded_keys:
|
||||
# Ensure JSON serializable
|
||||
try:
|
||||
json.dumps(value)
|
||||
if isinstance(value, _SERIALIZABLE_TYPES) or isinstance(
|
||||
value, (list, dict)
|
||||
):
|
||||
extra_data[key] = value
|
||||
except (TypeError, ValueError):
|
||||
else:
|
||||
extra_data[key] = str(value)
|
||||
|
||||
if extra_data:
|
||||
|
||||
@ -124,6 +124,22 @@ class TradeEmbedView(discord.ui.View):
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Handle submit trade button click."""
|
||||
# Check trade deadline
|
||||
current = await league_service.get_current_state()
|
||||
if not current:
|
||||
await interaction.response.send_message(
|
||||
"❌ Could not retrieve league state. Please try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
if current.is_past_trade_deadline:
|
||||
await interaction.response.send_message(
|
||||
f"❌ **The trade deadline has passed** (Week {current.trade_deadline}). "
|
||||
f"This trade can no longer be submitted.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
if self.builder.is_empty:
|
||||
await interaction.response.send_message(
|
||||
"Cannot submit empty trade. Add some moves first!", ephemeral=True
|
||||
@ -328,6 +344,7 @@ class TradeAcceptanceView(discord.ui.View):
|
||||
def __init__(self, builder: TradeBuilder):
|
||||
super().__init__(timeout=3600.0) # 1 hour timeout
|
||||
self.builder = builder
|
||||
self._checked_teams: dict[int, Team] = {}
|
||||
|
||||
async def _get_user_team(self, interaction: discord.Interaction) -> Optional[Team]:
|
||||
"""Get the team owned by the interacting user."""
|
||||
@ -353,6 +370,7 @@ class TradeAcceptanceView(discord.ui.View):
|
||||
)
|
||||
return False
|
||||
|
||||
self._checked_teams[interaction.user.id] = user_team
|
||||
return True
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
@ -366,7 +384,7 @@ class TradeAcceptanceView(discord.ui.View):
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Handle accept button click."""
|
||||
user_team = await self._get_user_team(interaction)
|
||||
user_team = self._checked_teams.get(interaction.user.id)
|
||||
if not user_team:
|
||||
return
|
||||
|
||||
@ -401,7 +419,7 @@ class TradeAcceptanceView(discord.ui.View):
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Handle reject button click - moves trade back to DRAFT."""
|
||||
user_team = await self._get_user_team(interaction)
|
||||
user_team = self._checked_teams.get(interaction.user.id)
|
||||
if not user_team:
|
||||
return
|
||||
|
||||
@ -433,7 +451,16 @@ class TradeAcceptanceView(discord.ui.View):
|
||||
config = get_config()
|
||||
|
||||
current = await league_service.get_current_state()
|
||||
next_week = current.week + 1 if current else 1
|
||||
if not current or current.is_past_trade_deadline:
|
||||
deadline_msg = (
|
||||
f"❌ **The trade deadline has passed** (Week {current.trade_deadline}). "
|
||||
f"This trade cannot be finalized."
|
||||
if current
|
||||
else "❌ Could not retrieve league state. Please try again later."
|
||||
)
|
||||
await interaction.followup.send(deadline_msg, ephemeral=True)
|
||||
return
|
||||
next_week = current.week + 1
|
||||
|
||||
fa_team = Team(
|
||||
id=config.free_agent_team_id,
|
||||
@ -708,10 +735,10 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
|
||||
Returns:
|
||||
Discord embed with current trade state
|
||||
"""
|
||||
validation = await builder.validate_trade()
|
||||
if builder.is_empty:
|
||||
color = EmbedColors.SECONDARY
|
||||
else:
|
||||
validation = await builder.validate_trade()
|
||||
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
|
||||
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
@ -766,7 +793,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
|
||||
inline=False,
|
||||
)
|
||||
|
||||
validation = await builder.validate_trade()
|
||||
if validation.is_legal:
|
||||
status_text = "Trade appears legal"
|
||||
else:
|
||||
|
||||
@ -6,6 +6,8 @@ Handles the Discord embed and button interfaces for the transaction builder.
|
||||
|
||||
import discord
|
||||
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
from services.transaction_builder import (
|
||||
TransactionBuilder,
|
||||
clear_transaction_builder,
|
||||
@ -235,6 +237,7 @@ class SubmitConfirmationModal(discord.ui.Modal):
|
||||
super().__init__(title="Confirm Transaction Submission")
|
||||
self.builder = builder
|
||||
self.submission_handler = submission_handler
|
||||
self.logger = get_contextual_logger(f"{__name__}.SubmitConfirmationModal")
|
||||
|
||||
self.confirmation = discord.ui.TextInput(
|
||||
label="Type 'CONFIRM' to submit",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user