Merge pull request 'Merge next-release into main' (#111) from next-release into main
All checks were successful
Build Docker Image / build (push) Successful in 1m2s

Reviewed-on: #111
This commit is contained in:
cal 2026-03-20 17:54:28 +00:00
commit fd24a41422
26 changed files with 1324 additions and 628 deletions

View File

@ -1,9 +1,9 @@
# Gitea Actions: Docker Build, Push, and Notify
#
# CI/CD pipeline for Major Domo Discord Bot:
# - Builds Docker images on every push/PR
# - Builds Docker images on merge to main/next-release
# - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges
# - Supports multi-channel releases: stable (main), rc (next-release), dev (PRs)
# - Supports multi-channel releases: stable (main), rc (next-release)
# - Pushes to Docker Hub and creates git tag on main
# - Sends Discord notifications on success/failure
@ -14,9 +14,6 @@ on:
branches:
- main
- next-release
pull_request:
branches:
- main
jobs:
build:

1
.gitignore vendored
View File

@ -218,5 +218,6 @@ __marimo__/
# Project-specific
data/
storage/
production_logs/
*.json

25
bot.py
View File

@ -42,7 +42,9 @@ def setup_logging():
# JSON file handler - structured logging for monitoring and analysis
json_handler = RotatingFileHandler(
"logs/discord_bot_v2.json", maxBytes=5 * 1024 * 1024, backupCount=5 # 5MB
"logs/discord_bot_v2.json",
maxBytes=5 * 1024 * 1024,
backupCount=5, # 5MB
)
json_handler.setFormatter(JSONFormatter())
logger.addHandler(json_handler)
@ -120,28 +122,11 @@ class SBABot(commands.Bot):
self.maintenance_mode: bool = False
self.logger = logging.getLogger("discord_bot_v2")
self.maintenance_mode: bool = False
async def setup_hook(self):
"""Called when the bot is starting up."""
self.logger.info("Setting up bot...")
@self.tree.interaction_check
async def maintenance_check(interaction: discord.Interaction) -> bool:
"""Block non-admin users when maintenance mode is enabled."""
if not self.maintenance_mode:
return True
if (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
):
return True
await interaction.response.send_message(
"🔧 The bot is currently in maintenance mode. Please try again later.",
ephemeral=True,
)
return False
# Load command packages
await self._load_command_packages()
@ -443,7 +428,9 @@ async def health_command(interaction: discord.Interaction):
embed.add_field(name="Bot Status", value="✅ Online", inline=True)
embed.add_field(name="API Status", value=api_status, inline=True)
embed.add_field(name="Guilds", value=str(guild_count), inline=True)
embed.add_field(name="Latency", value=f"{bot.latency*1000:.1f}ms", inline=True)
embed.add_field(
name="Latency", value=f"{bot.latency * 1000:.1f}ms", inline=True
)
if bot.user:
embed.set_footer(

View File

@ -568,14 +568,9 @@ class AdminCommands(commands.Cog):
return
try:
# Clear all messages from the channel
deleted_count = 0
async for message in live_scores_channel.history(limit=100):
try:
await message.delete()
deleted_count += 1
except discord.NotFound:
pass # Message already deleted
# Clear all messages from the channel using bulk delete
deleted_messages = await live_scores_channel.purge(limit=100)
deleted_count = len(deleted_messages)
self.logger.info(f"Cleared {deleted_count} messages from #live-sba-scores")

View File

@ -4,6 +4,7 @@ Scorebug Commands
Implements commands for publishing and displaying live game scorebugs from Google Sheets scorecards.
"""
import asyncio
import discord
from discord.ext import commands
from discord import app_commands
@ -73,12 +74,18 @@ class ScorebugCommands(commands.Cog):
return
# Get team data for display
away_team = None
home_team = None
if scorebug_data.away_team_id:
away_team = await team_service.get_team(scorebug_data.away_team_id)
if scorebug_data.home_team_id:
home_team = await team_service.get_team(scorebug_data.home_team_id)
away_team, home_team = await asyncio.gather(
(
team_service.get_team(scorebug_data.away_team_id)
if scorebug_data.away_team_id
else asyncio.sleep(0)
),
(
team_service.get_team(scorebug_data.home_team_id)
if scorebug_data.home_team_id
else asyncio.sleep(0)
),
)
# Format scorecard link
away_abbrev = away_team.abbrev if away_team else "AWAY"
@ -86,7 +93,7 @@ class ScorebugCommands(commands.Cog):
scorecard_link = f"[{away_abbrev} @ {home_abbrev}]({url})"
# Store the scorecard in the tracker
self.scorecard_tracker.publish_scorecard(
await self.scorecard_tracker.publish_scorecard(
text_channel_id=interaction.channel_id, # type: ignore
sheet_url=url,
publisher_id=interaction.user.id,
@ -157,7 +164,7 @@ class ScorebugCommands(commands.Cog):
await interaction.response.defer(ephemeral=True)
# Check if a scorecard is published in this channel
sheet_url = self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
sheet_url = await self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
if not sheet_url:
embed = EmbedTemplate.error(
@ -179,12 +186,18 @@ class ScorebugCommands(commands.Cog):
)
# Get team data
away_team = None
home_team = None
if scorebug_data.away_team_id:
away_team = await team_service.get_team(scorebug_data.away_team_id)
if scorebug_data.home_team_id:
home_team = await team_service.get_team(scorebug_data.home_team_id)
away_team, home_team = await asyncio.gather(
(
team_service.get_team(scorebug_data.away_team_id)
if scorebug_data.away_team_id
else asyncio.sleep(0)
),
(
team_service.get_team(scorebug_data.home_team_id)
if scorebug_data.home_team_id
else asyncio.sleep(0)
),
)
# Create scorebug embed using shared utility
embed = create_scorebug_embed(
@ -194,7 +207,7 @@ class ScorebugCommands(commands.Cog):
await interaction.edit_original_response(content=None, embed=embed)
# Update timestamp in tracker
self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
await self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
except SheetsException as e:
embed = EmbedTemplate.error(

View File

@ -24,7 +24,7 @@ class ScorecardTracker:
- Timestamp tracking for monitoring
"""
def __init__(self, data_file: str = "data/scorecards.json"):
def __init__(self, data_file: str = "storage/scorecards.json"):
"""
Initialize the scorecard tracker.

View File

@ -11,6 +11,7 @@ The injury rating format (#p##) encodes both games played and rating:
- Remaining: Injury rating (p70, p65, p60, p50, p40, p30, p20)
"""
import asyncio
import math
import random
import discord
@ -114,16 +115,14 @@ class InjuryGroup(app_commands.Group):
"""Roll for injury using 3d6 dice and injury tables."""
await interaction.response.defer()
# Get current season
current = await league_service.get_current_state()
# Get current season and search for player in parallel
current, players = await asyncio.gather(
league_service.get_current_state(),
player_service.search_players(player_name, limit=10),
)
if not current:
raise BotException("Failed to get current season information")
# Search for player using the search endpoint (more reliable than name param)
players = await player_service.search_players(
player_name, limit=10, season=current.season
)
if not players:
embed = EmbedTemplate.error(
title="Player Not Found",
@ -530,16 +529,14 @@ class InjuryGroup(app_commands.Group):
await interaction.followup.send(embed=embed, ephemeral=True)
return
# Get current season
current = await league_service.get_current_state()
# Get current season and search for player in parallel
current, players = await asyncio.gather(
league_service.get_current_state(),
player_service.search_players(player_name, limit=10),
)
if not current:
raise BotException("Failed to get current season information")
# Search for player using the search endpoint (more reliable than name param)
players = await player_service.search_players(
player_name, limit=10, season=current.season
)
if not players:
embed = EmbedTemplate.error(
title="Player Not Found",
@ -717,16 +714,14 @@ class InjuryGroup(app_commands.Group):
await interaction.response.defer()
# Get current season
current = await league_service.get_current_state()
# Get current season and search for player in parallel
current, players = await asyncio.gather(
league_service.get_current_state(),
player_service.search_players(player_name, limit=10),
)
if not current:
raise BotException("Failed to get current season information")
# Search for player using the search endpoint (more reliable than name param)
players = await player_service.search_players(
player_name, limit=10, season=current.season
)
if not players:
embed = EmbedTemplate.error(
title="Player Not Found",

View File

@ -3,6 +3,7 @@ League Schedule Commands
Implements slash commands for displaying game schedules and results.
"""
from typing import Optional
import asyncio
@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate
class ScheduleCommands(commands.Cog):
"""League schedule command handlers."""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands')
@discord.app_commands.command(
name="schedule",
description="Display game schedule"
)
self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands")
@discord.app_commands.command(name="schedule", description="Display game schedule")
@discord.app_commands.describe(
season="Season to show schedule for (defaults to current season)",
week="Week number to show (optional)",
team="Team abbreviation to filter by (optional)"
team="Team abbreviation to filter by (optional)",
)
@requires_team()
@logged_command("/schedule")
@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog):
interaction: discord.Interaction,
season: Optional[int] = None,
week: Optional[int] = None,
team: Optional[str] = None
team: Optional[str] = None,
):
"""Display game schedule for a week or team."""
await interaction.response.defer()
search_season = season or get_config().sba_season
if team:
# Show team schedule
await self._show_team_schedule(interaction, search_season, team, week)
@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog):
else:
# Show recent/upcoming games
await self._show_current_schedule(interaction, search_season)
# @discord.app_commands.command(
# name="results",
# description="Display recent game results"
@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog):
# ):
# """Display recent game results."""
# await interaction.response.defer()
# search_season = season or get_config().sba_season
# if week:
# # Show specific week results
# games = await schedule_service.get_week_schedule(search_season, week)
# completed_games = [game for game in games if game.is_completed]
# if not completed_games:
# await interaction.followup.send(
# f"❌ No completed games found for season {search_season}, week {week}.",
# ephemeral=True
# )
# return
# embed = await self._create_week_results_embed(completed_games, search_season, week)
# await interaction.followup.send(embed=embed)
# else:
# # Show recent results
# recent_games = await schedule_service.get_recent_games(search_season)
# if not recent_games:
# await interaction.followup.send(
# f"❌ No recent games found for season {search_season}.",
# ephemeral=True
# )
# return
# embed = await self._create_recent_results_embed(recent_games, search_season)
# await interaction.followup.send(embed=embed)
async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int):
async def _show_week_schedule(
self, interaction: discord.Interaction, season: int, week: int
):
"""Show schedule for a specific week."""
self.logger.debug("Fetching week schedule", season=season, week=week)
games = await schedule_service.get_week_schedule(season, week)
if not games:
await interaction.followup.send(
f"❌ No games found for season {season}, week {week}.",
ephemeral=True
f"❌ No games found for season {season}, week {week}.", ephemeral=True
)
return
embed = await self._create_week_schedule_embed(games, season, week)
await interaction.followup.send(embed=embed)
async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]):
async def _show_team_schedule(
self,
interaction: discord.Interaction,
season: int,
team: str,
week: Optional[int],
):
"""Show schedule for a specific team."""
self.logger.debug("Fetching team schedule", season=season, team=team, week=week)
if week:
# Show team games for specific week
week_games = await schedule_service.get_week_schedule(season, week)
team_games = [
game for game in week_games
if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper()
game
for game in week_games
if game.away_team.abbrev.upper() == team.upper()
or game.home_team.abbrev.upper() == team.upper()
]
else:
# Show team's recent/upcoming games (limited weeks)
team_games = await schedule_service.get_team_schedule(season, team, weeks=4)
if not team_games:
week_text = f" for week {week}" if week else ""
await interaction.followup.send(
f"❌ No games found for team '{team}'{week_text} in season {season}.",
ephemeral=True
ephemeral=True,
)
return
embed = await self._create_team_schedule_embed(team_games, season, team, week)
await interaction.followup.send(embed=embed)
async def _show_current_schedule(self, interaction: discord.Interaction, season: int):
async def _show_current_schedule(
self, interaction: discord.Interaction, season: int
):
"""Show current schedule overview with recent and upcoming games."""
self.logger.debug("Fetching current schedule overview", season=season)
# Get both recent and upcoming games
recent_games, upcoming_games = await asyncio.gather(
schedule_service.get_recent_games(season, weeks_back=1),
schedule_service.get_upcoming_games(season, weeks_ahead=1)
schedule_service.get_upcoming_games(season),
)
if not recent_games and not upcoming_games:
await interaction.followup.send(
f"❌ No recent or upcoming games found for season {season}.",
ephemeral=True
ephemeral=True,
)
return
embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season)
embed = await self._create_current_schedule_embed(
recent_games, upcoming_games, season
)
await interaction.followup.send(embed=embed)
async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed:
async def _create_week_schedule_embed(
self, games, season: int, week: int
) -> discord.Embed:
"""Create an embed for a week's schedule."""
embed = EmbedTemplate.create_base_embed(
title=f"📅 Week {week} Schedule - Season {season}",
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
# Group games by series
series_games = schedule_service.group_games_by_series(games)
schedule_lines = []
for (team1, team2), series in series_games.items():
series_summary = await self._format_series_summary(series)
schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}")
if schedule_lines:
embed.add_field(
name="Games",
value="\n\n".join(schedule_lines),
inline=False
name="Games", value="\n\n".join(schedule_lines), inline=False
)
# Add week summary
completed = len([g for g in games if g.is_completed])
total = len(games)
embed.add_field(
name="Week Progress",
value=f"{completed}/{total} games completed",
inline=True
inline=True,
)
embed.set_footer(text=f"Season {season} • Week {week}")
return embed
async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed:
async def _create_team_schedule_embed(
self, games, season: int, team: str, week: Optional[int]
) -> discord.Embed:
"""Create an embed for a team's schedule."""
week_text = f" - Week {week}" if week else ""
embed = EmbedTemplate.create_base_embed(
title=f"📅 {team.upper()} Schedule{week_text} - Season {season}",
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
# Separate completed and upcoming games
completed_games = [g for g in games if g.is_completed]
upcoming_games = [g for g in games if not g.is_completed]
if completed_games:
recent_lines = []
for game in completed_games[-5:]: # Last 5 games
result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L"
result = (
"W"
if game.winner and game.winner.abbrev.upper() == team.upper()
else "L"
)
if game.home_team.abbrev.upper() == team.upper():
# Team was home
recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})")
recent_lines.append(
f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})"
)
else:
# Team was away
recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})")
# Team was away
recent_lines.append(
f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})"
)
embed.add_field(
name="Recent Results",
value="\n".join(recent_lines) if recent_lines else "No recent games",
inline=False
inline=False,
)
if upcoming_games:
upcoming_lines = []
for game in upcoming_games[:5]: # Next 5 games
if game.home_team.abbrev.upper() == team.upper():
# Team is home
upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}")
upcoming_lines.append(
f"Week {game.week}: vs {game.away_team.abbrev}"
)
else:
# Team is away
upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}")
upcoming_lines.append(
f"Week {game.week}: @ {game.home_team.abbrev}"
)
embed.add_field(
name="Upcoming Games",
value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games",
inline=False
value=(
"\n".join(upcoming_lines) if upcoming_lines else "No upcoming games"
),
inline=False,
)
embed.set_footer(text=f"Season {season}{team.upper()}")
return embed
async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed:
async def _create_week_results_embed(
self, games, season: int, week: int
) -> discord.Embed:
"""Create an embed for week results."""
embed = EmbedTemplate.create_base_embed(
title=f"🏆 Week {week} Results - Season {season}",
color=EmbedColors.SUCCESS
title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS
)
# Group by series and show results
series_games = schedule_service.group_games_by_series(games)
results_lines = []
for (team1, team2), series in series_games.items():
# Count wins for each team
team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1])
team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2])
team1_wins = len(
[g for g in series if g.winner and g.winner.abbrev == team1]
)
team2_wins = len(
[g for g in series if g.winner and g.winner.abbrev == team2]
)
# Series result
series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**"
# Individual games
game_details = []
for game in series:
if game.series_game_display:
game_details.append(f"{game.series_game_display}: {game.matchup_display}")
game_details.append(
f"{game.series_game_display}: {game.matchup_display}"
)
results_lines.append(f"{series_result}\n" + "\n".join(game_details))
if results_lines:
embed.add_field(
name="Series Results",
value="\n\n".join(results_lines),
inline=False
name="Series Results", value="\n\n".join(results_lines), inline=False
)
embed.set_footer(text=f"Season {season} • Week {week}{len(games)} games completed")
embed.set_footer(
text=f"Season {season} • Week {week}{len(games)} games completed"
)
return embed
async def _create_recent_results_embed(self, games, season: int) -> discord.Embed:
"""Create an embed for recent results."""
embed = EmbedTemplate.create_base_embed(
title=f"🏆 Recent Results - Season {season}",
color=EmbedColors.SUCCESS
title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS
)
# Show most recent games
recent_lines = []
for game in games[:10]: # Show last 10 games
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
if recent_lines:
embed.add_field(
name="Latest Games",
value="\n".join(recent_lines),
inline=False
name="Latest Games", value="\n".join(recent_lines), inline=False
)
embed.set_footer(text=f"Season {season} • Last {len(games)} completed games")
return embed
async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed:
async def _create_current_schedule_embed(
self, recent_games, upcoming_games, season: int
) -> discord.Embed:
"""Create an embed for current schedule overview."""
embed = EmbedTemplate.create_base_embed(
title=f"📅 Current Schedule - Season {season}",
color=EmbedColors.INFO
title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO
)
if recent_games:
recent_lines = []
for game in recent_games[:5]:
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
embed.add_field(
name="Recent Results",
value="\n".join(recent_lines),
inline=False
name="Recent Results", value="\n".join(recent_lines), inline=False
)
if upcoming_games:
upcoming_lines = []
for game in upcoming_games[:5]:
upcoming_lines.append(f"Week {game.week}: {game.matchup_display}")
embed.add_field(
name="Upcoming Games",
value="\n".join(upcoming_lines),
inline=False
name="Upcoming Games", value="\n".join(upcoming_lines), inline=False
)
embed.set_footer(text=f"Season {season}")
return embed
async def _format_series_summary(self, series) -> str:
"""Format a series summary."""
lines = []
for game in series:
game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display
game_display = (
f"{game.series_game_display}: {game.matchup_display}"
if game.series_game_display
else game.matchup_display
)
lines.append(game_display)
return "\n".join(lines) if lines else "No games"
async def setup(bot: commands.Bot):
"""Load the schedule commands cog."""
await bot.add_cog(ScheduleCommands(bot))
await bot.add_cog(ScheduleCommands(bot))

