Merge next-release into main #81

Merged
cal merged 34 commits from next-release into main 2026-03-17 16:44:45 +00:00
42 changed files with 3245 additions and 2596 deletions

View File

@ -3,6 +3,7 @@
# CI/CD pipeline for Major Domo Discord Bot:
# - Builds Docker images on every push/PR
# - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges
# - Supports multi-channel releases: stable (main), rc (next-release), dev (PRs)
# - Pushes to Docker Hub and creates git tag on main
# - Sends Discord notifications on success/failure
@ -12,6 +13,7 @@ on:
push:
branches:
- main
- next-release
pull_request:
branches:
- main
@ -39,30 +41,20 @@ jobs:
id: calver
uses: cal/gitea-actions/calver@main
# Dev build: push with dev + dev-SHA tags (PR/feature branches)
- name: Build Docker image (dev)
if: github.ref != 'refs/heads/main'
uses: https://github.com/docker/build-push-action@v5
- name: Resolve Docker tags
id: tags
uses: cal/gitea-actions/docker-tags@main
with:
context: .
push: true
tags: |
manticorum67/major-domo-discordapp:dev
manticorum67/major-domo-discordapp:dev-${{ steps.calver.outputs.sha_short }}
cache-from: type=registry,ref=manticorum67/major-domo-discordapp:buildcache
cache-to: type=registry,ref=manticorum67/major-domo-discordapp:buildcache,mode=max
image: manticorum67/major-domo-discordapp
version: ${{ steps.calver.outputs.version }}
sha_short: ${{ steps.calver.outputs.sha_short }}
# Production build: push with latest + CalVer tags (main only)
- name: Build Docker image (production)
if: github.ref == 'refs/heads/main'
- name: Build and push Docker image
uses: https://github.com/docker/build-push-action@v5
with:
context: .
push: true
tags: |
manticorum67/major-domo-discordapp:latest
manticorum67/major-domo-discordapp:${{ steps.calver.outputs.version }}
manticorum67/major-domo-discordapp:${{ steps.calver.outputs.version_sha }}
tags: ${{ steps.tags.outputs.tags }}
cache-from: type=registry,ref=manticorum67/major-domo-discordapp:buildcache
cache-to: type=registry,ref=manticorum67/major-domo-discordapp:buildcache,mode=max
@ -77,38 +69,35 @@ jobs:
run: |
echo "## Docker Build Successful" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Channel:** \`${{ steps.tags.outputs.channel }}\`" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Image Tags:**" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-discordapp:latest\`" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-discordapp:${{ steps.calver.outputs.version }}\`" >> $GITHUB_STEP_SUMMARY
echo "- \`manticorum67/major-domo-discordapp:${{ steps.calver.outputs.version_sha }}\`" >> $GITHUB_STEP_SUMMARY
IFS=',' read -ra TAG_ARRAY <<< "${{ steps.tags.outputs.tags }}"
for tag in "${TAG_ARRAY[@]}"; do
echo "- \`${tag}\`" >> $GITHUB_STEP_SUMMARY
done
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Build Details:**" >> $GITHUB_STEP_SUMMARY
echo "- Branch: \`${{ steps.calver.outputs.branch }}\`" >> $GITHUB_STEP_SUMMARY
echo "- Commit: \`${{ github.sha }}\`" >> $GITHUB_STEP_SUMMARY
echo "- Timestamp: \`${{ steps.calver.outputs.timestamp }}\`" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "Pushed to Docker Hub!" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Pull with: \`docker pull manticorum67/major-domo-discordapp:latest\`" >> $GITHUB_STEP_SUMMARY
else
echo "_PR build - image not pushed to Docker Hub_" >> $GITHUB_STEP_SUMMARY
fi
echo "Pull with: \`docker pull manticorum67/major-domo-discordapp:${{ steps.tags.outputs.primary_tag }}\`" >> $GITHUB_STEP_SUMMARY
- name: Discord Notification - Success
if: success() && github.ref == 'refs/heads/main'
if: success() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next-release')
uses: cal/gitea-actions/discord-notify@main
with:
webhook_url: ${{ secrets.DISCORD_WEBHOOK }}
title: "Major Domo Bot"
status: success
version: ${{ steps.calver.outputs.version }}
image_tag: ${{ steps.calver.outputs.version_sha }}
image_tag: ${{ steps.tags.outputs.primary_tag }}
commit_sha: ${{ steps.calver.outputs.sha_short }}
timestamp: ${{ steps.calver.outputs.timestamp }}
- name: Discord Notification - Failure
if: failure() && github.ref == 'refs/heads/main'
if: failure() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next-release')
uses: cal/gitea-actions/discord-notify@main
with:
webhook_url: ${{ secrets.DISCORD_WEBHOOK }}

View File

@ -16,10 +16,15 @@ manticorum67/major-domo-discordapp
There is NO DASH between "discord" and "app". Not `discord-app`, not `discordapp-v2`.
### Git Workflow
NEVER commit directly to `main`. Always use feature branches:
NEVER commit directly to `main` or `next-release`. Always use feature branches.
**Branch from `next-release`** for normal work targeting the next release:
```bash
git checkout -b feature/name # or fix/name
git checkout -b feature/name origin/next-release # or fix/name, refactor/name
```
**Branch from `main`** only for urgent hotfixes that bypass the release cycle.
PRs go to `next-release` (staging), then `next-release → main` when releasing.
### Double Emoji in Embeds
`EmbedTemplate.success/error/warning/info/loading()` auto-add emoji prefixes.

124
bot.py
View File

@ -17,14 +17,13 @@ from discord.ext import commands
from config import get_config
from exceptions import BotException
from api.client import get_global_client, cleanup_global_client
from utils.logging import JSONFormatter
from utils.random_gen import STARTUP_WATCHING, random_from_list
from views.embeds import EmbedTemplate
def setup_logging():
"""Configure hybrid logging: human-readable console + structured JSON files."""
from utils.logging import JSONFormatter
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
@ -81,11 +80,28 @@ class SBABot(commands.Bot):
)
self.logger = logging.getLogger("discord_bot_v2")
self.maintenance_mode: bool = False
async def setup_hook(self):
"""Called when the bot is starting up."""
self.logger.info("Setting up bot...")
@self.tree.interaction_check
async def maintenance_check(interaction: discord.Interaction) -> bool:
"""Block non-admin users when maintenance mode is enabled."""
if not self.maintenance_mode:
return True
if (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
):
return True
await interaction.response.send_message(
"🔧 The bot is currently in maintenance mode. Please try again later.",
ephemeral=True,
)
return False
# Load command packages
await self._load_command_packages()
@ -220,43 +236,45 @@ class SBABot(commands.Bot):
f"❌ Failed to initialize background tasks: {e}", exc_info=True
)
def _compute_command_hash(self) -> str:
"""Compute a hash of the current command tree for change detection."""
commands_data = []
for cmd in self.tree.get_commands():
# Handle different command types properly
cmd_dict = {}
cmd_dict["name"] = cmd.name
cmd_dict["type"] = type(cmd).__name__
# Add description if available (most command types have this)
if hasattr(cmd, "description"):
cmd_dict["description"] = cmd.description # type: ignore
# Add parameters for Command objects
if isinstance(cmd, discord.app_commands.Command):
cmd_dict["parameters"] = [
{
"name": param.name,
"description": param.description,
"required": param.required,
"type": str(param.type),
}
for param in cmd.parameters
]
elif isinstance(cmd, discord.app_commands.Group):
# For groups, include subcommands
cmd_dict["subcommands"] = [subcmd.name for subcmd in cmd.commands]
commands_data.append(cmd_dict)
commands_data.sort(key=lambda x: x["name"])
return hashlib.sha256(
json.dumps(commands_data, sort_keys=True).encode()
).hexdigest()
async def _should_sync_commands(self) -> bool:
"""Check if commands have changed since last sync."""
try:
# Create hash of current command tree
commands_data = []
for cmd in self.tree.get_commands():
# Handle different command types properly
cmd_dict = {}
cmd_dict["name"] = cmd.name
cmd_dict["type"] = type(cmd).__name__
# Add description if available (most command types have this)
if hasattr(cmd, "description"):
cmd_dict["description"] = cmd.description # type: ignore
# Add parameters for Command objects
if isinstance(cmd, discord.app_commands.Command):
cmd_dict["parameters"] = [
{
"name": param.name,
"description": param.description,
"required": param.required,
"type": str(param.type),
}
for param in cmd.parameters
]
elif isinstance(cmd, discord.app_commands.Group):
# For groups, include subcommands
cmd_dict["subcommands"] = [subcmd.name for subcmd in cmd.commands]
commands_data.append(cmd_dict)
# Sort for consistent hashing
commands_data.sort(key=lambda x: x["name"])
current_hash = hashlib.sha256(
json.dumps(commands_data, sort_keys=True).encode()
).hexdigest()
current_hash = self._compute_command_hash()
# Compare with stored hash
hash_file = ".last_command_hash"
@ -276,39 +294,7 @@ class SBABot(commands.Bot):
async def _save_command_hash(self):
"""Save current command hash for future comparison."""
try:
# Create hash of current command tree (same logic as _should_sync_commands)
commands_data = []
for cmd in self.tree.get_commands():
# Handle different command types properly
cmd_dict = {}
cmd_dict["name"] = cmd.name
cmd_dict["type"] = type(cmd).__name__
# Add description if available (most command types have this)
if hasattr(cmd, "description"):
cmd_dict["description"] = cmd.description # type: ignore
# Add parameters for Command objects
if isinstance(cmd, discord.app_commands.Command):
cmd_dict["parameters"] = [
{
"name": param.name,
"description": param.description,
"required": param.required,
"type": str(param.type),
}
for param in cmd.parameters
]
elif isinstance(cmd, discord.app_commands.Group):
# For groups, include subcommands
cmd_dict["subcommands"] = [subcmd.name for subcmd in cmd.commands]
commands_data.append(cmd_dict)
commands_data.sort(key=lambda x: x["name"])
current_hash = hashlib.sha256(
json.dumps(commands_data, sort_keys=True).encode()
).hexdigest()
current_hash = self._compute_command_hash()
# Save hash to file
with open(".last_command_hash", "w") as f:

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ Draft Admin Commands
Admin-only commands for draft management and configuration.
"""
from typing import Optional
import discord
@ -16,6 +17,7 @@ from services.draft_sheet_service import get_draft_sheet_service
from utils.logging import get_contextual_logger
from utils.decorators import logged_command
from utils.permissions import league_admin_only
from utils.draft_helpers import format_pick_display
from views.draft_views import create_admin_draft_info_embed
from views.embeds import EmbedTemplate
@ -25,11 +27,10 @@ class DraftAdminGroup(app_commands.Group):
def __init__(self, bot: commands.Bot):
super().__init__(
name="draft-admin",
description="Admin commands for draft management"
name="draft-admin", description="Admin commands for draft management"
)
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.DraftAdminGroup')
self.logger = get_contextual_logger(f"{__name__}.DraftAdminGroup")
def _ensure_monitor_running(self) -> str:
"""
@ -40,7 +41,7 @@ class DraftAdminGroup(app_commands.Group):
"""
from tasks.draft_monitor import setup_draft_monitor
if not hasattr(self.bot, 'draft_monitor') or self.bot.draft_monitor is None:
if not hasattr(self.bot, "draft_monitor") or self.bot.draft_monitor is None:
self.bot.draft_monitor = setup_draft_monitor(self.bot)
self.logger.info("Draft monitor task started")
return "\n\n🤖 **Draft monitor started** - auto-draft and warnings active"
@ -63,8 +64,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -72,8 +72,7 @@ class DraftAdminGroup(app_commands.Group):
# Get current pick
config = get_config()
current_pick = await draft_pick_service.get_pick(
config.sba_season,
draft_data.currentpick
config.sba_season, draft_data.currentpick
)
# Get sheet URL
@ -86,7 +85,7 @@ class DraftAdminGroup(app_commands.Group):
@app_commands.command(name="timer", description="Enable or disable draft timer")
@app_commands.describe(
enabled="Turn timer on or off",
minutes="Minutes per pick (optional, default uses current setting)"
minutes="Minutes per pick (optional, default uses current setting)",
)
@league_admin_only()
@logged_command("/draft-admin timer")
@ -94,7 +93,7 @@ class DraftAdminGroup(app_commands.Group):
self,
interaction: discord.Interaction,
enabled: bool,
minutes: Optional[int] = None
minutes: Optional[int] = None,
):
"""Enable or disable the draft timer."""
await interaction.response.defer()
@ -103,8 +102,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -114,8 +112,7 @@ class DraftAdminGroup(app_commands.Group):
if not updated:
embed = EmbedTemplate.error(
"Update Failed",
"Failed to update draft timer."
"Update Failed", "Failed to update draft timer."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -148,15 +145,11 @@ class DraftAdminGroup(app_commands.Group):
await interaction.followup.send(embed=embed)
@app_commands.command(name="set-pick", description="Set current pick number")
@app_commands.describe(
pick_number="Overall pick number to jump to (1-512)"
)
@app_commands.describe(pick_number="Overall pick number to jump to (1-512)")
@league_admin_only()
@logged_command("/draft-admin set-pick")
async def draft_admin_set_pick(
self,
interaction: discord.Interaction,
pick_number: int
self, interaction: discord.Interaction, pick_number: int
):
"""Set the current pick number (admin operation)."""
await interaction.response.defer()
@ -167,7 +160,7 @@ class DraftAdminGroup(app_commands.Group):
if pick_number < 1 or pick_number > config.draft_total_picks:
embed = EmbedTemplate.error(
"Invalid Pick Number",
f"Pick number must be between 1 and {config.draft_total_picks}."
f"Pick number must be between 1 and {config.draft_total_picks}.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -176,8 +169,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -186,38 +178,36 @@ class DraftAdminGroup(app_commands.Group):
pick = await draft_pick_service.get_pick(config.sba_season, pick_number)
if not pick:
embed = EmbedTemplate.error(
"Pick Not Found",
f"Pick #{pick_number} does not exist in the database."
"Pick Not Found", f"Pick #{pick_number} does not exist in the database."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Update current pick
updated = await draft_service.set_current_pick(
draft_data.id,
pick_number,
reset_timer=True
draft_data.id, pick_number, reset_timer=True
)
if not updated:
embed = EmbedTemplate.error(
"Update Failed",
"Failed to update current pick."
"Update Failed", "Failed to update current pick."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Success message
from utils.draft_helpers import format_pick_display
description = f"Current pick set to **{format_pick_display(pick_number)}**."
if pick.owner:
description += f"\n\n{pick.owner.abbrev} {pick.owner.sname} is now on the clock."
description += (
f"\n\n{pick.owner.abbrev} {pick.owner.sname} is now on the clock."
)
# Add timer status and ensure monitor is running if timer is active
if updated.timer and updated.pick_deadline:
deadline_timestamp = int(updated.pick_deadline.timestamp())
description += f"\n\n⏱️ **Timer Active** - Deadline <t:{deadline_timestamp}:R>"
description += (
f"\n\n⏱️ **Timer Active** - Deadline <t:{deadline_timestamp}:R>"
)
# Ensure monitor is running
monitor_status = self._ensure_monitor_running()
description += monitor_status
@ -227,10 +217,12 @@ class DraftAdminGroup(app_commands.Group):
embed = EmbedTemplate.success("Pick Updated", description)
await interaction.followup.send(embed=embed)
@app_commands.command(name="channels", description="Configure draft Discord channels")
@app_commands.command(
name="channels", description="Configure draft Discord channels"
)
@app_commands.describe(
ping_channel="Channel for 'on the clock' pings",
result_channel="Channel for draft results"
result_channel="Channel for draft results",
)
@league_admin_only()
@logged_command("/draft-admin channels")
@ -238,15 +230,14 @@ class DraftAdminGroup(app_commands.Group):
self,
interaction: discord.Interaction,
ping_channel: Optional[discord.TextChannel] = None,
result_channel: Optional[discord.TextChannel] = None
result_channel: Optional[discord.TextChannel] = None,
):
"""Configure draft Discord channels."""
await interaction.response.defer()
if not ping_channel and not result_channel:
embed = EmbedTemplate.error(
"No Channels Provided",
"Please specify at least one channel to update."
"No Channels Provided", "Please specify at least one channel to update."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -255,8 +246,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -265,13 +255,12 @@ class DraftAdminGroup(app_commands.Group):
updated = await draft_service.update_channels(
draft_data.id,
ping_channel_id=ping_channel.id if ping_channel else None,
result_channel_id=result_channel.id if result_channel else None
result_channel_id=result_channel.id if result_channel else None,
)
if not updated:
embed = EmbedTemplate.error(
"Update Failed",
"Failed to update draft channels."
"Update Failed", "Failed to update draft channels."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -286,16 +275,14 @@ class DraftAdminGroup(app_commands.Group):
embed = EmbedTemplate.success("Channels Updated", description)
await interaction.followup.send(embed=embed)
@app_commands.command(name="reset-deadline", description="Reset current pick deadline")
@app_commands.describe(
minutes="Minutes to add (uses default if not provided)"
@app_commands.command(
name="reset-deadline", description="Reset current pick deadline"
)
@app_commands.describe(minutes="Minutes to add (uses default if not provided)")
@league_admin_only()
@logged_command("/draft-admin reset-deadline")
async def draft_admin_reset_deadline(
self,
interaction: discord.Interaction,
minutes: Optional[int] = None
self, interaction: discord.Interaction, minutes: Optional[int] = None
):
"""Reset the current pick deadline."""
await interaction.response.defer()
@ -304,8 +291,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -313,7 +299,7 @@ class DraftAdminGroup(app_commands.Group):
if not draft_data.timer:
embed = EmbedTemplate.warning(
"Timer Inactive",
"Draft timer is currently disabled. Enable it with `/draft-admin timer on` first."
"Draft timer is currently disabled. Enable it with `/draft-admin timer on` first.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -323,8 +309,7 @@ class DraftAdminGroup(app_commands.Group):
if not updated:
embed = EmbedTemplate.error(
"Update Failed",
"Failed to reset draft deadline."
"Update Failed", "Failed to reset draft deadline."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -334,7 +319,9 @@ class DraftAdminGroup(app_commands.Group):
minutes_used = minutes if minutes else updated.pick_minutes
description = f"Pick deadline reset: **{minutes_used} minutes** added.\n\n"
description += f"New deadline: <t:{deadline_timestamp}:F> (<t:{deadline_timestamp}:R>)"
description += (
f"New deadline: <t:{deadline_timestamp}:F> (<t:{deadline_timestamp}:R>)"
)
embed = EmbedTemplate.success("Deadline Reset", description)
await interaction.followup.send(embed=embed)
@ -350,8 +337,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -359,8 +345,7 @@ class DraftAdminGroup(app_commands.Group):
# Check if already paused
if draft_data.paused:
embed = EmbedTemplate.warning(
"Already Paused",
"The draft is already paused."
"Already Paused", "The draft is already paused."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -369,10 +354,7 @@ class DraftAdminGroup(app_commands.Group):
updated = await draft_service.pause_draft(draft_data.id)
if not updated:
embed = EmbedTemplate.error(
"Pause Failed",
"Failed to pause the draft."
)
embed = EmbedTemplate.error("Pause Failed", "Failed to pause the draft.")
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -400,8 +382,7 @@ class DraftAdminGroup(app_commands.Group):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -409,8 +390,7 @@ class DraftAdminGroup(app_commands.Group):
# Check if already unpaused
if not draft_data.paused:
embed = EmbedTemplate.warning(
"Not Paused",
"The draft is not currently paused."
"Not Paused", "The draft is not currently paused."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -419,10 +399,7 @@ class DraftAdminGroup(app_commands.Group):
updated = await draft_service.resume_draft(draft_data.id)
if not updated:
embed = EmbedTemplate.error(
"Resume Failed",
"Failed to resume the draft."
)
embed = EmbedTemplate.error("Resume Failed", "Failed to resume the draft.")
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -432,7 +409,9 @@ class DraftAdminGroup(app_commands.Group):
# Add timer info if active
if updated.timer and updated.pick_deadline:
deadline_timestamp = int(updated.pick_deadline.timestamp())
description += f"\n\n⏱️ **Timer Active** - Current deadline <t:{deadline_timestamp}:R>"
description += (
f"\n\n⏱️ **Timer Active** - Current deadline <t:{deadline_timestamp}:R>"
)
# Ensure monitor is running
monitor_status = self._ensure_monitor_running()
@ -441,7 +420,9 @@ class DraftAdminGroup(app_commands.Group):
embed = EmbedTemplate.success("Draft Resumed", description)
await interaction.followup.send(embed=embed)
@app_commands.command(name="resync-sheet", description="Resync all picks to Google Sheet")
@app_commands.command(
name="resync-sheet", description="Resync all picks to Google Sheet"
)
@league_admin_only()
@logged_command("/draft-admin resync-sheet")
async def draft_admin_resync_sheet(self, interaction: discord.Interaction):
@ -458,8 +439,7 @@ class DraftAdminGroup(app_commands.Group):
# Check if sheet integration is enabled
if not config.draft_sheet_enabled:
embed = EmbedTemplate.warning(
"Sheet Disabled",
"Draft sheet integration is currently disabled."
"Sheet Disabled", "Draft sheet integration is currently disabled."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -469,7 +449,7 @@ class DraftAdminGroup(app_commands.Group):
if not sheet_url:
embed = EmbedTemplate.error(
"No Sheet Configured",
f"No draft sheet is configured for season {config.sba_season}."
f"No draft sheet is configured for season {config.sba_season}.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -479,8 +459,7 @@ class DraftAdminGroup(app_commands.Group):
if not all_picks:
embed = EmbedTemplate.warning(
"No Picks Found",
"No draft picks found for the current season."
"No Picks Found", "No draft picks found for the current season."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -490,8 +469,7 @@ class DraftAdminGroup(app_commands.Group):
if not completed_picks:
embed = EmbedTemplate.warning(
"No Completed Picks",
"No draft picks have been made yet."
"No Completed Picks", "No draft picks have been made yet."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -499,40 +477,37 @@ class DraftAdminGroup(app_commands.Group):
# Prepare pick data for batch write
pick_data = []
for pick in completed_picks:
orig_abbrev = pick.origowner.abbrev if pick.origowner else (pick.owner.abbrev if pick.owner else "???")
orig_abbrev = (
pick.origowner.abbrev
if pick.origowner
else (pick.owner.abbrev if pick.owner else "???")
)
owner_abbrev = pick.owner.abbrev if pick.owner else "???"
player_name = pick.player.name if pick.player else "Unknown"
swar = pick.player.wara if pick.player else 0.0
pick_data.append((
pick.overall,
orig_abbrev,
owner_abbrev,
player_name,
swar
))
pick_data.append(
(pick.overall, orig_abbrev, owner_abbrev, player_name, swar)
)
# Get draft sheet service
draft_sheet_service = get_draft_sheet_service()
# Clear existing sheet data first
cleared = await draft_sheet_service.clear_picks_range(
config.sba_season,
start_overall=1,
end_overall=config.draft_total_picks
config.sba_season, start_overall=1, end_overall=config.draft_total_picks
)
if not cleared:
embed = EmbedTemplate.warning(
"Clear Failed",
"Failed to clear existing sheet data. Attempting to write picks anyway..."
"Failed to clear existing sheet data. Attempting to write picks anyway...",
)
# Don't return - try to write anyway
# Write all picks in batch
success_count, failure_count = await draft_sheet_service.write_picks_batch(
config.sba_season,
pick_data
config.sba_season, pick_data
)
# Build result message

View File

@ -16,6 +16,7 @@ from config import get_config
from services.draft_service import draft_service
from services.draft_pick_service import draft_pick_service
from services.draft_sheet_service import get_draft_sheet_service
from services.league_service import league_service
from services.player_service import player_service
from services.team_service import team_service
from services.roster_service import roster_service
@ -290,8 +291,6 @@ class DraftPicksCog(commands.Cog):
return
# Get current league state for dem_week calculation
from services.league_service import league_service
current = await league_service.get_current_state()
# Update player team with dem_week set to current.week + 2 for draft picks

View File

@ -3,12 +3,14 @@ Draft Status Commands
Display current draft state and information.
"""
import discord
from discord.ext import commands
from config import get_config
from services.draft_service import draft_service
from services.draft_pick_service import draft_pick_service
from services.team_service import team_service
from utils.logging import get_contextual_logger
from utils.decorators import logged_command
from utils.permissions import requires_team
@ -21,11 +23,11 @@ class DraftStatusCommands(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.DraftStatusCommands')
self.logger = get_contextual_logger(f"{__name__}.DraftStatusCommands")
@discord.app_commands.command(
name="draft-status",
description="View current draft state and timer information"
description="View current draft state and timer information",
)
@requires_team()
@logged_command("/draft-status")
@ -39,34 +41,33 @@ class DraftStatusCommands(commands.Cog):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Get current pick
current_pick = await draft_pick_service.get_pick(
config.sba_season,
draft_data.currentpick
config.sba_season, draft_data.currentpick
)
if not current_pick:
embed = EmbedTemplate.error(
"Pick Not Found",
f"Could not retrieve pick #{draft_data.currentpick}."
"Pick Not Found", f"Could not retrieve pick #{draft_data.currentpick}."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Check pick lock status
draft_picks_cog = self.bot.get_cog('DraftPicksCog')
draft_picks_cog = self.bot.get_cog("DraftPicksCog")
lock_status = "🔓 No pick in progress"
if draft_picks_cog and draft_picks_cog.pick_lock.locked():
if draft_picks_cog.lock_acquired_by:
user = self.bot.get_user(draft_picks_cog.lock_acquired_by)
user_name = user.name if user else f"User {draft_picks_cog.lock_acquired_by}"
user_name = (
user.name if user else f"User {draft_picks_cog.lock_acquired_by}"
)
lock_status = f"🔒 Pick in progress by {user_name}"
else:
lock_status = "🔒 Pick in progress (system)"
@ -75,12 +76,13 @@ class DraftStatusCommands(commands.Cog):
sheet_url = config.get_draft_sheet_url(config.sba_season)
# Create status embed
embed = await create_draft_status_embed(draft_data, current_pick, lock_status, sheet_url)
embed = await create_draft_status_embed(
draft_data, current_pick, lock_status, sheet_url
)
await interaction.followup.send(embed=embed)
@discord.app_commands.command(
name="draft-on-clock",
description="View detailed 'on the clock' information"
name="draft-on-clock", description="View detailed 'on the clock' information"
)
@requires_team()
@logged_command("/draft-on-clock")
@ -94,47 +96,39 @@ class DraftStatusCommands(commands.Cog):
draft_data = await draft_service.get_draft_data()
if not draft_data:
embed = EmbedTemplate.error(
"Draft Not Found",
"Could not retrieve draft configuration."
"Draft Not Found", "Could not retrieve draft configuration."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Get current pick
current_pick = await draft_pick_service.get_pick(
config.sba_season,
draft_data.currentpick
config.sba_season, draft_data.currentpick
)
if not current_pick or not current_pick.owner:
embed = EmbedTemplate.error(
"Pick Not Found",
f"Could not retrieve pick #{draft_data.currentpick}."
"Pick Not Found", f"Could not retrieve pick #{draft_data.currentpick}."
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Get recent picks
recent_picks = await draft_pick_service.get_recent_picks(
config.sba_season,
draft_data.currentpick,
limit=5
config.sba_season, draft_data.currentpick, limit=5
)
# Get upcoming picks
upcoming_picks = await draft_pick_service.get_upcoming_picks(
config.sba_season,
draft_data.currentpick,
limit=5
config.sba_season, draft_data.currentpick, limit=5
)
# Get team roster sWAR (optional)
from services.team_service import team_service
team_roster_swar = None
roster = await team_service.get_team_roster(current_pick.owner.id, 'current')
if roster and roster.get('active'):
team_roster_swar = roster['active'].get('WARa')
roster = await team_service.get_team_roster(current_pick.owner.id, "current")
if roster and roster.get("active"):
team_roster_swar = roster["active"].get("WARa")
# Get sheet URL
sheet_url = config.get_draft_sheet_url(config.sba_season)
@ -146,7 +140,7 @@ class DraftStatusCommands(commands.Cog):
recent_picks,
upcoming_picks,
team_roster_swar,
sheet_url
sheet_url,
)
await interaction.followup.send(embed=embed)

View File

@ -22,6 +22,7 @@ from models.team import RosterType
from services.player_service import player_service
from services.injury_service import injury_service
from services.league_service import league_service
from services.team_service import team_service
from services.giphy_service import GiphyService
from utils import team_utils
from utils.logging import get_contextual_logger
@ -42,6 +43,52 @@ class InjuryGroup(app_commands.Group):
self.logger = get_contextual_logger(f"{__name__}.InjuryGroup")
self.logger.info("InjuryGroup initialized")
async def _verify_team_ownership(
self, interaction: discord.Interaction, player: "Player"
) -> bool:
"""
Verify the invoking user owns the team the player is on.
Returns True if ownership is confirmed, False if denied (sends error embed).
Admins bypass the check.
"""
# Admins can manage any team's injuries
if (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
):
return True
if not player.team_id:
return True # Can't verify without team data, allow through
from services.team_service import team_service
config = get_config()
user_team = await team_service.get_team_by_owner(
owner_id=interaction.user.id,
season=config.sba_season,
)
if user_team is None:
embed = EmbedTemplate.error(
title="No Team Found",
description="You don't appear to own a team this season.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return False
player_team = player.team
if player_team is None or not user_team.is_same_organization(player_team):
embed = EmbedTemplate.error(
title="Not Your Player",
description=f"**{player.name}** is not on your team. You can only manage injuries for your own players.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return False
return True
def has_player_role(self, interaction: discord.Interaction) -> bool:
"""Check if user has the SBA Players role."""
# Cast to Member to access roles (User doesn't have roles attribute)
@ -89,8 +136,6 @@ class InjuryGroup(app_commands.Group):
# Fetch full team data if team is not populated
if player.team_id and not player.team:
from services.team_service import team_service
player.team = await team_service.get_team(player.team_id)
# Check if player already has an active injury
@ -507,14 +552,11 @@ class InjuryGroup(app_commands.Group):
# Fetch full team data if team is not populated
if player.team_id and not player.team:
from services.team_service import team_service
player.team = await team_service.get_team(player.team_id)
# Check if player is on user's team
# Note: This assumes you have a function to get team by owner
# For now, we'll skip this check - you can add it if needed
# TODO: Add team ownership verification
# Verify the invoking user owns this player's team
if not await self._verify_team_ownership(interaction, player):
return
# Check if player already has an active injury
existing_injury = await injury_service.get_active_injury(
@ -697,10 +739,12 @@ class InjuryGroup(app_commands.Group):
# Fetch full team data if team is not populated
if player.team_id and not player.team:
from services.team_service import team_service
player.team = await team_service.get_team(player.team_id)
# Verify the invoking user owns this player's team
if not await self._verify_team_ownership(interaction, player):
return
# Get active injury
injury = await injury_service.get_active_injury(player.id, current.season)

View File

@ -210,17 +210,41 @@ class SubmitScorecardCommands(commands.Cog):
game_id = scheduled_game.id
# Phase 6: Read Scorecard Data
# Phase 6: Read ALL Scorecard Data (before any DB writes)
# Reading everything first prevents partial commits if the
# spreadsheet has formula errors (e.g. #N/A in pitching decisions)
await interaction.edit_original_response(
content="📊 Reading play-by-play data..."
content="📊 Reading scorecard data..."
)
plays_data = await self.sheets_service.read_playtable_data(scorecard)
box_score = await self.sheets_service.read_box_score(scorecard)
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game_id to each play
for play in plays_data:
play["game_id"] = game_id
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 7: POST Plays
await interaction.edit_original_response(
content="💾 Submitting plays to database..."
@ -244,10 +268,7 @@ class SubmitScorecardCommands(commands.Cog):
)
return
# Phase 8: Read Box Score
box_score = await self.sheets_service.read_box_score(scorecard)
# Phase 9: PATCH Game
# Phase 8: PATCH Game
await interaction.edit_original_response(
content="⚾ Updating game result..."
)
@ -275,33 +296,7 @@ class SubmitScorecardCommands(commands.Cog):
)
return
# Phase 10: Read Pitching Decisions
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
# Rollback
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 11: POST Decisions
# Phase 9: POST Decisions
await interaction.edit_original_response(
content="🎯 Submitting pitching decisions..."
)
@ -361,6 +356,30 @@ class SubmitScorecardCommands(commands.Cog):
# Success!
await interaction.edit_original_response(content="✅ You are all set!")
except SheetsException as e:
# Spreadsheet reading error - show the detailed message to the user
self.logger.error(
f"Spreadsheet error in scorecard submission: {e}", error=e
)
if rollback_state and game_id:
try:
if rollback_state == "GAME_PATCHED":
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
elif rollback_state == "PLAYS_POSTED":
await play_service.delete_plays_for_game(game_id)
except Exception:
pass # Best effort rollback
await interaction.edit_original_response(
content=(
f"❌ There's a problem with your scorecard:\n\n"
f"{str(e)}\n\n"
f"Please fix the issue in your spreadsheet and resubmit."
)
)
except Exception as e:
# Unexpected error - attempt rollback
self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e)

View File

@ -4,6 +4,7 @@ Player Information Commands
Implements slash commands for displaying player information and statistics.
"""
import asyncio
from typing import Optional, List
import discord
@ -218,8 +219,6 @@ class PlayerInfoCommands(commands.Cog):
)
# Fetch player data and stats concurrently for better performance
import asyncio
player_with_team, (batting_stats, pitching_stats) = await asyncio.gather(
player_service.get_player(player.id),
stats_service.get_player_stats(player.id, search_season),

View File

@ -4,6 +4,7 @@ Player Image Management Commands
Allows users to update player fancy card and headshot images for players
on teams they own. Admins can update any player's images.
"""
from typing import List, Tuple
import asyncio
import aiohttp
@ -15,15 +16,17 @@ from discord.ext import commands
from config import get_config
from services.player_service import player_service
from services.team_service import team_service
from utils.autocomplete import player_autocomplete
from utils.logging import get_contextual_logger
from utils.decorators import logged_command
from views.embeds import EmbedColors, EmbedTemplate
from views.base import BaseView
from models.player import Player
from utils.permissions import is_admin
# URL Validation Functions
def validate_url_format(url: str) -> Tuple[bool, str]:
"""
Validate URL format for image links.
@ -40,17 +43,20 @@ def validate_url_format(url: str) -> Tuple[bool, str]:
return False, "URL too long (max 500 characters)"
# Protocol check
if not url.startswith(('http://', 'https://')):
if not url.startswith(("http://", "https://")):
return False, "URL must start with http:// or https://"
# Image extension check
valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp')
valid_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp")
url_lower = url.lower()
# Check if URL ends with valid extension (ignore query params)
base_url = url_lower.split('?')[0] # Remove query parameters
base_url = url_lower.split("?")[0] # Remove query parameters
if not any(base_url.endswith(ext) for ext in valid_extensions):
return False, f"URL must end with a valid image extension: {', '.join(valid_extensions)}"
return (
False,
f"URL must end with a valid image extension: {', '.join(valid_extensions)}",
)
return True, ""
@ -68,14 +74,19 @@ async def check_url_accessibility(url: str) -> Tuple[bool, str]:
"""
try:
async with aiohttp.ClientSession() as session:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=5)) as response:
async with session.head(
url, timeout=aiohttp.ClientTimeout(total=5)
) as response:
if response.status != 200:
return False, f"URL returned status {response.status}"
# Check content-type header
content_type = response.headers.get('content-type', '').lower()
if content_type and not content_type.startswith('image/'):
return False, f"URL does not return an image (content-type: {content_type})"
content_type = response.headers.get("content-type", "").lower()
if content_type and not content_type.startswith("image/"):
return (
False,
f"URL does not return an image (content-type: {content_type})",
)
return True, ""
@ -89,11 +100,9 @@ async def check_url_accessibility(url: str) -> Tuple[bool, str]:
# Permission Checking
async def can_edit_player_image(
interaction: discord.Interaction,
player: Player,
season: int,
logger
interaction: discord.Interaction, player: Player, season: int, logger
) -> Tuple[bool, str]:
"""
Check if user can edit player's image.
@ -109,7 +118,7 @@ async def can_edit_player_image(
If has permission, error_message is empty string
"""
# Admins can edit anyone
if interaction.user.guild_permissions.administrator:
if is_admin(interaction):
logger.debug("User is admin, granting permission", user_id=interaction.user.id)
return True, ""
@ -130,7 +139,7 @@ async def can_edit_player_image(
"User owns organization, granting permission",
user_id=interaction.user.id,
user_team=user_team.abbrev,
player_team=player.team.abbrev
player_team=player.team.abbrev,
)
return True, ""
@ -141,6 +150,7 @@ async def can_edit_player_image(
# Confirmation View
class ImageUpdateConfirmView(BaseView):
"""Confirmation view showing image preview before updating."""
@ -151,27 +161,33 @@ class ImageUpdateConfirmView(BaseView):
self.image_type = image_type
self.confirmed = False
@discord.ui.button(label="Confirm Update", style=discord.ButtonStyle.success, emoji="")
async def confirm_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="Confirm Update", style=discord.ButtonStyle.success, emoji=""
)
async def confirm_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Confirm the image update."""
self.confirmed = True
# Disable all buttons
for item in self.children:
if hasattr(item, 'disabled'):
if hasattr(item, "disabled"):
item.disabled = True # type: ignore
await interaction.response.edit_message(view=self)
self.stop()
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary, emoji="")
async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button):
async def cancel_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Cancel the image update."""
self.confirmed = False
# Disable all buttons
for item in self.children:
if hasattr(item, 'disabled'):
if hasattr(item, "disabled"):
item.disabled = True # type: ignore
await interaction.response.edit_message(view=self)
@ -180,6 +196,7 @@ class ImageUpdateConfirmView(BaseView):
# Autocomplete
async def player_name_autocomplete(
interaction: discord.Interaction,
current: str,
@ -190,7 +207,6 @@ async def player_name_autocomplete(
try:
# Use the shared autocomplete utility with team prioritization
from utils.autocomplete import player_autocomplete
return await player_autocomplete(interaction, current)
except Exception:
# Return empty list on error to avoid breaking autocomplete
@ -199,27 +215,29 @@ async def player_name_autocomplete(
# Main Command Cog
class ImageCommands(commands.Cog):
"""Player image management command handlers."""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.ImageCommands')
self.logger = get_contextual_logger(f"{__name__}.ImageCommands")
self.logger.info("ImageCommands cog initialized")
@app_commands.command(
name="set-image",
description="Update a player's fancy card or headshot image"
name="set-image", description="Update a player's fancy card or headshot image"
)
@app_commands.describe(
image_type="Type of image to update",
player_name="Player name (use autocomplete)",
image_url="Direct URL to the image file"
image_url="Direct URL to the image file",
)
@app_commands.choices(
image_type=[
app_commands.Choice(name="Fancy Card", value="fancy-card"),
app_commands.Choice(name="Headshot", value="headshot"),
]
)
@app_commands.choices(image_type=[
app_commands.Choice(name="Fancy Card", value="fancy-card"),
app_commands.Choice(name="Headshot", value="headshot")
])
@app_commands.autocomplete(player_name=player_name_autocomplete)
@logged_command("/set-image")
async def set_image(
@ -227,7 +245,7 @@ class ImageCommands(commands.Cog):
interaction: discord.Interaction,
image_type: app_commands.Choice[str],
player_name: str,
image_url: str
image_url: str,
):
"""Update a player's image (fancy card or headshot)."""
# Defer response for potentially slow operations
@ -242,7 +260,7 @@ class ImageCommands(commands.Cog):
"Image update requested",
user_id=interaction.user.id,
player_name=player_name,
image_type=img_type
image_type=img_type,
)
# Step 1: Validate URL format
@ -252,10 +270,10 @@ class ImageCommands(commands.Cog):
embed = EmbedTemplate.error(
title="Invalid URL Format",
description=f"{format_error}\n\n"
f"**Requirements:**\n"
f"• Must start with `http://` or `https://`\n"
f"• Must end with `.jpg`, `.jpeg`, `.png`, `.gif`, or `.webp`\n"
f"• Maximum 500 characters"
f"**Requirements:**\n"
f"• Must start with `http://` or `https://`\n"
f"• Must end with `.jpg`, `.jpeg`, `.png`, `.gif`, or `.webp`\n"
f"• Maximum 500 characters",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -268,24 +286,26 @@ class ImageCommands(commands.Cog):
embed = EmbedTemplate.error(
title="URL Not Accessible",
description=f"{access_error}\n\n"
f"**Please check:**\n"
f"• URL is correct and not expired\n"
f"• Image host is online\n"
f"• URL points directly to an image file\n"
f"• URL is publicly accessible"
f"**Please check:**\n"
f"• URL is correct and not expired\n"
f"• Image host is online\n"
f"• URL points directly to an image file\n"
f"• URL is publicly accessible",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Step 3: Find player
self.logger.debug("Searching for player", player_name=player_name)
players = await player_service.get_players_by_name(player_name, get_config().sba_season)
players = await player_service.get_players_by_name(
player_name, get_config().sba_season
)
if not players:
self.logger.warning("Player not found", player_name=player_name)
embed = EmbedTemplate.error(
title="Player Not Found",
description=f"❌ No player found matching `{player_name}` in the current season."
description=f"❌ No player found matching `{player_name}` in the current season.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -303,11 +323,13 @@ class ImageCommands(commands.Cog):
if player is None:
# Multiple candidates, ask user to be more specific
player_list = "\n".join([f"{p.name} ({p.primary_position})" for p in players[:10]])
player_list = "\n".join(
[f"{p.name} ({p.primary_position})" for p in players[:10]]
)
embed = EmbedTemplate.info(
title="Multiple Players Found",
description=f"🔍 Multiple players match `{player_name}`:\n\n{player_list}\n\n"
f"Please use the exact name from autocomplete."
f"Please use the exact name from autocomplete.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -324,12 +346,12 @@ class ImageCommands(commands.Cog):
"Permission denied",
user_id=interaction.user.id,
player_id=player.id,
error=permission_error
error=permission_error,
)
embed = EmbedTemplate.error(
title="Permission Denied",
description=f"{permission_error}\n\n"
f"You can only update images for players on teams you own."
f"You can only update images for players on teams you own.",
)
await interaction.followup.send(embed=embed, ephemeral=True)
return
@ -339,52 +361,46 @@ class ImageCommands(commands.Cog):
preview_embed = EmbedTemplate.create_base_embed(
title=f"🖼️ Update {display_name} for {player.name}",
description=f"Preview the new {display_name.lower()} below. Click **Confirm Update** to save this change.",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
# Add current image info
current_image = getattr(player, field_name, None)
if current_image:
preview_embed.add_field(
name="Current Image",
value="Will be replaced",
inline=True
name="Current Image", value="Will be replaced", inline=True
)
else:
preview_embed.add_field(
name="Current Image",
value="None set",
inline=True
)
preview_embed.add_field(name="Current Image", value="None set", inline=True)
# Add player info
preview_embed.add_field(
name="Player",
value=f"{player.name} ({player.primary_position})",
inline=True
inline=True,
)
if hasattr(player, 'team') and player.team:
preview_embed.add_field(
name="Team",
value=player.team.abbrev,
inline=True
)
if hasattr(player, "team") and player.team:
preview_embed.add_field(name="Team", value=player.team.abbrev, inline=True)
# Set the new image as thumbnail for preview
preview_embed.set_thumbnail(url=image_url)
preview_embed.set_footer(text="This preview shows how the image will appear. Confirm to save.")
preview_embed.set_footer(
text="This preview shows how the image will appear. Confirm to save."
)
# Create confirmation view
confirm_view = ImageUpdateConfirmView(
player=player,
image_url=image_url,
image_type=img_type,
user_id=interaction.user.id
user_id=interaction.user.id,
)
await interaction.followup.send(embed=preview_embed, view=confirm_view, ephemeral=True)
await interaction.followup.send(
embed=preview_embed, view=confirm_view, ephemeral=True
)
# Wait for confirmation
await confirm_view.wait()
@ -393,7 +409,7 @@ class ImageCommands(commands.Cog):
self.logger.info("Image update cancelled by user", player_id=player.id)
cancelled_embed = EmbedTemplate.info(
title="Update Cancelled",
description=f"No changes were made to {player.name}'s {display_name.lower()}."
description=f"No changes were made to {player.name}'s {display_name.lower()}.",
)
await interaction.edit_original_response(embed=cancelled_embed, view=None)
return
@ -403,7 +419,7 @@ class ImageCommands(commands.Cog):
"Updating player image",
player_id=player.id,
field=field_name,
url_length=len(image_url)
url_length=len(image_url),
)
update_data = {field_name: image_url}
@ -413,7 +429,7 @@ class ImageCommands(commands.Cog):
self.logger.error("Failed to update player", player_id=player.id)
error_embed = EmbedTemplate.error(
title="Update Failed",
description="❌ An error occurred while updating the player's image. Please try again."
description="❌ An error occurred while updating the player's image. Please try again.",
)
await interaction.edit_original_response(embed=error_embed, view=None)
return
@ -423,32 +439,24 @@ class ImageCommands(commands.Cog):
"Player image updated successfully",
player_id=player.id,
field=field_name,
user_id=interaction.user.id
user_id=interaction.user.id,
)
success_embed = EmbedTemplate.success(
title="Image Updated Successfully!",
description=f"**{display_name}** for **{player.name}** has been updated."
description=f"**{display_name}** for **{player.name}** has been updated.",
)
success_embed.add_field(
name="Player",
value=f"{player.name} ({player.primary_position})",
inline=True
inline=True,
)
if hasattr(player, 'team') and player.team:
success_embed.add_field(
name="Team",
value=player.team.abbrev,
inline=True
)
if hasattr(player, "team") and player.team:
success_embed.add_field(name="Team", value=player.team.abbrev, inline=True)
success_embed.add_field(
name="Image Type",
value=display_name,
inline=True
)
success_embed.add_field(name="Image Type", value=display_name, inline=True)
# Show the new image
success_embed.set_thumbnail(url=image_url)

View File

@ -3,6 +3,9 @@ Soak Info Commands
Provides information about soak mentions without triggering the easter egg.
"""
from datetime import datetime
import discord
from discord import app_commands
from discord.ext import commands
@ -19,11 +22,13 @@ class SoakInfoCommands(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.SoakInfoCommands')
self.logger = get_contextual_logger(f"{__name__}.SoakInfoCommands")
self.tracker = SoakTracker()
self.logger.info("SoakInfoCommands cog initialized")
@app_commands.command(name="lastsoak", description="Get information about the last soak mention")
@app_commands.command(
name="lastsoak", description="Get information about the last soak mention"
)
@logged_command("/lastsoak")
async def last_soak(self, interaction: discord.Interaction):
"""Show information about the last soak mention."""
@ -35,13 +40,9 @@ class SoakInfoCommands(commands.Cog):
if not last_soak:
embed = EmbedTemplate.info(
title="Last Soak",
description="No one has said the forbidden word yet. 🤫"
)
embed.add_field(
name="Total Mentions",
value="0",
inline=False
description="No one has said the forbidden word yet. 🤫",
)
embed.add_field(name="Total Mentions", value="0", inline=False)
await interaction.followup.send(embed=embed)
return
@ -50,23 +51,24 @@ class SoakInfoCommands(commands.Cog):
total_count = self.tracker.get_soak_count()
# Determine disappointment tier
tier_key = get_tier_for_seconds(int(time_since.total_seconds()) if time_since else None)
tier_key = get_tier_for_seconds(
int(time_since.total_seconds()) if time_since else None
)
tier_description = get_tier_description(tier_key)
# Create embed
embed = EmbedTemplate.create_base_embed(
title="📊 Last Soak",
description="Information about the most recent soak mention",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
# Parse timestamp for Discord formatting
try:
from datetime import datetime
timestamp_str = last_soak["timestamp"]
if timestamp_str.endswith('Z'):
timestamp_str = timestamp_str[:-1] + '+00:00'
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
if timestamp_str.endswith("Z"):
timestamp_str = timestamp_str[:-1] + "+00:00"
timestamp = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
unix_timestamp = int(timestamp.timestamp())
# Add relative time with warning if very recent
@ -74,54 +76,44 @@ class SoakInfoCommands(commands.Cog):
if time_since and time_since.total_seconds() < 1800: # Less than 30 minutes
time_field_value += "\n\n😤 Way too soon!"
embed.add_field(
name="Last Mentioned",
value=time_field_value,
inline=False
)
embed.add_field(name="Last Mentioned", value=time_field_value, inline=False)
except Exception as e:
self.logger.error(f"Error parsing timestamp: {e}")
embed.add_field(
name="Last Mentioned",
value="Error parsing timestamp",
inline=False
name="Last Mentioned", value="Error parsing timestamp", inline=False
)
# Add user info
user_mention = f"<@{last_soak['user_id']}>"
display_name = last_soak.get('display_name', last_soak.get('username', 'Unknown'))
display_name = last_soak.get(
"display_name", last_soak.get("username", "Unknown")
)
embed.add_field(
name="By",
value=f"{user_mention} ({display_name})",
inline=True
name="By", value=f"{user_mention} ({display_name})", inline=True
)
# Add message link
try:
guild_id = interaction.guild_id
channel_id = last_soak['channel_id']
message_id = last_soak['message_id']
jump_url = f"https://discord.com/channels/{guild_id}/{channel_id}/{message_id}"
channel_id = last_soak["channel_id"]
message_id = last_soak["message_id"]
jump_url = (
f"https://discord.com/channels/{guild_id}/{channel_id}/{message_id}"
)
embed.add_field(
name="Message",
value=f"[Jump to message]({jump_url})",
inline=True
name="Message", value=f"[Jump to message]({jump_url})", inline=True
)
except Exception as e:
self.logger.error(f"Error creating jump URL: {e}")
# Add total count
embed.add_field(
name="Total Mentions",
value=str(total_count),
inline=True
)
embed.add_field(name="Total Mentions", value=str(total_count), inline=True)
# Add disappointment level
embed.add_field(
name="Disappointment Level",
value=f"{tier_key.replace('_', ' ').title()}: {tier_description}",
inline=False
inline=False,
)
await interaction.followup.send(embed=embed)

View File

@ -1,6 +1,7 @@
"""
Team roster commands for Discord Bot v2.0
"""
from typing import Dict, Any, List
import discord
@ -18,144 +19,173 @@ from views.embeds import EmbedTemplate, EmbedColors
class TeamRosterCommands(commands.Cog):
"""Team roster command handlers."""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.TeamRosterCommands')
self.logger = get_contextual_logger(f"{__name__}.TeamRosterCommands")
self.logger.info("TeamRosterCommands cog initialized")
@discord.app_commands.command(name="roster", description="Display team roster")
@discord.app_commands.describe(
abbrev="Team abbreviation (e.g., BSG, DEN, WV, etc.)",
roster_type="Roster week: current or next (defaults to current)"
roster_type="Roster week: current or next (defaults to current)",
)
@discord.app_commands.choices(
roster_type=[
discord.app_commands.Choice(name="Current Week", value="current"),
discord.app_commands.Choice(name="Next Week", value="next"),
]
)
@discord.app_commands.choices(roster_type=[
discord.app_commands.Choice(name="Current Week", value="current"),
discord.app_commands.Choice(name="Next Week", value="next")
])
@requires_team()
@logged_command("/roster")
async def team_roster(self, interaction: discord.Interaction, abbrev: str,
roster_type: str = "current"):
async def team_roster(
self,
interaction: discord.Interaction,
abbrev: str,
roster_type: str = "current",
):
"""Display team roster with position breakdowns."""
await interaction.response.defer()
# Get team by abbreviation
team = await team_service.get_team_by_abbrev(abbrev, get_config().sba_season)
if team is None:
self.logger.info("Team not found", team_abbrev=abbrev)
embed = EmbedTemplate.error(
title="Team Not Found",
description=f"No team found with abbreviation '{abbrev.upper()}'"
description=f"No team found with abbreviation '{abbrev.upper()}'",
)
await interaction.followup.send(embed=embed)
return
# Get roster data
roster_data = await team_service.get_team_roster(team.id, roster_type)
if not roster_data:
embed = EmbedTemplate.error(
title="Roster Not Available",
description=f"No {roster_type} roster data available for {team.abbrev}"
description=f"No {roster_type} roster data available for {team.abbrev}",
)
await interaction.followup.send(embed=embed)
return
# Create roster embeds
embeds = await self._create_roster_embeds(team, roster_data, roster_type)
# Send first embed and follow up with others if needed
await interaction.followup.send(embed=embeds[0])
for embed in embeds[1:]:
await interaction.followup.send(embed=embed)
async def _create_roster_embeds(self, team: Team, roster_data: Dict[str, Any],
roster_type: str) -> List[discord.Embed]:
async def _create_roster_embeds(
self, team: Team, roster_data: Dict[str, Any], roster_type: str
) -> List[discord.Embed]:
"""Create embeds for team roster data."""
embeds = []
# Main roster embed
embed = EmbedTemplate.create_base_embed(
title=f"{team.abbrev} - {roster_type.title()} Week",
description=f"{team.lname} Roster Breakdown",
color=int(team.color, 16) if team.color else EmbedColors.PRIMARY
color=int(team.color, 16) if team.color else EmbedColors.PRIMARY,
)
# Position counts for active roster
for key in ['active', 'longil', 'shortil']:
if key in roster_data:
roster_titles = {
"active": "Active Roster",
"longil": "Minor League",
"shortil": "Injured List",
}
for key in ["active", "longil", "shortil"]:
if key in roster_data:
this_roster = roster_data[key]
players = this_roster.get('players')
if len(players) > 0:
this_team = players[0].get("team", {"id": "Unknown", "sname": "Unknown"})
embed.add_field(name=roster_titles[key], value="\u200b", inline=False)
embed.add_field(
name='Team (ID)',
value=f'{this_team.get("sname")} ({this_team.get("id")})',
inline=True
players = this_roster.get("players")
if len(players) > 0:
this_team = players[0].get(
"team", {"id": "Unknown", "sname": "Unknown"}
)
embed.add_field(
name='Player Count',
value=f'{len(players)} Players'
name="Team (ID)",
value=f'{this_team.get("sname")} ({this_team.get("id")})',
inline=True,
)
embed.add_field(
name="Player Count", value=f"{len(players)} Players"
)
# Total WAR
total_war = this_roster.get('WARa', 0)
total_war = this_roster.get("WARa", 0)
embed.add_field(
name="Total sWAR",
value=f"{total_war:.2f}" if isinstance(total_war, (int, float)) else str(total_war),
inline=True
name="Total sWAR",
value=(
f"{total_war:.2f}"
if isinstance(total_war, (int, float))
else str(total_war)
),
inline=True,
)
embed.add_field(
name='Position Counts',
name="Position Counts",
value=self._position_code_block(this_roster),
inline=False
inline=False,
)
embeds.append(embed)
# Create detailed player list embeds if there are players
for roster_name, roster_info in roster_data.items():
if roster_name in ['active', 'longil', 'shortil'] and 'players' in roster_info:
players = sorted(roster_info['players'], key=lambda player: player.get('wara', 0), reverse=True)
if (
roster_name in ["active", "longil", "shortil"]
and "players" in roster_info
):
players = sorted(
roster_info["players"],
key=lambda player: player.get("wara", 0),
reverse=True,
)
if players:
player_embed = self._create_player_list_embed(
team, roster_name, players
)
embeds.append(player_embed)
return embeds
def _position_code_block(self, roster_data: dict) -> str:
return f'```\n C 1B 2B 3B SS\n' \
f' {roster_data.get("C", 0)} {roster_data.get("1B", 0)} {roster_data.get("2B", 0)} ' \
f'{roster_data.get("3B", 0)} {roster_data.get("SS", 0)}\n\nLF CF RF SP RP\n' \
f' {roster_data.get("LF", 0)} {roster_data.get("CF", 0)} {roster_data.get("RF", 0)} ' \
f'{roster_data.get("SP", 0)} {roster_data.get("RP", 0)}\n```'
def _create_player_list_embed(self, team: Team, roster_name: str,
players: List[Dict[str, Any]]) -> discord.Embed:
return embeds
def _position_code_block(self, roster_data: dict) -> str:
return (
f"```\n C 1B 2B 3B SS\n"
f' {roster_data.get("C", 0)} {roster_data.get("1B", 0)} {roster_data.get("2B", 0)} '
f'{roster_data.get("3B", 0)} {roster_data.get("SS", 0)}\n\nLF CF RF SP RP\n'
f' {roster_data.get("LF", 0)} {roster_data.get("CF", 0)} {roster_data.get("RF", 0)} '
f'{roster_data.get("SP", 0)} {roster_data.get("RP", 0)}\n```'
)
def _create_player_list_embed(
self, team: Team, roster_name: str, players: List[Dict[str, Any]]
) -> discord.Embed:
"""Create an embed with detailed player list."""
roster_titles = {
'active': 'Active Roster',
'longil': 'Minor League',
'shortil': 'Injured List'
"active": "Active Roster",
"longil": "Minor League",
"shortil": "Injured List",
}
embed = EmbedTemplate.create_base_embed(
title=f"{team.abbrev} - {roster_titles.get(roster_name, roster_name.title())}",
color=int(team.color, 16) if team.color else EmbedColors.PRIMARY
color=int(team.color, 16) if team.color else EmbedColors.PRIMARY,
)
# Group players by position for better organization
batters = []
pitchers = []
for player in players:
try:
this_player = Player.from_api_data(player)
@ -166,8 +196,11 @@ class TeamRosterCommands(commands.Cog):
else:
batters.append(player_line)
except Exception as e:
self.logger.warning(f"Failed to create player from data: {e}", player_id=player.get('id'))
self.logger.warning(
f"Failed to create player from data: {e}",
player_id=player.get("id"),
)
# Add player lists to embed
if batters:
# Split long lists into multiple fields if needed
@ -175,18 +208,18 @@ class TeamRosterCommands(commands.Cog):
for i, chunk in enumerate(batter_chunks):
field_name = "Batters" if i == 0 else f"Batters (cont.)"
embed.add_field(name=field_name, value="\n".join(chunk), inline=True)
embed.add_field(name='', value='', inline=False)
embed.add_field(name="", value="", inline=False)
if pitchers:
pitcher_chunks = self._chunk_list(pitchers, 16)
for i, chunk in enumerate(pitcher_chunks):
field_name = "Pitchers" if i == 0 else f"Pitchers (cont.)"
embed.add_field(name=field_name, value="\n".join(chunk), inline=False)
embed.set_footer(text=f"Total players: {len(players)}")
return embed
def _chunk_list(self, lst: List[str], chunk_size: int) -> List[List[str]]:
"""Split a list into chunks of specified size."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]

View File

@ -3,6 +3,7 @@ Trade Commands
Interactive multi-team trade builder with real-time validation and elegant UX.
"""
from typing import Optional
import discord
@ -12,7 +13,11 @@ from discord import app_commands
from config import get_config
from utils.logging import get_contextual_logger
from utils.decorators import logged_command
from utils.autocomplete import player_autocomplete, major_league_team_autocomplete, team_autocomplete
from utils.autocomplete import (
player_autocomplete,
major_league_team_autocomplete,
team_autocomplete,
)
from utils.team_utils import validate_user_has_team, get_team_by_abbrev_with_validation
from services.trade_builder import (
@ -22,6 +27,7 @@ from services.trade_builder import (
clear_trade_builder_by_team,
)
from services.player_service import player_service
from services.team_service import team_service
from models.team import RosterType
from views.trade_embed import TradeEmbedView, create_trade_embed
from commands.transactions.trade_channels import TradeChannelManager
@ -33,16 +39,20 @@ class TradeCommands(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.TradeCommands')
self.logger = get_contextual_logger(f"{__name__}.TradeCommands")
# Initialize trade channel management
self.channel_tracker = TradeChannelTracker()
self.channel_manager = TradeChannelManager(self.channel_tracker)
# Create the trade command group
trade_group = app_commands.Group(name="trade", description="Multi-team trade management")
trade_group = app_commands.Group(
name="trade", description="Multi-team trade management"
)
def _get_trade_channel(self, guild: discord.Guild, trade_id: str) -> Optional[discord.TextChannel]:
def _get_trade_channel(
self, guild: discord.Guild, trade_id: str
) -> Optional[discord.TextChannel]:
"""Get the trade channel for a given trade ID."""
channel_data = self.channel_tracker.get_channel_by_trade_id(trade_id)
if not channel_data:
@ -55,7 +65,9 @@ class TradeCommands(commands.Cog):
return channel
return None
def _is_in_trade_channel(self, interaction: discord.Interaction, trade_id: str) -> bool:
def _is_in_trade_channel(
self, interaction: discord.Interaction, trade_id: str
) -> bool:
"""Check if the interaction is happening in the trade's dedicated channel."""
trade_channel = self._get_trade_channel(interaction.guild, trade_id)
if not trade_channel:
@ -68,7 +80,7 @@ class TradeCommands(commands.Cog):
trade_id: str,
embed: discord.Embed,
view: Optional[discord.ui.View] = None,
content: Optional[str] = None
content: Optional[str] = None,
) -> bool:
"""
Post the trade embed to the trade channel.
@ -90,19 +102,12 @@ class TradeCommands(commands.Cog):
return False
@trade_group.command(
name="initiate",
description="Start a new trade with another team"
)
@app_commands.describe(
other_team="Team abbreviation to trade with"
name="initiate", description="Start a new trade with another team"
)
@app_commands.describe(other_team="Team abbreviation to trade with")
@app_commands.autocomplete(other_team=major_league_team_autocomplete)
@logged_command("/trade initiate")
async def trade_initiate(
self,
interaction: discord.Interaction,
other_team: str
):
async def trade_initiate(self, interaction: discord.Interaction, other_team: str):
"""Initiate a new trade with another team."""
await interaction.response.defer(ephemeral=True)
@ -112,15 +117,16 @@ class TradeCommands(commands.Cog):
return
# Get the other team
other_team_obj = await get_team_by_abbrev_with_validation(other_team, interaction)
other_team_obj = await get_team_by_abbrev_with_validation(
other_team, interaction
)
if not other_team_obj:
return
# Check if it's the same team
if user_team.id == other_team_obj.id:
await interaction.followup.send(
"❌ You cannot initiate a trade with yourself.",
ephemeral=True
"❌ You cannot initiate a trade with yourself.", ephemeral=True
)
return
@ -133,7 +139,7 @@ class TradeCommands(commands.Cog):
if not success:
await interaction.followup.send(
f"❌ Failed to add {other_team_obj.abbrev} to trade: {error_msg}",
ephemeral=True
ephemeral=True,
)
return
@ -143,7 +149,7 @@ class TradeCommands(commands.Cog):
trade_id=trade_builder.trade_id,
team1=user_team,
team2=other_team_obj,
creator_id=interaction.user.id
creator_id=interaction.user.id,
)
# Show trade interface
@ -156,31 +162,26 @@ class TradeCommands(commands.Cog):
success_msg += f"\n📝 Discussion channel: {channel.mention}"
else:
success_msg += f"\n⚠️ **Warning:** Failed to create discussion channel. Check bot permissions or contact an admin."
self.logger.warning(f"Failed to create trade channel for trade {trade_builder.trade_id}")
self.logger.warning(
f"Failed to create trade channel for trade {trade_builder.trade_id}"
)
await interaction.followup.send(
content=success_msg,
embed=embed,
view=view,
ephemeral=True
content=success_msg, embed=embed, view=view, ephemeral=True
)
self.logger.info(f"Trade initiated: {user_team.abbrev}{other_team_obj.abbrev}")
self.logger.info(
f"Trade initiated: {user_team.abbrev}{other_team_obj.abbrev}"
)
@trade_group.command(
name="add-team",
description="Add another team to your current trade (for 3+ team trades)"
)
@app_commands.describe(
other_team="Team abbreviation to add to the trade"
description="Add another team to your current trade (for 3+ team trades)",
)
@app_commands.describe(other_team="Team abbreviation to add to the trade")
@app_commands.autocomplete(other_team=major_league_team_autocomplete)
@logged_command("/trade add-team")
async def trade_add_team(
self,
interaction: discord.Interaction,
other_team: str
):
async def trade_add_team(self, interaction: discord.Interaction, other_team: str):
"""Add a team to an existing trade."""
await interaction.response.defer(ephemeral=False)
@ -194,7 +195,7 @@ class TradeCommands(commands.Cog):
if not trade_builder:
await interaction.followup.send(
"❌ Your team is not part of an active trade. Use `/trade initiate` first.",
ephemeral=True
ephemeral=True,
)
return
@ -207,8 +208,7 @@ class TradeCommands(commands.Cog):
success, error_msg = await trade_builder.add_team(team_to_add)
if not success:
await interaction.followup.send(
f"❌ Failed to add {team_to_add.abbrev}: {error_msg}",
ephemeral=True
f"❌ Failed to add {team_to_add.abbrev}: {error_msg}", ephemeral=True
)
return
@ -216,7 +216,7 @@ class TradeCommands(commands.Cog):
channel_updated = await self.channel_manager.add_team_to_channel(
guild=interaction.guild,
trade_id=trade_builder.trade_id,
new_team=team_to_add
new_team=team_to_add,
)
# Show updated trade interface
@ -226,13 +226,12 @@ class TradeCommands(commands.Cog):
# Build success message
success_msg = f"✅ **Added {team_to_add.abbrev} to the trade**"
if channel_updated:
success_msg += f"\n📝 {team_to_add.abbrev} has been added to the discussion channel"
success_msg += (
f"\n📝 {team_to_add.abbrev} has been added to the discussion channel"
)
await interaction.followup.send(
content=success_msg,
embed=embed,
view=view,
ephemeral=True
content=success_msg, embed=embed, view=view, ephemeral=True
)
# If command was executed outside trade channel, post update to trade channel
@ -242,27 +241,23 @@ class TradeCommands(commands.Cog):
trade_id=trade_builder.trade_id,
embed=embed,
view=view,
content=success_msg
content=success_msg,
)
self.logger.info(f"Team added to trade {trade_builder.trade_id}: {team_to_add.abbrev}")
self.logger.info(
f"Team added to trade {trade_builder.trade_id}: {team_to_add.abbrev}"
)
@trade_group.command(
name="add-player",
description="Add a player to the trade"
)
@trade_group.command(name="add-player", description="Add a player to the trade")
@app_commands.describe(
player_name="Player name; begin typing for autocomplete",
destination_team="Team abbreviation where the player will go"
destination_team="Team abbreviation where the player will go",
)
@app_commands.autocomplete(player_name=player_autocomplete)
@app_commands.autocomplete(destination_team=team_autocomplete)
@logged_command("/trade add-player")
async def trade_add_player(
self,
interaction: discord.Interaction,
player_name: str,
destination_team: str
self, interaction: discord.Interaction, player_name: str, destination_team: str
):
"""Add a player move to the trade."""
await interaction.response.defer(ephemeral=False)
@ -277,16 +272,17 @@ class TradeCommands(commands.Cog):
if not trade_builder:
await interaction.followup.send(
"❌ Your team is not part of an active trade. Use `/trade initiate` or ask another GM to add your team.",
ephemeral=True
ephemeral=True,
)
return
# Find the player
players = await player_service.search_players(player_name, limit=10, season=get_config().sba_season)
players = await player_service.search_players(
player_name, limit=10, season=get_config().sba_season
)
if not players:
await interaction.followup.send(
f"❌ Player '{player_name}' not found.",
ephemeral=True
f"❌ Player '{player_name}' not found.", ephemeral=True
)
return
@ -300,15 +296,19 @@ class TradeCommands(commands.Cog):
player = players[0]
# Get destination team
dest_team = await get_team_by_abbrev_with_validation(destination_team, interaction)
dest_team = await get_team_by_abbrev_with_validation(
destination_team, interaction
)
if not dest_team:
return
# Determine source team and roster locations
# For now, assume player comes from user's team and goes to ML of destination
# The service will validate that the player is actually on the user's team organization
from_roster = RosterType.MAJOR_LEAGUE # Default assumption
to_roster = RosterType.MAJOR_LEAGUE # Default destination
# Auto-detect source roster from player's actual team assignment
player_team = await team_service.get_team(player.team_id)
if player_team:
from_roster = player_team.roster_type()
else:
from_roster = RosterType.MAJOR_LEAGUE # Fallback
to_roster = dest_team.roster_type()
# Add the player move (service layer will validate)
success, error_msg = await trade_builder.add_player_move(
@ -316,26 +316,22 @@ class TradeCommands(commands.Cog):
from_team=user_team,
to_team=dest_team,
from_roster=from_roster,
to_roster=to_roster
to_roster=to_roster,
)
if not success:
await interaction.followup.send(
f"{error_msg}",
ephemeral=True
)
await interaction.followup.send(f"{error_msg}", ephemeral=True)
return
# Show updated trade interface
embed = await create_trade_embed(trade_builder)
view = TradeEmbedView(trade_builder, interaction.user.id)
success_msg = f"✅ **Added {player.name}: {user_team.abbrev}{dest_team.abbrev}**"
success_msg = (
f"✅ **Added {player.name}: {user_team.abbrev}{dest_team.abbrev}**"
)
await interaction.followup.send(
content=success_msg,
embed=embed,
view=view,
ephemeral=True
content=success_msg, embed=embed, view=view, ephemeral=True
)
# If command was executed outside trade channel, post update to trade channel
@ -345,31 +341,32 @@ class TradeCommands(commands.Cog):
trade_id=trade_builder.trade_id,
embed=embed,
view=view,
content=success_msg
content=success_msg,
)
self.logger.info(f"Player added to trade {trade_builder.trade_id}: {player.name} to {dest_team.abbrev}")
self.logger.info(
f"Player added to trade {trade_builder.trade_id}: {player.name} to {dest_team.abbrev}"
)
@trade_group.command(
name="supplementary",
description="Add a supplementary move within your organization for roster legality"
description="Add a supplementary move within your organization for roster legality",
)
@app_commands.describe(
player_name="Player name; begin typing for autocomplete",
destination="Where to move the player: Major League, Minor League, or Free Agency"
destination="Where to move the player: Major League, Minor League, or Free Agency",
)
@app_commands.autocomplete(player_name=player_autocomplete)
@app_commands.choices(destination=[
app_commands.Choice(name="Major League", value="ml"),
app_commands.Choice(name="Minor League", value="mil"),
app_commands.Choice(name="Free Agency", value="fa")
])
@app_commands.choices(
destination=[
app_commands.Choice(name="Major League", value="ml"),
app_commands.Choice(name="Minor League", value="mil"),
app_commands.Choice(name="Free Agency", value="fa"),
]
)
@logged_command("/trade supplementary")
async def trade_supplementary(
self,
interaction: discord.Interaction,
player_name: str,
destination: str
self, interaction: discord.Interaction, player_name: str, destination: str
):
"""Add a supplementary (internal organization) move for roster legality."""
await interaction.response.defer(ephemeral=False)
@ -384,16 +381,17 @@ class TradeCommands(commands.Cog):
if not trade_builder:
await interaction.followup.send(
"❌ Your team is not part of an active trade. Use `/trade initiate` or ask another GM to add your team.",
ephemeral=True
ephemeral=True,
)
return
# Find the player
players = await player_service.search_players(player_name, limit=10, season=get_config().sba_season)
players = await player_service.search_players(
player_name, limit=10, season=get_config().sba_season
)
if not players:
await interaction.followup.send(
f"❌ Player '{player_name}' not found.",
ephemeral=True
f"❌ Player '{player_name}' not found.", ephemeral=True
)
return
@ -403,45 +401,47 @@ class TradeCommands(commands.Cog):
destination_map = {
"ml": RosterType.MAJOR_LEAGUE,
"mil": RosterType.MINOR_LEAGUE,
"fa": RosterType.FREE_AGENCY
"fa": RosterType.FREE_AGENCY,
}
to_roster = destination_map.get(destination.lower())
if not to_roster:
await interaction.followup.send(
f"❌ Invalid destination: {destination}",
ephemeral=True
f"❌ Invalid destination: {destination}", ephemeral=True
)
return
# Determine current roster (default assumption)
from_roster = RosterType.MINOR_LEAGUE if to_roster == RosterType.MAJOR_LEAGUE else RosterType.MAJOR_LEAGUE
# Auto-detect source roster from player's actual team assignment
player_team = await team_service.get_team(player.team_id)
if player_team:
from_roster = player_team.roster_type()
else:
from_roster = (
RosterType.MINOR_LEAGUE
if to_roster == RosterType.MAJOR_LEAGUE
else RosterType.MAJOR_LEAGUE
)
# Add supplementary move
success, error_msg = await trade_builder.add_supplementary_move(
team=user_team,
player=player,
from_roster=from_roster,
to_roster=to_roster
team=user_team, player=player, from_roster=from_roster, to_roster=to_roster
)
if not success:
await interaction.followup.send(
f"❌ Failed to add supplementary move: {error_msg}",
ephemeral=True
f"❌ Failed to add supplementary move: {error_msg}", ephemeral=True
)
return
# Show updated trade interface
embed = await create_trade_embed(trade_builder)
view = TradeEmbedView(trade_builder, interaction.user.id)
success_msg = f"✅ **Added supplementary move: {player.name}{destination.upper()}**"
success_msg = (
f"✅ **Added supplementary move: {player.name}{destination.upper()}**"
)
await interaction.followup.send(
content=success_msg,
embed=embed,
view=view,
ephemeral=True
content=success_msg, embed=embed, view=view, ephemeral=True
)
# If command was executed outside trade channel, post update to trade channel
@ -451,15 +451,14 @@ class TradeCommands(commands.Cog):
trade_id=trade_builder.trade_id,
embed=embed,
view=view,
content=success_msg
content=success_msg,
)
self.logger.info(f"Supplementary move added to trade {trade_builder.trade_id}: {player.name} to {destination}")
self.logger.info(
f"Supplementary move added to trade {trade_builder.trade_id}: {player.name} to {destination}"
)
@trade_group.command(
name="view",
description="View your current trade"
)
@trade_group.command(name="view", description="View your current trade")
@logged_command("/trade view")
async def trade_view(self, interaction: discord.Interaction):
"""View the current trade."""
@ -474,8 +473,7 @@ class TradeCommands(commands.Cog):
trade_builder = get_trade_builder_by_team(user_team.id)
if not trade_builder:
await interaction.followup.send(
"❌ Your team is not part of an active trade.",
ephemeral=True
"❌ Your team is not part of an active trade.", ephemeral=True
)
return
@ -483,11 +481,7 @@ class TradeCommands(commands.Cog):
embed = await create_trade_embed(trade_builder)
view = TradeEmbedView(trade_builder, interaction.user.id)
await interaction.followup.send(
embed=embed,
view=view,
ephemeral=True
)
await interaction.followup.send(embed=embed, view=view, ephemeral=True)
# If command was executed outside trade channel, post update to trade channel
if not self._is_in_trade_channel(interaction, trade_builder.trade_id):
@ -495,13 +489,10 @@ class TradeCommands(commands.Cog):
guild=interaction.guild,
trade_id=trade_builder.trade_id,
embed=embed,
view=view
view=view,
)
@trade_group.command(
name="clear",
description="Clear your current trade"
)
@trade_group.command(name="clear", description="Clear your current trade")
@logged_command("/trade clear")
async def trade_clear(self, interaction: discord.Interaction):
"""Clear the current trade."""
@ -516,8 +507,7 @@ class TradeCommands(commands.Cog):
trade_builder = get_trade_builder_by_team(user_team.id)
if not trade_builder:
await interaction.followup.send(
"❌ Your team is not part of an active trade.",
ephemeral=True
"❌ Your team is not part of an active trade.", ephemeral=True
)
return
@ -525,19 +515,17 @@ class TradeCommands(commands.Cog):
# Delete associated trade channel if it exists
await self.channel_manager.delete_trade_channel(
guild=interaction.guild,
trade_id=trade_id
guild=interaction.guild, trade_id=trade_id
)
# Clear the trade builder using team-based function
clear_trade_builder_by_team(user_team.id)
await interaction.followup.send(
"✅ The trade has been cleared.",
ephemeral=True
"✅ The trade has been cleared.", ephemeral=True
)
async def setup(bot):
"""Setup function for the cog."""
await bot.add_cog(TradeCommands(bot))
await bot.add_cog(TradeCommands(bot))

View File

@ -7,6 +7,7 @@ This model matches the database schema at /database/app/routers_v3/stratplay.py.
NOTE: ID fields have corresponding optional model object fields for API-populated nested data.
Future enhancement could add validators to ensure consistency between ID and model fields.
"""
from typing import Optional, Literal
from pydantic import Field, field_validator
from models.base import SBABaseModel
@ -28,9 +29,11 @@ class Play(SBABaseModel):
game: Optional[Game] = Field(None, description="Game object (API-populated)")
play_num: int = Field(..., description="Sequential play number in game")
pitcher_id: Optional[int] = Field(None, description="Pitcher ID")
pitcher: Optional[Player] = Field(None, description="Pitcher object (API-populated)")
pitcher: Optional[Player] = Field(
None, description="Pitcher object (API-populated)"
)
on_base_code: str = Field(..., description="Base runners code (e.g., '100', '011')")
inning_half: Literal['top', 'bot'] = Field(..., description="Inning half")
inning_half: Literal["top", "bot"] = Field(..., description="Inning half")
inning_num: int = Field(..., description="Inning number")
batting_order: int = Field(..., description="Batting order position")
starting_outs: int = Field(..., description="Outs at start of play")
@ -41,21 +44,37 @@ class Play(SBABaseModel):
batter_id: Optional[int] = Field(None, description="Batter ID")
batter: Optional[Player] = Field(None, description="Batter object (API-populated)")
batter_team_id: Optional[int] = Field(None, description="Batter's team ID")
batter_team: Optional[Team] = Field(None, description="Batter's team object (API-populated)")
batter_team: Optional[Team] = Field(
None, description="Batter's team object (API-populated)"
)
pitcher_team_id: Optional[int] = Field(None, description="Pitcher's team ID")
pitcher_team: Optional[Team] = Field(None, description="Pitcher's team object (API-populated)")
pitcher_team: Optional[Team] = Field(
None, description="Pitcher's team object (API-populated)"
)
batter_pos: Optional[str] = Field(None, description="Batter's position")
# Base runner information
on_first_id: Optional[int] = Field(None, description="Runner on first ID")
on_first: Optional[Player] = Field(None, description="Runner on first object (API-populated)")
on_first_final: Optional[int] = Field(None, description="Runner on first final base")
on_first: Optional[Player] = Field(
None, description="Runner on first object (API-populated)"
)
on_first_final: Optional[int] = Field(
None, description="Runner on first final base"
)
on_second_id: Optional[int] = Field(None, description="Runner on second ID")
on_second: Optional[Player] = Field(None, description="Runner on second object (API-populated)")
on_second_final: Optional[int] = Field(None, description="Runner on second final base")
on_second: Optional[Player] = Field(
None, description="Runner on second object (API-populated)"
)
on_second_final: Optional[int] = Field(
None, description="Runner on second final base"
)
on_third_id: Optional[int] = Field(None, description="Runner on third ID")
on_third: Optional[Player] = Field(None, description="Runner on third object (API-populated)")
on_third_final: Optional[int] = Field(None, description="Runner on third final base")
on_third: Optional[Player] = Field(
None, description="Runner on third object (API-populated)"
)
on_third_final: Optional[int] = Field(
None, description="Runner on third final base"
)
batter_final: Optional[int] = Field(None, description="Batter's final base")
# Statistical fields (all default to 0)
@ -96,17 +115,27 @@ class Play(SBABaseModel):
# Defensive players
catcher_id: Optional[int] = Field(None, description="Catcher ID")
catcher: Optional[Player] = Field(None, description="Catcher object (API-populated)")
catcher: Optional[Player] = Field(
None, description="Catcher object (API-populated)"
)
catcher_team_id: Optional[int] = Field(None, description="Catcher's team ID")
catcher_team: Optional[Team] = Field(None, description="Catcher's team object (API-populated)")
catcher_team: Optional[Team] = Field(
None, description="Catcher's team object (API-populated)"
)
defender_id: Optional[int] = Field(None, description="Defender ID")
defender: Optional[Player] = Field(None, description="Defender object (API-populated)")
defender: Optional[Player] = Field(
None, description="Defender object (API-populated)"
)
defender_team_id: Optional[int] = Field(None, description="Defender's team ID")
defender_team: Optional[Team] = Field(None, description="Defender's team object (API-populated)")
defender_team: Optional[Team] = Field(
None, description="Defender's team object (API-populated)"
)
runner_id: Optional[int] = Field(None, description="Runner ID")
runner: Optional[Player] = Field(None, description="Runner object (API-populated)")
runner_team_id: Optional[int] = Field(None, description="Runner's team ID")
runner_team: Optional[Team] = Field(None, description="Runner's team object (API-populated)")
runner_team: Optional[Team] = Field(
None, description="Runner's team object (API-populated)"
)
# Defensive plays
check_pos: Optional[str] = Field(None, description="Position checked")
@ -126,35 +155,35 @@ class Play(SBABaseModel):
hand_pitching: Optional[str] = Field(None, description="Pitcher handedness (L/R)")
# Validators from database model
@field_validator('on_first_final')
@field_validator("on_first_final")
@classmethod
def no_final_if_no_runner_one(cls, v, info):
"""Validate on_first_final is None if no runner on first."""
if info.data.get('on_first_id') is None:
if info.data.get("on_first_id") is None:
return None
return v
@field_validator('on_second_final')
@field_validator("on_second_final")
@classmethod
def no_final_if_no_runner_two(cls, v, info):
"""Validate on_second_final is None if no runner on second."""
if info.data.get('on_second_id') is None:
if info.data.get("on_second_id") is None:
return None
return v
@field_validator('on_third_final')
@field_validator("on_third_final")
@classmethod
def no_final_if_no_runner_three(cls, v, info):
"""Validate on_third_final is None if no runner on third."""
if info.data.get('on_third_id') is None:
if info.data.get("on_third_id") is None:
return None
return v
@field_validator('batter_final')
@field_validator("batter_final")
@classmethod
def no_final_if_no_batter(cls, v, info):
"""Validate batter_final is None if no batter."""
if info.data.get('batter_id') is None:
if info.data.get("batter_id") is None:
return None
return v
@ -170,25 +199,28 @@ class Play(SBABaseModel):
Formatted string like: "Top 3: Player Name (NYY) homers in 2 runs"
"""
# Determine inning text
inning_text = f"{'Top' if self.inning_half == 'top' else 'Bot'} {self.inning_num}"
inning_text = (
f"{'Top' if self.inning_half == 'top' else 'Bot'} {self.inning_num}"
)
# Determine team abbreviation based on inning half
away_score = self.away_score
home_score = self.home_score
if self.inning_half == 'top':
if self.inning_half == "top":
away_score += self.rbi
else:
home_score += self.rbi
score_text = f'tied at {home_score}'
if home_score > away_score:
score_text = f'{home_team.abbrev} up {home_score}-{away_score}'
score_text = f"{home_team.abbrev} up {home_score}-{away_score}"
elif away_score > home_score:
score_text = f"{away_team.abbrev} up {away_score}-{home_score}"
else:
score_text = f'{away_team.abbrev} up {away_score}-{home_score}'
score_text = f"tied at {home_score}"
# Build play description based on play type
description_parts = []
which_player = 'batter'
which_player = "batter"
# Offensive plays
if self.homerun > 0:
@ -199,63 +231,79 @@ class Play(SBABaseModel):
elif self.triple > 0:
description_parts.append("triples")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.double > 0:
description_parts.append("doubles")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.hit > 0:
description_parts.append("singles")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.bb > 0:
if self.ibb > 0:
description_parts.append("intentionally walked")
else:
description_parts.append("walks")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.hbp > 0:
description_parts.append("hit by pitch")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.sac > 0:
description_parts.append("sacrifice fly")
if self.rbi > 0:
description_parts.append(f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"scoring {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.sb > 0:
description_parts.append("steals a base")
elif self.cs > 0:
which_player = 'catcher'
which_player = "catcher"
description_parts.append("guns down a baserunner")
elif self.gidp > 0:
description_parts.append("grounds into double play")
elif self.so > 0:
which_player = 'pitcher'
which_player = "pitcher"
description_parts.append(f"gets a strikeout")
# Defensive plays
elif self.error > 0:
which_player = 'defender'
which_player = "defender"
description_parts.append("commits an error")
if self.rbi > 0:
description_parts.append(f"allowing {self.rbi} run{'s' if self.rbi > 1 else ''}")
description_parts.append(
f"allowing {self.rbi} run{'s' if self.rbi > 1 else ''}"
)
elif self.wild_pitch > 0:
which_player = 'pitcher'
which_player = "pitcher"
description_parts.append("uncorks a wild pitch")
elif self.passed_ball > 0:
which_player = 'catcher'
which_player = "catcher"
description_parts.append("passed ball")
elif self.pick_off > 0:
which_player = 'runner'
which_player = "runner"
description_parts.append("picked off")
elif self.balk > 0:
which_player = 'pitcher'
which_player = "pitcher"
description_parts.append("balk")
else:
# Generic out
if self.outs > 0:
which_player = 'pitcher'
description_parts.append(f'records out number {self.starting_outs + self.outs}')
which_player = "pitcher"
description_parts.append(
f"records out number {self.starting_outs + self.outs}"
)
# Combine parts
if description_parts:
@ -264,18 +312,18 @@ class Play(SBABaseModel):
play_desc = "makes a play"
player_dict = {
'batter': self.batter,
'pitcher': self.pitcher,
'catcher': self.catcher,
'runner': self.runner,
'defender': self.defender
"batter": self.batter,
"pitcher": self.pitcher,
"catcher": self.catcher,
"runner": self.runner,
"defender": self.defender,
}
team_dict = {
'batter': self.batter_team,
'pitcher': self.pitcher_team,
'catcher': self.catcher_team,
'runner': self.runner_team,
'defender': self.defender_team
"batter": self.batter_team,
"pitcher": self.pitcher_team,
"catcher": self.catcher_team,
"runner": self.runner_team,
"defender": self.defender_team,
}
# Format: "Top 3: Derek Jeter (NYY) homers in 2 runs, NYY up 2-0"

View File

@ -8,7 +8,9 @@ import logging
from typing import Optional, Dict, Any
from datetime import UTC, datetime, timedelta
from config import get_config
from services.base_service import BaseService
from services.draft_pick_service import draft_pick_service
from models.draft_data import DraftData
logger = logging.getLogger(f"{__name__}.DraftService")
@ -162,9 +164,6 @@ class DraftService(BaseService[DraftData]):
Updated DraftData with new currentpick
"""
try:
from services.draft_pick_service import draft_pick_service
from config import get_config
config = get_config()
season = config.sba_season
total_picks = config.draft_total_picks