View File

@ -5,6 +5,7 @@ Implements the /submit-scorecard command for submitting Google Sheets
scorecards with play-by-play data, pitching decisions, and game results.
"""
import asyncio
from typing import Optional, List
import discord
@ -107,11 +108,13 @@ class SubmitScorecardCommands(commands.Cog):
content="🔍 Looking up teams and managers..."
)
away_team = await team_service.get_team_by_abbrev(
setup_data["away_team_abbrev"], current.season
)
home_team = await team_service.get_team_by_abbrev(
setup_data["home_team_abbrev"], current.season
away_team, home_team = await asyncio.gather(
team_service.get_team_by_abbrev(
setup_data["away_team_abbrev"], current.season
),
team_service.get_team_by_abbrev(
setup_data["home_team_abbrev"], current.season
),
)
if not away_team or not home_team:
@ -235,9 +238,13 @@ class SubmitScorecardCommands(commands.Cog):
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
(
wp,
lp,
sv,
holders,
_blown_saves,
) = await decision_service.find_winning_losing_pitchers(decisions_data)
if wp is None or lp is None:
await interaction.edit_original_response(

View File

@ -3,13 +3,14 @@ Soak Tracker
Provides persistent tracking of "soak" mentions using JSON file storage.
"""
import json
import logging
from datetime import datetime, timedelta, UTC
from pathlib import Path
from typing import Dict, List, Optional, Any
logger = logging.getLogger(f'{__name__}.SoakTracker')
logger = logging.getLogger(f"{__name__}.SoakTracker")
class SoakTracker:
@ -22,7 +23,7 @@ class SoakTracker:
- Time-based calculations for disappointment tiers
"""
def __init__(self, data_file: str = "data/soak_data.json"):
def __init__(self, data_file: str = "storage/soak_data.json"):
"""
Initialize the soak tracker.
@ -38,28 +39,22 @@ class SoakTracker:
"""Load soak data from JSON file."""
try:
if self.data_file.exists():
with open(self.data_file, 'r') as f:
with open(self.data_file, "r") as f:
self._data = json.load(f)
logger.debug(f"Loaded soak data: {self._data.get('total_count', 0)} total soaks")
logger.debug(
f"Loaded soak data: {self._data.get('total_count', 0)} total soaks"
)
else:
self._data = {
"last_soak": None,
"total_count": 0,
"history": []
}
self._data = {"last_soak": None, "total_count": 0, "history": []}
logger.info("No existing soak data found, starting fresh")
except Exception as e:
logger.error(f"Failed to load soak data: {e}")
self._data = {
"last_soak": None,
"total_count": 0,
"history": []
}
self._data = {"last_soak": None, "total_count": 0, "history": []}
def save_data(self) -> None:
"""Save soak data to JSON file."""
try:
with open(self.data_file, 'w') as f:
with open(self.data_file, "w") as f:
json.dump(self._data, f, indent=2, default=str)
logger.debug("Soak data saved successfully")
except Exception as e:
@ -71,7 +66,7 @@ class SoakTracker:
username: str,
display_name: str,
channel_id: int,
message_id: int
message_id: int,
) -> None:
"""
Record a new soak mention.
@ -89,7 +84,7 @@ class SoakTracker:
"username": username,
"display_name": display_name,
"channel_id": str(channel_id),
"message_id": str(message_id)
"message_id": str(message_id),
}
# Update last_soak
@ -110,7 +105,9 @@ class SoakTracker:
self.save_data()
logger.info(f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}")
logger.info(
f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}"
)
def get_last_soak(self) -> Optional[Dict[str, Any]]:
"""
@ -135,10 +132,12 @@ class SoakTracker:
try:
# Parse ISO format timestamp
last_timestamp_str = last_soak["timestamp"]
if last_timestamp_str.endswith('Z'):
last_timestamp_str = last_timestamp_str[:-1] + '+00:00'
if last_timestamp_str.endswith("Z"):
last_timestamp_str = last_timestamp_str[:-1] + "+00:00"
last_timestamp = datetime.fromisoformat(last_timestamp_str.replace('Z', '+00:00'))
last_timestamp = datetime.fromisoformat(
last_timestamp_str.replace("Z", "+00:00")
)
# Ensure both times are timezone-aware
if last_timestamp.tzinfo is None:

View File

@ -3,6 +3,7 @@ Transaction Management Commands
Core transaction commands for roster management and transaction tracking.
"""
from typing import Optional
import asyncio
@ -21,6 +22,7 @@ from views.base import PaginationView
from services.transaction_service import transaction_service
from services.roster_service import roster_service
from services.team_service import team_service
# No longer need TransactionStatus enum
@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView):
all_transactions: list,
user_id: int,
timeout: float = 300.0,
show_page_numbers: bool = True
show_page_numbers: bool = True,
):
super().__init__(
pages=pages,
user_id=user_id,
timeout=timeout,
show_page_numbers=show_page_numbers
show_page_numbers=show_page_numbers,
)
self.all_transactions = all_transactions
@discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1)
async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(
label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1
)
async def show_move_ids(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Show all move IDs in an ephemeral message."""
self.increment_interaction_count()
if not self.all_transactions:
await interaction.response.send_message(
"No transactions to show.",
ephemeral=True
"No transactions to show.", ephemeral=True
)
return
@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView):
# Send the messages
if not messages:
await interaction.response.send_message(
"No transactions to display.",
ephemeral=True
"No transactions to display.", ephemeral=True
)
return
@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView):
class TransactionCommands(commands.Cog):
"""Transaction command handlers for roster management."""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.TransactionCommands')
self.logger = get_contextual_logger(f"{__name__}.TransactionCommands")
@app_commands.command(
name="mymoves",
description="View your pending and scheduled transactions"
name="mymoves", description="View your pending and scheduled transactions"
)
@app_commands.describe(
show_cancelled="Include cancelled transactions in the display (default: False)"
@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog):
@requires_team()
@logged_command("/mymoves")
async def my_moves(
self,
interaction: discord.Interaction,
show_cancelled: bool = False
self, interaction: discord.Interaction, show_cancelled: bool = False
):
"""Display user's transaction status and history."""
await interaction.response.defer()
# Get user's team
team = await get_user_major_league_team(interaction.user.id, get_config().sba_season)
team = await get_user_major_league_team(
interaction.user.id, get_config().sba_season
)
if not team:
await interaction.followup.send(
"❌ You don't appear to own a team in the current season.",
ephemeral=True
ephemeral=True,
)
return
# Get transactions in parallel
pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season)
frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season)
processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season)
pending_transactions = await pending_task
frozen_transactions = await frozen_task
processed_transactions = await processed_task
(
pending_transactions,
frozen_transactions,
processed_transactions,
) = await asyncio.gather(
transaction_service.get_pending_transactions(
team.abbrev, get_config().sba_season
),
transaction_service.get_frozen_transactions(
team.abbrev, get_config().sba_season
),
transaction_service.get_processed_transactions(
team.abbrev, get_config().sba_season
),
)
# Get cancelled if requested
cancelled_transactions = []
if show_cancelled:
cancelled_transactions = await transaction_service.get_team_transactions(
team.abbrev,
get_config().sba_season,
cancelled=True
team.abbrev, get_config().sba_season, cancelled=True
)
pages = self._create_my_moves_pages(
@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog):
pending_transactions,
frozen_transactions,
processed_transactions,
cancelled_transactions
cancelled_transactions,
)
# Collect all transactions for the "Show Move IDs" button
all_transactions = (
pending_transactions +
frozen_transactions +
processed_transactions +
cancelled_transactions
pending_transactions
+ frozen_transactions
+ processed_transactions
+ cancelled_transactions
)
# If only one page and no transactions, send without any buttons
@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog):
all_transactions=all_transactions,
user_id=interaction.user.id,
timeout=300.0,
show_page_numbers=True
show_page_numbers=True,
)
await interaction.followup.send(embed=view.get_current_embed(), view=view)
@app_commands.command(
name="legal",
description="Check roster legality for current and next week"
)
@app_commands.describe(
team="Team abbreviation to check (defaults to your team)"
name="legal", description="Check roster legality for current and next week"
)
@app_commands.describe(team="Team abbreviation to check (defaults to your team)")
@requires_team()
@logged_command("/legal")
async def legal(
self,
interaction: discord.Interaction,
team: Optional[str] = None
):
async def legal(self, interaction: discord.Interaction, team: Optional[str] = None):
"""Check roster legality and display detailed validation results."""
await interaction.response.defer()
# Get target team
if team:
target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season)
target_team = await team_service.get_team_by_abbrev(
team.upper(), get_config().sba_season
)
if not target_team:
await interaction.followup.send(
f"❌ Could not find team '{team}' in season {get_config().sba_season}.",
ephemeral=True
ephemeral=True,
)
return
else:
# Get user's team
user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season)
user_teams = await team_service.get_teams_by_owner(
interaction.user.id, get_config().sba_season
)
if not user_teams:
await interaction.followup.send(
"❌ You don't appear to own a team. Please specify a team abbreviation.",
ephemeral=True
ephemeral=True,
)
return
target_team = user_teams[0]
# Get rosters in parallel
current_roster, next_roster = await asyncio.gather(
roster_service.get_current_roster(target_team.id),
roster_service.get_next_roster(target_team.id)
roster_service.get_next_roster(target_team.id),
)
if not current_roster and not next_roster:
await interaction.followup.send(
f"❌ Could not retrieve roster data for {target_team.abbrev}.",
ephemeral=True
ephemeral=True,
)
return
# Validate rosters in parallel
validation_tasks = []
if current_roster:
validation_tasks.append(roster_service.validate_roster(current_roster))
else:
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
if next_roster:
validation_tasks.append(roster_service.validate_roster(next_roster))
else:
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
validation_results = await asyncio.gather(*validation_tasks)
current_validation = validation_results[0] if current_roster else None
next_validation = validation_results[1] if next_roster else None
embed = await self._create_legal_embed(
target_team,
current_roster,
next_roster,
next_roster,
current_validation,
next_validation
next_validation,
)
await interaction.followup.send(embed=embed)
def _create_my_moves_pages(
self,
team,
pending_transactions,
frozen_transactions,
processed_transactions,
cancelled_transactions
cancelled_transactions,
) -> list[discord.Embed]:
"""Create paginated embeds showing user's transaction status."""
@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog):
# Page 1: Summary + Pending Transactions
if pending_transactions:
total_pending = len(pending_transactions)
total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page
total_pages = (
total_pending + transactions_per_page - 1
) // transactions_per_page
for page_num in range(total_pages):
start_idx = page_num * transactions_per_page
@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
# Add team thumbnail if available
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
# Pending transactions for this page
@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog):
embed.add_field(
name=f"⏳ Pending Transactions ({total_pending} total)",
value="\n".join(pending_lines),
inline=False
inline=False,
)
# Add summary only on first page
@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog):
embed.add_field(
name="Summary",
value=", ".join(status_text) if status_text else "No active transactions",
inline=True
value=(
", ".join(status_text)
if status_text
else "No active transactions"
),
inline=True,
)
pages.append(embed)
@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
embed.add_field(
name="⏳ Pending Transactions",
value="No pending transactions",
inline=False
inline=False,
)
total_frozen = len(frozen_transactions)
@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog):
embed.add_field(
name="Summary",
value=", ".join(status_text) if status_text else "No active transactions",
inline=True
value=(
", ".join(status_text) if status_text else "No active transactions"
),
inline=True,
)
pages.append(embed)
@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
frozen_lines = [format_transaction(tx) for tx in frozen_transactions]
@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog):
embed.add_field(
name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)",
value="\n".join(frozen_lines),
inline=False
inline=False,
)
pages.append(embed)
@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20
processed_lines = [
format_transaction(tx) for tx in processed_transactions[-20:]
] # Last 20
embed.add_field(
name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)",
value="\n".join(processed_lines),
inline=False
inline=False,
)
pages.append(embed)
@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO
color=EmbedColors.INFO,
)
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20
cancelled_lines = [
format_transaction(tx) for tx in cancelled_transactions[-20:]
] # Last 20
embed.add_field(
name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)",
value="\n".join(cancelled_lines),
inline=False
inline=False,
)
pages.append(embed)
@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog):
page.set_footer(text="Use /legal to check roster legality")
return pages
async def _create_legal_embed(
self,
team,
current_roster,
next_roster,
current_validation,
next_validation
self, team, current_roster, next_roster, current_validation, next_validation
) -> discord.Embed:
"""Create embed showing roster legality check results."""
# Determine overall status
overall_legal = True
if current_validation and not current_validation.is_legal:
overall_legal = False
if next_validation and not next_validation.is_legal:
overall_legal = False
status_emoji = "" if overall_legal else ""
embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR
embed = EmbedTemplate.create_base_embed(
title=f"{status_emoji} Roster Check - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}",
color=embed_color
color=embed_color,
)
# Add team thumbnail if available
if hasattr(team, 'thumbnail') and team.thumbnail:
if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail)
# Current week roster
if current_roster and current_validation:
current_lines = []
current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL")
current_lines.append(
f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL"
)
current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}")
if current_validation.errors:
current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}")
for error in current_validation.errors[:3]: # Show first 3 errors
current_lines.append(f"{error}")
if current_validation.warnings:
current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}")
current_lines.append(
f"**⚠️ Warnings:** {len(current_validation.warnings)}"
)
for warning in current_validation.warnings[:2]: # Show first 2 warnings
current_lines.append(f"{warning}")
embed.add_field(
name=f"{current_validation.status_emoji} Current Week",
value="\n".join(current_lines),
inline=True
inline=True,
)
else:
embed.add_field(
name="❓ Current Week",
value="Roster data not available",
inline=True
name="❓ Current Week", value="Roster data not available", inline=True
)
# Next week roster
# Next week roster
if next_roster and next_validation:
next_lines = []
next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL")
next_lines.append(
f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL"
)
next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}")
if next_validation.errors:
next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}")
for error in next_validation.errors[:3]: # Show first 3 errors
next_lines.append(f"{error}")
if next_validation.warnings:
next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}")
for warning in next_validation.warnings[:2]: # Show first 2 warnings
next_lines.append(f"{warning}")
embed.add_field(
name=f"{next_validation.status_emoji} Next Week",
value="\n".join(next_lines),
inline=True
inline=True,
)
else:
embed.add_field(
name="❓ Next Week",
value="Roster data not available",
inline=True
name="❓ Next Week", value="Roster data not available", inline=True
)
# Overall status
if overall_legal:
embed.add_field(
name="Overall Status",
value="✅ All rosters are legal",
inline=False
name="Overall Status", value="✅ All rosters are legal", inline=False
)
else:
embed.add_field(
name="Overall Status",
name="Overall Status",
value="❌ Roster violations found - please review and correct",
inline=False
inline=False,
)
embed.set_footer(text="Roster validation based on current league rules")
return embed
async def setup(bot: commands.Bot):
"""Load the transaction commands cog."""
await bot.add_cog(TransactionCommands(bot))
await bot.add_cog(TransactionCommands(bot))

View File

@ -3,6 +3,7 @@ Trade Channel Tracker
Provides persistent tracking of bot-created trade discussion channels using JSON file storage.
"""
import json
from datetime import datetime, UTC
from pathlib import Path
@ -12,7 +13,7 @@ import discord
from utils.logging import get_contextual_logger
logger = get_contextual_logger(f'{__name__}.TradeChannelTracker')
logger = get_contextual_logger(f"{__name__}.TradeChannelTracker")
class TradeChannelTracker:
@ -26,7 +27,7 @@ class TradeChannelTracker:
- Automatic stale entry removal
"""
def __init__(self, data_file: str = "data/trade_channels.json"):
def __init__(self, data_file: str = "storage/trade_channels.json"):
"""
Initialize the trade channel tracker.
@ -42,9 +43,11 @@ class TradeChannelTracker:
"""Load channel data from JSON file."""
try:
if self.data_file.exists():
with open(self.data_file, 'r') as f:
with open(self.data_file, "r") as f:
self._data = json.load(f)
logger.debug(f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels")
logger.debug(
f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels"
)
else:
self._data = {"trade_channels": {}}
logger.info("No existing trade channel data found, starting fresh")
@ -55,7 +58,7 @@ class TradeChannelTracker:
def save_data(self) -> None:
"""Save channel data to JSON file."""
try:
with open(self.data_file, 'w') as f:
with open(self.data_file, "w") as f:
json.dump(self._data, f, indent=2, default=str)
logger.debug("Trade channel data saved successfully")
except Exception as e:
@ -67,7 +70,7 @@ class TradeChannelTracker:
trade_id: str,
team1_abbrev: str,
team2_abbrev: str,
creator_id: int
creator_id: int,
) -> None:
"""
Add a new trade channel to tracking.
@ -87,10 +90,12 @@ class TradeChannelTracker:
"team1_abbrev": team1_abbrev,
"team2_abbrev": team2_abbrev,
"created_at": datetime.now(UTC).isoformat(),
"creator_id": str(creator_id)
"creator_id": str(creator_id),
}
self.save_data()
logger.info(f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})")
logger.info(
f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})"
)
def remove_channel(self, channel_id: int) -> None:
"""
@ -108,7 +113,9 @@ class TradeChannelTracker:
channel_name = channel_data["name"]
del channels[channel_key]
self.save_data()
logger.info(f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})")
logger.info(
f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})"
)
def get_channel_by_trade_id(self, trade_id: str) -> Optional[Dict[str, Any]]:
"""
@ -175,7 +182,9 @@ class TradeChannelTracker:
channel_name = channels[channel_id_str].get("name", "unknown")
trade_id = channels[channel_id_str].get("trade_id", "unknown")
del channels[channel_id_str]
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})")
logger.info(
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})"
)
if stale_entries:
self.save_data()