View File

@ -98,6 +98,13 @@ class GiphyService:
self.api_key = self.config.giphy_api_key
self.translate_url = self.config.giphy_translate_url
self.logger = get_contextual_logger(f"{__name__}.GiphyService")
self._session: Optional[aiohttp.ClientSession] = None
def _get_session(self) -> aiohttp.ClientSession:
"""Return the shared aiohttp session, creating it lazily if needed."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
def get_tier_for_seconds(self, seconds_elapsed: Optional[int]) -> str:
"""
@ -181,55 +188,53 @@ class GiphyService:
# Shuffle phrases for variety and retry capability
shuffled_phrases = random.sample(phrases, len(phrases))
async with aiohttp.ClientSession() as session:
for phrase in shuffled_phrases:
try:
url = f"{self.translate_url}?s={quote(phrase)}&api_key={quote(self.api_key)}"
session = self._get_session()
for phrase in shuffled_phrases:
try:
url = f"{self.translate_url}?s={quote(phrase)}&api_key={quote(self.api_key)}"
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
data = await resp.json()
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
data = await resp.json()
# Filter out Trump GIFs (legacy behavior)
gif_title = data.get("data", {}).get("title", "").lower()
if "trump" in gif_title:
self.logger.debug(
f"Filtered out Trump GIF for phrase: {phrase}"
)
continue
# Get the actual GIF image URL, not the web page URL
gif_url = (
data.get("data", {})
.get("images", {})
.get("original", {})
.get("url")
# Filter out Trump GIFs (legacy behavior)
gif_title = data.get("data", {}).get("title", "").lower()
if "trump" in gif_title:
self.logger.debug(
f"Filtered out Trump GIF for phrase: {phrase}"
)
if gif_url:
self.logger.info(
f"Successfully fetched GIF for phrase: {phrase}",
gif_url=gif_url,
)
return gif_url
else:
self.logger.warning(
f"No GIF URL in response for phrase: {phrase}"
)
continue
# Get the actual GIF image URL, not the web page URL
gif_url = (
data.get("data", {})
.get("images", {})
.get("original", {})
.get("url")
)
if gif_url:
self.logger.info(
f"Successfully fetched GIF for phrase: {phrase}",
gif_url=gif_url,
)
return gif_url
else:
self.logger.warning(
f"Giphy API returned status {resp.status} for phrase: {phrase}"
f"No GIF URL in response for phrase: {phrase}"
)
else:
self.logger.warning(
f"Giphy API returned status {resp.status} for phrase: {phrase}"
)
except aiohttp.ClientError as e:
self.logger.error(
f"HTTP error fetching GIF for phrase '{phrase}': {e}"
)
except Exception as e:
self.logger.error(
f"Unexpected error fetching GIF for phrase '{phrase}': {e}"
)
except aiohttp.ClientError as e:
self.logger.error(f"HTTP error fetching GIF for phrase '{phrase}': {e}")
except Exception as e:
self.logger.error(
f"Unexpected error fetching GIF for phrase '{phrase}': {e}"
)
# All phrases failed
error_msg = f"Failed to fetch any GIF for tier: {tier_key}"
@ -264,58 +269,58 @@ class GiphyService:
elif phrase_options is not None:
search_phrase = random.choice(phrase_options)
async with aiohttp.ClientSession() as session:
attempts = 0
while attempts < 3:
attempts += 1
try:
url = f"{self.translate_url}?s={quote(search_phrase)}&api_key={quote(self.api_key)}"
session = self._get_session()
attempts = 0
while attempts < 3:
attempts += 1
try:
url = f"{self.translate_url}?s={quote(search_phrase)}&api_key={quote(self.api_key)}"
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status != 200:
self.logger.warning(
f"Giphy API returned status {resp.status} for phrase: {search_phrase}"
)
continue
data = await resp.json()
# Filter out Trump GIFs (legacy behavior)
gif_title = data.get("data", {}).get("title", "").lower()
if "trump" in gif_title:
self.logger.debug(
f"Filtered out Trump GIF for phrase: {search_phrase}"
)
continue
# Get the actual GIF image URL, not the web page URL
gif_url = (
data.get("data", {})
.get("images", {})
.get("original", {})
.get("url")
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status != 200:
self.logger.warning(
f"Giphy API returned status {resp.status} for phrase: {search_phrase}"
)
if gif_url:
self.logger.info(
f"Successfully fetched GIF for phrase: {search_phrase}",
gif_url=gif_url,
)
return gif_url
else:
self.logger.warning(
f"No GIF URL in response for phrase: {search_phrase}"
)
continue
except aiohttp.ClientError as e:
self.logger.error(
f"HTTP error fetching GIF for phrase '{search_phrase}': {e}"
)
except Exception as e:
self.logger.error(
f"Unexpected error fetching GIF for phrase '{search_phrase}': {e}"
data = await resp.json()
# Filter out Trump GIFs (legacy behavior)
gif_title = data.get("data", {}).get("title", "").lower()
if "trump" in gif_title:
self.logger.debug(
f"Filtered out Trump GIF for phrase: {search_phrase}"
)
continue
# Get the actual GIF image URL, not the web page URL
gif_url = (
data.get("data", {})
.get("images", {})
.get("original", {})
.get("url")
)
if gif_url:
self.logger.info(
f"Successfully fetched GIF for phrase: {search_phrase}",
gif_url=gif_url,
)
return gif_url
else:
self.logger.warning(
f"No GIF URL in response for phrase: {search_phrase}"
)
except aiohttp.ClientError as e:
self.logger.error(
f"HTTP error fetching GIF for phrase '{search_phrase}': {e}"
)
except Exception as e:
self.logger.error(
f"Unexpected error fetching GIF for phrase '{search_phrase}': {e}"
)
# All attempts failed
error_msg = f"Failed to fetch any GIF for phrase: {search_phrase}"

View File

@ -7,6 +7,7 @@ Handles roster operations and validation.
import logging
from typing import Optional, List, Dict
from api.client import get_global_client
from models.roster import TeamRoster
from models.player import Player
from models.transaction import RosterValidation
@ -20,8 +21,6 @@ class RosterService:
def __init__(self):
"""Initialize roster service."""
from api.client import get_global_client
self._get_client = get_global_client
logger.debug("RosterService initialized")

View File

@ -3,65 +3,63 @@ Schedule service for Discord Bot v2.0
Handles game schedule and results retrieval and processing.
"""
import logging
from typing import Optional, List, Dict, Tuple
from api.client import get_global_client
from models.game import Game
logger = logging.getLogger(f'{__name__}.ScheduleService')
logger = logging.getLogger(f"{__name__}.ScheduleService")
class ScheduleService:
"""
Service for schedule and game operations.
Features:
- Weekly schedule retrieval
- Team-specific schedules
- Game results and upcoming games
- Series organization
"""
def __init__(self):
"""Initialize schedule service."""
from api.client import get_global_client
self._get_client = get_global_client
logger.debug("ScheduleService initialized")
async def get_client(self):
"""Get the API client."""
return await self._get_client()
async def get_week_schedule(self, season: int, week: int) -> List[Game]:
"""
Get all games for a specific week.
Args:
season: Season number
week: Week number
Returns:
List of Game instances for the week
"""
try:
client = await self.get_client()
params = [
('season', str(season)),
('week', str(week))
]
response = await client.get('games', params=params)
if not response or 'games' not in response:
params = [("season", str(season)), ("week", str(week))]
response = await client.get("games", params=params)
if not response or "games" not in response:
logger.warning(f"No games data found for season {season}, week {week}")
return []
games_list = response['games']
games_list = response["games"]
if not games_list:
logger.warning(f"Empty games list for season {season}, week {week}")
return []
# Convert to Game objects
games = []
for game_data in games_list:
@ -71,185 +69,206 @@ class ScheduleService:
except Exception as e:
logger.error(f"Error parsing game data: {e}")
continue
logger.info(f"Retrieved {len(games)} games for season {season}, week {week}")
logger.info(
f"Retrieved {len(games)} games for season {season}, week {week}"
)
return games
except Exception as e:
logger.error(f"Error getting week schedule for season {season}, week {week}: {e}")
logger.error(
f"Error getting week schedule for season {season}, week {week}: {e}"
)
return []
async def get_team_schedule(self, season: int, team_abbrev: str, weeks: Optional[int] = None) -> List[Game]:
async def get_team_schedule(
self, season: int, team_abbrev: str, weeks: Optional[int] = None
) -> List[Game]:
"""
Get schedule for a specific team.
Args:
season: Season number
team_abbrev: Team abbreviation (e.g., 'NYY')
weeks: Number of weeks to retrieve (None for all weeks)
Returns:
List of Game instances for the team
"""
try:
team_games = []
team_abbrev_upper = team_abbrev.upper()
# If weeks not specified, try a reasonable range (18 weeks typical)
week_range = range(1, (weeks + 1) if weeks else 19)
for week in week_range:
week_games = await self.get_week_schedule(season, week)
# Filter games involving this team
for game in week_games:
if (game.away_team.abbrev.upper() == team_abbrev_upper or
game.home_team.abbrev.upper() == team_abbrev_upper):
if (
game.away_team.abbrev.upper() == team_abbrev_upper
or game.home_team.abbrev.upper() == team_abbrev_upper
):
team_games.append(game)
logger.info(f"Retrieved {len(team_games)} games for team {team_abbrev}")
return team_games
except Exception as e:
logger.error(f"Error getting team schedule for {team_abbrev}: {e}")
return []
async def get_recent_games(self, season: int, weeks_back: int = 2) -> List[Game]:
"""
Get recently completed games.
Args:
season: Season number
weeks_back: Number of weeks back to look
Returns:
List of completed Game instances
"""
try:
recent_games = []
# Get games from recent weeks
for week_offset in range(weeks_back):
# This is simplified - in production you'd want to determine current week
week = 10 - week_offset # Assuming we're around week 10
if week <= 0:
break
week_games = await self.get_week_schedule(season, week)
# Only include completed games
completed_games = [game for game in week_games if game.is_completed]
recent_games.extend(completed_games)
# Sort by week descending (most recent first)
recent_games.sort(key=lambda x: (x.week, x.game_num or 0), reverse=True)
logger.debug(f"Retrieved {len(recent_games)} recent games")
return recent_games
except Exception as e:
logger.error(f"Error getting recent games: {e}")
return []
async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]:
"""
Get upcoming scheduled games by scanning multiple weeks.
Args:
season: Season number
weeks_ahead: Number of weeks to scan ahead (default 6)
Returns:
List of upcoming Game instances
"""
try:
upcoming_games = []
# Scan through weeks to find games without scores
for week in range(1, 19): # Standard season length
week_games = await self.get_week_schedule(season, week)
# Find games without scores (not yet played)
upcoming_games_week = [game for game in week_games if not game.is_completed]
upcoming_games_week = [
game for game in week_games if not game.is_completed
]
upcoming_games.extend(upcoming_games_week)
# If we found upcoming games, we can limit how many more weeks to check
if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit
break
# Sort by week, then game number
upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0))
logger.debug(f"Retrieved {len(upcoming_games)} upcoming games")
return upcoming_games
except Exception as e:
logger.error(f"Error getting upcoming games: {e}")
return []
async def get_series_by_teams(self, season: int, week: int, team1_abbrev: str, team2_abbrev: str) -> List[Game]:
async def get_series_by_teams(
self, season: int, week: int, team1_abbrev: str, team2_abbrev: str
) -> List[Game]:
"""
Get all games in a series between two teams for a specific week.
Args:
season: Season number
week: Week number
team1_abbrev: First team abbreviation
team2_abbrev: Second team abbreviation
Returns:
List of Game instances in the series
"""
try:
week_games = await self.get_week_schedule(season, week)
team1_upper = team1_abbrev.upper()
team2_upper = team2_abbrev.upper()
# Find games between these two teams
series_games = []
for game in week_games:
game_teams = {game.away_team.abbrev.upper(), game.home_team.abbrev.upper()}
game_teams = {
game.away_team.abbrev.upper(),
game.home_team.abbrev.upper(),
}
if game_teams == {team1_upper, team2_upper}:
series_games.append(game)
# Sort by game number
series_games.sort(key=lambda x: x.game_num or 0)
logger.debug(f"Retrieved {len(series_games)} games in series between {team1_abbrev} and {team2_abbrev}")
logger.debug(
f"Retrieved {len(series_games)} games in series between {team1_abbrev} and {team2_abbrev}"
)
return series_games
except Exception as e:
logger.error(f"Error getting series between {team1_abbrev} and {team2_abbrev}: {e}")
logger.error(
f"Error getting series between {team1_abbrev} and {team2_abbrev}: {e}"
)
return []
def group_games_by_series(self, games: List[Game]) -> Dict[Tuple[str, str], List[Game]]:
def group_games_by_series(
self, games: List[Game]
) -> Dict[Tuple[str, str], List[Game]]:
"""
Group games by matchup (series).
Args:
games: List of Game instances
Returns:
Dictionary mapping (team1, team2) tuples to game lists
"""
series_games = {}
for game in games:
# Create consistent team pairing (alphabetical order)
teams = sorted([game.away_team.abbrev, game.home_team.abbrev])
series_key = (teams[0], teams[1])
if series_key not in series_games:
series_games[series_key] = []
series_games[series_key].append(game)
# Sort each series by game number
for series_key in series_games:
series_games[series_key].sort(key=lambda x: x.game_num or 0)
return series_games
# Global service instance
schedule_service = ScheduleService()
schedule_service = ScheduleService()

View File

@ -8,6 +8,7 @@ import asyncio
from typing import Dict, List, Any, Optional
import pygsheets
from config import get_config
from utils.logging import get_contextual_logger
from exceptions import SheetsException
@ -24,8 +25,6 @@ class SheetsService:
If None, will use path from config
"""
if credentials_path is None:
from config import get_config
credentials_path = get_config().sheets_credentials_path
self.credentials_path = credentials_path
@ -416,6 +415,8 @@ class SheetsService:
self.logger.info(f"Read {len(pit_data)} valid pitching decisions")
return pit_data
except SheetsException:
raise
except Exception as e:
self.logger.error(f"Failed to read pitching decisions: {e}")
raise SheetsException("Unable to read pitching decisions") from e
@ -458,6 +459,8 @@ class SheetsService:
"home": [int(x) for x in score_table[1]], # [R, H, E]
}
except SheetsException:
raise
except Exception as e:
self.logger.error(f"Failed to read box score: {e}")
raise SheetsException("Unable to read box score") from e

View File

@ -3,61 +3,62 @@ Standings service for Discord Bot v2.0
Handles team standings retrieval and processing.
"""
import logging
from typing import Optional, List, Dict
from api.client import get_global_client
from models.standings import TeamStandings
from exceptions import APIException
logger = logging.getLogger(f'{__name__}.StandingsService')
logger = logging.getLogger(f"{__name__}.StandingsService")
class StandingsService:
"""
Service for team standings operations.
Features:
- League standings retrieval
- Division-based filtering
- Season-specific data
- Playoff positioning
"""
def __init__(self):
"""Initialize standings service."""
from api.client import get_global_client
self._get_client = get_global_client
logger.debug("StandingsService initialized")
async def get_client(self):
"""Get the API client."""
return await self._get_client()
async def get_league_standings(self, season: int) -> List[TeamStandings]:
"""
Get complete league standings for a season.
Args:
season: Season number
Returns:
List of TeamStandings ordered by record
"""
try:
client = await self.get_client()
params = [('season', str(season))]
response = await client.get('standings', params=params)
if not response or 'standings' not in response:
params = [("season", str(season))]
response = await client.get("standings", params=params)
if not response or "standings" not in response:
logger.warning(f"No standings data found for season {season}")
return []
standings_list = response['standings']
standings_list = response["standings"]
if not standings_list:
logger.warning(f"Empty standings for season {season}")
return []
# Convert to model objects
standings = []
for standings_data in standings_list:
@ -67,34 +68,41 @@ class StandingsService:
except Exception as e:
logger.error(f"Error parsing standings data for team: {e}")
continue
logger.info(f"Retrieved standings for {len(standings)} teams in season {season}")
logger.info(
f"Retrieved standings for {len(standings)} teams in season {season}"
)
return standings
except Exception as e:
logger.error(f"Error getting league standings for season {season}: {e}")
return []
async def get_standings_by_division(self, season: int) -> Dict[str, List[TeamStandings]]:
async def get_standings_by_division(
self, season: int
) -> Dict[str, List[TeamStandings]]:
"""
Get standings grouped by division.
Args:
season: Season number
Returns:
Dictionary mapping division names to team standings
"""
try:
all_standings = await self.get_league_standings(season)
if not all_standings:
return {}
# Group by division
divisions = {}
for team_standings in all_standings:
if hasattr(team_standings.team, 'division') and team_standings.team.division:
if (
hasattr(team_standings.team, "division")
and team_standings.team.division
):
div_name = team_standings.team.division.division_name
if div_name not in divisions:
divisions[div_name] = []
@ -104,95 +112,99 @@ class StandingsService:
if "No Division" not in divisions:
divisions["No Division"] = []
divisions["No Division"].append(team_standings)
# Sort each division by record (wins descending, then by winning percentage)
for div_name in divisions:
divisions[div_name].sort(
key=lambda x: (x.wins, x.winning_percentage),
reverse=True
key=lambda x: (x.wins, x.winning_percentage), reverse=True
)
logger.debug(f"Grouped standings into {len(divisions)} divisions")
return divisions
except Exception as e:
logger.error(f"Error grouping standings by division: {e}")
return {}
async def get_team_standings(self, team_abbrev: str, season: int) -> Optional[TeamStandings]:
async def get_team_standings(
self, team_abbrev: str, season: int
) -> Optional[TeamStandings]:
"""
Get standings for a specific team.
Args:
team_abbrev: Team abbreviation (e.g., 'NYY')
season: Season number
Returns:
TeamStandings instance or None if not found
"""
try:
all_standings = await self.get_league_standings(season)
# Find team by abbreviation
team_abbrev_upper = team_abbrev.upper()
for team_standings in all_standings:
if team_standings.team.abbrev.upper() == team_abbrev_upper:
logger.debug(f"Found standings for {team_abbrev}: {team_standings}")
return team_standings
logger.warning(f"No standings found for team {team_abbrev} in season {season}")
logger.warning(
f"No standings found for team {team_abbrev} in season {season}"
)
return None
except Exception as e:
logger.error(f"Error getting standings for team {team_abbrev}: {e}")
return None
async def get_playoff_picture(self, season: int) -> Dict[str, List[TeamStandings]]:
"""
Get playoff picture with division leaders and wild card contenders.
Args:
season: Season number
Returns:
Dictionary with 'division_leaders' and 'wild_card' lists
"""
try:
divisions = await self.get_standings_by_division(season)
if not divisions:
return {"division_leaders": [], "wild_card": []}
# Get division leaders (first place in each division)
division_leaders = []
wild_card_candidates = []
for div_name, teams in divisions.items():
if teams: # Division has teams
# First team is division leader
division_leaders.append(teams[0])
# Rest are potential wild card candidates
for team in teams[1:]:
wild_card_candidates.append(team)
# Sort wild card candidates by record
wild_card_candidates.sort(
key=lambda x: (x.wins, x.winning_percentage),
reverse=True
key=lambda x: (x.wins, x.winning_percentage), reverse=True
)
# Take top wild card contenders (typically top 6-8 teams)
wild_card_contenders = wild_card_candidates[:8]
logger.debug(f"Playoff picture: {len(division_leaders)} division leaders, "
f"{len(wild_card_contenders)} wild card contenders")
logger.debug(
f"Playoff picture: {len(division_leaders)} division leaders, "
f"{len(wild_card_contenders)} wild card contenders"
)
return {
"division_leaders": division_leaders,
"wild_card": wild_card_contenders
"wild_card": wild_card_contenders,
}
except Exception as e:
logger.error(f"Error generating playoff picture: {e}")
return {"division_leaders": [], "wild_card": []}
@ -217,9 +229,7 @@ class StandingsService:
# Use 8 second timeout for this potentially slow operation
response = await client.post(
f'standings/s{season}/recalculate',
{},
timeout=8.0
f"standings/s{season}/recalculate", {}, timeout=8.0
)
logger.info(f"Recalculated standings for season {season}")
@ -231,4 +241,4 @@ class StandingsService:
# Global service instance
standings_service = StandingsService()
standings_service = StandingsService()

View File

@ -3,129 +3,142 @@ Statistics service for Discord Bot v2.0
Handles batting and pitching statistics retrieval and processing.
"""
import logging
from typing import Optional
from api.client import get_global_client
from models.batting_stats import BattingStats
from models.pitching_stats import PitchingStats
logger = logging.getLogger(f'{__name__}.StatsService')
logger = logging.getLogger(f"{__name__}.StatsService")
class StatsService:
"""
Service for player statistics operations.
Features:
- Batting statistics retrieval
- Pitching statistics retrieval
- Season-specific filtering
- Error handling and logging
"""
def __init__(self):
"""Initialize stats service."""
# We don't inherit from BaseService since we need custom endpoints
from api.client import get_global_client
self._get_client = get_global_client
logger.debug("StatsService initialized")
async def get_client(self):
"""Get the API client."""
return await self._get_client()
async def get_batting_stats(self, player_id: int, season: int) -> Optional[BattingStats]:
async def get_batting_stats(
self, player_id: int, season: int
) -> Optional[BattingStats]:
"""
Get batting statistics for a player in a specific season.
Args:
player_id: Player ID
season: Season number
Returns:
BattingStats instance or None if not found
"""
try:
client = await self.get_client()
# Call the batting stats view endpoint
params = [
('player_id', str(player_id)),
('season', str(season))
]
response = await client.get('views/season-stats/batting', params=params)
if not response or 'stats' not in response:
logger.debug(f"No batting stats found for player {player_id}, season {season}")
params = [("player_id", str(player_id)), ("season", str(season))]
response = await client.get("views/season-stats/batting", params=params)
if not response or "stats" not in response:
logger.debug(
f"No batting stats found for player {player_id}, season {season}"
)
return None
stats_list = response['stats']
stats_list = response["stats"]
if not stats_list:
logger.debug(f"Empty batting stats for player {player_id}, season {season}")
logger.debug(
f"Empty batting stats for player {player_id}, season {season}"
)
return None
# Take the first (should be only) result
stats_data = stats_list[0]
batting_stats = BattingStats.from_api_data(stats_data)
logger.debug(f"Retrieved batting stats for player {player_id}: {batting_stats.avg:.3f} AVG")
logger.debug(
f"Retrieved batting stats for player {player_id}: {batting_stats.avg:.3f} AVG"
)
return batting_stats
except Exception as e:
logger.error(f"Error getting batting stats for player {player_id}: {e}")
return None
async def get_pitching_stats(self, player_id: int, season: int) -> Optional[PitchingStats]:
async def get_pitching_stats(
self, player_id: int, season: int
) -> Optional[PitchingStats]:
"""
Get pitching statistics for a player in a specific season.
Args:
player_id: Player ID
season: Season number
Returns:
PitchingStats instance or None if not found
"""
try:
client = await self.get_client()
# Call the pitching stats view endpoint
params = [
('player_id', str(player_id)),
('season', str(season))
]
response = await client.get('views/season-stats/pitching', params=params)
if not response or 'stats' not in response:
logger.debug(f"No pitching stats found for player {player_id}, season {season}")
params = [("player_id", str(player_id)), ("season", str(season))]
response = await client.get("views/season-stats/pitching", params=params)
if not response or "stats" not in response:
logger.debug(
f"No pitching stats found for player {player_id}, season {season}"
)
return None
stats_list = response['stats']
stats_list = response["stats"]
if not stats_list:
logger.debug(f"Empty pitching stats for player {player_id}, season {season}")
logger.debug(
f"Empty pitching stats for player {player_id}, season {season}"
)
return None
# Take the first (should be only) result
stats_data = stats_list[0]
pitching_stats = PitchingStats.from_api_data(stats_data)
logger.debug(f"Retrieved pitching stats for player {player_id}: {pitching_stats.era:.2f} ERA")
logger.debug(
f"Retrieved pitching stats for player {player_id}: {pitching_stats.era:.2f} ERA"
)
return pitching_stats
except Exception as e:
logger.error(f"Error getting pitching stats for player {player_id}: {e}")
return None
async def get_player_stats(self, player_id: int, season: int) -> tuple[Optional[BattingStats], Optional[PitchingStats]]:
async def get_player_stats(
self, player_id: int, season: int
) -> tuple[Optional[BattingStats], Optional[PitchingStats]]:
"""
Get both batting and pitching statistics for a player.
Args:
player_id: Player ID
season: Season number
Returns:
Tuple of (batting_stats, pitching_stats) - either can be None
"""
@ -133,20 +146,22 @@ class StatsService:
# Get both types of stats concurrently
batting_task = self.get_batting_stats(player_id, season)
pitching_task = self.get_pitching_stats(player_id, season)
batting_stats = await batting_task
pitching_stats = await pitching_task
logger.debug(f"Retrieved stats for player {player_id}: "
f"batting={'yes' if batting_stats else 'no'}, "
f"pitching={'yes' if pitching_stats else 'no'}")
logger.debug(
f"Retrieved stats for player {player_id}: "
f"batting={'yes' if batting_stats else 'no'}, "
f"pitching={'yes' if pitching_stats else 'no'}"
)
return batting_stats, pitching_stats
except Exception as e:
logger.error(f"Error getting player stats for {player_id}: {e}")
return None, None
# Global service instance
stats_service = StatsService()
stats_service = StatsService()

View File

@ -3,6 +3,7 @@ Team service for Discord Bot v2.0
Handles team-related operations with roster management and league queries.
"""
import logging
from typing import Optional, List, Dict, Any
@ -12,13 +13,13 @@ from models.team import Team, RosterType
from exceptions import APIException
from utils.decorators import cached_single_item
logger = logging.getLogger(f'{__name__}.TeamService')
logger = logging.getLogger(f"{__name__}.TeamService")
class TeamService(BaseService[Team]):
"""
Service for team-related operations.
Features:
- Team retrieval by ID, abbreviation, and season
- Manager-based team queries
@ -27,12 +28,12 @@ class TeamService(BaseService[Team]):
- Season-specific team data
- Standings integration
"""
def __init__(self):
"""Initialize team service."""
super().__init__(Team, 'teams')
super().__init__(Team, "teams")
logger.debug("TeamService initialized")
@cached_single_item(ttl=1800) # 30-minute cache
async def get_team(self, team_id: int) -> Optional[Team]:
"""
@ -57,12 +58,12 @@ class TeamService(BaseService[Team]):
except Exception as e:
logger.error(f"Unexpected error getting team {team_id}: {e}")
return None
async def get_teams_by_owner(
self,
owner_id: int,
season: Optional[int] = None,
roster_type: Optional[str] = None
roster_type: Optional[str] = None,
) -> List[Team]:
"""
Get teams owned by a specific Discord user.
@ -80,10 +81,7 @@ class TeamService(BaseService[Team]):
Allows caller to distinguish between "no teams" vs "error occurred"
"""
season = season or get_config().sba_season
params = [
('owner_id', str(owner_id)),
('season', str(season))
]
params = [("owner_id", str(owner_id)), ("season", str(season))]
teams = await self.get_all_items(params=params)
@ -92,19 +90,27 @@ class TeamService(BaseService[Team]):
try:
target_type = RosterType(roster_type)
teams = [team for team in teams if team.roster_type() == target_type]
logger.debug(f"Filtered to {len(teams)} {roster_type} teams for owner {owner_id}")
logger.debug(
f"Filtered to {len(teams)} {roster_type} teams for owner {owner_id}"
)
except ValueError:
logger.warning(f"Invalid roster_type '{roster_type}' - returning all teams")
logger.warning(
f"Invalid roster_type '{roster_type}' - returning all teams"
)
if teams:
logger.debug(f"Found {len(teams)} teams for owner {owner_id} in season {season}")
logger.debug(
f"Found {len(teams)} teams for owner {owner_id} in season {season}"
)
return teams
logger.debug(f"No teams found for owner {owner_id} in season {season}")
return []
@cached_single_item(ttl=1800) # 30-minute cache
async def get_team_by_owner(self, owner_id: int, season: Optional[int] = None) -> Optional[Team]:
async def get_team_by_owner(
self, owner_id: int, season: Optional[int] = None
) -> Optional[Team]:
"""
Get the primary (Major League) team owned by a Discord user.
@ -124,125 +130,129 @@ class TeamService(BaseService[Team]):
Returns:
Team instance or None if not found
"""
teams = await self.get_teams_by_owner(owner_id, season, roster_type='ml')
teams = await self.get_teams_by_owner(owner_id, season, roster_type="ml")
return teams[0] if teams else None
async def get_team_by_abbrev(self, abbrev: str, season: Optional[int] = None) -> Optional[Team]:
async def get_team_by_abbrev(
self, abbrev: str, season: Optional[int] = None
) -> Optional[Team]:
"""
Get team by abbreviation for a specific season.
Args:
abbrev: Team abbreviation (e.g., 'NYY', 'BOS')
season: Season number (defaults to current season)
Returns:
Team instance or None if not found
"""
try:
season = season or get_config().sba_season
params = [
('team_abbrev', abbrev.upper()),
('season', str(season))
]
params = [("team_abbrev", abbrev.upper()), ("season", str(season))]
teams = await self.get_all_items(params=params)
if teams:
team = teams[0] # Should be unique per season
logger.debug(f"Found team {abbrev} for season {season}: {team.lname}")
return team
logger.debug(f"No team found for abbreviation '{abbrev}' in season {season}")
logger.debug(
f"No team found for abbreviation '{abbrev}' in season {season}"
)
return None
except Exception as e:
logger.error(f"Error getting team by abbreviation '{abbrev}': {e}")
return None
async def get_teams_by_season(self, season: int) -> List[Team]:
"""
Get all teams for a specific season.
Args:
season: Season number
Returns:
List of teams in the season
"""
try:
params = [('season', str(season))]
params = [("season", str(season))]
teams = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(teams)} teams for season {season}")
return teams
except Exception as e:
logger.error(f"Failed to get teams for season {season}: {e}")
return []
async def get_teams_by_manager(self, manager_id: int, season: Optional[int] = None) -> List[Team]:
async def get_teams_by_manager(
self, manager_id: int, season: Optional[int] = None
) -> List[Team]:
"""
Get teams managed by a specific manager.
Uses 'manager_id' query parameter which supports multiple manager matching.
Args:
manager_id: Manager identifier
season: Season number (optional)
Returns:
List of teams managed by the manager
"""
try:
params = [('manager_id', str(manager_id))]
params = [("manager_id", str(manager_id))]
if season:
params.append(('season', str(season)))
params.append(("season", str(season)))
teams = await self.get_all_items(params=params)
logger.debug(f"Found {len(teams)} teams for manager {manager_id}")
return teams
except Exception as e:
logger.error(f"Failed to get teams for manager {manager_id}: {e}")
return []
async def get_teams_by_division(self, division_id: int, season: int) -> List[Team]:
"""
Get teams in a specific division for a season.
Args:
division_id: Division identifier
season: Season number
Returns:
List of teams in the division
"""
try:
params = [
('division_id', str(division_id)),
('season', str(season))
]
params = [("division_id", str(division_id)), ("season", str(season))]
teams = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(teams)} teams for division {division_id} in season {season}")
logger.debug(
f"Retrieved {len(teams)} teams for division {division_id} in season {season}"
)
return teams
except Exception as e:
logger.error(f"Failed to get teams for division {division_id}: {e}")
return []
async def get_team_roster(self, team_id: int, roster_type: str = 'current') -> Optional[Dict[str, Any]]:
async def get_team_roster(
self, team_id: int, roster_type: str = "current"
) -> Optional[Dict[str, Any]]:
"""
Get the roster for a team with position counts and player lists.
Returns roster data with active, shortil (minor league), and longil (injured list)
Returns roster data with active, shortil (injured list), and longil (minor league)
rosters. Each roster contains position counts and players sorted by descending WARa.
Args:
team_id: Team identifier
roster_type: 'current' or 'next' roster
Returns:
Dictionary with roster structure:
{
@ -257,19 +267,19 @@ class TeamService(BaseService[Team]):
"""
try:
client = await self.get_client()
data = await client.get(f'teams/{team_id}/roster/{roster_type}')
data = await client.get(f"teams/{team_id}/roster/{roster_type}")
if data:
logger.debug(f"Retrieved {roster_type} roster for team {team_id}")
return data
logger.debug(f"No roster data found for team {team_id}")
return None
except Exception as e:
logger.error(f"Failed to get roster for team {team_id}: {e}")
return None
async def update_team(self, team_id: int, updates: dict) -> Optional[Team]:
"""
Update team information.
@ -287,52 +297,58 @@ class TeamService(BaseService[Team]):
except Exception as e:
logger.error(f"Failed to update team {team_id}: {e}")
return None
async def get_team_standings_position(self, team_id: int, season: int) -> Optional[dict]:
async def get_team_standings_position(
self, team_id: int, season: int
) -> Optional[dict]:
"""
Get team's standings information.
Calls /standings/team/{team_id} endpoint which returns a Standings object.
Args:
team_id: Team identifier
season: Season number
Returns:
Standings object data for the team
"""
try:
client = await self.get_client()
data = await client.get(f'standings/team/{team_id}', params=[('season', str(season))])
data = await client.get(
f"standings/team/{team_id}", params=[("season", str(season))]
)
if data:
logger.debug(f"Retrieved standings for team {team_id}")
return data
return None
except Exception as e:
logger.error(f"Failed to get standings for team {team_id}: {e}")
return None
async def is_valid_team_abbrev(self, abbrev: str, season: Optional[int] = None) -> bool:
async def is_valid_team_abbrev(
self, abbrev: str, season: Optional[int] = None
) -> bool:
"""
Check if a team abbreviation is valid for a season.
Args:
abbrev: Team abbreviation to validate
season: Season number (defaults to current)
Returns:
True if the abbreviation is valid
"""
team = await self.get_team_by_abbrev(abbrev, season)
return team is not None
async def get_current_season_teams(self) -> List[Team]:
"""
Get all teams for the current season.
Returns:
List of teams in current season
"""
@ -340,4 +356,4 @@ class TeamService(BaseService[Team]):
# Global service instance
team_service = TeamService()
team_service = TeamService()

View File

@ -29,6 +29,7 @@ class TradeValidationResult:
def __init__(self):
self.is_legal: bool = True
self.participant_validations: Dict[int, RosterValidationResult] = {}
self.team_abbrevs: Dict[int, str] = {} # team_id -> abbreviation
self.trade_errors: List[str] = []
self.trade_warnings: List[str] = []
self.trade_suggestions: List[str] = []
@ -37,24 +38,30 @@ class TradeValidationResult:
def all_errors(self) -> List[str]:
"""Get all errors including trade-level and roster-level errors."""
errors = self.trade_errors.copy()
for validation in self.participant_validations.values():
errors.extend(validation.errors)
for team_id, validation in self.participant_validations.items():
abbrev = self.team_abbrevs.get(team_id, "???")
for error in validation.errors:
errors.append(f"[{abbrev}] {error}")
return errors
@property
def all_warnings(self) -> List[str]:
"""Get all warnings across trade and roster levels."""
warnings = self.trade_warnings.copy()
for validation in self.participant_validations.values():
warnings.extend(validation.warnings)
for team_id, validation in self.participant_validations.items():
abbrev = self.team_abbrevs.get(team_id, "???")
for warning in validation.warnings:
warnings.append(f"[{abbrev}] {warning}")
return warnings
@property
def all_suggestions(self) -> List[str]:
"""Get all suggestions across trade and roster levels."""
suggestions = self.trade_suggestions.copy()
for validation in self.participant_validations.values():
suggestions.extend(validation.suggestions)
for team_id, validation in self.participant_validations.items():
abbrev = self.team_abbrevs.get(team_id, "???")
for suggestion in validation.suggestions:
suggestions.append(f"[{abbrev}] {suggestion}")
return suggestions
def get_participant_validation(
@ -518,6 +525,7 @@ class TradeBuilder:
# Validate each team's roster after the trade
for participant in self.trade.participants:
team_id = participant.team.id
result.team_abbrevs[team_id] = participant.team.abbrev
if team_id in self._team_builders:
builder = self._team_builders[team_id]
roster_validation = await builder.validate_transaction(next_week)

View File

@ -14,7 +14,7 @@ from models.transaction import Transaction
from models.team import Team
from models.player import Player
from models.roster import TeamRoster
from services.roster_service import roster_service
from services.roster_service import RosterService, roster_service
from services.transaction_service import transaction_service
from services.league_service import league_service
from models.team import RosterType
@ -174,7 +174,13 @@ class RosterValidationResult:
class TransactionBuilder:
"""Interactive transaction builder for complex multi-move transactions."""
def __init__(self, team: Team, user_id: int, season: int = get_config().sba_season):
def __init__(
self,
team: Team,
user_id: int,
season: int = get_config().sba_season,
roster_svc: Optional[RosterService] = None,
):
"""
Initialize transaction builder.
@ -182,32 +188,39 @@ class TransactionBuilder:
team: Team making the transaction
user_id: Discord user ID of the GM
season: Season number
roster_svc: RosterService instance (defaults to global roster_service)
"""
self.team = team
self.user_id = user_id
self.season = season
self.moves: List[TransactionMove] = []
self.created_at = datetime.now(timezone.utc)
self._roster_svc = roster_svc or roster_service
# Cache for roster data
self._current_roster: Optional[TeamRoster] = None
self._roster_loaded = False
# Cache for pre-existing transactions
# Pre-existing transactions (re-fetched on each validation)
self._existing_transactions: Optional[List[Transaction]] = None
self._existing_transactions_loaded = False
logger.info(
f"TransactionBuilder initialized for {team.abbrev} by user {user_id}"
)
async def load_roster_data(self) -> None:
"""Load current roster data for the team."""
if self._roster_loaded:
async def load_roster_data(self, force_refresh: bool = False) -> None:
"""Load current roster data for the team.
Args:
force_refresh: If True, bypass cache and fetch fresh data from API.
"""
if self._roster_loaded and not force_refresh:
return
try:
self._current_roster = await roster_service.get_current_roster(self.team.id)
self._current_roster = await self._roster_svc.get_current_roster(
self.team.id
)
self._roster_loaded = True
logger.debug(f"Loaded roster data for team {self.team.abbrev}")
except Exception as e:
@ -219,11 +232,12 @@ class TransactionBuilder:
"""
Load pre-existing transactions for next week.
Always re-fetches from the API to capture transactions submitted
by other users or sessions since the builder was initialized.
Queries for all organizational affiliates (ML, MiL, IL) to ensure
trades involving affiliate teams are included in roster projections.
"""
if self._existing_transactions_loaded:
return
try:
# Include all org affiliates so trades involving MiL/IL teams are captured
@ -238,14 +252,12 @@ class TransactionBuilder:
week_start=next_week,
)
)
self._existing_transactions_loaded = True
logger.debug(
f"Loaded {len(self._existing_transactions or [])} existing transactions for {self.team.abbrev} org ({org_abbrevs}) week {next_week}"
)
except Exception as e:
logger.error(f"Failed to load existing transactions: {e}")
self._existing_transactions = []
self._existing_transactions_loaded = True
async def add_move(
self,
@ -678,8 +690,17 @@ class TransactionBuilder:
logger.info(
f"Created {len(transactions)} transactions for submission with move_id {move_id}"
)
# Invalidate roster cache so subsequent operations fetch fresh data
self.invalidate_roster_cache()
return transactions
def invalidate_roster_cache(self) -> None:
"""Invalidate cached roster data so next load fetches fresh data."""
self._roster_loaded = False
self._current_roster = None
def clear_moves(self) -> None:
"""Clear all moves from the transaction builder."""
self.moves.clear()

View File

@ -14,10 +14,17 @@ from services.draft_service import draft_service
from services.draft_pick_service import draft_pick_service
from services.draft_list_service import draft_list_service
from services.draft_sheet_service import get_draft_sheet_service
from services.league_service import league_service
from services.player_service import player_service
from services.roster_service import roster_service
from services.team_service import team_service
from utils.draft_helpers import validate_cap_space
from utils.logging import get_contextual_logger
from utils.helpers import get_team_salary_cap
from views.draft_views import create_on_clock_announcement_embed
from views.draft_views import (
create_on_clock_announcement_embed,
create_player_draft_card,
)
from config import get_config
@ -303,9 +310,6 @@ class DraftMonitorTask:
True if draft succeeded
"""
try:
from utils.draft_helpers import validate_cap_space
from services.team_service import team_service
# Get team roster for cap validation
roster = await team_service.get_team_roster(draft_pick.owner.id, "current")
@ -337,9 +341,6 @@ class DraftMonitorTask:
return False
# Get current league state for dem_week calculation
from services.player_service import player_service
from services.league_service import league_service
current = await league_service.get_current_state()
# Update player team with dem_week set to current.week + 2 for draft picks
@ -366,8 +367,6 @@ class DraftMonitorTask:
if draft_data.result_channel:
result_channel = guild.get_channel(draft_data.result_channel)
if result_channel:
from views.draft_views import create_player_draft_card
draft_card = await create_player_draft_card(player, draft_pick)
draft_card.set_footer(text="🤖 Auto-drafted from draft list")
await result_channel.send(embed=draft_card)