View File

@ -3,6 +3,7 @@ Voice Channel Cleanup Service
Provides automatic cleanup of empty voice channels with restart resilience.
"""
import logging
import discord
@ -12,7 +13,7 @@ from .tracker import VoiceChannelTracker
from commands.gameplay.scorecard_tracker import ScorecardTracker
from utils.logging import get_contextual_logger
logger = logging.getLogger(f'{__name__}.VoiceChannelCleanupService')
logger = logging.getLogger(f"{__name__}.VoiceChannelCleanupService")
class VoiceChannelCleanupService:
@ -27,7 +28,9 @@ class VoiceChannelCleanupService:
- Automatic scorecard unpublishing when voice channel is cleaned up
"""
def __init__(self, bot: commands.Bot, data_file: str = "data/voice_channels.json"):
def __init__(
self, bot: commands.Bot, data_file: str = "storage/voice_channels.json"
):
"""
Initialize the cleanup service.
@ -36,10 +39,10 @@ class VoiceChannelCleanupService:
data_file: Path to the JSON data file for persistence
"""
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.VoiceChannelCleanupService')
self.logger = get_contextual_logger(f"{__name__}.VoiceChannelCleanupService")
self.tracker = VoiceChannelTracker(data_file)
self.scorecard_tracker = ScorecardTracker()
self.empty_threshold = 5 # Delete after 5 minutes empty
self.empty_threshold = 5 # Delete after 5 minutes empty
# Start the cleanup task - @before_loop will wait for bot readiness
self.cleanup_loop.start()
@ -90,13 +93,17 @@ class VoiceChannelCleanupService:
guild = bot.get_guild(guild_id)
if not guild:
self.logger.warning(f"Guild {guild_id} not found, removing channel {channel_data['name']}")
self.logger.warning(
f"Guild {guild_id} not found, removing channel {channel_data['name']}"
)
channels_to_remove.append(channel_id)
continue
channel = guild.get_channel(channel_id)
if not channel:
self.logger.warning(f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists")
self.logger.warning(
f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists"
)
channels_to_remove.append(channel_id)
continue
@ -121,18 +128,26 @@ class VoiceChannelCleanupService:
if channel_data and channel_data.get("text_channel_id"):
try:
text_channel_id_int = int(channel_data["text_channel_id"])
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
text_channel_id_int
)
if was_unpublished:
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)")
self.logger.info(
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Invalid text_channel_id in stale voice channel data: {e}")
self.logger.warning(
f"Invalid text_channel_id in stale voice channel data: {e}"
)
# Also clean up any additional stale entries
stale_removed = self.tracker.cleanup_stale_entries(valid_channel_ids)
total_removed = len(channels_to_remove) + stale_removed
if total_removed > 0:
self.logger.info(f"Cleaned up {total_removed} stale channel tracking entries")
self.logger.info(
f"Cleaned up {total_removed} stale channel tracking entries"
)
self.logger.info(f"Verified {len(valid_channel_ids)} valid tracked channels")
@ -149,10 +164,14 @@ class VoiceChannelCleanupService:
await self.update_all_channel_statuses(bot)
# Get channels ready for cleanup
channels_for_cleanup = self.tracker.get_channels_for_cleanup(self.empty_threshold)
channels_for_cleanup = self.tracker.get_channels_for_cleanup(
self.empty_threshold
)
if channels_for_cleanup:
self.logger.info(f"Found {len(channels_for_cleanup)} channels ready for cleanup")
self.logger.info(
f"Found {len(channels_for_cleanup)} channels ready for cleanup"
)
# Delete empty channels
for channel_data in channels_for_cleanup:
@ -182,12 +201,16 @@ class VoiceChannelCleanupService:
guild = bot.get_guild(guild_id)
if not guild:
self.logger.debug(f"Guild {guild_id} not found for channel {channel_data['name']}")
self.logger.debug(
f"Guild {guild_id} not found for channel {channel_data['name']}"
)
return
channel = guild.get_channel(channel_id)
if not channel:
self.logger.debug(f"Channel {channel_data['name']} no longer exists, removing from tracking")
self.logger.debug(
f"Channel {channel_data['name']} no longer exists, removing from tracking"
)
self.tracker.remove_channel(channel_id)
# Unpublish associated scorecard if it exists
@ -195,17 +218,25 @@ class VoiceChannelCleanupService:
if text_channel_id:
try:
text_channel_id_int = int(text_channel_id)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
text_channel_id_int
)
if was_unpublished:
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)")
self.logger.info(
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Invalid text_channel_id in manually deleted voice channel data: {e}")
self.logger.warning(
f"Invalid text_channel_id in manually deleted voice channel data: {e}"
)
return
# Ensure it's a voice channel before checking members
if not isinstance(channel, discord.VoiceChannel):
self.logger.warning(f"Channel {channel_data['name']} is not a voice channel, removing from tracking")
self.logger.warning(
f"Channel {channel_data['name']} is not a voice channel, removing from tracking"
)
self.tracker.remove_channel(channel_id)
# Unpublish associated scorecard if it exists
@ -213,11 +244,17 @@ class VoiceChannelCleanupService:
if text_channel_id:
try:
text_channel_id_int = int(text_channel_id)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
text_channel_id_int
)
if was_unpublished:
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)")
self.logger.info(
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Invalid text_channel_id in wrong channel type data: {e}")
self.logger.warning(
f"Invalid text_channel_id in wrong channel type data: {e}"
)
return
@ -225,11 +262,15 @@ class VoiceChannelCleanupService:
is_empty = len(channel.members) == 0
self.tracker.update_channel_status(channel_id, is_empty)
self.logger.debug(f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
f"({len(channel.members)} members)")
self.logger.debug(
f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
f"({len(channel.members)} members)"
)
except Exception as e:
self.logger.error(f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}")
self.logger.error(
f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}"
)
async def cleanup_channel(self, bot: commands.Bot, channel_data: dict) -> None:
"""
@ -246,25 +287,33 @@ class VoiceChannelCleanupService:
guild = bot.get_guild(guild_id)
if not guild:
self.logger.info(f"Guild {guild_id} not found, removing tracking for {channel_name}")
self.logger.info(
f"Guild {guild_id} not found, removing tracking for {channel_name}"
)
self.tracker.remove_channel(channel_id)
return
channel = guild.get_channel(channel_id)
if not channel:
self.logger.info(f"Channel {channel_name} already deleted, removing from tracking")
self.logger.info(
f"Channel {channel_name} already deleted, removing from tracking"
)
self.tracker.remove_channel(channel_id)
return
# Ensure it's a voice channel before checking members
if not isinstance(channel, discord.VoiceChannel):
self.logger.warning(f"Channel {channel_name} is not a voice channel, removing from tracking")
self.logger.warning(
f"Channel {channel_name} is not a voice channel, removing from tracking"
)
self.tracker.remove_channel(channel_id)
return
# Final check: make sure channel is still empty before deleting
if len(channel.members) > 0:
self.logger.info(f"Channel {channel_name} is no longer empty, skipping cleanup")
self.logger.info(
f"Channel {channel_name} is no longer empty, skipping cleanup"
)
self.tracker.update_channel_status(channel_id, False)
return
@ -272,24 +321,36 @@ class VoiceChannelCleanupService:
await channel.delete(reason="Automatic cleanup - empty for 5+ minutes")
self.tracker.remove_channel(channel_id)
self.logger.info(f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})")
self.logger.info(
f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})"
)
# Unpublish associated scorecard if it exists
text_channel_id = channel_data.get("text_channel_id")
if text_channel_id:
try:
text_channel_id_int = int(text_channel_id)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
text_channel_id_int
)
if was_unpublished:
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)")
self.logger.info(
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)"
)
else:
self.logger.debug(f"No scorecard found for text channel {text_channel_id_int}")
self.logger.debug(
f"No scorecard found for text channel {text_channel_id_int}"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
self.logger.warning(
f"Invalid text_channel_id in voice channel data: {e}"
)
except discord.NotFound:
# Channel was already deleted
self.logger.info(f"Channel {channel_data.get('name', 'unknown')} was already deleted")
self.logger.info(
f"Channel {channel_data.get('name', 'unknown')} was already deleted"
)
self.tracker.remove_channel(int(channel_data["channel_id"]))
# Still try to unpublish associated scorecard
@ -297,15 +358,25 @@ class VoiceChannelCleanupService:
if text_channel_id:
try:
text_channel_id_int = int(text_channel_id)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
text_channel_id_int
)
if was_unpublished:
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)")
self.logger.info(
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
self.logger.warning(
f"Invalid text_channel_id in voice channel data: {e}"
)
except discord.Forbidden:
self.logger.error(f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}")
self.logger.error(
f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}"
)
except Exception as e:
self.logger.error(f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}")
self.logger.error(
f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}"
)
def get_tracker(self) -> VoiceChannelTracker:
"""
@ -330,7 +401,7 @@ class VoiceChannelCleanupService:
"running": self.cleanup_loop.is_running(),
"total_tracked": len(all_channels),
"empty_channels": len(empty_channels),
"empty_threshold": self.empty_threshold
"empty_threshold": self.empty_threshold,
}
@ -344,4 +415,4 @@ def setup_voice_cleanup(bot: commands.Bot) -> VoiceChannelCleanupService:
Returns:
VoiceChannelCleanupService instance
"""
return VoiceChannelCleanupService(bot)
return VoiceChannelCleanupService(bot)

View File

@ -3,6 +3,7 @@ Voice Channel Tracker
Provides persistent tracking of bot-created voice channels using JSON file storage.
"""
import json
import logging
from datetime import datetime, timedelta, UTC
@ -11,7 +12,7 @@ from typing import Dict, List, Optional, Any
import discord
logger = logging.getLogger(f'{__name__}.VoiceChannelTracker')
logger = logging.getLogger(f"{__name__}.VoiceChannelTracker")
class VoiceChannelTracker:
@ -25,7 +26,7 @@ class VoiceChannelTracker:
- Automatic stale entry removal
"""
def __init__(self, data_file: str = "data/voice_channels.json"):
def __init__(self, data_file: str = "storage/voice_channels.json"):
"""
Initialize the voice channel tracker.
@ -41,9 +42,11 @@ class VoiceChannelTracker:
"""Load channel data from JSON file."""
try:
if self.data_file.exists():
with open(self.data_file, 'r') as f:
with open(self.data_file, "r") as f:
self._data = json.load(f)
logger.debug(f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels")
logger.debug(
f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels"
)
else:
self._data = {"voice_channels": {}}
logger.info("No existing voice channel data found, starting fresh")
@ -54,7 +57,7 @@ class VoiceChannelTracker:
def save_data(self) -> None:
"""Save channel data to JSON file."""
try:
with open(self.data_file, 'w') as f:
with open(self.data_file, "w") as f:
json.dump(self._data, f, indent=2, default=str)
logger.debug("Voice channel data saved successfully")
except Exception as e:
@ -65,7 +68,7 @@ class VoiceChannelTracker:
channel: discord.VoiceChannel,
channel_type: str,
creator_id: int,
text_channel_id: Optional[int] = None
text_channel_id: Optional[int] = None,
) -> None:
"""
Add a new channel to tracking.
@ -85,7 +88,7 @@ class VoiceChannelTracker:
"last_checked": datetime.now(UTC).isoformat(),
"empty_since": None,
"creator_id": str(creator_id),
"text_channel_id": str(text_channel_id) if text_channel_id else None
"text_channel_id": str(text_channel_id) if text_channel_id else None,
}
self.save_data()
logger.info(f"Added channel to tracking: {channel.name} (ID: {channel.id})")
@ -130,9 +133,13 @@ class VoiceChannelTracker:
channel_name = channels[channel_key]["name"]
del channels[channel_key]
self.save_data()
logger.info(f"Removed channel from tracking: {channel_name} (ID: {channel_id})")
logger.info(
f"Removed channel from tracking: {channel_name} (ID: {channel_id})"
)
def get_channels_for_cleanup(self, empty_threshold_minutes: int = 15) -> List[Dict[str, Any]]:
def get_channels_for_cleanup(
self, empty_threshold_minutes: int = 15
) -> List[Dict[str, Any]]:
"""
Get channels that should be deleted based on empty duration.
@ -153,10 +160,12 @@ class VoiceChannelTracker:
# Parse empty_since timestamp
empty_since_str = channel_data["empty_since"]
# Handle both with and without timezone info
if empty_since_str.endswith('Z'):
empty_since_str = empty_since_str[:-1] + '+00:00'
if empty_since_str.endswith("Z"):
empty_since_str = empty_since_str[:-1] + "+00:00"
empty_since = datetime.fromisoformat(empty_since_str.replace('Z', '+00:00'))
empty_since = datetime.fromisoformat(
empty_since_str.replace("Z", "+00:00")
)
# Remove timezone info for comparison (both times are UTC)
if empty_since.tzinfo:
@ -164,10 +173,14 @@ class VoiceChannelTracker:
if empty_since <= cutoff_time:
cleanup_candidates.append(channel_data)
logger.debug(f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})")
logger.debug(
f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})"
)
except (ValueError, TypeError) as e:
logger.warning(f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}")
logger.warning(
f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}"
)
return cleanup_candidates
@ -242,9 +255,11 @@ class VoiceChannelTracker:
for channel_id_str in stale_entries:
channel_name = channels[channel_id_str].get("name", "unknown")
del channels[channel_id_str]
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})")
logger.info(
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})"
)
if stale_entries:
self.save_data()
return len(stale_entries)
return len(stale_entries)