View File

@ -3,17 +3,19 @@ Tests for player image management commands.
Covers URL validation, permission checking, and command execution.
"""
import pytest
import asyncio
from unittest.mock import MagicMock, patch
import aiohttp
import discord
from aioresponses import aioresponses
from commands.profile.images import (
validate_url_format,
check_url_accessibility,
can_edit_player_image,
ImageCommands
ImageCommands,
)
from tests.factories import PlayerFactory, TeamFactory
@ -94,7 +96,7 @@ class TestURLAccessibility:
url = "https://example.com/image.jpg"
with aioresponses() as m:
m.head(url, status=200, headers={'content-type': 'image/jpeg'})
m.head(url, status=200, headers={"content-type": "image/jpeg"})
is_accessible, error = await check_url_accessibility(url)
@ -118,7 +120,7 @@ class TestURLAccessibility:
url = "https://example.com/page.html"
with aioresponses() as m:
m.head(url, status=200, headers={'content-type': 'text/html'})
m.head(url, status=200, headers={"content-type": "text/html"})
is_accessible, error = await check_url_accessibility(url)
@ -157,6 +159,7 @@ class TestPermissionChecking:
async def test_admin_can_edit_any_player(self):
"""Test administrator can edit any player's images."""
mock_interaction = MagicMock()
mock_interaction.user = MagicMock(spec=discord.Member)
mock_interaction.user.id = 12345
mock_interaction.user.guild_permissions.administrator = True
@ -186,7 +189,9 @@ class TestPermissionChecking:
mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams:
with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image(
@ -211,7 +216,9 @@ class TestPermissionChecking:
mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams:
with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image(
@ -236,7 +243,9 @@ class TestPermissionChecking:
mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams:
with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image(
@ -258,7 +267,9 @@ class TestPermissionChecking:
mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams:
with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = []
has_permission, error = await can_edit_player_image(
@ -299,7 +310,7 @@ class TestImageCommandsIntegration:
async def test_set_image_command_structure(self, commands_cog):
"""Test that set_image command is properly configured."""
assert hasattr(commands_cog, 'set_image')
assert hasattr(commands_cog, "set_image")
assert commands_cog.set_image.name == "set-image"
async def test_fancy_card_updates_vanity_card_field(self, commands_cog):

View File

@ -0,0 +1,245 @@
"""Tests for injury command team ownership verification (issue #18).
Ensures /injury set-new and /injury clear only allow users to manage
injuries for players on their own team (or organizational affiliates).
Admins bypass the check.
"""
import discord
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from commands.injuries.management import InjuryGroup
from models.player import Player
from models.team import Team
def _make_team(team_id: int, abbrev: str, sname: str | None = None) -> Team:
"""Create a Team via model_construct to skip validation.
For MiL teams (e.g. PORMIL), pass sname explicitly to avoid the IL
disambiguation logic in _get_base_abbrev treating them as IL teams.
"""
return Team.model_construct(
id=team_id,
abbrev=abbrev,
sname=sname or abbrev,
lname=f"Team {abbrev}",
season=13,
)
def _make_player(player_id: int, name: str, team: Team) -> Player:
"""Create a Player via model_construct to skip validation."""
return Player.model_construct(
id=player_id,
name=name,
wara=2.0,
season=13,
team_id=team.id,
team=team,
)
def _make_interaction(is_admin: bool = False) -> MagicMock:
"""Create a mock Discord interaction with configurable admin status."""
interaction = MagicMock()
interaction.user = MagicMock()
interaction.user.id = 12345
interaction.user.guild_permissions = MagicMock()
interaction.user.guild_permissions.administrator = is_admin
# Make isinstance(interaction.user, discord.Member) return True
interaction.user.__class__ = discord.Member
interaction.followup = MagicMock()
interaction.followup.send = AsyncMock()
return interaction
@pytest.fixture
def injury_group():
return InjuryGroup()
class TestVerifyTeamOwnership:
"""Tests for InjuryGroup._verify_team_ownership (issue #18)."""
@pytest.mark.asyncio
async def test_admin_bypasses_check(self, injury_group):
"""Admins should always pass the ownership check."""
interaction = _make_interaction(is_admin=True)
por_team = _make_team(1, "POR")
player = _make_player(100, "Mike Trout", por_team)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is True
@pytest.mark.asyncio
async def test_owner_passes_check(self, injury_group):
"""User who owns the player's team should pass."""
interaction = _make_interaction(is_admin=False)
por_team = _make_team(1, "POR")
player = _make_player(100, "Mike Trout", por_team)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_team)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is True
@pytest.mark.asyncio
async def test_org_affiliate_passes_check(self, injury_group):
"""User who owns the ML team should pass for MiL/IL affiliate players."""
interaction = _make_interaction(is_admin=False)
por_ml = _make_team(1, "POR")
por_mil = _make_team(2, "PORMIL", sname="POR MiL")
player = _make_player(100, "Minor Leaguer", por_mil)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_ml)
mock_ts.get_team = AsyncMock(return_value=por_mil)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is True
@pytest.mark.asyncio
async def test_different_team_fails(self, injury_group):
"""User who owns a different team should be denied."""
interaction = _make_interaction(is_admin=False)
por_team = _make_team(1, "POR")
nyy_team = _make_team(2, "NYY")
player = _make_player(100, "Mike Trout", nyy_team)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_team)
mock_ts.get_team = AsyncMock(return_value=nyy_team)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is False
interaction.followup.send.assert_called_once()
call_kwargs = interaction.followup.send.call_args
embed = call_kwargs.kwargs.get("embed") or call_kwargs.args[0]
assert "Not Your Player" in embed.title
@pytest.mark.asyncio
async def test_no_team_owned_fails(self, injury_group):
"""User who owns no team should be denied."""
interaction = _make_interaction(is_admin=False)
nyy_team = _make_team(2, "NYY")
player = _make_player(100, "Mike Trout", nyy_team)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=None)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is False
interaction.followup.send.assert_called_once()
call_kwargs = interaction.followup.send.call_args
embed = call_kwargs.kwargs.get("embed") or call_kwargs.args[0]
assert "No Team Found" in embed.title
@pytest.mark.asyncio
async def test_il_affiliate_passes_check(self, injury_group):
"""User who owns the ML team should pass for IL (injured list) players."""
interaction = _make_interaction(is_admin=False)
por_ml = _make_team(1, "POR")
por_il = _make_team(3, "PORIL", sname="POR IL")
player = _make_player(100, "IL Stash", por_il)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_ml)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is True
@pytest.mark.asyncio
async def test_player_team_not_populated_fails(self, injury_group):
"""Player with team_id but unpopulated team object should be denied.
Callers are expected to populate player.team before calling
_verify_team_ownership. If they don't, the method treats the missing
team as a failed check rather than silently allowing access.
"""
interaction = _make_interaction(is_admin=False)
por_team = _make_team(1, "POR")
player = Player.model_construct(
id=100,
name="Orphan Player",
wara=2.0,
season=13,
team_id=99,
team=None,
)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_team)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is False
@pytest.mark.asyncio
async def test_error_embeds_are_ephemeral(self, injury_group):
"""Error responses should be ephemeral so only the invoking user sees them."""
interaction = _make_interaction(is_admin=False)
nyy_team = _make_team(2, "NYY")
player = _make_player(100, "Mike Trout", nyy_team)
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
# Test "No Team Found" path
mock_ts.get_team_by_owner = AsyncMock(return_value=None)
await injury_group._verify_team_ownership(interaction, player)
call_kwargs = interaction.followup.send.call_args
assert call_kwargs.kwargs.get("ephemeral") is True
# Reset and test "Not Your Player" path
interaction = _make_interaction(is_admin=False)
por_team = _make_team(1, "POR")
with patch("services.team_service.team_service") as mock_ts, patch(
"commands.injuries.management.get_config"
) as mock_config:
mock_config.return_value.sba_season = 13
mock_ts.get_team_by_owner = AsyncMock(return_value=por_team)
await injury_group._verify_team_ownership(interaction, player)
call_kwargs = interaction.followup.send.call_args
assert call_kwargs.kwargs.get("ephemeral") is True
@pytest.mark.asyncio
async def test_player_without_team_id_passes(self, injury_group):
"""Players with no team_id should pass (can't verify, allow through)."""
interaction = _make_interaction(is_admin=False)
player = Player.model_construct(
id=100,
name="Free Agent",
wara=0.0,
season=13,
team_id=None,
team=None,
)
result = await injury_group._verify_team_ownership(interaction, player)
assert result is True

129
tests/test_models_play.py Normal file
View File

@ -0,0 +1,129 @@
"""Tests for Play model descriptive_text method.
Covers score text generation for key plays display, specifically
ensuring tied games show 'tied at X' instead of 'Team up X-X'.
"""
from models.play import Play
from models.player import Player
from models.team import Team
def _make_team(abbrev: str) -> Team:
"""Create a minimal Team for descriptive_text tests."""
return Team.model_construct(
id=1,
abbrev=abbrev,
sname=abbrev,
lname=f"Team {abbrev}",
season=13,
)
def _make_player(name: str, team: Team) -> Player:
"""Create a minimal Player for descriptive_text tests."""
return Player.model_construct(id=1, name=name, wara=0.0, season=13, team_id=team.id)
def _make_play(**overrides) -> Play:
"""Create a Play with sensible defaults for descriptive_text tests."""
tst_team = _make_team("TST")
opp_team = _make_team("OPP")
defaults = dict(
id=1,
game_id=1,
play_num=1,
on_base_code="000",
inning_half="top",
inning_num=7,
batting_order=1,
starting_outs=2,
away_score=0,
home_score=0,
outs=1,
batter_id=10,
batter=_make_player("Test Batter", tst_team),
batter_team=tst_team,
pitcher_id=20,
pitcher=_make_player("Test Pitcher", opp_team),
pitcher_team=opp_team,
)
defaults.update(overrides)
return Play.model_construct(**defaults)
class TestDescriptiveTextScoreText:
"""Tests for score text in Play.descriptive_text (issue #48)."""
def test_tied_score_shows_tied_at(self):
"""When scores are equal after the play, should show 'tied at X' not 'Team up X-X'."""
away = _make_team("BSG")
home = _make_team("DEN")
# Top 7: away scores 1 RBI, making it 2-2
play = _make_play(
inning_half="top",
inning_num=7,
away_score=1,
home_score=2,
rbi=1,
hit=1,
)
result = play.descriptive_text(away, home)
assert "tied at 2" in result
assert "up" not in result
def test_home_team_leading(self):
"""When home team leads, should show 'HOME up X-Y'."""
away = _make_team("BSG")
home = _make_team("DEN")
play = _make_play(
inning_half="top",
away_score=0,
home_score=3,
outs=1,
)
result = play.descriptive_text(away, home)
assert "DEN up 3-0" in result
def test_away_team_leading(self):
"""When away team leads, should show 'AWAY up X-Y'."""
away = _make_team("BSG")
home = _make_team("DEN")
play = _make_play(
inning_half="bot",
away_score=5,
home_score=2,
outs=1,
)
result = play.descriptive_text(away, home)
assert "BSG up 5-2" in result
def test_tied_at_zero(self):
"""Tied at 0-0 should show 'tied at 0'."""
away = _make_team("BSG")
home = _make_team("DEN")
play = _make_play(
inning_half="top",
away_score=0,
home_score=0,
outs=1,
)
result = play.descriptive_text(away, home)
assert "tied at 0" in result
def test_rbi_creates_tie_bottom_inning(self):
"""Bottom inning RBI that ties the game should show 'tied at X'."""
away = _make_team("BSG")
home = _make_team("DEN")
# Bot 5: home scores 2 RBI, tying at 4-4
play = _make_play(
inning_half="bot",
inning_num=5,
away_score=4,
home_score=2,
rbi=2,
hit=1,
)
result = play.descriptive_text(away, home)
assert "tied at 4" in result
assert "up" not in result