View File

@ -36,8 +36,11 @@ services:
# Volume mounts
volumes:
# Google Sheets credentials (required)
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data}:/app/data:ro
# Google Sheets credentials (read-only, file mount)
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data/major-domo-service-creds.json}:/app/data/major-domo-service-creds.json:ro
# Runtime state files (writable) - scorecards, voice channels, trade channels, soak data
- ${STATE_HOST_PATH:-./storage}:/app/storage:rw
# Logs directory (persistent) - mounted to /app/logs where the application expects it
- ${LOGS_HOST_PATH:-./logs}:/app/logs:rw

View File

@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0
Handles game schedule and results retrieval and processing.
"""
import asyncio
import logging
from typing import Optional, List, Dict, Tuple
@ -102,10 +103,10 @@ class ScheduleService:
# If weeks not specified, try a reasonable range (18 weeks typical)
week_range = range(1, (weeks + 1) if weeks else 19)
for week in week_range:
week_games = await self.get_week_schedule(season, week)
# Filter games involving this team
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in week_range]
)
for week_games in all_week_games:
for game in week_games:
if (
game.away_team.abbrev.upper() == team_abbrev_upper
@ -135,15 +136,13 @@ class ScheduleService:
recent_games = []
# Get games from recent weeks
for week_offset in range(weeks_back):
# This is simplified - in production you'd want to determine current week
week = 10 - week_offset # Assuming we're around week 10
if week <= 0:
break
week_games = await self.get_week_schedule(season, week)
# Only include completed games
weeks_to_fetch = [
(10 - offset) for offset in range(weeks_back) if (10 - offset) > 0
]
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in weeks_to_fetch]
)
for week_games in all_week_games:
completed_games = [game for game in week_games if game.is_completed]
recent_games.extend(completed_games)
@ -157,13 +156,12 @@ class ScheduleService:
logger.error(f"Error getting recent games: {e}")
return []
async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]:
async def get_upcoming_games(self, season: int) -> List[Game]:
"""
Get upcoming scheduled games by scanning multiple weeks.
Get upcoming scheduled games by scanning all weeks.
Args:
season: Season number
weeks_ahead: Number of weeks to scan ahead (default 6)
Returns:
List of upcoming Game instances
@ -171,20 +169,16 @@ class ScheduleService:
try:
upcoming_games = []
# Scan through weeks to find games without scores
for week in range(1, 19): # Standard season length
week_games = await self.get_week_schedule(season, week)
# Find games without scores (not yet played)
# Fetch all weeks in parallel and filter for incomplete games
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in range(1, 19)]
)
for week_games in all_week_games:
upcoming_games_week = [
game for game in week_games if not game.is_completed
]
upcoming_games.extend(upcoming_games_week)
# If we found upcoming games, we can limit how many more weeks to check
if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit
break
# Sort by week, then game number
upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0))

View File

@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0
Handles batting and pitching statistics retrieval and processing.
"""
import asyncio
import logging
from typing import Optional
@ -144,11 +145,10 @@ class StatsService:
"""
try:
# Get both types of stats concurrently
batting_task = self.get_batting_stats(player_id, season)
pitching_task = self.get_pitching_stats(player_id, season)
batting_stats = await batting_task
pitching_stats = await pitching_task
batting_stats, pitching_stats = await asyncio.gather(
self.get_batting_stats(player_id, season),
self.get_pitching_stats(player_id, season),
)
logger.debug(
f"Retrieved stats for player {player_id}: "

View File

@ -4,6 +4,7 @@ Trade Builder Service
Extends the TransactionBuilder to support multi-team trades and player exchanges.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Set
from datetime import datetime, timezone
@ -524,14 +525,22 @@ class TradeBuilder:
# Validate each team's roster after the trade
for participant in self.trade.participants:
team_id = participant.team.id
result.team_abbrevs[team_id] = participant.team.abbrev
if team_id in self._team_builders:
builder = self._team_builders[team_id]
roster_validation = await builder.validate_transaction(next_week)
result.team_abbrevs[participant.team.id] = participant.team.abbrev
team_ids_to_validate = [
participant.team.id
for participant in self.trade.participants
if participant.team.id in self._team_builders
]
if team_ids_to_validate:
validations = await asyncio.gather(
*[
self._team_builders[tid].validate_transaction(next_week)
for tid in team_ids_to_validate
]
)
for team_id, roster_validation in zip(team_ids_to_validate, validations):
result.participant_validations[team_id] = roster_validation
if not roster_validation.is_legal:
result.is_legal = False

View File

@ -95,7 +95,7 @@ class LiveScorebugTracker:
# Don't return - still update voice channels
else:
# Get all published scorecards
all_scorecards = self.scorecard_tracker.get_all_scorecards()
all_scorecards = await self.scorecard_tracker.get_all_scorecards()
if not all_scorecards:
# No active scorebugs - clear the channel and hide it
@ -112,17 +112,16 @@ class LiveScorebugTracker:
for text_channel_id, sheet_url in all_scorecards:
try:
scorebug_data = await self.scorebug_service.read_scorebug_data(
sheet_url, full_length=False # Compact view for live channel
sheet_url,
full_length=False, # Compact view for live channel
)
# Only include active (non-final) games
if scorebug_data.is_active:
# Get team data
away_team = await team_service.get_team(
scorebug_data.away_team_id
)
home_team = await team_service.get_team(
scorebug_data.home_team_id
away_team, home_team = await asyncio.gather(
team_service.get_team(scorebug_data.away_team_id),
team_service.get_team(scorebug_data.home_team_id),
)
if away_team is None or home_team is None:
@ -188,9 +187,8 @@ class LiveScorebugTracker:
embeds: List of scorebug embeds
"""
try:
# Clear old messages
async for message in channel.history(limit=25):
await message.delete()
# Clear old messages using bulk delete
await channel.purge(limit=25)
# Post new scorebugs (Discord allows up to 10 embeds per message)
if len(embeds) <= 10:
@ -216,9 +214,8 @@ class LiveScorebugTracker:
channel: Discord text channel
"""
try:
# Clear all messages
async for message in channel.history(limit=25):
await message.delete()
# Clear all messages using bulk delete
await channel.purge(limit=25)
self.logger.info("Cleared live-sba-scores channel (no active games)")

View File

@ -0,0 +1,284 @@
"""
Tests for schedule service functionality.
Covers get_week_schedule, get_team_schedule, get_recent_games,
get_upcoming_games, and group_games_by_series verifying the
asyncio.gather parallelization and post-fetch filtering logic.
"""
import pytest
from unittest.mock import AsyncMock, patch
from services.schedule_service import ScheduleService
from tests.factories import GameFactory, TeamFactory
def _game(game_id, week, away_abbrev, home_abbrev, **kwargs):
"""Create a Game with distinct team IDs per matchup."""
return GameFactory.create(
id=game_id,
week=week,
away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev),
home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev),
**kwargs,
)
class TestGetWeekSchedule:
"""Tests for ScheduleService.get_week_schedule — the HTTP layer."""
@pytest.fixture
def service(self):
svc = ScheduleService()
svc.get_client = AsyncMock()
return svc
@pytest.mark.asyncio
async def test_success(self, service):
"""get_week_schedule returns parsed Game objects on a normal response."""
mock_client = AsyncMock()
mock_client.get.return_value = {
"games": [
{
"id": 1,
"season": 12,
"week": 5,
"game_num": 1,
"season_type": "regular",
"away_team": {
"id": 10,
"abbrev": "NYY",
"sname": "NYY",
"lname": "New York",
"season": 12,
},
"home_team": {
"id": 11,
"abbrev": "BOS",
"sname": "BOS",
"lname": "Boston",
"season": 12,
},
"away_score": 4,
"home_score": 2,
}
]
}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 5)
assert len(games) == 1
assert games[0].away_team.abbrev == "NYY"
assert games[0].home_team.abbrev == "BOS"
assert games[0].is_completed
@pytest.mark.asyncio
async def test_empty_response(self, service):
"""get_week_schedule returns [] when the API has no games."""
mock_client = AsyncMock()
mock_client.get.return_value = {"games": []}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 99)
assert games == []
@pytest.mark.asyncio
async def test_api_error_returns_empty(self, service):
"""get_week_schedule returns [] on API error (no exception raised)."""
service.get_client.side_effect = Exception("connection refused")
games = await service.get_week_schedule(12, 1)
assert games == []
@pytest.mark.asyncio
async def test_missing_games_key(self, service):
"""get_week_schedule returns [] when response lacks 'games' key."""
mock_client = AsyncMock()
mock_client.get.return_value = {"status": "ok"}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 1)
assert games == []
class TestGetTeamSchedule:
"""Tests for get_team_schedule — gather + team-abbrev filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_filters_by_team_case_insensitive(self, service):
"""get_team_schedule returns only games involving the requested team,
regardless of abbreviation casing."""
week1 = [
_game(1, 1, "NYY", "BOS", away_score=3, home_score=1),
_game(2, 1, "LAD", "CHC", away_score=5, home_score=2),
]
week2 = [
_game(3, 2, "BOS", "NYY", away_score=2, home_score=4),
]
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.side_effect = [week1, week2]
result = await service.get_team_schedule(12, "nyy", weeks=2)
assert len(result) == 2
assert all(
g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result
)
@pytest.mark.asyncio
async def test_full_season_fetches_18_weeks(self, service):
"""When weeks is None, all 18 weeks are fetched via gather."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_team_schedule(12, "NYY")
assert mock.call_count == 18
@pytest.mark.asyncio
async def test_limited_weeks(self, service):
"""When weeks=5, only 5 weeks are fetched."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_team_schedule(12, "NYY", weeks=5)
assert mock.call_count == 5
class TestGetRecentGames:
"""Tests for get_recent_games — gather + completed-only filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_returns_only_completed_games(self, service):
"""get_recent_games filters out games without scores."""
completed = GameFactory.completed(id=1, week=10)
incomplete = GameFactory.upcoming(id=2, week=10)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = [completed, incomplete]
result = await service.get_recent_games(12, weeks_back=1)
assert len(result) == 1
assert result[0].is_completed
@pytest.mark.asyncio
async def test_sorted_descending_by_week_and_game_num(self, service):
"""Recent games are sorted most-recent first."""
game_w10 = GameFactory.completed(id=1, week=10, game_num=2)
game_w9 = GameFactory.completed(id=2, week=9, game_num=1)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.side_effect = [[game_w10], [game_w9]]
result = await service.get_recent_games(12, weeks_back=2)
assert result[0].week == 10
assert result[1].week == 9
@pytest.mark.asyncio
async def test_skips_negative_weeks(self, service):
"""Weeks that would be <= 0 are excluded from fetch."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_recent_games(12, weeks_back=15)
# weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks
assert mock.call_count == 10
class TestGetUpcomingGames:
"""Tests for get_upcoming_games — gather all 18 weeks + incomplete filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_returns_only_incomplete_games(self, service):
"""get_upcoming_games filters out completed games."""
completed = GameFactory.completed(id=1, week=5)
upcoming = GameFactory.upcoming(id=2, week=5)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = [completed, upcoming]
result = await service.get_upcoming_games(12)
assert len(result) == 18 # 1 incomplete game per week × 18 weeks
assert all(not g.is_completed for g in result)
@pytest.mark.asyncio
async def test_sorted_ascending_by_week_and_game_num(self, service):
"""Upcoming games are sorted earliest first."""
game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1)
game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
def side_effect(season, week):
if week == 1:
return [game_w1]
if week == 3:
return [game_w3]
return []
mock.side_effect = side_effect
result = await service.get_upcoming_games(12)
assert result[0].week == 1
assert result[1].week == 3
@pytest.mark.asyncio
async def test_fetches_all_18_weeks(self, service):
"""All 18 weeks are fetched in parallel (no early exit)."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_upcoming_games(12)
assert mock.call_count == 18
class TestGroupGamesBySeries:
"""Tests for group_games_by_series — synchronous grouping logic."""
@pytest.fixture
def service(self):
return ScheduleService()
def test_groups_by_alphabetical_pairing(self, service):
"""Games between the same two teams are grouped under one key,
with the alphabetically-first team first in the tuple."""
games = [
_game(1, 1, "NYY", "BOS", game_num=1),
_game(2, 1, "BOS", "NYY", game_num=2),
_game(3, 1, "LAD", "CHC", game_num=1),
]
result = service.group_games_by_series(games)
assert ("BOS", "NYY") in result
assert len(result[("BOS", "NYY")]) == 2
assert ("CHC", "LAD") in result
assert len(result[("CHC", "LAD")]) == 1
def test_sorted_by_game_num_within_series(self, service):
"""Games within each series are sorted by game_num."""
games = [
_game(1, 1, "NYY", "BOS", game_num=3),
_game(2, 1, "NYY", "BOS", game_num=1),
_game(3, 1, "NYY", "BOS", game_num=2),
]
result = service.group_games_by_series(games)
series = result[("BOS", "NYY")]
assert [g.game_num for g in series] == [1, 2, 3]
def test_empty_input(self, service):
"""Empty games list returns empty dict."""
assert service.group_games_by_series([]) == {}

View File

@ -0,0 +1,111 @@
"""
Tests for StatsService
Validates stats service functionality including concurrent stat retrieval
and error handling in get_player_stats().
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from services.stats_service import StatsService
class TestStatsServiceGetPlayerStats:
"""Test StatsService.get_player_stats() concurrent retrieval."""
@pytest.fixture
def service(self):
"""Create a fresh StatsService instance for testing."""
return StatsService()
@pytest.fixture
def mock_batting_stats(self):
"""Create a mock BattingStats object."""
stats = MagicMock()
stats.avg = 0.300
return stats
@pytest.fixture
def mock_pitching_stats(self):
"""Create a mock PitchingStats object."""
stats = MagicMock()
stats.era = 3.50
return stats
@pytest.mark.asyncio
async def test_both_stats_returned(
self, service, mock_batting_stats, mock_pitching_stats
):
"""When both batting and pitching stats exist, both are returned.
Verifies that get_player_stats returns a tuple of (batting, pitching)
when both stat types are available for the player.
"""
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
batting, pitching = await service.get_player_stats(player_id=100, season=12)
assert batting is mock_batting_stats
assert pitching is mock_pitching_stats
service.get_batting_stats.assert_called_once_with(100, 12)
service.get_pitching_stats.assert_called_once_with(100, 12)
@pytest.mark.asyncio
async def test_batting_only(self, service, mock_batting_stats):
"""When only batting stats exist, pitching is None.
Covers the case of a position player with no pitching record.
"""
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=200, season=12)
assert batting is mock_batting_stats
assert pitching is None
@pytest.mark.asyncio
async def test_pitching_only(self, service, mock_pitching_stats):
"""When only pitching stats exist, batting is None.
Covers the case of a pitcher with no batting record.
"""
service.get_batting_stats = AsyncMock(return_value=None)
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
batting, pitching = await service.get_player_stats(player_id=300, season=12)
assert batting is None
assert pitching is mock_pitching_stats
@pytest.mark.asyncio
async def test_no_stats_found(self, service):
"""When no stats exist for the player, both are None.
Covers the case where a player has no stats for the given season
(e.g., didn't play).
"""
service.get_batting_stats = AsyncMock(return_value=None)
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=400, season=12)
assert batting is None
assert pitching is None
@pytest.mark.asyncio
async def test_exception_returns_none_tuple(self, service):
"""When an exception occurs, (None, None) is returned.
The get_player_stats method wraps both calls in a try/except and
returns (None, None) on any error, ensuring callers always get a tuple.
"""
service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down"))
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=500, season=12)
assert batting is None
assert pitching is None

View File