View File

@ -364,7 +364,7 @@ class TestDraftService:
# Mock draft_pick_service at the module level
with patch(
"services.draft_pick_service.draft_pick_service"
"services.draft_service.draft_pick_service"
) as mock_pick_service:
unfilled_pick = DraftPick(
**create_draft_pick_data(
@ -402,7 +402,7 @@ class TestDraftService:
mock_config.return_value = config
with patch(
"services.draft_pick_service.draft_pick_service"
"services.draft_service.draft_pick_service"
) as mock_pick_service:
# Picks 26-28 are filled, 29 is empty
async def get_pick_side_effect(season, overall):

View File

@ -4,6 +4,7 @@ Tests for trade builder service.
Tests the TradeBuilder service functionality including multi-team management,
move validation, and trade validation logic.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
@ -109,7 +110,7 @@ class TestTradeBuilder:
self.player1.team_id = self.team1.id
# Mock team_service to return team1 for this player
with patch('services.trade_builder.team_service') as mock_team_service:
with patch("services.trade_builder.team_service") as mock_team_service:
mock_team_service.get_team = AsyncMock(return_value=self.team1)
# Don't mock is_same_organization - let the real method work
@ -119,7 +120,7 @@ class TestTradeBuilder:
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert success
@ -141,7 +142,7 @@ class TestTradeBuilder:
from_team=self.team2,
to_team=self.team1,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not success
@ -156,9 +157,7 @@ class TestTradeBuilder:
# Create a player on Free Agency
fa_player = PlayerFactory.create(
id=100,
name="FA Player",
team_id=get_config().free_agent_team_id
id=100, name="FA Player", team_id=get_config().free_agent_team_id
)
# Try to add player from FA (should fail)
@ -167,7 +166,7 @@ class TestTradeBuilder:
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not success
@ -182,9 +181,7 @@ class TestTradeBuilder:
# Create a player without a team
no_team_player = PlayerFactory.create(
id=101,
name="No Team Player",
team_id=None
id=101, name="No Team Player", team_id=None
)
# Try to add player without team (should fail)
@ -193,7 +190,7 @@ class TestTradeBuilder:
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not success
@ -208,24 +205,22 @@ class TestTradeBuilder:
# Create a player on team3 (not in trade)
player_on_team3 = PlayerFactory.create(
id=102,
name="Team3 Player",
team_id=self.team3.id
id=102, name="Team3 Player", team_id=self.team3.id
)
# Mock team_service to return team3 for this player
with patch('services.trade_builder.team_service') as mock_team_service:
with patch("services.trade_builder.team_service") as mock_team_service:
mock_team_service.get_team = AsyncMock(return_value=self.team3)
# Mock is_same_organization to return False (different organization, sync method)
with patch('models.team.Team.is_same_organization', return_value=False):
with patch("models.team.Team.is_same_organization", return_value=False):
# Try to add player from team3 claiming it's from team1 (should fail)
success, error = await builder.add_player_move(
player=player_on_team3,
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not success
@ -241,26 +236,24 @@ class TestTradeBuilder:
# Create a player on team1's minor league affiliate
player_on_team1_mil = PlayerFactory.create(
id=103,
name="Team1 MiL Player",
team_id=999 # Some MiL team ID
id=103, name="Team1 MiL Player", team_id=999 # Some MiL team ID
)
# Mock team_service to return the MiL team
mil_team = TeamFactory.create(id=999, abbrev="WVMiL", sname="West Virginia MiL")
with patch('services.trade_builder.team_service') as mock_team_service:
with patch("services.trade_builder.team_service") as mock_team_service:
mock_team_service.get_team = AsyncMock(return_value=mil_team)
# Mock is_same_organization to return True (same organization, sync method)
with patch('models.team.Team.is_same_organization', return_value=True):
with patch("models.team.Team.is_same_organization", return_value=True):
# Add player from WVMiL (should succeed because it's same organization as WV)
success, error = await builder.add_player_move(
player=player_on_team1_mil,
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MINOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert success
@ -278,7 +271,7 @@ class TestTradeBuilder:
team=self.team1,
player=self.player1,
from_roster=RosterType.MINOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert success
@ -293,7 +286,7 @@ class TestTradeBuilder:
team=self.team3,
player=self.player2,
from_roster=RosterType.MINOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not success
@ -309,7 +302,7 @@ class TestTradeBuilder:
self.player1.team_id = self.team1.id
# Mock team_service for adding the player
with patch('services.trade_builder.team_service') as mock_team_service:
with patch("services.trade_builder.team_service") as mock_team_service:
mock_team_service.get_team = AsyncMock(return_value=self.team1)
# Add a player move
@ -318,7 +311,7 @@ class TestTradeBuilder:
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
assert not builder.is_empty
@ -347,7 +340,7 @@ class TestTradeBuilder:
await builder.add_team(self.team2)
# Mock the transaction builders
with patch.object(builder, '_get_or_create_builder') as mock_get_builder:
with patch.object(builder, "_get_or_create_builder") as mock_get_builder:
mock_builder1 = MagicMock()
mock_builder2 = MagicMock()
@ -360,7 +353,7 @@ class TestTradeBuilder:
minor_league_count=5,
warnings=[],
errors=[],
suggestions=[]
suggestions=[],
)
mock_builder1.validate_transaction = AsyncMock(return_value=valid_result)
@ -391,7 +384,7 @@ class TestTradeBuilder:
await builder.add_team(self.team2)
# Mock the transaction builders
with patch.object(builder, '_get_or_create_builder') as mock_get_builder:
with patch.object(builder, "_get_or_create_builder") as mock_get_builder:
mock_builder1 = MagicMock()
mock_builder2 = MagicMock()
@ -404,7 +397,7 @@ class TestTradeBuilder:
minor_league_count=5,
warnings=[],
errors=[],
suggestions=[]
suggestions=[],
)
mock_builder1.validate_transaction = AsyncMock(return_value=valid_result)
@ -440,7 +433,7 @@ class TestTradeBuilder:
return self.team2
return None
with patch('services.trade_builder.team_service') as mock_team_service:
with patch("services.trade_builder.team_service") as mock_team_service:
mock_team_service.get_team = AsyncMock(side_effect=get_team_side_effect)
# Add balanced moves - no need to mock is_same_organization
@ -449,7 +442,7 @@ class TestTradeBuilder:
from_team=self.team1,
to_team=self.team2,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
await builder.add_player_move(
@ -457,7 +450,7 @@ class TestTradeBuilder:
from_team=self.team2,
to_team=self.team1,
from_roster=RosterType.MAJOR_LEAGUE,
to_roster=RosterType.MAJOR_LEAGUE
to_roster=RosterType.MAJOR_LEAGUE,
)
# Validate balanced trade
@ -829,7 +822,9 @@ class TestTradeBuilderCache:
"""
user_id = 12345
team1 = TeamFactory.west_virginia()
team3 = TeamFactory.create(id=999, abbrev="POR", name="Portland") # Non-participant
team3 = TeamFactory.create(
id=999, abbrev="POR", name="Portland"
) # Non-participant
# Create builder with team1
get_trade_builder(user_id, team1)
@ -955,7 +950,9 @@ class TestTradeBuilderCache:
This ensures proper error handling when a GM not in the trade tries to clear it.
"""
team3 = TeamFactory.create(id=999, abbrev="POR", name="Portland") # Non-participant
team3 = TeamFactory.create(
id=999, abbrev="POR", name="Portland"
) # Non-participant
result = clear_trade_builder_by_team(team3.id)
assert result is False
@ -982,7 +979,7 @@ class TestTradeValidationResult:
minor_league_count=5,
warnings=["Team1 warning"],
errors=["Team1 error"],
suggestions=["Team1 suggestion"]
suggestions=["Team1 suggestion"],
)
team2_validation = RosterValidationResult(
@ -991,28 +988,30 @@ class TestTradeValidationResult:
minor_league_count=4,
warnings=[],
errors=[],
suggestions=[]
suggestions=[],
)
result.participant_validations[1] = team1_validation
result.participant_validations[2] = team2_validation
result.team_abbrevs[1] = "TM1"
result.team_abbrevs[2] = "TM2"
result.is_legal = False # One team has errors
# Test aggregated results
# Test aggregated results - roster errors are prefixed with team abbrev
all_errors = result.all_errors
assert len(all_errors) == 3 # 2 trade + 1 team
assert "Trade error 1" in all_errors
assert "Team1 error" in all_errors
assert "[TM1] Team1 error" in all_errors
all_warnings = result.all_warnings
assert len(all_warnings) == 2 # 1 trade + 1 team
assert "Trade warning 1" in all_warnings
assert "Team1 warning" in all_warnings
assert "[TM1] Team1 warning" in all_warnings
all_suggestions = result.all_suggestions
assert len(all_suggestions) == 2 # 1 trade + 1 team
assert "Trade suggestion 1" in all_suggestions
assert "Team1 suggestion" in all_suggestions
assert "[TM1] Team1 suggestion" in all_suggestions
# Test participant validation lookup
team1_val = result.get_participant_validation(1)
@ -1029,4 +1028,4 @@ class TestTradeValidationResult:
assert len(result.all_errors) == 0
assert len(result.all_warnings) == 0
assert len(result.all_suggestions) == 0
assert len(result.participant_validations) == 0
assert len(result.participant_validations) == 0

View File

@ -3,6 +3,7 @@ Tests for TransactionService
Validates transaction service functionality, API interaction, and business logic.
"""
import pytest
from unittest.mock import AsyncMock, patch
@ -13,117 +14,131 @@ from exceptions import APIException
class TestTransactionService:
"""Test TransactionService functionality."""
@pytest.fixture
def service(self):
"""Create a fresh TransactionService instance for testing."""
return TransactionService()
@pytest.fixture
def mock_transaction_data(self):
"""Create mock transaction data for testing."""
return {
'id': 27787,
'week': 10,
'season': 12,
'moveid': 'Season-012-Week-10-19-13:04:41',
'player': {
'id': 12472,
'name': 'Test Player',
'wara': 2.47,
'season': 12,
'pos_1': 'LF'
"id": 27787,
"week": 10,
"season": 12,
"moveid": "Season-012-Week-10-19-13:04:41",
"player": {
"id": 12472,
"name": "Test Player",
"wara": 2.47,
"season": 12,
"pos_1": "LF",
},
'oldteam': {
'id': 508,
'abbrev': 'NYD',
'sname': 'Diamonds',
'lname': 'New York Diamonds',
'season': 12
"oldteam": {
"id": 508,
"abbrev": "NYD",
"sname": "Diamonds",
"lname": "New York Diamonds",
"season": 12,
},
'newteam': {
'id': 499,
'abbrev': 'WV',
'sname': 'Black Bears',
'lname': 'West Virginia Black Bears',
'season': 12
"newteam": {
"id": 499,
"abbrev": "WV",
"sname": "Black Bears",
"lname": "West Virginia Black Bears",
"season": 12,
},
'cancelled': False,
'frozen': False
"cancelled": False,
"frozen": False,
}
@pytest.fixture
@pytest.fixture
def mock_api_response(self, mock_transaction_data):
"""Create mock API response with multiple transactions."""
return {
'count': 3,
'transactions': [
"count": 3,
"transactions": [
mock_transaction_data,
{**mock_transaction_data, 'id': 27788, 'frozen': True},
{**mock_transaction_data, 'id': 27789, 'cancelled': True}
]
{**mock_transaction_data, "id": 27788, "frozen": True},
{**mock_transaction_data, "id": 27789, "cancelled": True},
],
}
@pytest.mark.asyncio
async def test_service_initialization(self, service):
"""Test service initialization."""
assert service.model_class == Transaction
assert service.endpoint == 'transactions'
assert service.endpoint == "transactions"
@pytest.mark.asyncio
async def test_get_team_transactions_basic(self, service, mock_api_response):
"""Test getting team transactions with basic parameters."""
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.return_value = [
Transaction.from_api_data(tx) for tx in mock_api_response['transactions']
Transaction.from_api_data(tx)
for tx in mock_api_response["transactions"]
]
result = await service.get_team_transactions('WV', 12)
result = await service.get_team_transactions("WV", 12)
assert len(result) == 3
assert all(isinstance(tx, Transaction) for tx in result)
# Verify API call was made
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_get_team_transactions_with_filters(self, service, mock_api_response):
"""Test getting team transactions with status filters."""
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.return_value = []
await service.get_team_transactions(
'WV', 12,
cancelled=True,
frozen=False,
week_start=5,
week_end=15
"WV", 12, cancelled=True, frozen=False, week_start=5, week_end=15
)
# Verify API call was made
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_get_team_transactions_sorting(self, service, mock_transaction_data):
"""Test transaction sorting by week and moveid."""
# Create transactions with different weeks and moveids
transactions_data = [
{**mock_transaction_data, 'id': 1, 'week': 10, 'moveid': 'Season-012-Week-10-19-13:04:41'},
{**mock_transaction_data, 'id': 2, 'week': 8, 'moveid': 'Season-012-Week-08-12-10:30:15'},
{**mock_transaction_data, 'id': 3, 'week': 10, 'moveid': 'Season-012-Week-10-15-09:22:33'},
{
**mock_transaction_data,
"id": 1,
"week": 10,
"moveid": "Season-012-Week-10-19-13:04:41",
},
{
**mock_transaction_data,
"id": 2,
"week": 8,
"moveid": "Season-012-Week-08-12-10:30:15",
},
{
**mock_transaction_data,
"id": 3,
"week": 10,
"moveid": "Season-012-Week-10-15-09:22:33",
},
]
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
mock_get.return_value = [Transaction.from_api_data(tx) for tx in transactions_data]
result = await service.get_team_transactions('WV', 12)
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.return_value = [
Transaction.from_api_data(tx) for tx in transactions_data
]
result = await service.get_team_transactions("WV", 12)
# Verify sorting: week 8 first, then week 10 sorted by moveid
assert result[0].week == 8
assert result[1].week == 10
assert result[2].week == 10
assert result[1].moveid < result[2].moveid # Alphabetical order
@pytest.mark.asyncio
async def test_get_pending_transactions(self, service):
"""Test getting pending transactions.
@ -131,115 +146,166 @@ class TestTransactionService:
The method first fetches the current week, then calls get_team_transactions
with week_start set to the current week.
"""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
mock_client.get.return_value = {'week': 17} # Simulate current week
mock_client.get.return_value = {"week": 17} # Simulate current week
mock_get_client.return_value = mock_client
with patch.object(service, 'get_team_transactions', new_callable=AsyncMock) as mock_get:
with patch.object(
service, "get_team_transactions", new_callable=AsyncMock
) as mock_get:
mock_get.return_value = []
await service.get_pending_transactions('WV', 12)
await service.get_pending_transactions("WV", 12)
mock_get.assert_called_once_with(
"WV", 12, cancelled=False, frozen=False, week_start=17
)
mock_get.assert_called_once_with('WV', 12, cancelled=False, frozen=False, week_start=17)
@pytest.mark.asyncio
async def test_get_frozen_transactions(self, service):
"""Test getting frozen transactions."""
with patch.object(service, 'get_team_transactions', new_callable=AsyncMock) as mock_get:
"""Test getting frozen transactions."""
with patch.object(
service, "get_team_transactions", new_callable=AsyncMock
) as mock_get:
mock_get.return_value = []
await service.get_frozen_transactions('WV', 12)
mock_get.assert_called_once_with('WV', 12, frozen=True)
await service.get_frozen_transactions("WV", 12)
mock_get.assert_called_once_with("WV", 12, frozen=True)
@pytest.mark.asyncio
async def test_get_processed_transactions_success(self, service, mock_transaction_data):
async def test_get_processed_transactions_success(
self, service, mock_transaction_data
):
"""Test getting processed transactions with current week lookup."""
# Mock current week response
current_response = {'week': 12}
current_response = {"week": 12}
# Create test transactions with different statuses
all_transactions = [
Transaction.from_api_data({**mock_transaction_data, 'id': 1, 'cancelled': False, 'frozen': False}), # pending
Transaction.from_api_data({**mock_transaction_data, 'id': 2, 'cancelled': False, 'frozen': True}), # frozen
Transaction.from_api_data({**mock_transaction_data, 'id': 3, 'cancelled': True, 'frozen': False}), # cancelled
Transaction.from_api_data({**mock_transaction_data, 'id': 4, 'cancelled': False, 'frozen': False}), # pending
Transaction.from_api_data(
{**mock_transaction_data, "id": 1, "cancelled": False, "frozen": False}
), # pending
Transaction.from_api_data(
{**mock_transaction_data, "id": 2, "cancelled": False, "frozen": True}
), # frozen
Transaction.from_api_data(
{**mock_transaction_data, "id": 3, "cancelled": True, "frozen": False}
), # cancelled
Transaction.from_api_data(
{**mock_transaction_data, "id": 4, "cancelled": False, "frozen": False}
), # pending
]
# Mock the service methods
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_client:
with patch.object(service, "get_client", new_callable=AsyncMock) as mock_client:
mock_api_client = AsyncMock()
mock_api_client.get.return_value = current_response
mock_client.return_value = mock_api_client
with patch.object(service, 'get_team_transactions', new_callable=AsyncMock) as mock_get_team:
with patch.object(
service, "get_team_transactions", new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = all_transactions
result = await service.get_processed_transactions('WV', 12)
result = await service.get_processed_transactions("WV", 12)
# Should return empty list since all test transactions are either pending, frozen, or cancelled
# (none are processed - not pending, not frozen, not cancelled)
assert len(result) == 0
# Verify current week API call
mock_api_client.get.assert_called_once_with('current')
mock_api_client.get.assert_called_once_with("current")
# Verify team transactions call with week range
mock_get_team.assert_called_once_with('WV', 12, week_start=8) # 12 - 4 = 8
mock_get_team.assert_called_once_with(
"WV", 12, week_start=8
) # 12 - 4 = 8
@pytest.mark.asyncio
async def test_get_processed_transactions_fallback(self, service):
"""Test processed transactions fallback when current week fails."""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_client:
with patch.object(service, "get_client", new_callable=AsyncMock) as mock_client:
# Mock client to raise exception
mock_client.side_effect = Exception("API Error")
with patch.object(service, 'get_team_transactions', new_callable=AsyncMock) as mock_get_team:
with patch.object(
service, "get_team_transactions", new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = []
result = await service.get_processed_transactions('WV', 12)
result = await service.get_processed_transactions("WV", 12)
assert result == []
# Verify fallback call without week range
mock_get_team.assert_called_with('WV', 12)
mock_get_team.assert_called_with("WV", 12)
@pytest.mark.asyncio
async def test_validate_transaction_success(self, service, mock_transaction_data):
"""Test successful transaction validation."""
transaction = Transaction.from_api_data(mock_transaction_data)
result = await service.validate_transaction(transaction)
assert isinstance(result, RosterValidation)
assert result.is_legal is True
assert len(result.errors) == 0
@pytest.mark.asyncio
async def test_validate_transaction_no_moves(self, service, mock_transaction_data):
"""Test transaction validation with no moves (edge case)."""
# For single-move transactions, this test simulates validation logic
transaction = Transaction.from_api_data(mock_transaction_data)
# Mock validation that would fail for complex business rules
with patch.object(service, 'validate_transaction') as mock_validate:
with patch.object(service, "validate_transaction") as mock_validate:
validation_result = RosterValidation(
is_legal=False,
errors=['Transaction validation failed']
is_legal=False, errors=["Transaction validation failed"]
)
mock_validate.return_value = validation_result
result = await service.validate_transaction(transaction)
assert result.is_legal is False
assert 'Transaction validation failed' in result.errors
@pytest.mark.skip(reason="Exception handling test needs refactoring for new patterns")
assert "Transaction validation failed" in result.errors
@pytest.mark.asyncio
async def test_validate_transaction_exception_handling(self, service, mock_transaction_data):
"""Test transaction validation exception handling."""
pass
async def test_validate_transaction_exception_handling(
self, service, mock_transaction_data
):
"""Test transaction validation exception handling.
When an unexpected exception occurs inside validate_transaction (e.g., the
RosterValidation constructor raises), the method's except clause catches it
and returns a failed RosterValidation containing the error message rather
than propagating the exception to the caller.
Covers the critical except path at services/transaction_service.py:187-192.
"""
transaction = Transaction.from_api_data(mock_transaction_data)
_real = RosterValidation
call_count = [0]
def patched_rv(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
raise RuntimeError("Simulated validation failure")
return _real(*args, **kwargs)
with patch(
"services.transaction_service.RosterValidation", side_effect=patched_rv
):
result = await service.validate_transaction(transaction)
assert isinstance(result, RosterValidation)
assert result.is_legal is False
assert len(result.errors) == 1
assert result.errors[0] == "Validation error: Simulated validation failure"
@pytest.mark.asyncio
async def test_cancel_transaction_success(self, service, mock_transaction_data):
"""Test successful transaction cancellation.
@ -248,51 +314,57 @@ class TestTransactionService:
returning a success message for bulk updates. We mock the client.patch()
method to simulate successful cancellation.
"""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
# cancel_transaction expects a string response for success
mock_client.patch.return_value = "Updated 1 transactions"
mock_get_client.return_value = mock_client
result = await service.cancel_transaction('27787')
result = await service.cancel_transaction("27787")
assert result is True
mock_client.patch.assert_called_once()
call_args = mock_client.patch.call_args
assert call_args[0][0] == 'transactions' # endpoint
assert 'cancelled' in call_args[0][1] # update_data contains 'cancelled'
assert call_args[1]['object_id'] == '27787' # transaction_id
assert call_args[0][0] == "transactions" # endpoint
assert "cancelled" in call_args[0][1] # update_data contains 'cancelled'
assert call_args[1]["object_id"] == "27787" # transaction_id
@pytest.mark.asyncio
async def test_cancel_transaction_not_found(self, service):
"""Test cancelling non-existent transaction.
When the API returns None (no response), cancel_transaction returns False.
"""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
mock_client.patch.return_value = None # No response = failure
mock_get_client.return_value = mock_client
result = await service.cancel_transaction('99999')
result = await service.cancel_transaction("99999")
assert result is False
@pytest.mark.asyncio
async def test_cancel_transaction_not_pending(self, service, mock_transaction_data):
"""Test cancelling already processed transaction.
The API handles validation - we just need to simulate a failure response.
"""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
mock_client.patch.return_value = None # API returns None on failure
mock_get_client.return_value = mock_client
result = await service.cancel_transaction('27787')
result = await service.cancel_transaction("27787")
assert result is False
@pytest.mark.asyncio
async def test_cancel_transaction_exception_handling(self, service):
"""Test transaction cancellation exception handling.
@ -300,106 +372,134 @@ class TestTransactionService:
When the API call raises an exception, cancel_transaction catches it
and returns False.
"""
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
mock_client.patch.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
result = await service.cancel_transaction('27787')
result = await service.cancel_transaction("27787")
assert result is False
@pytest.mark.asyncio
async def test_get_contested_transactions(self, service, mock_transaction_data):
"""Test getting contested transactions."""
# Create transactions where multiple teams want the same player
contested_data = [
{**mock_transaction_data, 'id': 1, 'newteam': {'id': 499, 'abbrev': 'WV', 'sname': 'Black Bears', 'lname': 'West Virginia Black Bears', 'season': 12}},
{**mock_transaction_data, 'id': 2, 'newteam': {'id': 502, 'abbrev': 'LAA', 'sname': 'Angels', 'lname': 'Los Angeles Angels', 'season': 12}}, # Same player, different team
{
**mock_transaction_data,
"id": 1,
"newteam": {
"id": 499,
"abbrev": "WV",
"sname": "Black Bears",
"lname": "West Virginia Black Bears",
"season": 12,
},
},
{
**mock_transaction_data,
"id": 2,
"newteam": {
"id": 502,
"abbrev": "LAA",
"sname": "Angels",
"lname": "Los Angeles Angels",
"season": 12,
},
}, # Same player, different team
]
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
mock_get.return_value = [Transaction.from_api_data(tx) for tx in contested_data]
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.return_value = [
Transaction.from_api_data(tx) for tx in contested_data
]
result = await service.get_contested_transactions(12, 10)
# Should return both transactions since they're for the same player
assert len(result) == 2
# Verify API call was made
mock_get.assert_called_once()
# Note: This test might need adjustment based on actual contested transaction logic
@pytest.mark.asyncio
async def test_api_exception_handling(self, service):
"""Test API exception handling in service methods."""
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.side_effect = APIException("API unavailable")
with pytest.raises(APIException):
await service.get_team_transactions('WV', 12)
await service.get_team_transactions("WV", 12)
def test_global_service_instance(self):
"""Test that global service instance is properly initialized."""
assert isinstance(transaction_service, TransactionService)
assert transaction_service.model_class == Transaction
assert transaction_service.endpoint == 'transactions'
assert transaction_service.endpoint == "transactions"
class TestTransactionServiceIntegration:
"""Integration tests for TransactionService with real-like scenarios."""
@pytest.mark.asyncio
async def test_full_transaction_workflow(self):
"""Test complete transaction workflow simulation."""
service = TransactionService()
# Mock data for a complete workflow
mock_data = {
'id': 27787,
'week': 10,
'season': 12,
'moveid': 'Season-012-Week-10-19-13:04:41',
'player': {
'id': 12472,
'name': 'Test Player',
'wara': 2.47,
'season': 12,
'pos_1': 'LF'
"id": 27787,
"week": 10,
"season": 12,
"moveid": "Season-012-Week-10-19-13:04:41",
"player": {
"id": 12472,
"name": "Test Player",
"wara": 2.47,
"season": 12,
"pos_1": "LF",
},
'oldteam': {
'id': 508,
'abbrev': 'NYD',
'sname': 'Diamonds',
'lname': 'New York Diamonds',
'season': 12
"oldteam": {
"id": 508,
"abbrev": "NYD",
"sname": "Diamonds",
"lname": "New York Diamonds",
"season": 12,
},
'newteam': {
'id': 499,
'abbrev': 'WV',
'sname': 'Black Bears',
'lname': 'West Virginia Black Bears',
'season': 12
"newteam": {
"id": 499,
"abbrev": "WV",
"sname": "Black Bears",
"lname": "West Virginia Black Bears",
"season": 12,
},
'cancelled': False,
'frozen': False
"cancelled": False,
"frozen": False,
}
# Mock the full workflow properly
transaction = Transaction.from_api_data(mock_data)
# Mock get_pending_transactions to return our test transaction
with patch.object(service, 'get_pending_transactions', new_callable=AsyncMock) as mock_get_pending:
with patch.object(
service, "get_pending_transactions", new_callable=AsyncMock
) as mock_get_pending:
mock_get_pending.return_value = [transaction]
# Mock cancel_transaction
with patch.object(service, 'get_client', new_callable=AsyncMock) as mock_get_client:
with patch.object(
service, "get_client", new_callable=AsyncMock
) as mock_get_client:
mock_client = AsyncMock()
mock_client.patch.return_value = "Updated 1 transactions"
mock_get_client.return_value = mock_client
# Test workflow: get pending -> validate -> cancel
pending = await service.get_pending_transactions('WV', 12)
pending = await service.get_pending_transactions("WV", 12)
assert len(pending) == 1
validation = await service.validate_transaction(pending[0])
@ -407,61 +507,62 @@ class TestTransactionServiceIntegration:
cancelled = await service.cancel_transaction(str(pending[0].id))
assert cancelled is True
@pytest.mark.asyncio
async def test_performance_with_large_dataset(self):
"""Test service performance with large transaction dataset."""
service = TransactionService()
# Create 100 mock transactions
large_dataset = []
for i in range(100):
tx_data = {
'id': i,
'week': (i % 18) + 1, # Weeks 1-18
'season': 12,
'moveid': f'Season-012-Week-{(i % 18) + 1:02d}-{i}',
'player': {
'id': i + 1000,
'name': f'Player {i}',
'wara': round(1.0 + (i % 50) * 0.1, 2),
'season': 12,
'pos_1': 'LF'
"id": i,
"week": (i % 18) + 1, # Weeks 1-18
"season": 12,
"moveid": f"Season-012-Week-{(i % 18) + 1:02d}-{i}",
"player": {
"id": i + 1000,
"name": f"Player {i}",
"wara": round(1.0 + (i % 50) * 0.1, 2),
"season": 12,
"pos_1": "LF",
},
'oldteam': {
'id': 508,
'abbrev': 'NYD',
'sname': 'Diamonds',
'lname': 'New York Diamonds',
'season': 12
"oldteam": {
"id": 508,
"abbrev": "NYD",
"sname": "Diamonds",
"lname": "New York Diamonds",
"season": 12,
},
'newteam': {
'id': 499,
'abbrev': 'WV',
'sname': 'Black Bears',
'lname': 'West Virginia Black Bears',
'season': 12
"newteam": {
"id": 499,
"abbrev": "WV",
"sname": "Black Bears",
"lname": "West Virginia Black Bears",
"season": 12,
},
'cancelled': i % 10 == 0, # Every 10th transaction is cancelled
'frozen': i % 7 == 0 # Every 7th transaction is frozen
"cancelled": i % 10 == 0, # Every 10th transaction is cancelled
"frozen": i % 7 == 0, # Every 7th transaction is frozen
}
large_dataset.append(Transaction.from_api_data(tx_data))
with patch.object(service, 'get_all_items', new_callable=AsyncMock) as mock_get:
with patch.object(service, "get_all_items", new_callable=AsyncMock) as mock_get:
mock_get.return_value = large_dataset
# Test that service handles large datasets efficiently
import time
start_time = time.time()
result = await service.get_team_transactions('WV', 12)
result = await service.get_team_transactions("WV", 12)
end_time = time.time()
processing_time = end_time - start_time
assert len(result) == 100
assert processing_time < 1.0 # Should process quickly
# Verify sorting worked correctly
for i in range(len(result) - 1):
assert result[i].week <= result[i + 1].week
assert result[i].week <= result[i + 1].week

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@ Tests for Injury Modal Validation in Discord Bot v2.0
Tests week and game validation for BatterInjuryModal and PitcherRestModal,
including regular season and playoff round validation.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
@ -36,7 +37,7 @@ def sample_player():
season=12,
team_id=1,
image="https://example.com/player.jpg",
pos_1="1B"
pos_1="1B",
)
@ -60,21 +61,21 @@ class TestBatterInjuryModalWeekValidation:
"""Test week validation in BatterInjuryModal."""
@pytest.mark.asyncio
async def test_regular_season_week_valid(self, sample_player, mock_interaction, mock_config):
async def test_regular_season_week_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that regular season weeks (1-18) are accepted."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
# Mock the TextInput values
modal.current_week = create_mock_text_input("10")
modal.current_game = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
# Mock successful injury creation
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
@ -84,26 +85,25 @@ class TestBatterInjuryModalWeekValidation:
# Should not send error message
assert not any(
call[1].get('embed') and
'Invalid Week' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Week" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_playoff_week_19_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_week_19_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that playoff week 19 (round 1) is accepted."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("19")
modal.current_game = create_mock_text_input("3")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -112,26 +112,25 @@ class TestBatterInjuryModalWeekValidation:
# Should not send error message
assert not any(
call[1].get('embed') and
'Invalid Week' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Week" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_playoff_week_21_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_week_21_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that playoff week 21 (round 3) is accepted."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("21")
modal.current_game = create_mock_text_input("5")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -140,73 +139,68 @@ class TestBatterInjuryModalWeekValidation:
# Should not send error message
assert not any(
call[1].get('embed') and
'Invalid Week' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Week" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_week_too_high_rejected(self, sample_player, mock_interaction, mock_config):
async def test_week_too_high_rejected(
self, sample_player, mock_interaction, mock_config
):
"""Test that week > 21 is rejected."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("22")
modal.current_game = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Week' in call_kwargs['embed'].title
assert '21 (including playoffs)' in call_kwargs['embed'].description
assert "embed" in call_kwargs
assert "Invalid Week" in call_kwargs["embed"].title
assert "21 (including playoffs)" in call_kwargs["embed"].description
@pytest.mark.asyncio
async def test_week_zero_rejected(self, sample_player, mock_interaction, mock_config):
async def test_week_zero_rejected(
self, sample_player, mock_interaction, mock_config
):
"""Test that week 0 is rejected."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("0")
modal.current_game = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Week' in call_kwargs['embed'].title
assert "embed" in call_kwargs
assert "Invalid Week" in call_kwargs["embed"].title
class TestBatterInjuryModalGameValidation:
"""Test game validation in BatterInjuryModal."""
@pytest.mark.asyncio
async def test_regular_season_game_4_valid(self, sample_player, mock_interaction, mock_config):
async def test_regular_season_game_4_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 4 is accepted in regular season."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("10")
modal.current_game = create_mock_text_input("4")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -215,48 +209,45 @@ class TestBatterInjuryModalGameValidation:
# Should not send error about invalid game
assert not any(
call[1].get('embed') and
'Invalid Game' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Game" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_regular_season_game_5_rejected(self, sample_player, mock_interaction, mock_config):
async def test_regular_season_game_5_rejected(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 5 is rejected in regular season (only 4 games)."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("10")
modal.current_game = create_mock_text_input("5")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Game' in call_kwargs['embed'].title
assert 'between 1 and 4' in call_kwargs['embed'].description
assert "embed" in call_kwargs
assert "Invalid Game" in call_kwargs["embed"].title
assert "between 1 and 4" in call_kwargs["embed"].description
@pytest.mark.asyncio
async def test_playoff_round_1_game_5_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_1_game_5_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 5 is accepted in playoff round 1 (week 19)."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("19")
modal.current_game = create_mock_text_input("5")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -265,48 +256,45 @@ class TestBatterInjuryModalGameValidation:
# Should not send error about invalid game
assert not any(
call[1].get('embed') and
'Invalid Game' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Game" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_playoff_round_1_game_6_rejected(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_1_game_6_rejected(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 6 is rejected in playoff round 1 (only 5 games)."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("19")
modal.current_game = create_mock_text_input("6")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Game' in call_kwargs['embed'].title
assert 'between 1 and 5' in call_kwargs['embed'].description
assert "embed" in call_kwargs
assert "Invalid Game" in call_kwargs["embed"].title
assert "between 1 and 5" in call_kwargs["embed"].description
@pytest.mark.asyncio
async def test_playoff_round_2_game_7_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_2_game_7_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 7 is accepted in playoff round 2 (week 20)."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("20")
modal.current_game = create_mock_text_input("7")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -315,26 +303,25 @@ class TestBatterInjuryModalGameValidation:
# Should not send error about invalid game
assert not any(
call[1].get('embed') and
'Invalid Game' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Game" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_playoff_round_3_game_7_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_3_game_7_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 7 is accepted in playoff round 3 (week 21)."""
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("21")
modal.current_game = create_mock_text_input("7")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -343,8 +330,7 @@ class TestBatterInjuryModalGameValidation:
# Should not send error about invalid game
assert not any(
call[1].get('embed') and
'Invalid Game' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Game" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@ -353,21 +339,21 @@ class TestPitcherRestModalValidation:
"""Test week and game validation in PitcherRestModal (should match BatterInjuryModal)."""
@pytest.mark.asyncio
async def test_playoff_week_19_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_week_19_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that playoff week 19 is accepted for pitchers."""
modal = PitcherRestModal(
player=sample_player,
injury_games=4,
season=12
)
modal = PitcherRestModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("19")
modal.current_game = create_mock_text_input("3")
modal.rest_games = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -376,50 +362,45 @@ class TestPitcherRestModalValidation:
# Should not send error about invalid week
assert not any(
call[1].get('embed') and
'Invalid Week' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Week" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_week_22_rejected(self, sample_player, mock_interaction, mock_config):
"""Test that week 22 is rejected for pitchers."""
modal = PitcherRestModal(
player=sample_player,
injury_games=4,
season=12
)
modal = PitcherRestModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("22")
modal.current_game = create_mock_text_input("2")
modal.rest_games = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Week' in call_kwargs['embed'].title
assert '21 (including playoffs)' in call_kwargs['embed'].description
assert "embed" in call_kwargs
assert "Invalid Week" in call_kwargs["embed"].title
assert "21 (including playoffs)" in call_kwargs["embed"].description
@pytest.mark.asyncio
async def test_playoff_round_2_game_7_valid(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_2_game_7_valid(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 7 is accepted in playoff round 2 for pitchers."""
modal = PitcherRestModal(
player=sample_player,
injury_games=4,
season=12
)
modal = PitcherRestModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("20")
modal.current_game = create_mock_text_input("7")
modal.rest_games = create_mock_text_input("3")
with patch('config.get_config', return_value=mock_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("config.get_config", return_value=mock_config), patch(
"services.player_service.player_service"
) as mock_player_service, patch(
"services.injury_service.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -428,40 +409,39 @@ class TestPitcherRestModalValidation:
# Should not send error about invalid game
assert not any(
call[1].get('embed') and
'Invalid Game' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Game" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)
@pytest.mark.asyncio
async def test_playoff_round_1_game_6_rejected(self, sample_player, mock_interaction, mock_config):
async def test_playoff_round_1_game_6_rejected(
self, sample_player, mock_interaction, mock_config
):
"""Test that game 6 is rejected in playoff round 1 for pitchers (only 5 games)."""
modal = PitcherRestModal(
player=sample_player,
injury_games=4,
season=12
)
modal = PitcherRestModal(player=sample_player, injury_games=4, season=12)
modal.current_week = create_mock_text_input("19")
modal.current_game = create_mock_text_input("6")
modal.rest_games = create_mock_text_input("2")
with patch('config.get_config', return_value=mock_config):
with patch("config.get_config", return_value=mock_config):
await modal.on_submit(mock_interaction)
# Should send error message
mock_interaction.response.send_message.assert_called_once()
call_kwargs = mock_interaction.response.send_message.call_args[1]
assert 'embed' in call_kwargs
assert 'Invalid Game' in call_kwargs['embed'].title
assert 'between 1 and 5' in call_kwargs['embed'].description
assert "embed" in call_kwargs
assert "Invalid Game" in call_kwargs["embed"].title
assert "between 1 and 5" in call_kwargs["embed"].description
class TestConfigDrivenValidation:
"""Test that validation correctly uses config values."""
@pytest.mark.asyncio
async def test_custom_config_values_respected(self, sample_player, mock_interaction):
async def test_custom_config_values_respected(
self, sample_player, mock_interaction
):
"""Test that custom config values change validation behavior."""
# Create config with different values
custom_config = MagicMock()
@ -472,19 +452,17 @@ class TestConfigDrivenValidation:
custom_config.playoff_round_two_games = 7
custom_config.playoff_round_three_games = 7
modal = BatterInjuryModal(
player=sample_player,
injury_games=4,
season=12
)
modal = BatterInjuryModal(player=sample_player, injury_games=4, season=12)
# Week 22 should be valid with this config (20 + 2 = 22)
modal.current_week = create_mock_text_input("22")
modal.current_game = create_mock_text_input("3")
with patch('config.get_config', return_value=custom_config), \
patch('services.player_service.player_service') as mock_player_service, \
patch('services.injury_service.injury_service') as mock_injury_service:
with patch("views.modals.get_config", return_value=custom_config), patch(
"views.modals.player_service"
) as mock_player_service, patch(
"views.modals.injury_service"
) as mock_injury_service:
mock_injury_service.create_injury = AsyncMock(return_value=MagicMock(id=1))
mock_player_service.update_player = AsyncMock()
@ -493,7 +471,6 @@ class TestConfigDrivenValidation:
# Should not send error about invalid week
assert not any(
call[1].get('embed') and
'Invalid Week' in str(call[1]['embed'].title)
call[1].get("embed") and "Invalid Week" in str(call[1]["embed"].title)
for call in mock_interaction.response.send_message.call_args_list
)

View File

@ -3,19 +3,20 @@ Autocomplete Utilities
Shared autocomplete functions for Discord slash commands.
"""
from typing import List
import discord
from discord import app_commands
from config import get_config
from models.team import RosterType
from services.player_service import player_service
from services.team_service import team_service
from utils.team_utils import get_user_major_league_team
async def player_autocomplete(
interaction: discord.Interaction,
current: str
interaction: discord.Interaction, current: str
) -> List[app_commands.Choice[str]]:
"""
Autocomplete for player names with team context prioritization.
@ -37,7 +38,9 @@ async def player_autocomplete(
user_team = await get_user_major_league_team(interaction.user.id)
# Search for players using the search endpoint
players = await player_service.search_players(current, limit=50, season=get_config().sba_season)
players = await player_service.search_players(
current, limit=50, season=get_config().sba_season
)
# Separate players by team (user's team vs others)
user_team_players = []
@ -46,10 +49,11 @@ async def player_autocomplete(
for player in players:
# Check if player belongs to user's team (any roster section)
is_users_player = False
if user_team and hasattr(player, 'team') and player.team:
if user_team and hasattr(player, "team") and player.team:
# Check if player is from user's major league team or has same base team
if (player.team.id == user_team.id or
(hasattr(player, 'team_id') and player.team_id == user_team.id)):
if player.team.id == user_team.id or (
hasattr(player, "team_id") and player.team_id == user_team.id
):
is_users_player = True
if is_users_player:
@ -63,7 +67,7 @@ async def player_autocomplete(
# Add user's team players first (prioritized)
for player in user_team_players[:15]: # Limit user team players
team_info = f"{player.primary_position}"
if hasattr(player, 'team') and player.team:
if hasattr(player, "team") and player.team:
team_info += f" - {player.team.abbrev}"
choice_name = f"{player.name} ({team_info})"
@ -73,7 +77,7 @@ async def player_autocomplete(
remaining_slots = 25 - len(choices)
for player in other_players[:remaining_slots]:
team_info = f"{player.primary_position}"
if hasattr(player, 'team') and player.team:
if hasattr(player, "team") and player.team:
team_info += f" - {player.team.abbrev}"
choice_name = f"{player.name} ({team_info})"
@ -87,8 +91,7 @@ async def player_autocomplete(
async def team_autocomplete(
interaction: discord.Interaction,
current: str
interaction: discord.Interaction, current: str
) -> List[app_commands.Choice[str]]:
"""
Autocomplete for team abbreviations.
@ -109,8 +112,10 @@ async def team_autocomplete(
# Filter teams by current input and limit to 25
matching_teams = [
team for team in teams
if current.lower() in team.abbrev.lower() or current.lower() in team.sname.lower()
team
for team in teams
if current.lower() in team.abbrev.lower()
or current.lower() in team.sname.lower()
][:25]
choices = []
@ -126,8 +131,7 @@ async def team_autocomplete(
async def major_league_team_autocomplete(
interaction: discord.Interaction,
current: str
interaction: discord.Interaction, current: str
) -> List[app_commands.Choice[str]]:
"""
Autocomplete for Major League team abbreviations only.
@ -149,16 +153,16 @@ async def major_league_team_autocomplete(
all_teams = await team_service.get_teams_by_season(get_config().sba_season)
# Filter to only Major League teams using the model's helper method
from models.team import RosterType
ml_teams = [
team for team in all_teams
if team.roster_type() == RosterType.MAJOR_LEAGUE
team for team in all_teams if team.roster_type() == RosterType.MAJOR_LEAGUE
]
# Filter teams by current input and limit to 25
matching_teams = [
team for team in ml_teams
if current.lower() in team.abbrev.lower() or current.lower() in team.sname.lower()
team
for team in ml_teams
if current.lower() in team.abbrev.lower()
or current.lower() in team.sname.lower()
][:25]
choices = []
@ -170,4 +174,4 @@ async def major_league_team_autocomplete(
except Exception:
# Silently fail on autocomplete errors
return []
return []

View File

@ -4,10 +4,12 @@ Discord Helper Utilities
Common Discord-related helper functions for channel lookups,
message sending, and formatting.
"""
from typing import Optional, List
import discord
from discord.ext import commands
from config import get_config
from models.play import Play
from models.team import Team
from utils.logging import get_contextual_logger
@ -16,8 +18,7 @@ logger = get_contextual_logger(__name__)
async def get_channel_by_name(
bot: commands.Bot,
channel_name: str
bot: commands.Bot, channel_name: str
) -> Optional[discord.TextChannel]:
"""
Get a text channel by name from the configured guild.
@ -29,8 +30,6 @@ async def get_channel_by_name(
Returns:
TextChannel if found, None otherwise
"""
from config import get_config
config = get_config()
guild_id = config.guild_id
@ -56,7 +55,7 @@ async def send_to_channel(
bot: commands.Bot,
channel_name: str,
content: Optional[str] = None,
embed: Optional[discord.Embed] = None
embed: Optional[discord.Embed] = None,
) -> bool:
"""
Send a message to a channel by name.
@ -80,9 +79,9 @@ async def send_to_channel(
# Build kwargs to avoid passing None for embed
kwargs = {}
if content is not None:
kwargs['content'] = content
kwargs["content"] = content
if embed is not None:
kwargs['embed'] = embed
kwargs["embed"] = embed
await channel.send(**kwargs)
logger.info(f"Sent message to #{channel_name}")
@ -92,11 +91,7 @@ async def send_to_channel(
return False
def format_key_plays(
plays: List[Play],
away_team: Team,
home_team: Team
) -> str:
def format_key_plays(plays: List[Play], away_team: Team, home_team: Team) -> str:
"""
Format top plays into embed field text.
@ -122,9 +117,7 @@ def format_key_plays(
async def set_channel_visibility(
channel: discord.TextChannel,
visible: bool,
reason: Optional[str] = None
channel: discord.TextChannel, visible: bool, reason: Optional[str] = None
) -> bool:
"""
Set channel visibility for @everyone.
@ -148,18 +141,14 @@ async def set_channel_visibility(
# Grant @everyone permission to view channel
default_reason = "Channel made visible to all members"
await channel.set_permissions(
everyone_role,
view_channel=True,
reason=reason or default_reason
everyone_role, view_channel=True, reason=reason or default_reason
)
logger.info(f"Set #{channel.name} to VISIBLE for @everyone")
else:
# Remove @everyone view permission
default_reason = "Channel hidden from members"
await channel.set_permissions(
everyone_role,
view_channel=False,
reason=reason or default_reason
everyone_role, view_channel=False, reason=reason or default_reason
)
logger.info(f"Set #{channel.name} to HIDDEN for @everyone")

View File

@ -3,8 +3,10 @@ Draft utility functions for Discord Bot v2.0
Provides helper functions for draft order calculation and cap space validation.
"""
import math
from typing import Tuple
from utils.helpers import get_team_salary_cap, SALARY_CAP_TOLERANCE
from utils.logging import get_contextual_logger
from config import get_config
@ -109,9 +111,7 @@ def calculate_overall_from_round_position(round_num: int, position: int) -> int:
async def validate_cap_space(
roster: dict,
new_player_wara: float,
team=None
roster: dict, new_player_wara: float, team=None
) -> Tuple[bool, float, float]:
"""
Validate team has cap space to draft player.
@ -138,17 +138,15 @@ async def validate_cap_space(
Raises:
ValueError: If roster structure is invalid
"""
from utils.helpers import get_team_salary_cap, SALARY_CAP_TOLERANCE
config = get_config()
cap_limit = get_team_salary_cap(team)
cap_player_count = config.cap_player_count
if not roster or not roster.get('active'):
if not roster or not roster.get("active"):
raise ValueError("Invalid roster structure - missing 'active' key")
active_roster = roster['active']
current_players = active_roster.get('players', [])
active_roster = roster["active"]
current_players = active_roster.get("players", [])
# Calculate how many players count toward cap after adding new player
current_roster_size = len(current_players)
@ -172,7 +170,7 @@ async def validate_cap_space(
players_counted = max(0, cap_player_count - max_zeroes)
# Sort all players (including new) by sWAR ASCENDING (cheapest first)
all_players_wara = [p['wara'] for p in current_players] + [new_player_wara]
all_players_wara = [p["wara"] for p in current_players] + [new_player_wara]
sorted_wara = sorted(all_players_wara) # Ascending order
# Sum bottom N players (the cheapest ones that count toward cap)

View File

@ -7,6 +7,7 @@ servers and user types:
- @league_only: Only available in the league server
- @requires_team: User must have a team (works with global commands)
"""
import logging
from functools import wraps
from typing import Optional, Callable
@ -21,6 +22,7 @@ logger = logging.getLogger(__name__)
class PermissionError(Exception):
"""Raised when a user doesn't have permission to use a command."""
pass
@ -54,17 +56,16 @@ async def get_user_team(user_id: int) -> Optional[dict]:
# This call is automatically cached by TeamService
config = get_config()
team = await team_service.get_team_by_owner(
owner_id=user_id,
season=config.sba_season
owner_id=user_id, season=config.sba_season
)
if team:
logger.debug(f"User {user_id} has team: {team.lname}")
return {
'id': team.id,
'name': team.lname,
'abbrev': team.abbrev,
'season': team.season
"id": team.id,
"name": team.lname,
"abbrev": team.abbrev,
"season": team.season,
}
logger.debug(f"User {user_id} does not have a team")
@ -77,6 +78,18 @@ def is_league_server(guild_id: int) -> bool:
return guild_id == config.guild_id
def is_admin(interaction: discord.Interaction) -> bool:
"""Check if the interaction user is a server administrator.
Includes an isinstance guard for discord.Member so this is safe
to call from DM contexts (where guild_permissions is unavailable).
"""
return (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
)
def league_only():
"""
Decorator to restrict a command to the league server only.
@ -87,14 +100,14 @@ def league_only():
async def team_command(self, interaction: discord.Interaction):
# Only executes in league server
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
# Check if in a guild
if not interaction.guild:
await interaction.response.send_message(
"❌ This command can only be used in a server.",
ephemeral=True
"❌ This command can only be used in a server.", ephemeral=True
)
return
@ -102,13 +115,14 @@ def league_only():
if not is_league_server(interaction.guild.id):
await interaction.response.send_message(
"❌ This command is only available in the SBa league server.",
ephemeral=True
ephemeral=True,
)
return
return await func(self, interaction, *args, **kwargs)
return wrapper
return decorator
@ -123,6 +137,7 @@ def requires_team():
async def mymoves_command(self, interaction: discord.Interaction):
# Only executes if user has a team
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
@ -133,29 +148,33 @@ def requires_team():
if team is None:
await interaction.response.send_message(
"❌ This command requires you to have a team in the SBa league. Contact an admin if you believe this is an error.",
ephemeral=True
ephemeral=True,
)
return
# Store team info in interaction for command to use
# This allows commands to access the team without another lookup
interaction.extras['user_team'] = team
interaction.extras["user_team"] = team
return await func(self, interaction, *args, **kwargs)
except Exception as e:
# Log the error for debugging
logger.error(f"Error checking team ownership for user {interaction.user.id}: {e}", exc_info=True)
logger.error(
f"Error checking team ownership for user {interaction.user.id}: {e}",
exc_info=True,
)
# Provide helpful error message to user
await interaction.response.send_message(
"❌ Unable to verify team ownership due to a temporary error. Please try again in a moment. "
"If this persists, contact an admin.",
ephemeral=True
ephemeral=True,
)
return
return wrapper
return decorator
@ -170,12 +189,14 @@ def global_command():
async def roll_command(self, interaction: discord.Interaction):
# Available in all servers
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
return await func(self, interaction, *args, **kwargs)
return wrapper
return decorator
@ -190,35 +211,35 @@ def admin_only():
async def sync_command(self, interaction: discord.Interaction):
# Only executes for admins
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
# Check if user is guild admin
if not interaction.guild:
await interaction.response.send_message(
"❌ This command can only be used in a server.",
ephemeral=True
"❌ This command can only be used in a server.", ephemeral=True
)
return
# Check if user has admin permissions
if not isinstance(interaction.user, discord.Member):
await interaction.response.send_message(
"❌ Unable to verify permissions.",
ephemeral=True
"❌ Unable to verify permissions.", ephemeral=True
)
return
if not interaction.user.guild_permissions.administrator:
await interaction.response.send_message(
"❌ This command requires administrator permissions.",
ephemeral=True
ephemeral=True,
)
return
return await func(self, interaction, *args, **kwargs)
return wrapper
return decorator
@ -241,6 +262,7 @@ def league_admin_only():
async def admin_sync_prefix(self, ctx: commands.Context):
# Only league server admins can use this
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, ctx_or_interaction, *args, **kwargs):
@ -254,6 +276,7 @@ def league_admin_only():
async def send_error(msg: str):
await ctx.send(msg)
else:
interaction = ctx_or_interaction
guild = interaction.guild
@ -269,7 +292,9 @@ def league_admin_only():
# Check if league server
if not is_league_server(guild.id):
await send_error("❌ This command is only available in the SBa league server.")
await send_error(
"❌ This command is only available in the SBa league server."
)
return
# Check admin permissions
@ -284,4 +309,5 @@ def league_admin_only():
return await func(self, ctx_or_interaction, *args, **kwargs)
return wrapper
return decorator

View File

@ -3,6 +3,7 @@ Draft Views for Discord Bot v2.0
Provides embeds and UI components for draft system.
"""
from typing import Optional, List
import discord
@ -14,6 +15,7 @@ from models.player import Player
from models.draft_list import DraftList
from views.embeds import EmbedTemplate, EmbedColors
from utils.draft_helpers import format_pick_display, get_round_name
from utils.helpers import get_team_salary_cap
async def create_on_the_clock_embed(
@ -22,7 +24,7 @@ async def create_on_the_clock_embed(
recent_picks: List[DraftPick],
upcoming_picks: List[DraftPick],
team_roster_swar: Optional[float] = None,
sheet_url: Optional[str] = None
sheet_url: Optional[str] = None,
) -> discord.Embed:
"""
Create "on the clock" embed showing current pick info.
@ -45,7 +47,7 @@ async def create_on_the_clock_embed(
embed = EmbedTemplate.create_base_embed(
title=f"{current_pick.owner.lname} On The Clock",
description=format_pick_display(current_pick.overall),
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
# Add team info
@ -53,26 +55,23 @@ async def create_on_the_clock_embed(
embed.add_field(
name="Team",
value=f"{current_pick.owner.abbrev} {current_pick.owner.sname}",
inline=True
inline=True,
)
# Add timer info
if draft_data.pick_deadline:
deadline_timestamp = int(draft_data.pick_deadline.timestamp())
embed.add_field(
name="Deadline",
value=f"<t:{deadline_timestamp}:R>",
inline=True
name="Deadline", value=f"<t:{deadline_timestamp}:R>", inline=True
)
# Add team sWAR if provided
if team_roster_swar is not None:
from utils.helpers import get_team_salary_cap
cap_limit = get_team_salary_cap(current_pick.owner)
embed.add_field(
name="Current sWAR",
value=f"{team_roster_swar:.2f} / {cap_limit:.2f}",
inline=True
inline=True,
)
# Add recent picks
@ -83,9 +82,7 @@ async def create_on_the_clock_embed(
recent_str += f"**#{pick.overall}** - {pick.player.name}\n"
if recent_str:
embed.add_field(
name="📋 Last 5 Picks",
value=recent_str or "None",
inline=False
name="📋 Last 5 Picks", value=recent_str or "None", inline=False
)
# Add upcoming picks
@ -94,18 +91,12 @@ async def create_on_the_clock_embed(
for pick in upcoming_picks[:5]:
upcoming_str += f"**#{pick.overall}** - {pick.owner.sname if pick.owner else 'Unknown'}\n"
if upcoming_str:
embed.add_field(
name="🔜 Next 5 Picks",
value=upcoming_str,
inline=False
)
embed.add_field(name="🔜 Next 5 Picks", value=upcoming_str, inline=False)
# Draft Sheet link
if sheet_url:
embed.add_field(
name="📊 Draft Sheet",
value=f"[View Full Board]({sheet_url})",
inline=False
name="📊 Draft Sheet", value=f"[View Full Board]({sheet_url})", inline=False
)
# Add footer
@ -119,7 +110,7 @@ async def create_draft_status_embed(
draft_data: DraftData,
current_pick: DraftPick,
lock_status: str = "🔓 No pick in progress",
sheet_url: Optional[str] = None
sheet_url: Optional[str] = None,
) -> discord.Embed:
"""
Create draft status embed showing current state.
@ -137,12 +128,12 @@ async def create_draft_status_embed(
if draft_data.paused:
embed = EmbedTemplate.warning(
title="Draft Status - PAUSED",
description=f"Currently on {format_pick_display(draft_data.currentpick)}"
description=f"Currently on {format_pick_display(draft_data.currentpick)}",
)
else:
embed = EmbedTemplate.info(
title="Draft Status",
description=f"Currently on {format_pick_display(draft_data.currentpick)}"
description=f"Currently on {format_pick_display(draft_data.currentpick)}",
)
# On the clock
@ -150,7 +141,7 @@ async def create_draft_status_embed(
embed.add_field(
name="On the Clock",
value=f"{current_pick.owner.abbrev} {current_pick.owner.sname}",
inline=True
inline=True,
)
# Timer status (show paused state prominently)
@ -163,53 +154,40 @@ async def create_draft_status_embed(
embed.add_field(
name="Timer",
value=f"{timer_status} ({draft_data.pick_minutes} min)",
inline=True
inline=True,
)
# Deadline
if draft_data.pick_deadline:
deadline_timestamp = int(draft_data.pick_deadline.timestamp())
embed.add_field(
name="Deadline",
value=f"<t:{deadline_timestamp}:R>",
inline=True
name="Deadline", value=f"<t:{deadline_timestamp}:R>", inline=True
)
else:
embed.add_field(
name="Deadline",
value="None",
inline=True
)
embed.add_field(name="Deadline", value="None", inline=True)
# Pause status (if paused, show prominent warning)
if draft_data.paused:
embed.add_field(
name="Pause Status",
value="🚫 **Draft is paused** - No picks allowed until admin resumes",
inline=False
inline=False,
)
# Lock status
embed.add_field(
name="Lock Status",
value=lock_status,
inline=False
)
embed.add_field(name="Lock Status", value=lock_status, inline=False)
# Draft Sheet link
if sheet_url:
embed.add_field(
name="Draft Sheet",
value=f"[View Sheet]({sheet_url})",
inline=False
name="Draft Sheet", value=f"[View Sheet]({sheet_url})", inline=False
)
return embed
async def create_player_draft_card(
player: Player,
draft_pick: DraftPick
player: Player, draft_pick: DraftPick
) -> discord.Embed:
"""
Create player draft card embed.
@ -226,41 +204,32 @@ async def create_player_draft_card(
embed = EmbedTemplate.success(
title=f"{player.name} Drafted!",
description=format_pick_display(draft_pick.overall)
description=format_pick_display(draft_pick.overall),
)
# Team info
embed.add_field(
name="Selected By",
value=f"{draft_pick.owner.abbrev} {draft_pick.owner.sname}",
inline=True
inline=True,
)
# Player info
if hasattr(player, 'pos_1') and player.pos_1:
embed.add_field(
name="Position",
value=player.pos_1,
inline=True
)
if hasattr(player, "pos_1") and player.pos_1:
embed.add_field(name="Position", value=player.pos_1, inline=True)
if hasattr(player, 'wara') and player.wara is not None:
embed.add_field(
name="sWAR",
value=f"{player.wara:.2f}",
inline=True
)
if hasattr(player, "wara") and player.wara is not None:
embed.add_field(name="sWAR", value=f"{player.wara:.2f}", inline=True)
# Add player image if available
if hasattr(player, 'image') and player.image:
if hasattr(player, "image") and player.image:
embed.set_thumbnail(url=player.image)
return embed
async def create_draft_list_embed(
team: Team,
draft_list: List[DraftList]
team: Team, draft_list: List[DraftList]
) -> discord.Embed:
"""
Create draft list embed showing team's auto-draft queue.
@ -274,38 +243,40 @@ async def create_draft_list_embed(
"""
embed = EmbedTemplate.info(
title=f"{team.sname} Draft List",
description=f"Auto-draft queue for {team.abbrev}"
description=f"Auto-draft queue for {team.abbrev}",
)
if not draft_list:
embed.add_field(
name="Queue Empty",
value="No players in auto-draft queue",
inline=False
name="Queue Empty", value="No players in auto-draft queue", inline=False
)
else:
# Group players by rank
list_str = ""
for entry in draft_list[:25]: # Limit to 25 for embed size
player_name = entry.player.name if entry.player else f"Player {entry.player_id}"
player_swar = f" ({entry.player.wara:.2f})" if entry.player and hasattr(entry.player, 'wara') else ""
player_name = (
entry.player.name if entry.player else f"Player {entry.player_id}"
)
player_swar = (
f" ({entry.player.wara:.2f})"
if entry.player and hasattr(entry.player, "wara")
else ""
)
list_str += f"**{entry.rank}.** {player_name}{player_swar}\n"
embed.add_field(
name=f"Queue ({len(draft_list)} players)",
value=list_str,
inline=False
name=f"Queue ({len(draft_list)} players)", value=list_str, inline=False
)
embed.set_footer(text="Commands: /draft-list-add, /draft-list-remove, /draft-list-clear")
embed.set_footer(
text="Commands: /draft-list-add, /draft-list-remove, /draft-list-clear"
)
return embed
async def create_draft_board_embed(
round_num: int,
picks: List[DraftPick],
sheet_url: Optional[str] = None
round_num: int, picks: List[DraftPick], sheet_url: Optional[str] = None
) -> discord.Embed:
"""
Create draft board embed showing all picks in a round.
@ -321,14 +292,12 @@ async def create_draft_board_embed(
embed = EmbedTemplate.create_base_embed(
title=f"📋 {get_round_name(round_num)}",
description=f"Draft board for round {round_num}",
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
if not picks:
embed.add_field(
name="No Picks",
value="No picks found for this round",
inline=False
name="No Picks", value="No picks found for this round", inline=False
)
else:
# Create picks display
@ -345,18 +314,12 @@ async def create_draft_board_embed(
pick_info = f"{round_num:>2}.{round_pick:<2} (#{pick.overall:>3})"
picks_str += f"`{pick_info}` {team_display} - {player_display}\n"
embed.add_field(
name="Picks",
value=picks_str,
inline=False
)
embed.add_field(name="Picks", value=picks_str, inline=False)
# Draft Sheet link
if sheet_url:
embed.add_field(
name="Draft Sheet",
value=f"[View Full Board]({sheet_url})",
inline=False
name="Draft Sheet", value=f"[View Full Board]({sheet_url})", inline=False
)
embed.set_footer(text="Use /draft-board [round] to view different rounds")
@ -365,8 +328,7 @@ async def create_draft_board_embed(
async def create_pick_illegal_embed(
reason: str,
details: Optional[str] = None
reason: str, details: Optional[str] = None
) -> discord.Embed:
"""
Create embed for illegal pick attempt.
@ -378,17 +340,10 @@ async def create_pick_illegal_embed(
Returns:
Discord error embed
"""
embed = EmbedTemplate.error(
title="Invalid Pick",
description=reason
)
embed = EmbedTemplate.error(title="Invalid Pick", description=reason)
if details:
embed.add_field(
name="Details",
value=details,
inline=False
)
embed.add_field(name="Details", value=details, inline=False)
return embed
@ -398,7 +353,7 @@ async def create_pick_success_embed(
team: Team,
pick_overall: int,
projected_swar: float,
cap_limit: float | None = None
cap_limit: float | None = None,
) -> discord.Embed:
"""
Create embed for successful pick.
@ -413,30 +368,20 @@ async def create_pick_success_embed(
Returns:
Discord success embed
"""
from utils.helpers import get_team_salary_cap
embed = EmbedTemplate.success(
title=f"{team.sname} select **{player.name}**",
description=format_pick_display(pick_overall)
description=format_pick_display(pick_overall),
)
if team.thumbnail is not None:
embed.set_thumbnail(url=team.thumbnail)
embed.set_image(url=player.image)
embed.add_field(
name="Player ID",
value=f"{player.id}",
inline=True
)
embed.add_field(name="Player ID", value=f"{player.id}", inline=True)
if hasattr(player, 'wara') and player.wara is not None:
embed.add_field(
name="sWAR",
value=f"{player.wara:.2f}",
inline=True
)
if hasattr(player, "wara") and player.wara is not None:
embed.add_field(name="sWAR", value=f"{player.wara:.2f}", inline=True)
# Use provided cap_limit or get from team
if cap_limit is None:
@ -445,7 +390,7 @@ async def create_pick_success_embed(
embed.add_field(
name="Projected Team sWAR",
value=f"{projected_swar:.2f} / {cap_limit:.2f}",
inline=False
inline=False,
)
return embed
@ -454,7 +399,7 @@ async def create_pick_success_embed(
async def create_admin_draft_info_embed(
draft_data: DraftData,
current_pick: Optional[DraftPick] = None,
sheet_url: Optional[str] = None
sheet_url: Optional[str] = None,
) -> discord.Embed:
"""
Create detailed admin view of draft status.
@ -472,21 +417,17 @@ async def create_admin_draft_info_embed(
embed = EmbedTemplate.create_base_embed(
title="⚙️ Draft Administration - PAUSED",
description="Current draft configuration and state",
color=EmbedColors.WARNING
color=EmbedColors.WARNING,
)
else:
embed = EmbedTemplate.create_base_embed(
title="⚙️ Draft Administration",
description="Current draft configuration and state",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
# Current pick
embed.add_field(
name="Current Pick",
value=str(draft_data.currentpick),
inline=True
)
embed.add_field(name="Current Pick", value=str(draft_data.currentpick), inline=True)
# Timer status (show paused prominently)
if draft_data.paused:
@ -500,16 +441,12 @@ async def create_admin_draft_info_embed(
timer_text = "Inactive"
embed.add_field(
name="Timer Status",
value=f"{timer_emoji} {timer_text}",
inline=True
name="Timer Status", value=f"{timer_emoji} {timer_text}", inline=True
)
# Timer duration
embed.add_field(
name="Pick Duration",
value=f"{draft_data.pick_minutes} minutes",
inline=True
name="Pick Duration", value=f"{draft_data.pick_minutes} minutes", inline=True
)
# Pause status (prominent if paused)
@ -517,31 +454,27 @@ async def create_admin_draft_info_embed(
embed.add_field(
name="Pause Status",
value="🚫 **PAUSED** - No picks allowed\nUse `/draft-admin resume` to allow picks",
inline=False
inline=False,
)
# Channels
ping_channel_value = f"<#{draft_data.ping_channel}>" if draft_data.ping_channel else "Not configured"
embed.add_field(
name="Ping Channel",
value=ping_channel_value,
inline=True
ping_channel_value = (
f"<#{draft_data.ping_channel}>" if draft_data.ping_channel else "Not configured"
)
embed.add_field(name="Ping Channel", value=ping_channel_value, inline=True)
result_channel_value = f"<#{draft_data.result_channel}>" if draft_data.result_channel else "Not configured"
embed.add_field(
name="Result Channel",
value=result_channel_value,
inline=True
result_channel_value = (
f"<#{draft_data.result_channel}>"
if draft_data.result_channel
else "Not configured"
)
embed.add_field(name="Result Channel", value=result_channel_value, inline=True)
# Deadline
if draft_data.pick_deadline:
deadline_timestamp = int(draft_data.pick_deadline.timestamp())
embed.add_field(
name="Current Deadline",
value=f"<t:{deadline_timestamp}:F>",
inline=True
name="Current Deadline", value=f"<t:{deadline_timestamp}:F>", inline=True
)
# Current pick owner
@ -549,15 +482,13 @@ async def create_admin_draft_info_embed(
embed.add_field(
name="On The Clock",
value=f"{current_pick.owner.abbrev} {current_pick.owner.sname}",
inline=False
inline=False,
)
# Draft Sheet link
if sheet_url:
embed.add_field(
name="Draft Sheet",
value=f"[View Sheet]({sheet_url})",
inline=False
name="Draft Sheet", value=f"[View Sheet]({sheet_url})", inline=False
)
embed.set_footer(text="Use /draft-admin to modify draft settings")
@ -572,7 +503,7 @@ async def create_on_clock_announcement_embed(
roster_swar: float,
cap_limit: float,
top_roster_players: List[Player],
sheet_url: Optional[str] = None
sheet_url: Optional[str] = None,
) -> discord.Embed:
"""
Create announcement embed for when a team is on the clock.
@ -604,7 +535,7 @@ async def create_on_clock_announcement_embed(
embed = EmbedTemplate.create_base_embed(
title=f"{team.lname} On The Clock",
description=format_pick_display(current_pick.overall),
color=team_color
color=team_color,
)
# Set team thumbnail
@ -617,55 +548,41 @@ async def create_on_clock_announcement_embed(
embed.add_field(
name="⏱️ Deadline",
value=f"<t:{deadline_timestamp}:T> (<t:{deadline_timestamp}:R>)",
inline=True
inline=True,
)
# Team sWAR
embed.add_field(
name="💰 Team sWAR",
value=f"{roster_swar:.2f} / {cap_limit:.2f}",
inline=True
name="💰 Team sWAR", value=f"{roster_swar:.2f} / {cap_limit:.2f}", inline=True
)
# Cap space remaining
cap_remaining = cap_limit - roster_swar
embed.add_field(
name="📊 Cap Space",
value=f"{cap_remaining:.2f}",
inline=True
)
embed.add_field(name="📊 Cap Space", value=f"{cap_remaining:.2f}", inline=True)
# Last 5 picks
if recent_picks:
recent_str = ""
for pick in recent_picks[:5]:
if pick.player and pick.owner:
recent_str += f"**#{pick.overall}** {pick.owner.abbrev} - {pick.player.name}\n"
recent_str += (
f"**#{pick.overall}** {pick.owner.abbrev} - {pick.player.name}\n"
)
if recent_str:
embed.add_field(
name="📋 Last 5 Picks",
value=recent_str,
inline=False
)
embed.add_field(name="📋 Last 5 Picks", value=recent_str, inline=False)
# Top 5 most expensive players on team roster
if top_roster_players:
expensive_str = ""
for player in top_roster_players[:5]:
pos = player.pos_1 if hasattr(player, 'pos_1') and player.pos_1 else "?"
pos = player.pos_1 if hasattr(player, "pos_1") and player.pos_1 else "?"
expensive_str += f"**{player.name}** ({pos}) - {player.wara:.2f}\n"
embed.add_field(
name="🌟 Top Roster sWAR",
value=expensive_str,
inline=False
)
embed.add_field(name="🌟 Top Roster sWAR", value=expensive_str, inline=False)
# Draft Sheet link
if sheet_url:
embed.add_field(
name="📊 Draft Sheet",
value=f"[View Full Board]({sheet_url})",
inline=False
name="📊 Draft Sheet", value=f"[View Full Board]({sheet_url})", inline=False
)
# Footer with pick info

View File

@ -3,6 +3,8 @@ Help Command Views for Discord Bot v2.0
Interactive views and modals for the custom help system.
"""
import re
from typing import Optional, List
import discord
@ -23,7 +25,7 @@ class HelpCommandCreateModal(BaseModal):
placeholder="e.g., trading-rules (2-32 chars, letters/numbers/dashes)",
required=True,
min_length=2,
max_length=32
max_length=32,
)
self.topic_title = discord.ui.TextInput(
@ -31,14 +33,14 @@ class HelpCommandCreateModal(BaseModal):
placeholder="e.g., Trading Rules & Guidelines",
required=True,
min_length=1,
max_length=200
max_length=200,
)
self.topic_category = discord.ui.TextInput(
label="Category (Optional)",
placeholder="e.g., rules, guides, resources, info, faq",
required=False,
max_length=50
max_length=50,
)
self.topic_content = discord.ui.TextInput(
@ -47,7 +49,7 @@ class HelpCommandCreateModal(BaseModal):
style=discord.TextStyle.paragraph,
required=True,
min_length=1,
max_length=4000
max_length=4000,
)
self.add_item(self.topic_name)
@ -57,11 +59,9 @@ class HelpCommandCreateModal(BaseModal):
async def on_submit(self, interaction: discord.Interaction):
"""Handle form submission."""
import re
# Validate topic name format
name = self.topic_name.value.strip().lower()
if not re.match(r'^[a-z0-9_-]+$', name):
if not re.match(r"^[a-z0-9_-]+$", name):
embed = EmbedTemplate.error(
title="Invalid Topic Name",
description=(
@ -69,14 +69,18 @@ class HelpCommandCreateModal(BaseModal):
"**Allowed:** lowercase letters, numbers, dashes, and underscores only.\n"
"**Examples:** `trading-rules`, `how_to_draft`, `faq1`\n\n"
"Please try again with a valid name."
)
),
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
# Validate category format if provided
category = self.topic_category.value.strip().lower() if self.topic_category.value else None
if category and not re.match(r'^[a-z0-9_-]+$', category):
category = (
self.topic_category.value.strip().lower()
if self.topic_category.value
else None
)
if category and not re.match(r"^[a-z0-9_-]+$", category):
embed = EmbedTemplate.error(
title="Invalid Category",
description=(
@ -84,17 +88,17 @@ class HelpCommandCreateModal(BaseModal):
"**Allowed:** lowercase letters, numbers, dashes, and underscores only.\n"
"**Examples:** `rules`, `guides`, `faq`\n\n"
"Please try again with a valid category."
)
),
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
# Store results
self.result = {
'name': name,
'title': self.topic_title.value.strip(),
'content': self.topic_content.value.strip(),
'category': category
"name": name,
"title": self.topic_title.value.strip(),
"content": self.topic_content.value.strip(),
"category": category,
}
self.is_submitted = True
@ -102,36 +106,28 @@ class HelpCommandCreateModal(BaseModal):
# Create preview embed
embed = EmbedTemplate.info(
title="Help Topic Preview",
description="Here's how your help topic will look:"
description="Here's how your help topic will look:",
)
embed.add_field(
name="Name",
value=f"`/help {self.result['name']}`",
inline=True
name="Name", value=f"`/help {self.result['name']}`", inline=True
)
embed.add_field(
name="Category",
value=self.result['category'] or "None",
inline=True
name="Category", value=self.result["category"] or "None", inline=True
)
embed.add_field(
name="Title",
value=self.result['title'],
inline=False
)
embed.add_field(name="Title", value=self.result["title"], inline=False)
# Show content preview (truncated if too long)
content_preview = self.result['content'][:500] + ('...' if len(self.result['content']) > 500 else '')
embed.add_field(
name="Content",
value=content_preview,
inline=False
content_preview = self.result["content"][:500] + (
"..." if len(self.result["content"]) > 500 else ""
)
embed.add_field(name="Content", value=content_preview, inline=False)
embed.set_footer(text="Creating this help topic will make it available to all server members")
embed.set_footer(
text="Creating this help topic will make it available to all server members"
)
await interaction.response.send_message(embed=embed, ephemeral=True)
@ -149,15 +145,15 @@ class HelpCommandEditModal(BaseModal):
default=help_command.title,
required=True,
min_length=1,
max_length=200
max_length=200,
)
self.topic_category = discord.ui.TextInput(
label="Category (Optional)",
placeholder="e.g., rules, guides, resources, info, faq",
default=help_command.category or '',
default=help_command.category or "",
required=False,
max_length=50
max_length=50,
)
self.topic_content = discord.ui.TextInput(
@ -167,7 +163,7 @@ class HelpCommandEditModal(BaseModal):
default=help_command.content,
required=True,
min_length=1,
max_length=4000
max_length=4000,
)
self.add_item(self.topic_title)
@ -178,10 +174,12 @@ class HelpCommandEditModal(BaseModal):
"""Handle form submission."""
# Store results
self.result = {
'name': self.original_help.name,
'title': self.topic_title.value.strip(),
'content': self.topic_content.value.strip(),
'category': self.topic_category.value.strip() if self.topic_category.value else None
"name": self.original_help.name,
"title": self.topic_title.value.strip(),
"content": self.topic_content.value.strip(),
"category": (
self.topic_category.value.strip() if self.topic_category.value else None
),
}
self.is_submitted = True
@ -189,38 +187,36 @@ class HelpCommandEditModal(BaseModal):
# Create preview embed showing changes
embed = EmbedTemplate.info(
title="Help Topic Edit Preview",
description=f"Changes to `/help {self.original_help.name}`:"
description=f"Changes to `/help {self.original_help.name}`:",
)
# Show title changes if different
if self.original_help.title != self.result['title']:
embed.add_field(name="Old Title", value=self.original_help.title, inline=True)
embed.add_field(name="New Title", value=self.result['title'], inline=True)
if self.original_help.title != self.result["title"]:
embed.add_field(
name="Old Title", value=self.original_help.title, inline=True
)
embed.add_field(name="New Title", value=self.result["title"], inline=True)
embed.add_field(name="\u200b", value="\u200b", inline=True) # Spacer
# Show category changes
old_cat = self.original_help.category or "None"
new_cat = self.result['category'] or "None"
new_cat = self.result["category"] or "None"
if old_cat != new_cat:
embed.add_field(name="Old Category", value=old_cat, inline=True)
embed.add_field(name="New Category", value=new_cat, inline=True)
embed.add_field(name="\u200b", value="\u200b", inline=True) # Spacer
# Show content preview (always show since it's the main field)
old_content = self.original_help.content[:300] + ('...' if len(self.original_help.content) > 300 else '')
new_content = self.result['content'][:300] + ('...' if len(self.result['content']) > 300 else '')
embed.add_field(
name="Old Content",
value=old_content,
inline=False
old_content = self.original_help.content[:300] + (
"..." if len(self.original_help.content) > 300 else ""
)
new_content = self.result["content"][:300] + (
"..." if len(self.result["content"]) > 300 else ""
)
embed.add_field(
name="New Content",
value=new_content,
inline=False
)
embed.add_field(name="Old Content", value=old_content, inline=False)
embed.add_field(name="New Content", value=new_content, inline=False)
embed.set_footer(text="Changes will be visible to all server members")
@ -230,48 +226,58 @@ class HelpCommandEditModal(BaseModal):
class HelpCommandDeleteConfirmView(BaseView):
"""Confirmation view for deleting a help topic."""
def __init__(self, help_command: HelpCommand, *, user_id: int, timeout: float = 180.0):
def __init__(
self, help_command: HelpCommand, *, user_id: int, timeout: float = 180.0
):
super().__init__(timeout=timeout, user_id=user_id)
self.help_command = help_command
self.result = None
@discord.ui.button(label="Delete Topic", emoji="🗑️", style=discord.ButtonStyle.danger, row=0)
async def confirm_delete(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="Delete Topic", emoji="🗑️", style=discord.ButtonStyle.danger, row=0
)
async def confirm_delete(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Confirm the topic deletion."""
self.result = True
embed = EmbedTemplate.success(
title="Help Topic Deleted",
description=f"The help topic `/help {self.help_command.name}` has been deleted (soft delete)."
description=f"The help topic `/help {self.help_command.name}` has been deleted (soft delete).",
)
embed.add_field(
name="Note",
value="This topic can be restored later if needed using admin commands.",
inline=False
inline=False,
)
# Disable all buttons
for item in self.children:
if hasattr(item, 'disabled'):
if hasattr(item, "disabled"):
item.disabled = True # type: ignore
await interaction.response.edit_message(embed=embed, view=self)
self.stop()
@discord.ui.button(label="Cancel", emoji="", style=discord.ButtonStyle.secondary, row=0)
async def cancel_delete(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="Cancel", emoji="", style=discord.ButtonStyle.secondary, row=0
)
async def cancel_delete(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Cancel the topic deletion."""
self.result = False
embed = EmbedTemplate.info(
title="Deletion Cancelled",
description=f"The help topic `/help {self.help_command.name}` was not deleted."
description=f"The help topic `/help {self.help_command.name}` was not deleted.",
)
# Disable all buttons
for item in self.children:
if hasattr(item, 'disabled'):
if hasattr(item, "disabled"):
item.disabled = True # type: ignore
await interaction.response.edit_message(embed=embed, view=self)
@ -287,7 +293,7 @@ class HelpCommandListView(BaseView):
user_id: Optional[int] = None,
category_filter: Optional[str] = None,
*,
timeout: float = 300.0
timeout: float = 300.0,
):
super().__init__(timeout=timeout, user_id=user_id)
self.help_commands = help_commands
@ -299,7 +305,11 @@ class HelpCommandListView(BaseView):
def _update_buttons(self):
"""Update button states based on current page."""
total_pages = max(1, (len(self.help_commands) + self.topics_per_page - 1) // self.topics_per_page)
total_pages = max(
1,
(len(self.help_commands) + self.topics_per_page - 1)
// self.topics_per_page,
)
self.previous_page.disabled = self.current_page == 0
self.next_page.disabled = self.current_page >= total_pages - 1
@ -324,16 +334,14 @@ class HelpCommandListView(BaseView):
description = f"Found {len(self.help_commands)} help topic{'s' if len(self.help_commands) != 1 else ''}"
embed = EmbedTemplate.create_base_embed(
title=title,
description=description,
color=EmbedColors.INFO
title=title, description=description, color=EmbedColors.INFO
)
if not current_topics:
embed.add_field(
name="No Topics",
value="No help topics found. Admins can create topics using `/help-create`.",
inline=False
inline=False,
)
else:
# Group by category for better organization
@ -347,13 +355,15 @@ class HelpCommandListView(BaseView):
for category, topics in sorted(by_category.items()):
topic_list = []
for topic in topics:
views_text = f"{topic.view_count} views" if topic.view_count > 0 else ""
topic_list.append(f"• `/help {topic.name}` - {topic.title}{views_text}")
views_text = (
f"{topic.view_count} views" if topic.view_count > 0 else ""
)
topic_list.append(
f"• `/help {topic.name}` - {topic.title}{views_text}"
)
embed.add_field(
name=f"📂 {category}",
value='\n'.join(topic_list),
inline=False
name=f"📂 {category}", value="\n".join(topic_list), inline=False
)
embed.set_footer(text="Use /help <topic-name> to view a specific topic")
@ -361,7 +371,9 @@ class HelpCommandListView(BaseView):
return embed
@discord.ui.button(emoji="◀️", style=discord.ButtonStyle.secondary, row=0)
async def previous_page(self, interaction: discord.Interaction, button: discord.ui.Button):
async def previous_page(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Go to previous page."""
self.current_page = max(0, self.current_page - 1)
self._update_buttons()
@ -369,15 +381,25 @@ class HelpCommandListView(BaseView):
embed = self._create_embed()
await interaction.response.edit_message(embed=embed, view=self)
@discord.ui.button(label="1/1", style=discord.ButtonStyle.secondary, disabled=True, row=0)
async def page_info(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="1/1", style=discord.ButtonStyle.secondary, disabled=True, row=0
)
async def page_info(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Page info (disabled button)."""
pass
@discord.ui.button(emoji="▶️", style=discord.ButtonStyle.secondary, row=0)
async def next_page(self, interaction: discord.Interaction, button: discord.ui.Button):
async def next_page(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Go to next page."""
total_pages = max(1, (len(self.help_commands) + self.topics_per_page - 1) // self.topics_per_page)
total_pages = max(
1,
(len(self.help_commands) + self.topics_per_page - 1)
// self.topics_per_page,
)
self.current_page = min(total_pages - 1, self.current_page + 1)
self._update_buttons()
@ -387,7 +409,7 @@ class HelpCommandListView(BaseView):
async def on_timeout(self):
"""Handle view timeout."""
for item in self.children:
if hasattr(item, 'disabled'):
if hasattr(item, "disabled"):
item.disabled = True # type: ignore
def get_embed(self) -> discord.Embed:
@ -408,7 +430,7 @@ def create_help_topic_embed(help_command: HelpCommand) -> discord.Embed:
embed = EmbedTemplate.create_base_embed(
title=help_command.title,
description=help_command.content,
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
# Add metadata footer

View File

@ -3,58 +3,73 @@ Modal Components for Discord Bot v2.0
Interactive forms and input dialogs for collecting user data.
"""
from typing import Optional, Callable, Awaitable, Dict, Any, List
import math
import re
import discord
from config import get_config
from .embeds import EmbedTemplate
from services.injury_service import injury_service
from services.player_service import player_service
from utils.injury_log import post_injury_and_update_log
from utils.logging import get_contextual_logger
class BaseModal(discord.ui.Modal):
"""Base modal class with consistent error handling and validation."""
def __init__(
self,
*,
title: str,
timeout: Optional[float] = 300.0,
custom_id: Optional[str] = None
custom_id: Optional[str] = None,
):
kwargs = {"title": title, "timeout": timeout}
if custom_id is not None:
kwargs["custom_id"] = custom_id
super().__init__(**kwargs)
self.logger = get_contextual_logger(f'{__name__}.{self.__class__.__name__}')
self.logger = get_contextual_logger(f"{__name__}.{self.__class__.__name__}")
self.result: Optional[Dict[str, Any]] = None
self.is_submitted = False
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None:
async def on_error(
self, interaction: discord.Interaction, error: Exception
) -> None:
"""Handle modal errors."""
self.logger.error("Modal error occurred",
error=error,
modal_title=self.title,
user_id=interaction.user.id)
self.logger.error(
"Modal error occurred",
error=error,
modal_title=self.title,
user_id=interaction.user.id,
)
try:
embed = EmbedTemplate.error(
title="Form Error",
description="An error occurred while processing your form. Please try again."
description="An error occurred while processing your form. Please try again.",
)
if not interaction.response.is_done():
await interaction.response.send_message(embed=embed, ephemeral=True)
else:
await interaction.followup.send(embed=embed, ephemeral=True)
except Exception as e:
self.logger.error("Failed to send error message", error=e)
def validate_input(self, field_name: str, value: str, validators: Optional[List[Callable[[str], bool]]] = None) -> tuple[bool, str]:
def validate_input(
self,
field_name: str,
value: str,
validators: Optional[List[Callable[[str], bool]]] = None,
) -> tuple[bool, str]:
"""Validate input field with optional custom validators."""
if not value.strip():
return False, f"{field_name} cannot be empty."
if validators:
for validator in validators:
try:
@ -62,49 +77,49 @@ class BaseModal(discord.ui.Modal):
return False, f"Invalid {field_name} format."
except Exception:
return False, f"Validation error for {field_name}."
return True, ""
class PlayerSearchModal(BaseModal):
"""Modal for collecting detailed player search criteria."""
def __init__(self, *, timeout: Optional[float] = 300.0):
super().__init__(title="Player Search", timeout=timeout)
self.player_name = discord.ui.TextInput(
label="Player Name",
placeholder="Enter player name (required)",
required=True,
max_length=100
max_length=100,
)
self.position = discord.ui.TextInput(
label="Position",
placeholder="e.g., SS, OF, P (optional)",
required=False,
max_length=10
max_length=10,
)
self.team = discord.ui.TextInput(
label="Team",
placeholder="Team abbreviation (optional)",
required=False,
max_length=5
max_length=5,
)
self.season = discord.ui.TextInput(
label="Season",
placeholder="Season number (optional)",
required=False,
max_length=4
max_length=4,
)
self.add_item(self.player_name)
self.add_item(self.position)
self.add_item(self.team)
self.add_item(self.season)
async def on_submit(self, interaction: discord.Interaction):
"""Handle form submission."""
# Validate season if provided
@ -117,60 +132,62 @@ class PlayerSearchModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Season",
description="Season must be a valid number between 1 and 50."
description="Season must be a valid number between 1 and 50.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
# Store results
self.result = {
'name': self.player_name.value.strip(),
'position': self.position.value.strip() if self.position.value else None,
'team': self.team.value.strip().upper() if self.team.value else None,
'season': season_value
"name": self.player_name.value.strip(),
"position": self.position.value.strip() if self.position.value else None,
"team": self.team.value.strip().upper() if self.team.value else None,
"season": season_value,
}
self.is_submitted = True
# Acknowledge submission
embed = EmbedTemplate.info(
title="Search Submitted",
description=f"Searching for player: **{self.result['name']}**"
description=f"Searching for player: **{self.result['name']}**",
)
if self.result['position']:
embed.add_field(name="Position", value=self.result['position'], inline=True)
if self.result['team']:
embed.add_field(name="Team", value=self.result['team'], inline=True)
if self.result['season']:
embed.add_field(name="Season", value=str(self.result['season']), inline=True)
if self.result["position"]:
embed.add_field(name="Position", value=self.result["position"], inline=True)
if self.result["team"]:
embed.add_field(name="Team", value=self.result["team"], inline=True)
if self.result["season"]:
embed.add_field(
name="Season", value=str(self.result["season"]), inline=True
)
await interaction.response.send_message(embed=embed, ephemeral=True)
class TeamSearchModal(BaseModal):
"""Modal for collecting team search criteria."""
def __init__(self, *, timeout: Optional[float] = 300.0):
super().__init__(title="Team Search", timeout=timeout)
self.team_input = discord.ui.TextInput(
label="Team Name or Abbreviation",
placeholder="Enter team name or abbreviation",
required=True,
max_length=50
max_length=50,
)
self.season = discord.ui.TextInput(
label="Season",
placeholder="Season number (optional)",
required=False,
max_length=4
max_length=4,
)
self.add_item(self.team_input)
self.add_item(self.season)
async def on_submit(self, interaction: discord.Interaction):
"""Handle form submission."""
# Validate season if provided
@ -183,267 +200,267 @@ class TeamSearchModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Season",
description="Season must be a valid number between 1 and 50."
description="Season must be a valid number between 1 and 50.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
# Store results
self.result = {
'team': self.team_input.value.strip(),
'season': season_value
}
self.result = {"team": self.team_input.value.strip(), "season": season_value}
self.is_submitted = True
# Acknowledge submission
embed = EmbedTemplate.info(
title="Search Submitted",
description=f"Searching for team: **{self.result['team']}**"
description=f"Searching for team: **{self.result['team']}**",
)
if self.result['season']:
embed.add_field(name="Season", value=str(self.result['season']), inline=True)
if self.result["season"]:
embed.add_field(
name="Season", value=str(self.result["season"]), inline=True
)
await interaction.response.send_message(embed=embed, ephemeral=True)
class FeedbackModal(BaseModal):
"""Modal for collecting user feedback."""
def __init__(
self,
*,
self,
*,
timeout: Optional[float] = 600.0,
submit_callback: Optional[Callable[[Dict[str, Any]], Awaitable[bool]]] = None
submit_callback: Optional[Callable[[Dict[str, Any]], Awaitable[bool]]] = None,
):
super().__init__(title="Submit Feedback", timeout=timeout)
self.submit_callback = submit_callback
self.feedback_type = discord.ui.TextInput(
label="Feedback Type",
placeholder="e.g., Bug Report, Feature Request, General",
required=True,
max_length=50
max_length=50,
)
self.subject = discord.ui.TextInput(
label="Subject",
placeholder="Brief description of your feedback",
required=True,
max_length=100
max_length=100,
)
self.description = discord.ui.TextInput(
label="Description",
placeholder="Detailed description of your feedback",
style=discord.TextStyle.paragraph,
required=True,
max_length=2000
max_length=2000,
)
self.contact = discord.ui.TextInput(
label="Contact Info (Optional)",
placeholder="How to reach you for follow-up",
required=False,
max_length=100
max_length=100,
)
self.add_item(self.feedback_type)
self.add_item(self.subject)
self.add_item(self.description)
self.add_item(self.contact)
async def on_submit(self, interaction: discord.Interaction):
"""Handle feedback submission."""
# Store results
self.result = {
'type': self.feedback_type.value.strip(),
'subject': self.subject.value.strip(),
'description': self.description.value.strip(),
'contact': self.contact.value.strip() if self.contact.value else None,
'user_id': interaction.user.id,
'username': str(interaction.user),
'submitted_at': discord.utils.utcnow()
"type": self.feedback_type.value.strip(),
"subject": self.subject.value.strip(),
"description": self.description.value.strip(),
"contact": self.contact.value.strip() if self.contact.value else None,
"user_id": interaction.user.id,
"username": str(interaction.user),
"submitted_at": discord.utils.utcnow(),
}
self.is_submitted = True
# Process feedback
if self.submit_callback:
try:
success = await self.submit_callback(self.result)
if success:
embed = EmbedTemplate.success(
title="Feedback Submitted",
description="Thank you for your feedback! We'll review it shortly."
description="Thank you for your feedback! We'll review it shortly.",
)
else:
embed = EmbedTemplate.error(
title="Submission Failed",
description="Failed to submit feedback. Please try again later."
description="Failed to submit feedback. Please try again later.",
)
except Exception as e:
self.logger.error("Feedback submission error", error=e)
embed = EmbedTemplate.error(
title="Submission Error",
description="An error occurred while submitting feedback."
description="An error occurred while submitting feedback.",
)
else:
embed = EmbedTemplate.success(
title="Feedback Received",
description="Your feedback has been recorded."
description="Your feedback has been recorded.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
class ConfigurationModal(BaseModal):
"""Modal for configuration settings with validation."""
def __init__(
self,
current_config: Dict[str, Any],
*,
timeout: Optional[float] = 300.0,
save_callback: Optional[Callable[[Dict[str, Any]], Awaitable[bool]]] = None
save_callback: Optional[Callable[[Dict[str, Any]], Awaitable[bool]]] = None,
):
super().__init__(title="Configuration Settings", timeout=timeout)
self.current_config = current_config
self.save_callback = save_callback
# Add configuration fields (customize based on needs)
self.setting1 = discord.ui.TextInput(
label="Setting 1",
placeholder="Enter value for setting 1",
default=str(current_config.get('setting1', '')),
default=str(current_config.get("setting1", "")),
required=False,
max_length=100
max_length=100,
)
self.setting2 = discord.ui.TextInput(
label="Setting 2",
placeholder="Enter value for setting 2",
default=str(current_config.get('setting2', '')),
default=str(current_config.get("setting2", "")),
required=False,
max_length=100
max_length=100,
)
self.add_item(self.setting1)
self.add_item(self.setting2)
async def on_submit(self, interaction: discord.Interaction):
"""Handle configuration submission."""
# Validate and store new configuration
new_config = self.current_config.copy()
if self.setting1.value:
new_config['setting1'] = self.setting1.value.strip()
new_config["setting1"] = self.setting1.value.strip()
if self.setting2.value:
new_config['setting2'] = self.setting2.value.strip()
new_config["setting2"] = self.setting2.value.strip()
self.result = new_config
self.is_submitted = True
# Save configuration
if self.save_callback:
try:
success = await self.save_callback(new_config)
if success:
embed = EmbedTemplate.success(
title="Configuration Saved",
description="Your configuration has been updated successfully."
description="Your configuration has been updated successfully.",
)
else:
embed = EmbedTemplate.error(
title="Save Failed",
description="Failed to save configuration. Please try again."
description="Failed to save configuration. Please try again.",
)
except Exception as e:
self.logger.error("Configuration save error", error=e)
embed = EmbedTemplate.error(
title="Save Error",
description="An error occurred while saving configuration."
description="An error occurred while saving configuration.",
)
else:
embed = EmbedTemplate.success(
title="Configuration Updated",
description="Configuration has been updated."
description="Configuration has been updated.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
class CustomInputModal(BaseModal):
"""Flexible modal for custom input collection."""
def __init__(
self,
title: str,
fields: List[Dict[str, Any]],
*,
timeout: Optional[float] = 300.0,
submit_callback: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None
submit_callback: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
):
super().__init__(title=title, timeout=timeout)
self.submit_callback = submit_callback
self.fields_config = fields
# Add text inputs based on field configuration
for field in fields[:5]: # Discord limit of 5 text inputs
text_input = discord.ui.TextInput(
label=field.get('label', 'Field'),
placeholder=field.get('placeholder', ''),
default=field.get('default', ''),
required=field.get('required', False),
max_length=field.get('max_length', 4000),
style=getattr(discord.TextStyle, field.get('style', 'short'))
label=field.get("label", "Field"),
placeholder=field.get("placeholder", ""),
default=field.get("default", ""),
required=field.get("required", False),
max_length=field.get("max_length", 4000),
style=getattr(discord.TextStyle, field.get("style", "short")),
)
self.add_item(text_input)
async def on_submit(self, interaction: discord.Interaction):
"""Handle custom form submission."""
# Collect all input values
results = {}
for i, item in enumerate(self.children):
if isinstance(item, discord.ui.TextInput):
field_config = self.fields_config[i] if i < len(self.fields_config) else {}
field_key = field_config.get('key', f'field_{i}')
field_config = (
self.fields_config[i] if i < len(self.fields_config) else {}
)
field_key = field_config.get("key", f"field_{i}")
# Apply validation if specified
validators = field_config.get('validators', [])
validators = field_config.get("validators", [])
if validators:
is_valid, error_msg = self.validate_input(
field_config.get('label', 'Field'),
item.value,
validators
field_config.get("label", "Field"), item.value, validators
)
if not is_valid:
embed = EmbedTemplate.error(
title="Validation Error",
description=error_msg
title="Validation Error", description=error_msg
)
await interaction.response.send_message(
embed=embed, ephemeral=True
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
results[field_key] = item.value.strip() if item.value else None
self.result = results
self.is_submitted = True
# Execute callback if provided
if self.submit_callback:
await self.submit_callback(results)
else:
embed = EmbedTemplate.success(
title="Form Submitted",
description="Your form has been submitted successfully."
description="Your form has been submitted successfully.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
@ -451,7 +468,7 @@ class CustomInputModal(BaseModal):
# Validation helper functions
def validate_email(email: str) -> bool:
"""Validate email format."""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, email))
@ -492,11 +509,11 @@ class BatterInjuryModal(BaseModal):
def __init__(
self,
player: 'Player',
player: "Player",
injury_games: int,
season: int,
*,
timeout: Optional[float] = 300.0
timeout: Optional[float] = 300.0,
):
"""
Initialize batter injury modal.
@ -519,7 +536,7 @@ class BatterInjuryModal(BaseModal):
placeholder="Enter current week number (e.g., 5)",
required=True,
max_length=2,
style=discord.TextStyle.short
style=discord.TextStyle.short,
)
# Current game input
@ -528,7 +545,7 @@ class BatterInjuryModal(BaseModal):
placeholder="Enter current game number (1-4)",
required=True,
max_length=1,
style=discord.TextStyle.short
style=discord.TextStyle.short,
)
self.add_item(self.current_week)
@ -536,11 +553,6 @@ class BatterInjuryModal(BaseModal):
async def on_submit(self, interaction: discord.Interaction):
"""Handle batter injury input and log injury."""
from services.player_service import player_service
from services.injury_service import injury_service
from config import get_config
import math
config = get_config()
max_week = config.weeks_per_season + config.playoff_weeks_per_season
@ -552,7 +564,7 @@ class BatterInjuryModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Week",
description=f"Current week must be a number between 1 and {max_week} (including playoffs)."
description=f"Current week must be a number between 1 and {max_week} (including playoffs).",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -577,7 +589,7 @@ class BatterInjuryModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Game",
description=f"Current game must be a number between 1 and {max_game}."
description=f"Current game must be a number between 1 and {max_game}.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -597,7 +609,7 @@ class BatterInjuryModal(BaseModal):
start_week = week if game != config.games_per_week else week + 1
start_game = game + 1 if game != config.games_per_week else 1
return_date = f'w{return_week:02d}g{return_game}'
return_date = f"w{return_week:02d}g{return_game}"
# Create injury record
try:
@ -608,70 +620,69 @@ class BatterInjuryModal(BaseModal):
start_week=start_week,
start_game=start_game,
end_week=return_week,
end_game=return_game
end_game=return_game,
)
if not injury:
raise ValueError("Failed to create injury record")
# Update player's il_return field
await player_service.update_player(self.player.id, {'il_return': return_date})
await player_service.update_player(
self.player.id, {"il_return": return_date}
)
# Success response
embed = EmbedTemplate.success(
title="Injury Logged",
description=f"{self.player.name}'s injury has been logged."
description=f"{self.player.name}'s injury has been logged.",
)
embed.add_field(
name="Duration",
value=f"{self.injury_games} game{'s' if self.injury_games > 1 else ''}",
inline=True
inline=True,
)
embed.add_field(
name="Return Date",
value=return_date,
inline=True
)
embed.add_field(name="Return Date", value=return_date, inline=True)
if self.player.team:
embed.add_field(
name="Team",
value=f"{self.player.team.lname} ({self.player.team.abbrev})",
inline=False
inline=False,
)
self.is_submitted = True
self.result = {
'injury_id': injury.id,
'total_games': self.injury_games,
'return_date': return_date
"injury_id": injury.id,
"total_games": self.injury_games,
"return_date": return_date,
}
await interaction.response.send_message(embed=embed)
# Post injury news and update injury log channel
try:
from utils.injury_log import post_injury_and_update_log
await post_injury_and_update_log(
bot=interaction.client,
player=self.player,
injury_games=self.injury_games,
return_date=return_date,
season=self.season
season=self.season,
)
except Exception as log_error:
self.logger.warning(
f"Failed to post injury to channels (injury was still logged): {log_error}",
player_id=self.player.id
player_id=self.player.id,
)
except Exception as e:
self.logger.error("Failed to create batter injury", error=e, player_id=self.player.id)
self.logger.error(
"Failed to create batter injury", error=e, player_id=self.player.id
)
embed = EmbedTemplate.error(
title="Error",
description="Failed to log the injury. Please try again or contact an administrator."
description="Failed to log the injury. Please try again or contact an administrator.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
@ -681,11 +692,11 @@ class PitcherRestModal(BaseModal):
def __init__(
self,
player: 'Player',
player: "Player",
injury_games: int,
season: int,
*,
timeout: Optional[float] = 300.0
timeout: Optional[float] = 300.0,
):
"""
Initialize pitcher rest modal.
@ -708,7 +719,7 @@ class PitcherRestModal(BaseModal):
placeholder="Enter current week number (e.g., 5)",
required=True,
max_length=2,
style=discord.TextStyle.short
style=discord.TextStyle.short,
)
# Current game input
@ -717,7 +728,7 @@ class PitcherRestModal(BaseModal):
placeholder="Enter current game number (1-4)",
required=True,
max_length=1,
style=discord.TextStyle.short
style=discord.TextStyle.short,
)
# Rest games input
@ -726,7 +737,7 @@ class PitcherRestModal(BaseModal):
placeholder="Enter number of rest games (0 or more)",
required=True,
max_length=2,
style=discord.TextStyle.short
style=discord.TextStyle.short,
)
self.add_item(self.current_week)
@ -735,11 +746,6 @@ class PitcherRestModal(BaseModal):
async def on_submit(self, interaction: discord.Interaction):
"""Handle pitcher rest input and log injury."""
from services.player_service import player_service
from services.injury_service import injury_service
from config import get_config
import math
config = get_config()
max_week = config.weeks_per_season + config.playoff_weeks_per_season
@ -751,7 +757,7 @@ class PitcherRestModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Week",
description=f"Current week must be a number between 1 and {max_week} (including playoffs)."
description=f"Current week must be a number between 1 and {max_week} (including playoffs).",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -776,7 +782,7 @@ class PitcherRestModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Game",
description=f"Current game must be a number between 1 and {max_game}."
description=f"Current game must be a number between 1 and {max_game}.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -789,7 +795,7 @@ class PitcherRestModal(BaseModal):
except ValueError:
embed = EmbedTemplate.error(
title="Invalid Rest Games",
description="Rest games must be a non-negative number."
description="Rest games must be a non-negative number.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -812,7 +818,7 @@ class PitcherRestModal(BaseModal):
start_week = week if game != 4 else week + 1
start_game = game + 1 if game != 4 else 1
return_date = f'w{return_week:02d}g{return_game}'
return_date = f"w{return_week:02d}g{return_game}"
# Create injury record
try:
@ -823,81 +829,80 @@ class PitcherRestModal(BaseModal):
start_week=start_week,
start_game=start_game,
end_week=return_week,
end_game=return_game
end_game=return_game,
)
if not injury:
raise ValueError("Failed to create injury record")
# Update player's il_return field
await player_service.update_player(self.player.id, {'il_return': return_date})
await player_service.update_player(
self.player.id, {"il_return": return_date}
)
# Success response
embed = EmbedTemplate.success(
title="Injury Logged",
description=f"{self.player.name}'s injury has been logged."
description=f"{self.player.name}'s injury has been logged.",
)
embed.add_field(
name="Base Injury",
value=f"{self.injury_games} game{'s' if self.injury_games > 1 else ''}",
inline=True
inline=True,
)
embed.add_field(
name="Rest Requirement",
value=f"{rest} game{'s' if rest > 1 else ''}",
inline=True
inline=True,
)
embed.add_field(
name="Total Duration",
value=f"{total_injury_games} game{'s' if total_injury_games > 1 else ''}",
inline=True
inline=True,
)
embed.add_field(
name="Return Date",
value=return_date,
inline=True
)
embed.add_field(name="Return Date", value=return_date, inline=True)
if self.player.team:
embed.add_field(
name="Team",
value=f"{self.player.team.lname} ({self.player.team.abbrev})",
inline=False
inline=False,
)
self.is_submitted = True
self.result = {
'injury_id': injury.id,
'total_games': total_injury_games,
'return_date': return_date
"injury_id": injury.id,
"total_games": total_injury_games,
"return_date": return_date,
}
await interaction.response.send_message(embed=embed)
# Post injury news and update injury log channel
try:
from utils.injury_log import post_injury_and_update_log
await post_injury_and_update_log(
bot=interaction.client,
player=self.player,
injury_games=total_injury_games,
return_date=return_date,
season=self.season
season=self.season,
)
except Exception as log_error:
self.logger.warning(
f"Failed to post injury to channels (injury was still logged): {log_error}",
player_id=self.player.id
player_id=self.player.id,
)
except Exception as e:
self.logger.error("Failed to create pitcher injury", error=e, player_id=self.player.id)
self.logger.error(
"Failed to create pitcher injury", error=e, player_id=self.player.id
)
embed = EmbedTemplate.error(
title="Error",
description="Failed to log the injury. Please try again or contact an administrator."
description="Failed to log the injury. Please try again or contact an administrator.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
await interaction.response.send_message(embed=embed, ephemeral=True)

View File

@ -8,9 +8,16 @@ import discord
from typing import Optional, List
from datetime import datetime, timezone
from services.trade_builder import TradeBuilder
from services.trade_builder import TradeBuilder, clear_trade_builder_by_team
from services.team_service import team_service
from services.league_service import league_service
from services.transaction_service import transaction_service
from models.team import Team, RosterType
from models.trade import TradeStatus
from models.transaction import Transaction
from views.embeds import EmbedColors, EmbedTemplate
from utils.transaction_logging import post_trade_to_log
from config import get_config
class TradeEmbedView(discord.ui.View):
@ -276,8 +283,6 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
await interaction.response.defer(ephemeral=True)
try:
from models.trade import TradeStatus
self.builder.trade.status = TradeStatus.PROPOSED
acceptance_embed = await create_trade_acceptance_embed(self.builder)
@ -326,9 +331,6 @@ class TradeAcceptanceView(discord.ui.View):
async def _get_user_team(self, interaction: discord.Interaction) -> Optional[Team]:
"""Get the team owned by the interacting user."""
from services.team_service import team_service
from config import get_config
config = get_config()
return await team_service.get_team_by_owner(
interaction.user.id, config.sba_season
@ -425,14 +427,6 @@ class TradeAcceptanceView(discord.ui.View):
async def _finalize_trade(self, interaction: discord.Interaction) -> None:
"""Finalize the trade - create transactions and complete."""
from services.league_service import league_service
from services.transaction_service import transaction_service
from services.trade_builder import clear_trade_builder_by_team
from models.transaction import Transaction
from models.trade import TradeStatus
from utils.transaction_logging import post_trade_to_log
from config import get_config
try:
await interaction.response.defer()