@ -3,10 +3,16 @@ Tests for shared autocomplete utility functions.
Validates the shared autocomplete functions used across multiple command modules.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from utils.autocomplete import player_autocomplete, team_autocomplete, major_league_team_autocomplete
import utils.autocomplete
from utils.autocomplete import (
player_autocomplete,
team_autocomplete,
major_league_team_autocomplete,
)
from tests.factories import PlayerFactory, TeamFactory
from models.team import RosterType
@ -14,6 +20,13 @@ from models.team import RosterType
class TestPlayerAutocomplete:
"""Test player autocomplete functionality."""
@pytest.fixture(autouse=True)
def clear_user_team_cache(self):
"""Clear the module-level user team cache before each test to prevent interference."""
utils.autocomplete._user_team_cache.clear()
yield
utils.autocomplete._user_team_cache.clear()
@pytest.fixture
def mock_interaction(self):
"""Create a mock Discord interaction."""
@ -26,41 +39,43 @@ class TestPlayerAutocomplete:
"""Test successful player autocomplete."""
mock_players = [
PlayerFactory.mike_trout(id=1),
PlayerFactory.ronald_acuna(id=2)
PlayerFactory.ronald_acuna(id=2),
]
with patch('utils.autocomplete.player_service') as mock_service:
with patch("utils.autocomplete.player_service") as mock_service:
mock_service.search_players = AsyncMock(return_value=mock_players)
choices = await player_autocomplete(mock_interaction, 'Trout')
choices = await player_autocomplete(mock_interaction, "Trout")
assert len(choices) == 2
assert choices[0].name == 'Mike Trout (CF)'
assert choices[0].value == 'Mike Trout'
assert choices[1].name == 'Ronald Acuna Jr. (OF)'
assert choices[1].value == 'Ronald Acuna Jr.'
assert choices[0].name == "Mike Trout (CF)"
assert choices[0].value == "Mike Trout"
assert choices[1].name == "Ronald Acuna Jr. (OF)"
assert choices[1].value == "Ronald Acuna Jr."
@pytest.mark.asyncio
async def test_player_autocomplete_with_team_info(self, mock_interaction):
"""Test player autocomplete with team information."""
mock_team = TeamFactory.create(id=499, abbrev='LAA', sname='Angels', lname='Los Angeles Angels')
mock_team = TeamFactory.create(
id=499, abbrev="LAA", sname="Angels", lname="Los Angeles Angels"
)
mock_player = PlayerFactory.mike_trout(id=1)
mock_player.team = mock_team
with patch('utils.autocomplete.player_service') as mock_service:
with patch("utils.autocomplete.player_service") as mock_service:
mock_service.search_players = AsyncMock(return_value=[mock_player])
choices = await player_autocomplete(mock_interaction, 'Trout')
choices = await player_autocomplete(mock_interaction, "Trout")
assert len(choices) == 1
assert choices[0].name == 'Mike Trout (CF - LAA)'
assert choices[0].value == 'Mike Trout'
assert choices[0].name == "Mike Trout (CF - LAA)"
assert choices[0].value == "Mike Trout"
@pytest.mark.asyncio
async def test_player_autocomplete_prioritizes_user_team(self, mock_interaction):
"""Test that user's team players are prioritized in autocomplete."""
user_team = TeamFactory.create(id=1, abbrev='POR', sname='Loggers')
other_team = TeamFactory.create(id=2, abbrev='LAA', sname='Angels')
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
other_team = TeamFactory.create(id=2, abbrev="LAA", sname="Angels")
# Create players - one from user's team, one from other team
user_player = PlayerFactory.mike_trout(id=1)
@ -71,32 +86,35 @@ class TestPlayerAutocomplete:
other_player.team = other_team
other_player.team_id = other_team.id
with patch('utils.autocomplete.player_service') as mock_service, \
patch('utils.autocomplete.get_user_major_league_team') as mock_get_team:
mock_service.search_players = AsyncMock(return_value=[other_player, user_player])
with (
patch("utils.autocomplete.player_service") as mock_service,
patch("utils.autocomplete.get_user_major_league_team") as mock_get_team,
):
mock_service.search_players = AsyncMock(
return_value=[other_player, user_player]
)
mock_get_team.return_value = user_team
choices = await player_autocomplete(mock_interaction, 'player')
choices = await player_autocomplete(mock_interaction, "player")
assert len(choices) == 2
# User's team player should be first
assert choices[0].name == 'Mike Trout (CF - POR)'
assert choices[1].name == 'Ronald Acuna Jr. (OF - LAA)'
assert choices[0].name == "Mike Trout (CF - POR)"
assert choices[1].name == "Ronald Acuna Jr. (OF - LAA)"
@pytest.mark.asyncio
async def test_player_autocomplete_short_input(self, mock_interaction):
"""Test player autocomplete with short input returns empty."""
choices = await player_autocomplete(mock_interaction, 'T')
choices = await player_autocomplete(mock_interaction, "T")
assert len(choices) == 0
@pytest.mark.asyncio
async def test_player_autocomplete_error_handling(self, mock_interaction):
"""Test player autocomplete error handling."""
with patch('utils.autocomplete.player_service') as mock_service:
with patch("utils.autocomplete.player_service") as mock_service:
mock_service.search_players.side_effect = Exception("API Error")
choices = await player_autocomplete(mock_interaction, 'Trout')
choices = await player_autocomplete(mock_interaction, "Trout")
assert len(choices) == 0
@ -114,35 +132,35 @@ class TestTeamAutocomplete:
async def test_team_autocomplete_success(self, mock_interaction):
"""Test successful team autocomplete."""
mock_teams = [
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'),
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'),
TeamFactory.create(id=4, abbrev='POR', sname='Loggers')
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"),
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"),
TeamFactory.create(id=4, abbrev="POR", sname="Loggers"),
]
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
choices = await team_autocomplete(mock_interaction, 'la')
choices = await team_autocomplete(mock_interaction, "la")
assert len(choices) == 3 # All teams with 'la' in abbrev or sname
assert any('LAA' in choice.name for choice in choices)
assert any('LAAMIL' in choice.name for choice in choices)
assert any('LAAAIL' in choice.name for choice in choices)
assert any("LAA" in choice.name for choice in choices)
assert any("LAAMIL" in choice.name for choice in choices)
assert any("LAAAIL" in choice.name for choice in choices)
@pytest.mark.asyncio
async def test_team_autocomplete_short_input(self, mock_interaction):
"""Test team autocomplete with very short input."""
choices = await team_autocomplete(mock_interaction, '')
choices = await team_autocomplete(mock_interaction, "")
assert len(choices) == 0
@pytest.mark.asyncio
async def test_team_autocomplete_error_handling(self, mock_interaction):
"""Test team autocomplete error handling."""
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season.side_effect = Exception("API Error")
choices = await team_autocomplete(mock_interaction, 'LAA')
choices = await team_autocomplete(mock_interaction, "LAA")
assert len(choices) == 0
@ -157,101 +175,197 @@ class TestMajorLeagueTeamAutocomplete:
return interaction
@pytest.mark.asyncio
async def test_major_league_team_autocomplete_filters_correctly(self, mock_interaction):
async def test_major_league_team_autocomplete_filters_correctly(
self, mock_interaction
):
"""Test that only major league teams are returned."""
# Create teams with different roster types
mock_teams = [
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), # ML
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'), # MiL
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'), # IL
TeamFactory.create(id=4, abbrev='FA', sname='Free Agents'), # FA
TeamFactory.create(id=5, abbrev='POR', sname='Loggers'), # ML
TeamFactory.create(id=6, abbrev='PORMIL', sname='Portland MiL'), # MiL
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), # ML
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"), # MiL
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"), # IL
TeamFactory.create(id=4, abbrev="FA", sname="Free Agents"), # FA
TeamFactory.create(id=5, abbrev="POR", sname="Loggers"), # ML
TeamFactory.create(id=6, abbrev="PORMIL", sname="Portland MiL"), # MiL
]
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
choices = await major_league_team_autocomplete(mock_interaction, 'l')
choices = await major_league_team_autocomplete(mock_interaction, "l")
# Should only return major league teams that match 'l' (LAA, POR)
choice_values = [choice.value for choice in choices]
assert 'LAA' in choice_values
assert 'POR' in choice_values
assert "LAA" in choice_values
assert "POR" in choice_values
assert len(choice_values) == 2
# Should NOT include MiL, IL, or FA teams
assert 'LAAMIL' not in choice_values
assert 'LAAAIL' not in choice_values
assert 'FA' not in choice_values
assert 'PORMIL' not in choice_values
assert "LAAMIL" not in choice_values
assert "LAAAIL" not in choice_values
assert "FA" not in choice_values
assert "PORMIL" not in choice_values
@pytest.mark.asyncio
async def test_major_league_team_autocomplete_matching(self, mock_interaction):
"""Test search matching on abbreviation and short name."""
mock_teams = [
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
TeamFactory.create(id=2, abbrev='LAD', sname='Dodgers'),
TeamFactory.create(id=3, abbrev='POR', sname='Loggers'),
TeamFactory.create(id=4, abbrev='BOS', sname='Red Sox'),
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
TeamFactory.create(id=2, abbrev="LAD", sname="Dodgers"),
TeamFactory.create(id=3, abbrev="POR", sname="Loggers"),
TeamFactory.create(id=4, abbrev="BOS", sname="Red Sox"),
]
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
# Test abbreviation matching
choices = await major_league_team_autocomplete(mock_interaction, 'la')
choices = await major_league_team_autocomplete(mock_interaction, "la")
assert len(choices) == 2 # LAA and LAD
choice_values = [choice.value for choice in choices]
assert 'LAA' in choice_values
assert 'LAD' in choice_values
assert "LAA" in choice_values
assert "LAD" in choice_values
# Test short name matching
choices = await major_league_team_autocomplete(mock_interaction, 'red')
choices = await major_league_team_autocomplete(mock_interaction, "red")
assert len(choices) == 1
assert choices[0].value == 'BOS'
assert choices[0].value == "BOS"
@pytest.mark.asyncio
async def test_major_league_team_autocomplete_short_input(self, mock_interaction):
"""Test major league team autocomplete with very short input."""
choices = await major_league_team_autocomplete(mock_interaction, '')
choices = await major_league_team_autocomplete(mock_interaction, "")
assert len(choices) == 0
@pytest.mark.asyncio
async def test_major_league_team_autocomplete_error_handling(self, mock_interaction):
async def test_major_league_team_autocomplete_error_handling(
self, mock_interaction
):
"""Test major league team autocomplete error handling."""
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season.side_effect = Exception("API Error")
choices = await major_league_team_autocomplete(mock_interaction, 'LAA')
choices = await major_league_team_autocomplete(mock_interaction, "LAA")
assert len(choices) == 0
@pytest.mark.asyncio
async def test_major_league_team_autocomplete_roster_type_detection(self, mock_interaction):
async def test_major_league_team_autocomplete_roster_type_detection(
self, mock_interaction
):
"""Test that roster type detection works correctly for edge cases."""
# Test edge cases like teams whose abbreviation ends in 'M' + 'IL'
mock_teams = [
TeamFactory.create(id=1, abbrev='BHM', sname='Iron'), # ML team ending in 'M'
TeamFactory.create(id=2, abbrev='BHMIL', sname='Iron IL'), # IL team (BHM + IL)
TeamFactory.create(id=3, abbrev='NYYMIL', sname='Staten Island RailRiders'), # MiL team (NYY + MIL)
TeamFactory.create(id=4, abbrev='NYY', sname='Yankees'), # ML team
TeamFactory.create(
id=1, abbrev="BHM", sname="Iron"
), # ML team ending in 'M'
TeamFactory.create(
id=2, abbrev="BHMIL", sname="Iron IL"
), # IL team (BHM + IL)
TeamFactory.create(
id=3, abbrev="NYYMIL", sname="Staten Island RailRiders"
), # MiL team (NYY + MIL)
TeamFactory.create(id=4, abbrev="NYY", sname="Yankees"), # ML team
]
with patch('utils.autocomplete.team_service') as mock_service:
with patch("utils.autocomplete.team_service") as mock_service:
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
choices = await major_league_team_autocomplete(mock_interaction, 'b')
choices = await major_league_team_autocomplete(mock_interaction, "b")
# Should only return major league teams
choice_values = [choice.value for choice in choices]
assert 'BHM' in choice_values # Major league team
assert 'BHMIL' not in choice_values # Should be detected as IL, not MiL
assert 'NYYMIL' not in choice_values # Minor league team
assert "BHM" in choice_values # Major league team
assert "BHMIL" not in choice_values # Should be detected as IL, not MiL
assert "NYYMIL" not in choice_values # Minor league team
# Verify the roster type detection is working
bhm_team = next(t for t in mock_teams if t.abbrev == 'BHM')
bhmil_team = next(t for t in mock_teams if t.abbrev == 'BHMIL')
nyymil_team = next(t for t in mock_teams if t.abbrev == 'NYYMIL')
bhm_team = next(t for t in mock_teams if t.abbrev == "BHM")
bhmil_team = next(t for t in mock_teams if t.abbrev == "BHMIL")
nyymil_team = next(t for t in mock_teams if t.abbrev == "NYYMIL")
assert bhm_team.roster_type() == RosterType.MAJOR_LEAGUE
assert bhmil_team.roster_type() == RosterType.INJURED_LIST
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
class TestGetCachedUserTeam:
"""Test the _get_cached_user_team caching helper.
Verifies that the cache avoids redundant get_user_major_league_team calls
on repeated invocations within the TTL window, and that expired entries are
re-fetched.
"""
@pytest.fixture(autouse=True)
def clear_cache(self):
"""Isolate each test from cache state left by other tests."""
utils.autocomplete._user_team_cache.clear()
yield
utils.autocomplete._user_team_cache.clear()
@pytest.fixture
def mock_interaction(self):
interaction = MagicMock()
interaction.user.id = 99999
return interaction
@pytest.mark.asyncio
async def test_caches_result_on_first_call(self, mock_interaction):
"""First call populates the cache; API function called exactly once."""
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
with patch(
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = user_team
from utils.autocomplete import _get_cached_user_team
result1 = await _get_cached_user_team(mock_interaction)
result2 = await _get_cached_user_team(mock_interaction)
assert result1 is user_team
assert result2 is user_team
# API called only once despite two invocations
mock_get_team.assert_called_once_with(99999)
@pytest.mark.asyncio
async def test_re_fetches_after_ttl_expires(self, mock_interaction):
"""Expired cache entries cause a fresh API call."""
import time
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
with patch(
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = user_team
from utils.autocomplete import _get_cached_user_team, _USER_TEAM_CACHE_TTL
# Seed the cache with a timestamp that is already expired
utils.autocomplete._user_team_cache[99999] = (
user_team,
time.time() - _USER_TEAM_CACHE_TTL - 1,
)
await _get_cached_user_team(mock_interaction)
# Should have called the API to refresh the stale entry
mock_get_team.assert_called_once_with(99999)
@pytest.mark.asyncio
async def test_caches_none_result(self, mock_interaction):
"""None (user has no team) is cached to avoid repeated API calls."""
with patch(
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = None
from utils.autocomplete import _get_cached_user_team
result1 = await _get_cached_user_team(mock_interaction)
result2 = await _get_cached_user_team(mock_interaction)
assert result1 is None
assert result2 is None
mock_get_team.assert_called_once()

View File

@ -4,16 +4,33 @@ Autocomplete Utilities
Shared autocomplete functions for Discord slash commands.
"""
from typing import List
import time
from typing import Dict, List, Optional, Tuple
import discord
from discord import app_commands
from config import get_config
from models.team import RosterType
from models.team import RosterType, Team
from services.player_service import player_service
from services.team_service import team_service
from utils.team_utils import get_user_major_league_team
# Cache for user team lookups: user_id -> (team, cached_at)
_user_team_cache: Dict[int, Tuple[Optional[Team], float]] = {}
_USER_TEAM_CACHE_TTL = 60 # seconds
async def _get_cached_user_team(interaction: discord.Interaction) -> Optional[Team]:
"""Return the user's major league team, cached for 60 seconds per user."""
user_id = interaction.user.id
if user_id in _user_team_cache:
team, cached_at = _user_team_cache[user_id]
if time.time() - cached_at < _USER_TEAM_CACHE_TTL:
return team
team = await get_user_major_league_team(user_id)
_user_team_cache[user_id] = (team, time.time())
return team
async def player_autocomplete(
interaction: discord.Interaction, current: str
@ -34,12 +51,12 @@ async def player_autocomplete(
return []
try:
# Get user's team for prioritization
user_team = await get_user_major_league_team(interaction.user.id)
# Get user's team for prioritization (cached per user, 60s TTL)
user_team = await _get_cached_user_team(interaction)
# Search for players using the search endpoint
players = await player_service.search_players(
current, limit=50, season=get_config().sba_season
current, limit=25, season=get_config().sba_season
)
# Separate players by team (user's team vs others)

View File

@ -188,9 +188,11 @@ class CacheManager:
try:
pattern = f"{prefix}:*"
keys = await client.keys(pattern)
if keys:
deleted = await client.delete(*keys)
keys_to_delete = []
async for key in client.scan_iter(match=pattern):
keys_to_delete.append(key)
if keys_to_delete:
deleted = await client.delete(*keys_to_delete)
logger.info(f"Cleared {deleted} cache keys with prefix '{prefix}'")
return deleted
except Exception as e:

View File

@ -11,29 +11,29 @@ from functools import wraps
from typing import List, Optional, Callable, Any
from utils.logging import set_discord_context, get_contextual_logger
cache_logger = logging.getLogger(f'{__name__}.CacheDecorators')
period_check_logger = logging.getLogger(f'{__name__}.PeriodCheckDecorators')
cache_logger = logging.getLogger(f"{__name__}.CacheDecorators")
period_check_logger = logging.getLogger(f"{__name__}.PeriodCheckDecorators")
def logged_command(
command_name: Optional[str] = None,
command_name: Optional[str] = None,
log_params: bool = True,
exclude_params: Optional[List[str]] = None
exclude_params: Optional[List[str]] = None,
):
"""
Decorator for Discord commands that adds comprehensive logging.
This decorator automatically handles:
- Setting Discord context with interaction details
- Starting/ending operation timing
- Logging command start/completion/failure
- Preserving function metadata and signature
Args:
command_name: Override command name (defaults to function name with slashes)
log_params: Whether to log command parameters (default: True)
exclude_params: List of parameter names to exclude from logging
Example:
@logged_command("/roster", exclude_params=["sensitive_data"])
async def team_roster(self, interaction, team_name: str, season: int = None):
@ -42,57 +42,65 @@ def logged_command(
players = await team_service.get_roster(team.id, season)
embed = create_roster_embed(team, players)
await interaction.followup.send(embed=embed)
Side Effects:
- Automatically sets Discord context for all subsequent log entries
- Creates trace_id for request correlation
- Logs command execution timing and results
- Re-raises all exceptions after logging (preserves original behavior)
Requirements:
- The decorated class must have a 'logger' attribute, or one will be created
- Function must be an async method with (self, interaction, ...) signature
- Preserves Discord.py command registration compatibility
"""
def decorator(func):
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
exclude_set = set(exclude_params or [])
@wraps(func)
async def wrapper(self, interaction, *args, **kwargs):
# Auto-detect command name if not provided
cmd_name = command_name or f"/{func.__name__.replace('_', '-')}"
# Build context with safe parameter logging
context = {"command": cmd_name}
if log_params:
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
exclude_set = set(exclude_params or [])
for i, (name, value) in enumerate(zip(param_names, args)):
if name not in exclude_set:
context[f"param_{name}"] = value
set_discord_context(interaction=interaction, **context)
# Get logger from the class instance or create one
logger = getattr(self, 'logger', get_contextual_logger(f'{self.__class__.__module__}.{self.__class__.__name__}'))
logger = getattr(
self,
"logger",
get_contextual_logger(
f"{self.__class__.__module__}.{self.__class__.__name__}"
),
)
trace_id = logger.start_operation(f"{func.__name__}_command")
try:
logger.info(f"{cmd_name} command started")
result = await func(self, interaction, *args, **kwargs)
logger.info(f"{cmd_name} command completed successfully")
logger.end_operation(trace_id, "completed")
return result
except Exception as e:
logger.error(f"{cmd_name} command failed", error=e)
logger.end_operation(trace_id, "failed")
# Re-raise to maintain original exception handling behavior
raise
# Preserve signature for Discord.py command registration
wrapper.__signature__ = inspect.signature(func) # type: ignore
wrapper.__signature__ = sig # type: ignore
return wrapper
return decorator
@ -122,6 +130,7 @@ def requires_draft_period(func):
- Should be placed before @logged_command decorator
- league_service must be available via import
"""
@wraps(func)
async def wrapper(self, interaction, *args, **kwargs):
# Import here to avoid circular imports
@ -133,10 +142,12 @@ def requires_draft_period(func):
current = await league_service.get_current_state()
if not current:
period_check_logger.error("Could not retrieve league state for draft period check")
period_check_logger.error(
"Could not retrieve league state for draft period check"
)
embed = EmbedTemplate.error(
"System Error",
"Could not verify draft period status. Please try again later."
"Could not verify draft period status. Please try again later.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -148,12 +159,12 @@ def requires_draft_period(func):
extra={
"user_id": interaction.user.id,
"command": func.__name__,
"current_week": current.week
}
"current_week": current.week,
},
)
embed = EmbedTemplate.error(
"Not Available",
"Draft commands are only available in the offseason."
"Draft commands are only available in the offseason.",
)
await interaction.response.send_message(embed=embed, ephemeral=True)
return
@ -161,7 +172,7 @@ def requires_draft_period(func):
# Week <= 0, allow command to proceed
period_check_logger.debug(
f"Draft period check passed - week {current.week}",
extra={"user_id": interaction.user.id, "command": func.__name__}
extra={"user_id": interaction.user.id, "command": func.__name__},
)
return await func(self, interaction, *args, **kwargs)
@ -169,7 +180,7 @@ def requires_draft_period(func):
period_check_logger.error(
f"Error in draft period check: {e}",
exc_info=True,
extra={"user_id": interaction.user.id, "command": func.__name__}
extra={"user_id": interaction.user.id, "command": func.__name__},
)
# Re-raise to let error handling in logged_command handle it
raise
@ -182,110 +193,115 @@ def requires_draft_period(func):
def cached_api_call(ttl: Optional[int] = None, cache_key_suffix: str = ""):
"""
Decorator to add Redis caching to service methods that return List[T].
This decorator will:
1. Check cache for existing data using generated key
2. Return cached data if found
3. Execute original method if cache miss
4. Cache the result for future calls
Args:
ttl: Time-to-live override in seconds (uses service default if None)
cache_key_suffix: Additional suffix for cache key differentiation
Usage:
@cached_api_call(ttl=600, cache_key_suffix="by_season")
async def get_teams_by_season(self, season: int) -> List[Team]:
# Original method implementation
Requirements:
- Method must be async
- Method must return List[T] where T is a model
- Class must have self.cache (CacheManager instance)
- Class must have self._generate_cache_key, self._get_cached_items, self._cache_items methods
"""
def decorator(func: Callable) -> Callable:
sig = inspect.signature(func)
@wraps(func)
async def wrapper(self, *args, **kwargs) -> List[Any]:
# Check if caching is available (service has cache manager)
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
# No caching available, execute original method
return await func(self, *args, **kwargs)
# Generate cache key from method name, args, and kwargs
method_name = f"{func.__name__}{cache_key_suffix}"
# Convert args and kwargs to params list for consistent cache key
sig = inspect.signature(func)
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
# Skip 'self' and convert to params format
params = []
for param_name, param_value in bound_args.arguments.items():
if param_name != 'self' and param_value is not None:
if param_name != "self" and param_value is not None:
params.append((param_name, param_value))
cache_key = self._generate_cache_key(method_name, params)
# Try to get from cache
if hasattr(self, '_get_cached_items'):
if hasattr(self, "_get_cached_items"):
cached_result = await self._get_cached_items(cache_key)
if cached_result is not None:
cache_logger.debug(f"Cache hit: {method_name}")
return cached_result
# Cache miss - execute original method
cache_logger.debug(f"Cache miss: {method_name}")
result = await func(self, *args, **kwargs)
# Cache the result if we have items and caching methods
if result and hasattr(self, '_cache_items'):
if result and hasattr(self, "_cache_items"):
await self._cache_items(cache_key, result, ttl)
cache_logger.debug(f"Cached {len(result)} items for {method_name}")
return result
return wrapper
return decorator
def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
"""
Decorator to add Redis caching to service methods that return Optional[T].
Similar to cached_api_call but for methods returning a single model instance.
Args:
ttl: Time-to-live override in seconds
cache_key_suffix: Additional suffix for cache key differentiation
Usage:
@cached_single_item(ttl=300, cache_key_suffix="by_id")
async def get_player(self, player_id: int) -> Optional[Player]:
# Original method implementation
"""
def decorator(func: Callable) -> Callable:
sig = inspect.signature(func)
@wraps(func)
async def wrapper(self, *args, **kwargs) -> Optional[Any]:
# Check if caching is available
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
return await func(self, *args, **kwargs)
# Generate cache key
method_name = f"{func.__name__}{cache_key_suffix}"
sig = inspect.signature(func)
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
params = []
for param_name, param_value in bound_args.arguments.items():
if param_name != 'self' and param_value is not None:
if param_name != "self" and param_value is not None:
params.append((param_name, param_value))
cache_key = self._generate_cache_key(method_name, params)
# Try cache first
try:
cached_data = await self.cache.get(cache_key)
@ -293,12 +309,14 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
cache_logger.debug(f"Cache hit: {method_name}")
return self.model_class.from_api_data(cached_data)
except Exception as e:
cache_logger.warning(f"Error reading single item cache for {cache_key}: {e}")
cache_logger.warning(
f"Error reading single item cache for {cache_key}: {e}"
)
# Cache miss - execute original method
cache_logger.debug(f"Cache miss: {method_name}")
result = await func(self, *args, **kwargs)
# Cache the single result
if result:
try:
@ -306,43 +324,54 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
await self.cache.set(cache_key, cache_data, ttl)
cache_logger.debug(f"Cached single item for {method_name}")
except Exception as e:
cache_logger.warning(f"Error caching single item for {cache_key}: {e}")
cache_logger.warning(
f"Error caching single item for {cache_key}: {e}"
)
return result
return wrapper
return decorator
def cache_invalidate(*cache_patterns: str):
"""
Decorator to invalidate cache entries when data is modified.
Args:
cache_patterns: Cache key patterns to invalidate (supports prefix matching)
Usage:
@cache_invalidate("players_by_team", "teams_by_season")
async def update_player(self, player_id: int, updates: dict) -> Optional[Player]:
# Original method implementation
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, *args, **kwargs):
# Execute original method first
result = await func(self, *args, **kwargs)
# Invalidate specified cache patterns
if hasattr(self, 'cache'):
if hasattr(self, "cache"):
for pattern in cache_patterns:
try:
cleared = await self.cache.clear_prefix(f"sba:{self.endpoint}_{pattern}")
cleared = await self.cache.clear_prefix(
f"sba:{self.endpoint}_{pattern}"
)
if cleared > 0:
cache_logger.info(f"Invalidated {cleared} cache entries for pattern: {pattern}")
cache_logger.info(
f"Invalidated {cleared} cache entries for pattern: {pattern}"
)
except Exception as e:
cache_logger.warning(f"Error invalidating cache pattern {pattern}: {e}")
cache_logger.warning(
f"Error invalidating cache pattern {pattern}: {e}"
)
return result
return wrapper
return decorator
return decorator

View File

@ -24,6 +24,8 @@ JSONValue = Union[
str, int, float, bool, None, dict[str, Any], list[Any] # nested object # arrays
]
_SERIALIZABLE_TYPES = (str, int, float, bool, type(None))
class JSONFormatter(logging.Formatter):
"""Custom JSON formatter for structured file logging."""
@ -93,11 +95,11 @@ class JSONFormatter(logging.Formatter):
extra_data = {}
for key, value in record.__dict__.items():
if key not in excluded_keys:
# Ensure JSON serializable
try:
json.dumps(value)
if isinstance(value, _SERIALIZABLE_TYPES) or isinstance(
value, (list, dict)
):
extra_data[key] = value
except (TypeError, ValueError):
else:
extra_data[key] = str(value)
if extra_data: