Merge pull request 'Merge next-release into main' (#111) from next-release into main
All checks were successful
Build Docker Image / build (push) Successful in 1m2s
All checks were successful
Build Docker Image / build (push) Successful in 1m2s
Reviewed-on: #111
This commit is contained in:
commit
fd24a41422
@ -1,9 +1,9 @@
|
||||
# Gitea Actions: Docker Build, Push, and Notify
|
||||
#
|
||||
# CI/CD pipeline for Major Domo Discord Bot:
|
||||
# - Builds Docker images on every push/PR
|
||||
# - Builds Docker images on merge to main/next-release
|
||||
# - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges
|
||||
# - Supports multi-channel releases: stable (main), rc (next-release), dev (PRs)
|
||||
# - Supports multi-channel releases: stable (main), rc (next-release)
|
||||
# - Pushes to Docker Hub and creates git tag on main
|
||||
# - Sends Discord notifications on success/failure
|
||||
|
||||
@ -14,9 +14,6 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- next-release
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -218,5 +218,6 @@ __marimo__/
|
||||
|
||||
# Project-specific
|
||||
data/
|
||||
storage/
|
||||
production_logs/
|
||||
*.json
|
||||
|
||||
25
bot.py
25
bot.py
@ -42,7 +42,9 @@ def setup_logging():
|
||||
|
||||
# JSON file handler - structured logging for monitoring and analysis
|
||||
json_handler = RotatingFileHandler(
|
||||
"logs/discord_bot_v2.json", maxBytes=5 * 1024 * 1024, backupCount=5 # 5MB
|
||||
"logs/discord_bot_v2.json",
|
||||
maxBytes=5 * 1024 * 1024,
|
||||
backupCount=5, # 5MB
|
||||
)
|
||||
json_handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(json_handler)
|
||||
@ -120,28 +122,11 @@ class SBABot(commands.Bot):
|
||||
|
||||
self.maintenance_mode: bool = False
|
||||
self.logger = logging.getLogger("discord_bot_v2")
|
||||
self.maintenance_mode: bool = False
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Called when the bot is starting up."""
|
||||
self.logger.info("Setting up bot...")
|
||||
|
||||
@self.tree.interaction_check
|
||||
async def maintenance_check(interaction: discord.Interaction) -> bool:
|
||||
"""Block non-admin users when maintenance mode is enabled."""
|
||||
if not self.maintenance_mode:
|
||||
return True
|
||||
if (
|
||||
isinstance(interaction.user, discord.Member)
|
||||
and interaction.user.guild_permissions.administrator
|
||||
):
|
||||
return True
|
||||
await interaction.response.send_message(
|
||||
"🔧 The bot is currently in maintenance mode. Please try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return False
|
||||
|
||||
# Load command packages
|
||||
await self._load_command_packages()
|
||||
|
||||
@ -443,7 +428,9 @@ async def health_command(interaction: discord.Interaction):
|
||||
embed.add_field(name="Bot Status", value="✅ Online", inline=True)
|
||||
embed.add_field(name="API Status", value=api_status, inline=True)
|
||||
embed.add_field(name="Guilds", value=str(guild_count), inline=True)
|
||||
embed.add_field(name="Latency", value=f"{bot.latency*1000:.1f}ms", inline=True)
|
||||
embed.add_field(
|
||||
name="Latency", value=f"{bot.latency * 1000:.1f}ms", inline=True
|
||||
)
|
||||
|
||||
if bot.user:
|
||||
embed.set_footer(
|
||||
|
||||
@ -568,14 +568,9 @@ class AdminCommands(commands.Cog):
|
||||
return
|
||||
|
||||
try:
|
||||
# Clear all messages from the channel
|
||||
deleted_count = 0
|
||||
async for message in live_scores_channel.history(limit=100):
|
||||
try:
|
||||
await message.delete()
|
||||
deleted_count += 1
|
||||
except discord.NotFound:
|
||||
pass # Message already deleted
|
||||
# Clear all messages from the channel using bulk delete
|
||||
deleted_messages = await live_scores_channel.purge(limit=100)
|
||||
deleted_count = len(deleted_messages)
|
||||
|
||||
self.logger.info(f"Cleared {deleted_count} messages from #live-sba-scores")
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Scorebug Commands
|
||||
Implements commands for publishing and displaying live game scorebugs from Google Sheets scorecards.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord import app_commands
|
||||
@ -73,12 +74,18 @@ class ScorebugCommands(commands.Cog):
|
||||
return
|
||||
|
||||
# Get team data for display
|
||||
away_team = None
|
||||
home_team = None
|
||||
if scorebug_data.away_team_id:
|
||||
away_team = await team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.home_team_id:
|
||||
home_team = await team_service.get_team(scorebug_data.home_team_id)
|
||||
away_team, home_team = await asyncio.gather(
|
||||
(
|
||||
team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.away_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
(
|
||||
team_service.get_team(scorebug_data.home_team_id)
|
||||
if scorebug_data.home_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
)
|
||||
|
||||
# Format scorecard link
|
||||
away_abbrev = away_team.abbrev if away_team else "AWAY"
|
||||
@ -86,7 +93,7 @@ class ScorebugCommands(commands.Cog):
|
||||
scorecard_link = f"[{away_abbrev} @ {home_abbrev}]({url})"
|
||||
|
||||
# Store the scorecard in the tracker
|
||||
self.scorecard_tracker.publish_scorecard(
|
||||
await self.scorecard_tracker.publish_scorecard(
|
||||
text_channel_id=interaction.channel_id, # type: ignore
|
||||
sheet_url=url,
|
||||
publisher_id=interaction.user.id,
|
||||
@ -157,7 +164,7 @@ class ScorebugCommands(commands.Cog):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Check if a scorecard is published in this channel
|
||||
sheet_url = self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
|
||||
sheet_url = await self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore
|
||||
|
||||
if not sheet_url:
|
||||
embed = EmbedTemplate.error(
|
||||
@ -179,12 +186,18 @@ class ScorebugCommands(commands.Cog):
|
||||
)
|
||||
|
||||
# Get team data
|
||||
away_team = None
|
||||
home_team = None
|
||||
if scorebug_data.away_team_id:
|
||||
away_team = await team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.home_team_id:
|
||||
home_team = await team_service.get_team(scorebug_data.home_team_id)
|
||||
away_team, home_team = await asyncio.gather(
|
||||
(
|
||||
team_service.get_team(scorebug_data.away_team_id)
|
||||
if scorebug_data.away_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
(
|
||||
team_service.get_team(scorebug_data.home_team_id)
|
||||
if scorebug_data.home_team_id
|
||||
else asyncio.sleep(0)
|
||||
),
|
||||
)
|
||||
|
||||
# Create scorebug embed using shared utility
|
||||
embed = create_scorebug_embed(
|
||||
@ -194,7 +207,7 @@ class ScorebugCommands(commands.Cog):
|
||||
await interaction.edit_original_response(content=None, embed=embed)
|
||||
|
||||
# Update timestamp in tracker
|
||||
self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
|
||||
await self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore
|
||||
|
||||
except SheetsException as e:
|
||||
embed = EmbedTemplate.error(
|
||||
|
||||
@ -24,7 +24,7 @@ class ScorecardTracker:
|
||||
- Timestamp tracking for monitoring
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/scorecards.json"):
|
||||
def __init__(self, data_file: str = "storage/scorecards.json"):
|
||||
"""
|
||||
Initialize the scorecard tracker.
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ The injury rating format (#p##) encodes both games played and rating:
|
||||
- Remaining: Injury rating (p70, p65, p60, p50, p40, p30, p20)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
import discord
|
||||
@ -114,16 +115,14 @@ class InjuryGroup(app_commands.Group):
|
||||
"""Roll for injury using 3d6 dice and injury tables."""
|
||||
await interaction.response.defer()
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
@ -530,16 +529,14 @@ class InjuryGroup(app_commands.Group):
|
||||
await interaction.followup.send(embed=embed, ephemeral=True)
|
||||
return
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
@ -717,16 +714,14 @@ class InjuryGroup(app_commands.Group):
|
||||
|
||||
await interaction.response.defer()
|
||||
|
||||
# Get current season
|
||||
current = await league_service.get_current_state()
|
||||
# Get current season and search for player in parallel
|
||||
current, players = await asyncio.gather(
|
||||
league_service.get_current_state(),
|
||||
player_service.search_players(player_name, limit=10),
|
||||
)
|
||||
if not current:
|
||||
raise BotException("Failed to get current season information")
|
||||
|
||||
# Search for player using the search endpoint (more reliable than name param)
|
||||
players = await player_service.search_players(
|
||||
player_name, limit=10, season=current.season
|
||||
)
|
||||
|
||||
if not players:
|
||||
embed = EmbedTemplate.error(
|
||||
title="Player Not Found",
|
||||
|
||||
@ -3,6 +3,7 @@ League Schedule Commands
|
||||
|
||||
Implements slash commands for displaying game schedules and results.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate
|
||||
|
||||
class ScheduleCommands(commands.Cog):
|
||||
"""League schedule command handlers."""
|
||||
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands')
|
||||
|
||||
@discord.app_commands.command(
|
||||
name="schedule",
|
||||
description="Display game schedule"
|
||||
)
|
||||
self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands")
|
||||
|
||||
@discord.app_commands.command(name="schedule", description="Display game schedule")
|
||||
@discord.app_commands.describe(
|
||||
season="Season to show schedule for (defaults to current season)",
|
||||
week="Week number to show (optional)",
|
||||
team="Team abbreviation to filter by (optional)"
|
||||
team="Team abbreviation to filter by (optional)",
|
||||
)
|
||||
@requires_team()
|
||||
@logged_command("/schedule")
|
||||
@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog):
|
||||
interaction: discord.Interaction,
|
||||
season: Optional[int] = None,
|
||||
week: Optional[int] = None,
|
||||
team: Optional[str] = None
|
||||
team: Optional[str] = None,
|
||||
):
|
||||
"""Display game schedule for a week or team."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
search_season = season or get_config().sba_season
|
||||
|
||||
|
||||
if team:
|
||||
# Show team schedule
|
||||
await self._show_team_schedule(interaction, search_season, team, week)
|
||||
@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog):
|
||||
else:
|
||||
# Show recent/upcoming games
|
||||
await self._show_current_schedule(interaction, search_season)
|
||||
|
||||
|
||||
# @discord.app_commands.command(
|
||||
# name="results",
|
||||
# description="Display recent game results"
|
||||
@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog):
|
||||
# ):
|
||||
# """Display recent game results."""
|
||||
# await interaction.response.defer()
|
||||
|
||||
|
||||
# search_season = season or get_config().sba_season
|
||||
|
||||
|
||||
# if week:
|
||||
# # Show specific week results
|
||||
# games = await schedule_service.get_week_schedule(search_season, week)
|
||||
# completed_games = [game for game in games if game.is_completed]
|
||||
|
||||
|
||||
# if not completed_games:
|
||||
# await interaction.followup.send(
|
||||
# f"❌ No completed games found for season {search_season}, week {week}.",
|
||||
# ephemeral=True
|
||||
# )
|
||||
# return
|
||||
|
||||
|
||||
# embed = await self._create_week_results_embed(completed_games, search_season, week)
|
||||
# await interaction.followup.send(embed=embed)
|
||||
# else:
|
||||
# # Show recent results
|
||||
# recent_games = await schedule_service.get_recent_games(search_season)
|
||||
|
||||
|
||||
# if not recent_games:
|
||||
# await interaction.followup.send(
|
||||
# f"❌ No recent games found for season {search_season}.",
|
||||
# ephemeral=True
|
||||
# )
|
||||
# return
|
||||
|
||||
|
||||
# embed = await self._create_recent_results_embed(recent_games, search_season)
|
||||
# await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int):
|
||||
|
||||
async def _show_week_schedule(
|
||||
self, interaction: discord.Interaction, season: int, week: int
|
||||
):
|
||||
"""Show schedule for a specific week."""
|
||||
self.logger.debug("Fetching week schedule", season=season, week=week)
|
||||
|
||||
|
||||
games = await schedule_service.get_week_schedule(season, week)
|
||||
|
||||
|
||||
if not games:
|
||||
await interaction.followup.send(
|
||||
f"❌ No games found for season {season}, week {week}.",
|
||||
ephemeral=True
|
||||
f"❌ No games found for season {season}, week {week}.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
embed = await self._create_week_schedule_embed(games, season, week)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]):
|
||||
|
||||
async def _show_team_schedule(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
season: int,
|
||||
team: str,
|
||||
week: Optional[int],
|
||||
):
|
||||
"""Show schedule for a specific team."""
|
||||
self.logger.debug("Fetching team schedule", season=season, team=team, week=week)
|
||||
|
||||
|
||||
if week:
|
||||
# Show team games for specific week
|
||||
week_games = await schedule_service.get_week_schedule(season, week)
|
||||
team_games = [
|
||||
game for game in week_games
|
||||
if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper()
|
||||
game
|
||||
for game in week_games
|
||||
if game.away_team.abbrev.upper() == team.upper()
|
||||
or game.home_team.abbrev.upper() == team.upper()
|
||||
]
|
||||
else:
|
||||
# Show team's recent/upcoming games (limited weeks)
|
||||
team_games = await schedule_service.get_team_schedule(season, team, weeks=4)
|
||||
|
||||
|
||||
if not team_games:
|
||||
week_text = f" for week {week}" if week else ""
|
||||
await interaction.followup.send(
|
||||
f"❌ No games found for team '{team}'{week_text} in season {season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
embed = await self._create_team_schedule_embed(team_games, season, team, week)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _show_current_schedule(self, interaction: discord.Interaction, season: int):
|
||||
|
||||
async def _show_current_schedule(
|
||||
self, interaction: discord.Interaction, season: int
|
||||
):
|
||||
"""Show current schedule overview with recent and upcoming games."""
|
||||
self.logger.debug("Fetching current schedule overview", season=season)
|
||||
|
||||
|
||||
# Get both recent and upcoming games
|
||||
recent_games, upcoming_games = await asyncio.gather(
|
||||
schedule_service.get_recent_games(season, weeks_back=1),
|
||||
schedule_service.get_upcoming_games(season, weeks_ahead=1)
|
||||
schedule_service.get_upcoming_games(season),
|
||||
)
|
||||
|
||||
|
||||
if not recent_games and not upcoming_games:
|
||||
await interaction.followup.send(
|
||||
f"❌ No recent or upcoming games found for season {season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season)
|
||||
|
||||
embed = await self._create_current_schedule_embed(
|
||||
recent_games, upcoming_games, season
|
||||
)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed:
|
||||
|
||||
async def _create_week_schedule_embed(
|
||||
self, games, season: int, week: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for a week's schedule."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 Week {week} Schedule - Season {season}",
|
||||
color=EmbedColors.PRIMARY
|
||||
color=EmbedColors.PRIMARY,
|
||||
)
|
||||
|
||||
|
||||
# Group games by series
|
||||
series_games = schedule_service.group_games_by_series(games)
|
||||
|
||||
|
||||
schedule_lines = []
|
||||
for (team1, team2), series in series_games.items():
|
||||
series_summary = await self._format_series_summary(series)
|
||||
schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}")
|
||||
|
||||
|
||||
if schedule_lines:
|
||||
embed.add_field(
|
||||
name="Games",
|
||||
value="\n\n".join(schedule_lines),
|
||||
inline=False
|
||||
name="Games", value="\n\n".join(schedule_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
# Add week summary
|
||||
completed = len([g for g in games if g.is_completed])
|
||||
total = len(games)
|
||||
embed.add_field(
|
||||
name="Week Progress",
|
||||
value=f"{completed}/{total} games completed",
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Week {week}")
|
||||
return embed
|
||||
|
||||
async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed:
|
||||
|
||||
async def _create_team_schedule_embed(
|
||||
self, games, season: int, team: str, week: Optional[int]
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for a team's schedule."""
|
||||
week_text = f" - Week {week}" if week else ""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 {team.upper()} Schedule{week_text} - Season {season}",
|
||||
color=EmbedColors.PRIMARY
|
||||
color=EmbedColors.PRIMARY,
|
||||
)
|
||||
|
||||
|
||||
# Separate completed and upcoming games
|
||||
completed_games = [g for g in games if g.is_completed]
|
||||
upcoming_games = [g for g in games if not g.is_completed]
|
||||
|
||||
|
||||
if completed_games:
|
||||
recent_lines = []
|
||||
for game in completed_games[-5:]: # Last 5 games
|
||||
result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L"
|
||||
result = (
|
||||
"W"
|
||||
if game.winner and game.winner.abbrev.upper() == team.upper()
|
||||
else "L"
|
||||
)
|
||||
if game.home_team.abbrev.upper() == team.upper():
|
||||
# Team was home
|
||||
recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})")
|
||||
recent_lines.append(
|
||||
f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})"
|
||||
)
|
||||
else:
|
||||
# Team was away
|
||||
recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})")
|
||||
|
||||
# Team was away
|
||||
recent_lines.append(
|
||||
f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})"
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Recent Results",
|
||||
value="\n".join(recent_lines) if recent_lines else "No recent games",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
if upcoming_games:
|
||||
upcoming_lines = []
|
||||
for game in upcoming_games[:5]: # Next 5 games
|
||||
if game.home_team.abbrev.upper() == team.upper():
|
||||
# Team is home
|
||||
upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}")
|
||||
upcoming_lines.append(
|
||||
f"Week {game.week}: vs {game.away_team.abbrev}"
|
||||
)
|
||||
else:
|
||||
# Team is away
|
||||
upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}")
|
||||
|
||||
upcoming_lines.append(
|
||||
f"Week {game.week}: @ {game.home_team.abbrev}"
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Upcoming Games",
|
||||
value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games",
|
||||
inline=False
|
||||
value=(
|
||||
"\n".join(upcoming_lines) if upcoming_lines else "No upcoming games"
|
||||
),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • {team.upper()}")
|
||||
return embed
|
||||
|
||||
async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed:
|
||||
|
||||
async def _create_week_results_embed(
|
||||
self, games, season: int, week: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for week results."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"🏆 Week {week} Results - Season {season}",
|
||||
color=EmbedColors.SUCCESS
|
||||
title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
# Group by series and show results
|
||||
series_games = schedule_service.group_games_by_series(games)
|
||||
|
||||
|
||||
results_lines = []
|
||||
for (team1, team2), series in series_games.items():
|
||||
# Count wins for each team
|
||||
team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1])
|
||||
team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2])
|
||||
|
||||
team1_wins = len(
|
||||
[g for g in series if g.winner and g.winner.abbrev == team1]
|
||||
)
|
||||
team2_wins = len(
|
||||
[g for g in series if g.winner and g.winner.abbrev == team2]
|
||||
)
|
||||
|
||||
# Series result
|
||||
series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**"
|
||||
|
||||
|
||||
# Individual games
|
||||
game_details = []
|
||||
for game in series:
|
||||
if game.series_game_display:
|
||||
game_details.append(f"{game.series_game_display}: {game.matchup_display}")
|
||||
|
||||
game_details.append(
|
||||
f"{game.series_game_display}: {game.matchup_display}"
|
||||
)
|
||||
|
||||
results_lines.append(f"{series_result}\n" + "\n".join(game_details))
|
||||
|
||||
|
||||
if results_lines:
|
||||
embed.add_field(
|
||||
name="Series Results",
|
||||
value="\n\n".join(results_lines),
|
||||
inline=False
|
||||
name="Series Results", value="\n\n".join(results_lines), inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Week {week} • {len(games)} games completed")
|
||||
|
||||
embed.set_footer(
|
||||
text=f"Season {season} • Week {week} • {len(games)} games completed"
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
async def _create_recent_results_embed(self, games, season: int) -> discord.Embed:
|
||||
"""Create an embed for recent results."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"🏆 Recent Results - Season {season}",
|
||||
color=EmbedColors.SUCCESS
|
||||
title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
# Show most recent games
|
||||
recent_lines = []
|
||||
for game in games[:10]: # Show last 10 games
|
||||
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
if recent_lines:
|
||||
embed.add_field(
|
||||
name="Latest Games",
|
||||
value="\n".join(recent_lines),
|
||||
inline=False
|
||||
name="Latest Games", value="\n".join(recent_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season} • Last {len(games)} completed games")
|
||||
return embed
|
||||
|
||||
async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed:
|
||||
|
||||
async def _create_current_schedule_embed(
|
||||
self, recent_games, upcoming_games, season: int
|
||||
) -> discord.Embed:
|
||||
"""Create an embed for current schedule overview."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📅 Current Schedule - Season {season}",
|
||||
color=EmbedColors.INFO
|
||||
title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO
|
||||
)
|
||||
|
||||
|
||||
if recent_games:
|
||||
recent_lines = []
|
||||
for game in recent_games[:5]:
|
||||
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name="Recent Results",
|
||||
value="\n".join(recent_lines),
|
||||
inline=False
|
||||
name="Recent Results", value="\n".join(recent_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
if upcoming_games:
|
||||
upcoming_lines = []
|
||||
for game in upcoming_games[:5]:
|
||||
upcoming_lines.append(f"Week {game.week}: {game.matchup_display}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name="Upcoming Games",
|
||||
value="\n".join(upcoming_lines),
|
||||
inline=False
|
||||
name="Upcoming Games", value="\n".join(upcoming_lines), inline=False
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text=f"Season {season}")
|
||||
return embed
|
||||
|
||||
|
||||
async def _format_series_summary(self, series) -> str:
|
||||
"""Format a series summary."""
|
||||
lines = []
|
||||
for game in series:
|
||||
game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display
|
||||
game_display = (
|
||||
f"{game.series_game_display}: {game.matchup_display}"
|
||||
if game.series_game_display
|
||||
else game.matchup_display
|
||||
)
|
||||
lines.append(game_display)
|
||||
|
||||
|
||||
return "\n".join(lines) if lines else "No games"
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
"""Load the schedule commands cog."""
|
||||
await bot.add_cog(ScheduleCommands(bot))
|
||||
await bot.add_cog(ScheduleCommands(bot))
|
||||
|
||||
@ -5,6 +5,7 @@ Implements the /submit-scorecard command for submitting Google Sheets
|
||||
scorecards with play-by-play data, pitching decisions, and game results.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
import discord
|
||||
@ -107,11 +108,13 @@ class SubmitScorecardCommands(commands.Cog):
|
||||
content="🔍 Looking up teams and managers..."
|
||||
)
|
||||
|
||||
away_team = await team_service.get_team_by_abbrev(
|
||||
setup_data["away_team_abbrev"], current.season
|
||||
)
|
||||
home_team = await team_service.get_team_by_abbrev(
|
||||
setup_data["home_team_abbrev"], current.season
|
||||
away_team, home_team = await asyncio.gather(
|
||||
team_service.get_team_by_abbrev(
|
||||
setup_data["away_team_abbrev"], current.season
|
||||
),
|
||||
team_service.get_team_by_abbrev(
|
||||
setup_data["home_team_abbrev"], current.season
|
||||
),
|
||||
)
|
||||
|
||||
if not away_team or not home_team:
|
||||
@ -235,9 +238,13 @@ class SubmitScorecardCommands(commands.Cog):
|
||||
decision["game_num"] = setup_data["game_num"]
|
||||
|
||||
# Validate WP and LP exist and fetch Player objects
|
||||
wp, lp, sv, holders, _blown_saves = (
|
||||
await decision_service.find_winning_losing_pitchers(decisions_data)
|
||||
)
|
||||
(
|
||||
wp,
|
||||
lp,
|
||||
sv,
|
||||
holders,
|
||||
_blown_saves,
|
||||
) = await decision_service.find_winning_losing_pitchers(decisions_data)
|
||||
|
||||
if wp is None or lp is None:
|
||||
await interaction.edit_original_response(
|
||||
|
||||
@ -3,13 +3,14 @@ Soak Tracker
|
||||
|
||||
Provides persistent tracking of "soak" mentions using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.SoakTracker')
|
||||
logger = logging.getLogger(f"{__name__}.SoakTracker")
|
||||
|
||||
|
||||
class SoakTracker:
|
||||
@ -22,7 +23,7 @@ class SoakTracker:
|
||||
- Time-based calculations for disappointment tiers
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/soak_data.json"):
|
||||
def __init__(self, data_file: str = "storage/soak_data.json"):
|
||||
"""
|
||||
Initialize the soak tracker.
|
||||
|
||||
@ -38,28 +39,22 @@ class SoakTracker:
|
||||
"""Load soak data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded soak data: {self._data.get('total_count', 0)} total soaks")
|
||||
logger.debug(
|
||||
f"Loaded soak data: {self._data.get('total_count', 0)} total soaks"
|
||||
)
|
||||
else:
|
||||
self._data = {
|
||||
"last_soak": None,
|
||||
"total_count": 0,
|
||||
"history": []
|
||||
}
|
||||
self._data = {"last_soak": None, "total_count": 0, "history": []}
|
||||
logger.info("No existing soak data found, starting fresh")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load soak data: {e}")
|
||||
self._data = {
|
||||
"last_soak": None,
|
||||
"total_count": 0,
|
||||
"history": []
|
||||
}
|
||||
self._data = {"last_soak": None, "total_count": 0, "history": []}
|
||||
|
||||
def save_data(self) -> None:
|
||||
"""Save soak data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Soak data saved successfully")
|
||||
except Exception as e:
|
||||
@ -71,7 +66,7 @@ class SoakTracker:
|
||||
username: str,
|
||||
display_name: str,
|
||||
channel_id: int,
|
||||
message_id: int
|
||||
message_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Record a new soak mention.
|
||||
@ -89,7 +84,7 @@ class SoakTracker:
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"channel_id": str(channel_id),
|
||||
"message_id": str(message_id)
|
||||
"message_id": str(message_id),
|
||||
}
|
||||
|
||||
# Update last_soak
|
||||
@ -110,7 +105,9 @@ class SoakTracker:
|
||||
|
||||
self.save_data()
|
||||
|
||||
logger.info(f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}")
|
||||
logger.info(
|
||||
f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}"
|
||||
)
|
||||
|
||||
def get_last_soak(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
@ -135,10 +132,12 @@ class SoakTracker:
|
||||
try:
|
||||
# Parse ISO format timestamp
|
||||
last_timestamp_str = last_soak["timestamp"]
|
||||
if last_timestamp_str.endswith('Z'):
|
||||
last_timestamp_str = last_timestamp_str[:-1] + '+00:00'
|
||||
if last_timestamp_str.endswith("Z"):
|
||||
last_timestamp_str = last_timestamp_str[:-1] + "+00:00"
|
||||
|
||||
last_timestamp = datetime.fromisoformat(last_timestamp_str.replace('Z', '+00:00'))
|
||||
last_timestamp = datetime.fromisoformat(
|
||||
last_timestamp_str.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
# Ensure both times are timezone-aware
|
||||
if last_timestamp.tzinfo is None:
|
||||
|
||||
@ -3,6 +3,7 @@ Transaction Management Commands
|
||||
|
||||
Core transaction commands for roster management and transaction tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
@ -21,6 +22,7 @@ from views.base import PaginationView
|
||||
from services.transaction_service import transaction_service
|
||||
from services.roster_service import roster_service
|
||||
from services.team_service import team_service
|
||||
|
||||
# No longer need TransactionStatus enum
|
||||
|
||||
|
||||
@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView):
|
||||
all_transactions: list,
|
||||
user_id: int,
|
||||
timeout: float = 300.0,
|
||||
show_page_numbers: bool = True
|
||||
show_page_numbers: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
pages=pages,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
show_page_numbers=show_page_numbers
|
||||
show_page_numbers=show_page_numbers,
|
||||
)
|
||||
self.all_transactions = all_transactions
|
||||
|
||||
@discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1)
|
||||
async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
@discord.ui.button(
|
||||
label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1
|
||||
)
|
||||
async def show_move_ids(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Show all move IDs in an ephemeral message."""
|
||||
self.increment_interaction_count()
|
||||
|
||||
if not self.all_transactions:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to show.",
|
||||
ephemeral=True
|
||||
"No transactions to show.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView):
|
||||
# Send the messages
|
||||
if not messages:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to display.",
|
||||
ephemeral=True
|
||||
"No transactions to display.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView):
|
||||
|
||||
class TransactionCommands(commands.Cog):
|
||||
"""Transaction command handlers for roster management."""
|
||||
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.TransactionCommands')
|
||||
|
||||
self.logger = get_contextual_logger(f"{__name__}.TransactionCommands")
|
||||
|
||||
@app_commands.command(
|
||||
name="mymoves",
|
||||
description="View your pending and scheduled transactions"
|
||||
name="mymoves", description="View your pending and scheduled transactions"
|
||||
)
|
||||
@app_commands.describe(
|
||||
show_cancelled="Include cancelled transactions in the display (default: False)"
|
||||
@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog):
|
||||
@requires_team()
|
||||
@logged_command("/mymoves")
|
||||
async def my_moves(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
show_cancelled: bool = False
|
||||
self, interaction: discord.Interaction, show_cancelled: bool = False
|
||||
):
|
||||
"""Display user's transaction status and history."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get user's team
|
||||
team = await get_user_major_league_team(interaction.user.id, get_config().sba_season)
|
||||
|
||||
team = await get_user_major_league_team(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
|
||||
if not team:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team in the current season.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Get transactions in parallel
|
||||
pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season)
|
||||
frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season)
|
||||
processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season)
|
||||
|
||||
pending_transactions = await pending_task
|
||||
frozen_transactions = await frozen_task
|
||||
processed_transactions = await processed_task
|
||||
|
||||
(
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
) = await asyncio.gather(
|
||||
transaction_service.get_pending_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_frozen_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_processed_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
)
|
||||
|
||||
# Get cancelled if requested
|
||||
cancelled_transactions = []
|
||||
if show_cancelled:
|
||||
cancelled_transactions = await transaction_service.get_team_transactions(
|
||||
team.abbrev,
|
||||
get_config().sba_season,
|
||||
cancelled=True
|
||||
team.abbrev, get_config().sba_season, cancelled=True
|
||||
)
|
||||
|
||||
pages = self._create_my_moves_pages(
|
||||
@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog):
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
)
|
||||
|
||||
# Collect all transactions for the "Show Move IDs" button
|
||||
all_transactions = (
|
||||
pending_transactions +
|
||||
frozen_transactions +
|
||||
processed_transactions +
|
||||
cancelled_transactions
|
||||
pending_transactions
|
||||
+ frozen_transactions
|
||||
+ processed_transactions
|
||||
+ cancelled_transactions
|
||||
)
|
||||
|
||||
# If only one page and no transactions, send without any buttons
|
||||
@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog):
|
||||
all_transactions=all_transactions,
|
||||
user_id=interaction.user.id,
|
||||
timeout=300.0,
|
||||
show_page_numbers=True
|
||||
show_page_numbers=True,
|
||||
)
|
||||
await interaction.followup.send(embed=view.get_current_embed(), view=view)
|
||||
|
||||
|
||||
@app_commands.command(
|
||||
name="legal",
|
||||
description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(
|
||||
team="Team abbreviation to check (defaults to your team)"
|
||||
name="legal", description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(team="Team abbreviation to check (defaults to your team)")
|
||||
@requires_team()
|
||||
@logged_command("/legal")
|
||||
async def legal(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
team: Optional[str] = None
|
||||
):
|
||||
async def legal(self, interaction: discord.Interaction, team: Optional[str] = None):
|
||||
"""Check roster legality and display detailed validation results."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get target team
|
||||
if team:
|
||||
target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season)
|
||||
target_team = await team_service.get_team_by_abbrev(
|
||||
team.upper(), get_config().sba_season
|
||||
)
|
||||
if not target_team:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not find team '{team}' in season {get_config().sba_season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Get user's team
|
||||
user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season)
|
||||
user_teams = await team_service.get_teams_by_owner(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
if not user_teams:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team. Please specify a team abbreviation.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
target_team = user_teams[0]
|
||||
|
||||
|
||||
# Get rosters in parallel
|
||||
current_roster, next_roster = await asyncio.gather(
|
||||
roster_service.get_current_roster(target_team.id),
|
||||
roster_service.get_next_roster(target_team.id)
|
||||
roster_service.get_next_roster(target_team.id),
|
||||
)
|
||||
|
||||
|
||||
if not current_roster and not next_roster:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not retrieve roster data for {target_team.abbrev}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Validate rosters in parallel
|
||||
validation_tasks = []
|
||||
if current_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(current_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
if next_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(next_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
validation_results = await asyncio.gather(*validation_tasks)
|
||||
current_validation = validation_results[0] if current_roster else None
|
||||
next_validation = validation_results[1] if next_roster else None
|
||||
|
||||
|
||||
embed = await self._create_legal_embed(
|
||||
target_team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
next_validation,
|
||||
)
|
||||
|
||||
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
def _create_my_moves_pages(
|
||||
self,
|
||||
team,
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
) -> list[discord.Embed]:
|
||||
"""Create paginated embeds showing user's transaction status."""
|
||||
|
||||
@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog):
|
||||
# Page 1: Summary + Pending Transactions
|
||||
if pending_transactions:
|
||||
total_pending = len(pending_transactions)
|
||||
total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page
|
||||
total_pages = (
|
||||
total_pending + transactions_per_page - 1
|
||||
) // transactions_per_page
|
||||
|
||||
for page_num in range(total_pages):
|
||||
start_idx = page_num * transactions_per_page
|
||||
@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
# Pending transactions for this page
|
||||
@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"⏳ Pending Transactions ({total_pending} total)",
|
||||
value="\n".join(pending_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add summary only on first page
|
||||
@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text)
|
||||
if status_text
|
||||
else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
embed.add_field(
|
||||
name="⏳ Pending Transactions",
|
||||
value="No pending transactions",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
total_frozen = len(frozen_transactions)
|
||||
@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text) if status_text else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
frozen_lines = [format_transaction(tx) for tx in frozen_transactions]
|
||||
@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)",
|
||||
value="\n".join(frozen_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20
|
||||
processed_lines = [
|
||||
format_transaction(tx) for tx in processed_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)",
|
||||
value="\n".join(processed_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20
|
||||
cancelled_lines = [
|
||||
format_transaction(tx) for tx in cancelled_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)",
|
||||
value="\n".join(cancelled_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog):
|
||||
page.set_footer(text="Use /legal to check roster legality")
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
async def _create_legal_embed(
|
||||
self,
|
||||
team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
self, team, current_roster, next_roster, current_validation, next_validation
|
||||
) -> discord.Embed:
|
||||
"""Create embed showing roster legality check results."""
|
||||
|
||||
|
||||
# Determine overall status
|
||||
overall_legal = True
|
||||
if current_validation and not current_validation.is_legal:
|
||||
overall_legal = False
|
||||
if next_validation and not next_validation.is_legal:
|
||||
overall_legal = False
|
||||
|
||||
|
||||
status_emoji = "✅" if overall_legal else "❌"
|
||||
embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR
|
||||
|
||||
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"{status_emoji} Roster Check - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=embed_color
|
||||
color=embed_color,
|
||||
)
|
||||
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
|
||||
# Current week roster
|
||||
if current_roster and current_validation:
|
||||
current_lines = []
|
||||
current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL")
|
||||
current_lines.append(
|
||||
f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL"
|
||||
)
|
||||
current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if current_validation.errors:
|
||||
current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}")
|
||||
for error in current_validation.errors[:3]: # Show first 3 errors
|
||||
current_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if current_validation.warnings:
|
||||
current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}")
|
||||
current_lines.append(
|
||||
f"**⚠️ Warnings:** {len(current_validation.warnings)}"
|
||||
)
|
||||
for warning in current_validation.warnings[:2]: # Show first 2 warnings
|
||||
current_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{current_validation.status_emoji} Current Week",
|
||||
value="\n".join(current_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Current Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Current Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
# Next week roster
|
||||
|
||||
# Next week roster
|
||||
if next_roster and next_validation:
|
||||
next_lines = []
|
||||
next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL")
|
||||
next_lines.append(
|
||||
f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL"
|
||||
)
|
||||
next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if next_validation.errors:
|
||||
next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}")
|
||||
for error in next_validation.errors[:3]: # Show first 3 errors
|
||||
next_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if next_validation.warnings:
|
||||
next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}")
|
||||
for warning in next_validation.warnings[:2]: # Show first 2 warnings
|
||||
next_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{next_validation.status_emoji} Next Week",
|
||||
value="\n".join(next_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Next Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Next Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
|
||||
# Overall status
|
||||
if overall_legal:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
value="✅ All rosters are legal",
|
||||
inline=False
|
||||
name="Overall Status", value="✅ All rosters are legal", inline=False
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
name="Overall Status",
|
||||
value="❌ Roster violations found - please review and correct",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text="Roster validation based on current league rules")
|
||||
return embed
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
"""Load the transaction commands cog."""
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
|
||||
@ -3,6 +3,7 @@ Trade Channel Tracker
|
||||
|
||||
Provides persistent tracking of bot-created trade discussion channels using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
from pathlib import Path
|
||||
@ -12,7 +13,7 @@ import discord
|
||||
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
logger = get_contextual_logger(f'{__name__}.TradeChannelTracker')
|
||||
logger = get_contextual_logger(f"{__name__}.TradeChannelTracker")
|
||||
|
||||
|
||||
class TradeChannelTracker:
|
||||
@ -26,7 +27,7 @@ class TradeChannelTracker:
|
||||
- Automatic stale entry removal
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/trade_channels.json"):
|
||||
def __init__(self, data_file: str = "storage/trade_channels.json"):
|
||||
"""
|
||||
Initialize the trade channel tracker.
|
||||
|
||||
@ -42,9 +43,11 @@ class TradeChannelTracker:
|
||||
"""Load channel data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels")
|
||||
logger.debug(
|
||||
f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels"
|
||||
)
|
||||
else:
|
||||
self._data = {"trade_channels": {}}
|
||||
logger.info("No existing trade channel data found, starting fresh")
|
||||
@ -55,7 +58,7 @@ class TradeChannelTracker:
|
||||
def save_data(self) -> None:
|
||||
"""Save channel data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Trade channel data saved successfully")
|
||||
except Exception as e:
|
||||
@ -67,7 +70,7 @@ class TradeChannelTracker:
|
||||
trade_id: str,
|
||||
team1_abbrev: str,
|
||||
team2_abbrev: str,
|
||||
creator_id: int
|
||||
creator_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add a new trade channel to tracking.
|
||||
@ -87,10 +90,12 @@ class TradeChannelTracker:
|
||||
"team1_abbrev": team1_abbrev,
|
||||
"team2_abbrev": team2_abbrev,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"creator_id": str(creator_id)
|
||||
"creator_id": str(creator_id),
|
||||
}
|
||||
self.save_data()
|
||||
logger.info(f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
def remove_channel(self, channel_id: int) -> None:
|
||||
"""
|
||||
@ -108,7 +113,9 @@ class TradeChannelTracker:
|
||||
channel_name = channel_data["name"]
|
||||
del channels[channel_key]
|
||||
self.save_data()
|
||||
logger.info(f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
def get_channel_by_trade_id(self, trade_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
@ -175,7 +182,9 @@ class TradeChannelTracker:
|
||||
channel_name = channels[channel_id_str].get("name", "unknown")
|
||||
trade_id = channels[channel_id_str].get("trade_id", "unknown")
|
||||
del channels[channel_id_str]
|
||||
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})")
|
||||
logger.info(
|
||||
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})"
|
||||
)
|
||||
|
||||
if stale_entries:
|
||||
self.save_data()
|
||||
|
||||
@ -3,6 +3,7 @@ Voice Channel Cleanup Service
|
||||
|
||||
Provides automatic cleanup of empty voice channels with restart resilience.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
@ -12,7 +13,7 @@ from .tracker import VoiceChannelTracker
|
||||
from commands.gameplay.scorecard_tracker import ScorecardTracker
|
||||
from utils.logging import get_contextual_logger
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.VoiceChannelCleanupService')
|
||||
logger = logging.getLogger(f"{__name__}.VoiceChannelCleanupService")
|
||||
|
||||
|
||||
class VoiceChannelCleanupService:
|
||||
@ -27,7 +28,9 @@ class VoiceChannelCleanupService:
|
||||
- Automatic scorecard unpublishing when voice channel is cleaned up
|
||||
"""
|
||||
|
||||
def __init__(self, bot: commands.Bot, data_file: str = "data/voice_channels.json"):
|
||||
def __init__(
|
||||
self, bot: commands.Bot, data_file: str = "storage/voice_channels.json"
|
||||
):
|
||||
"""
|
||||
Initialize the cleanup service.
|
||||
|
||||
@ -36,10 +39,10 @@ class VoiceChannelCleanupService:
|
||||
data_file: Path to the JSON data file for persistence
|
||||
"""
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.VoiceChannelCleanupService')
|
||||
self.logger = get_contextual_logger(f"{__name__}.VoiceChannelCleanupService")
|
||||
self.tracker = VoiceChannelTracker(data_file)
|
||||
self.scorecard_tracker = ScorecardTracker()
|
||||
self.empty_threshold = 5 # Delete after 5 minutes empty
|
||||
self.empty_threshold = 5 # Delete after 5 minutes empty
|
||||
|
||||
# Start the cleanup task - @before_loop will wait for bot readiness
|
||||
self.cleanup_loop.start()
|
||||
@ -90,13 +93,17 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.warning(f"Guild {guild_id} not found, removing channel {channel_data['name']}")
|
||||
self.logger.warning(
|
||||
f"Guild {guild_id} not found, removing channel {channel_data['name']}"
|
||||
)
|
||||
channels_to_remove.append(channel_id)
|
||||
continue
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.warning(f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists"
|
||||
)
|
||||
channels_to_remove.append(channel_id)
|
||||
continue
|
||||
|
||||
@ -121,18 +128,26 @@ class VoiceChannelCleanupService:
|
||||
if channel_data and channel_data.get("text_channel_id"):
|
||||
try:
|
||||
text_channel_id_int = int(channel_data["text_channel_id"])
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in stale voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in stale voice channel data: {e}"
|
||||
)
|
||||
|
||||
# Also clean up any additional stale entries
|
||||
stale_removed = self.tracker.cleanup_stale_entries(valid_channel_ids)
|
||||
total_removed = len(channels_to_remove) + stale_removed
|
||||
|
||||
if total_removed > 0:
|
||||
self.logger.info(f"Cleaned up {total_removed} stale channel tracking entries")
|
||||
self.logger.info(
|
||||
f"Cleaned up {total_removed} stale channel tracking entries"
|
||||
)
|
||||
|
||||
self.logger.info(f"Verified {len(valid_channel_ids)} valid tracked channels")
|
||||
|
||||
@ -149,10 +164,14 @@ class VoiceChannelCleanupService:
|
||||
await self.update_all_channel_statuses(bot)
|
||||
|
||||
# Get channels ready for cleanup
|
||||
channels_for_cleanup = self.tracker.get_channels_for_cleanup(self.empty_threshold)
|
||||
channels_for_cleanup = self.tracker.get_channels_for_cleanup(
|
||||
self.empty_threshold
|
||||
)
|
||||
|
||||
if channels_for_cleanup:
|
||||
self.logger.info(f"Found {len(channels_for_cleanup)} channels ready for cleanup")
|
||||
self.logger.info(
|
||||
f"Found {len(channels_for_cleanup)} channels ready for cleanup"
|
||||
)
|
||||
|
||||
# Delete empty channels
|
||||
for channel_data in channels_for_cleanup:
|
||||
@ -182,12 +201,16 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.debug(f"Guild {guild_id} not found for channel {channel_data['name']}")
|
||||
self.logger.debug(
|
||||
f"Guild {guild_id} not found for channel {channel_data['name']}"
|
||||
)
|
||||
return
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.debug(f"Channel {channel_data['name']} no longer exists, removing from tracking")
|
||||
self.logger.debug(
|
||||
f"Channel {channel_data['name']} no longer exists, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
@ -195,17 +218,25 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in manually deleted voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in manually deleted voice channel data: {e}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Ensure it's a voice channel before checking members
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
self.logger.warning(f"Channel {channel_data['name']} is not a voice channel, removing from tracking")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_data['name']} is not a voice channel, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
@ -213,11 +244,17 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in wrong channel type data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in wrong channel type data: {e}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@ -225,11 +262,15 @@ class VoiceChannelCleanupService:
|
||||
is_empty = len(channel.members) == 0
|
||||
self.tracker.update_channel_status(channel_id, is_empty)
|
||||
|
||||
self.logger.debug(f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
|
||||
f"({len(channel.members)} members)")
|
||||
self.logger.debug(
|
||||
f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} "
|
||||
f"({len(channel.members)} members)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}")
|
||||
self.logger.error(
|
||||
f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
async def cleanup_channel(self, bot: commands.Bot, channel_data: dict) -> None:
|
||||
"""
|
||||
@ -246,25 +287,33 @@ class VoiceChannelCleanupService:
|
||||
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
self.logger.info(f"Guild {guild_id} not found, removing tracking for {channel_name}")
|
||||
self.logger.info(
|
||||
f"Guild {guild_id} not found, removing tracking for {channel_name}"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
channel = guild.get_channel(channel_id)
|
||||
if not channel:
|
||||
self.logger.info(f"Channel {channel_name} already deleted, removing from tracking")
|
||||
self.logger.info(
|
||||
f"Channel {channel_name} already deleted, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
# Ensure it's a voice channel before checking members
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
self.logger.warning(f"Channel {channel_name} is not a voice channel, removing from tracking")
|
||||
self.logger.warning(
|
||||
f"Channel {channel_name} is not a voice channel, removing from tracking"
|
||||
)
|
||||
self.tracker.remove_channel(channel_id)
|
||||
return
|
||||
|
||||
# Final check: make sure channel is still empty before deleting
|
||||
if len(channel.members) > 0:
|
||||
self.logger.info(f"Channel {channel_name} is no longer empty, skipping cleanup")
|
||||
self.logger.info(
|
||||
f"Channel {channel_name} is no longer empty, skipping cleanup"
|
||||
)
|
||||
self.tracker.update_channel_status(channel_id, False)
|
||||
return
|
||||
|
||||
@ -272,24 +321,36 @@ class VoiceChannelCleanupService:
|
||||
await channel.delete(reason="Automatic cleanup - empty for 5+ minutes")
|
||||
self.tracker.remove_channel(channel_id)
|
||||
|
||||
self.logger.info(f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})")
|
||||
self.logger.info(
|
||||
f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})"
|
||||
)
|
||||
|
||||
# Unpublish associated scorecard if it exists
|
||||
text_channel_id = channel_data.get("text_channel_id")
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"No scorecard found for text channel {text_channel_id_int}")
|
||||
self.logger.debug(
|
||||
f"No scorecard found for text channel {text_channel_id_int}"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in voice channel data: {e}"
|
||||
)
|
||||
|
||||
except discord.NotFound:
|
||||
# Channel was already deleted
|
||||
self.logger.info(f"Channel {channel_data.get('name', 'unknown')} was already deleted")
|
||||
self.logger.info(
|
||||
f"Channel {channel_data.get('name', 'unknown')} was already deleted"
|
||||
)
|
||||
self.tracker.remove_channel(int(channel_data["channel_id"]))
|
||||
|
||||
# Still try to unpublish associated scorecard
|
||||
@ -297,15 +358,25 @@ class VoiceChannelCleanupService:
|
||||
if text_channel_id:
|
||||
try:
|
||||
text_channel_id_int = int(text_channel_id)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int)
|
||||
was_unpublished = self.scorecard_tracker.unpublish_scorecard(
|
||||
text_channel_id_int
|
||||
)
|
||||
if was_unpublished:
|
||||
self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)")
|
||||
self.logger.info(
|
||||
f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}")
|
||||
self.logger.warning(
|
||||
f"Invalid text_channel_id in voice channel data: {e}"
|
||||
)
|
||||
except discord.Forbidden:
|
||||
self.logger.error(f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}")
|
||||
self.logger.error(
|
||||
f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}")
|
||||
self.logger.error(
|
||||
f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
def get_tracker(self) -> VoiceChannelTracker:
|
||||
"""
|
||||
@ -330,7 +401,7 @@ class VoiceChannelCleanupService:
|
||||
"running": self.cleanup_loop.is_running(),
|
||||
"total_tracked": len(all_channels),
|
||||
"empty_channels": len(empty_channels),
|
||||
"empty_threshold": self.empty_threshold
|
||||
"empty_threshold": self.empty_threshold,
|
||||
}
|
||||
|
||||
|
||||
@ -344,4 +415,4 @@ def setup_voice_cleanup(bot: commands.Bot) -> VoiceChannelCleanupService:
|
||||
Returns:
|
||||
VoiceChannelCleanupService instance
|
||||
"""
|
||||
return VoiceChannelCleanupService(bot)
|
||||
return VoiceChannelCleanupService(bot)
|
||||
|
||||
@ -3,6 +3,7 @@ Voice Channel Tracker
|
||||
|
||||
Provides persistent tracking of bot-created voice channels using JSON file storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, UTC
|
||||
@ -11,7 +12,7 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(f'{__name__}.VoiceChannelTracker')
|
||||
logger = logging.getLogger(f"{__name__}.VoiceChannelTracker")
|
||||
|
||||
|
||||
class VoiceChannelTracker:
|
||||
@ -25,7 +26,7 @@ class VoiceChannelTracker:
|
||||
- Automatic stale entry removal
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: str = "data/voice_channels.json"):
|
||||
def __init__(self, data_file: str = "storage/voice_channels.json"):
|
||||
"""
|
||||
Initialize the voice channel tracker.
|
||||
|
||||
@ -41,9 +42,11 @@ class VoiceChannelTracker:
|
||||
"""Load channel data from JSON file."""
|
||||
try:
|
||||
if self.data_file.exists():
|
||||
with open(self.data_file, 'r') as f:
|
||||
with open(self.data_file, "r") as f:
|
||||
self._data = json.load(f)
|
||||
logger.debug(f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels")
|
||||
logger.debug(
|
||||
f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels"
|
||||
)
|
||||
else:
|
||||
self._data = {"voice_channels": {}}
|
||||
logger.info("No existing voice channel data found, starting fresh")
|
||||
@ -54,7 +57,7 @@ class VoiceChannelTracker:
|
||||
def save_data(self) -> None:
|
||||
"""Save channel data to JSON file."""
|
||||
try:
|
||||
with open(self.data_file, 'w') as f:
|
||||
with open(self.data_file, "w") as f:
|
||||
json.dump(self._data, f, indent=2, default=str)
|
||||
logger.debug("Voice channel data saved successfully")
|
||||
except Exception as e:
|
||||
@ -65,7 +68,7 @@ class VoiceChannelTracker:
|
||||
channel: discord.VoiceChannel,
|
||||
channel_type: str,
|
||||
creator_id: int,
|
||||
text_channel_id: Optional[int] = None
|
||||
text_channel_id: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a new channel to tracking.
|
||||
@ -85,7 +88,7 @@ class VoiceChannelTracker:
|
||||
"last_checked": datetime.now(UTC).isoformat(),
|
||||
"empty_since": None,
|
||||
"creator_id": str(creator_id),
|
||||
"text_channel_id": str(text_channel_id) if text_channel_id else None
|
||||
"text_channel_id": str(text_channel_id) if text_channel_id else None,
|
||||
}
|
||||
self.save_data()
|
||||
logger.info(f"Added channel to tracking: {channel.name} (ID: {channel.id})")
|
||||
@ -130,9 +133,13 @@ class VoiceChannelTracker:
|
||||
channel_name = channels[channel_key]["name"]
|
||||
del channels[channel_key]
|
||||
self.save_data()
|
||||
logger.info(f"Removed channel from tracking: {channel_name} (ID: {channel_id})")
|
||||
logger.info(
|
||||
f"Removed channel from tracking: {channel_name} (ID: {channel_id})"
|
||||
)
|
||||
|
||||
def get_channels_for_cleanup(self, empty_threshold_minutes: int = 15) -> List[Dict[str, Any]]:
|
||||
def get_channels_for_cleanup(
|
||||
self, empty_threshold_minutes: int = 15
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get channels that should be deleted based on empty duration.
|
||||
|
||||
@ -153,10 +160,12 @@ class VoiceChannelTracker:
|
||||
# Parse empty_since timestamp
|
||||
empty_since_str = channel_data["empty_since"]
|
||||
# Handle both with and without timezone info
|
||||
if empty_since_str.endswith('Z'):
|
||||
empty_since_str = empty_since_str[:-1] + '+00:00'
|
||||
if empty_since_str.endswith("Z"):
|
||||
empty_since_str = empty_since_str[:-1] + "+00:00"
|
||||
|
||||
empty_since = datetime.fromisoformat(empty_since_str.replace('Z', '+00:00'))
|
||||
empty_since = datetime.fromisoformat(
|
||||
empty_since_str.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
# Remove timezone info for comparison (both times are UTC)
|
||||
if empty_since.tzinfo:
|
||||
@ -164,10 +173,14 @@ class VoiceChannelTracker:
|
||||
|
||||
if empty_since <= cutoff_time:
|
||||
cleanup_candidates.append(channel_data)
|
||||
logger.debug(f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})")
|
||||
logger.debug(
|
||||
f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})"
|
||||
)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}")
|
||||
logger.warning(
|
||||
f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
|
||||
return cleanup_candidates
|
||||
|
||||
@ -242,9 +255,11 @@ class VoiceChannelTracker:
|
||||
for channel_id_str in stale_entries:
|
||||
channel_name = channels[channel_id_str].get("name", "unknown")
|
||||
del channels[channel_id_str]
|
||||
logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})")
|
||||
logger.info(
|
||||
f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})"
|
||||
)
|
||||
|
||||
if stale_entries:
|
||||
self.save_data()
|
||||
|
||||
return len(stale_entries)
|
||||
return len(stale_entries)
|
||||
|
||||
@ -36,8 +36,11 @@ services:
|
||||
|
||||
# Volume mounts
|
||||
volumes:
|
||||
# Google Sheets credentials (required)
|
||||
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data}:/app/data:ro
|
||||
# Google Sheets credentials (read-only, file mount)
|
||||
- ${SHEETS_CREDENTIALS_HOST_PATH:-./data/major-domo-service-creds.json}:/app/data/major-domo-service-creds.json:ro
|
||||
|
||||
# Runtime state files (writable) - scorecards, voice channels, trade channels, soak data
|
||||
- ${STATE_HOST_PATH:-./storage}:/app/storage:rw
|
||||
|
||||
# Logs directory (persistent) - mounted to /app/logs where the application expects it
|
||||
- ${LOGS_HOST_PATH:-./logs}:/app/logs:rw
|
||||
|
||||
@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0
|
||||
Handles game schedule and results retrieval and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
@ -102,10 +103,10 @@ class ScheduleService:
|
||||
# If weeks not specified, try a reasonable range (18 weeks typical)
|
||||
week_range = range(1, (weeks + 1) if weeks else 19)
|
||||
|
||||
for week in week_range:
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Filter games involving this team
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in week_range]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
for game in week_games:
|
||||
if (
|
||||
game.away_team.abbrev.upper() == team_abbrev_upper
|
||||
@ -135,15 +136,13 @@ class ScheduleService:
|
||||
recent_games = []
|
||||
|
||||
# Get games from recent weeks
|
||||
for week_offset in range(weeks_back):
|
||||
# This is simplified - in production you'd want to determine current week
|
||||
week = 10 - week_offset # Assuming we're around week 10
|
||||
if week <= 0:
|
||||
break
|
||||
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Only include completed games
|
||||
weeks_to_fetch = [
|
||||
(10 - offset) for offset in range(weeks_back) if (10 - offset) > 0
|
||||
]
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in weeks_to_fetch]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
completed_games = [game for game in week_games if game.is_completed]
|
||||
recent_games.extend(completed_games)
|
||||
|
||||
@ -157,13 +156,12 @@ class ScheduleService:
|
||||
logger.error(f"Error getting recent games: {e}")
|
||||
return []
|
||||
|
||||
async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]:
|
||||
async def get_upcoming_games(self, season: int) -> List[Game]:
|
||||
"""
|
||||
Get upcoming scheduled games by scanning multiple weeks.
|
||||
Get upcoming scheduled games by scanning all weeks.
|
||||
|
||||
Args:
|
||||
season: Season number
|
||||
weeks_ahead: Number of weeks to scan ahead (default 6)
|
||||
|
||||
Returns:
|
||||
List of upcoming Game instances
|
||||
@ -171,20 +169,16 @@ class ScheduleService:
|
||||
try:
|
||||
upcoming_games = []
|
||||
|
||||
# Scan through weeks to find games without scores
|
||||
for week in range(1, 19): # Standard season length
|
||||
week_games = await self.get_week_schedule(season, week)
|
||||
|
||||
# Find games without scores (not yet played)
|
||||
# Fetch all weeks in parallel and filter for incomplete games
|
||||
all_week_games = await asyncio.gather(
|
||||
*[self.get_week_schedule(season, week) for week in range(1, 19)]
|
||||
)
|
||||
for week_games in all_week_games:
|
||||
upcoming_games_week = [
|
||||
game for game in week_games if not game.is_completed
|
||||
]
|
||||
upcoming_games.extend(upcoming_games_week)
|
||||
|
||||
# If we found upcoming games, we can limit how many more weeks to check
|
||||
if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit
|
||||
break
|
||||
|
||||
# Sort by week, then game number
|
||||
upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0))
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0
|
||||
Handles batting and pitching statistics retrieval and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -144,11 +145,10 @@ class StatsService:
|
||||
"""
|
||||
try:
|
||||
# Get both types of stats concurrently
|
||||
batting_task = self.get_batting_stats(player_id, season)
|
||||
pitching_task = self.get_pitching_stats(player_id, season)
|
||||
|
||||
batting_stats = await batting_task
|
||||
pitching_stats = await pitching_task
|
||||
batting_stats, pitching_stats = await asyncio.gather(
|
||||
self.get_batting_stats(player_id, season),
|
||||
self.get_pitching_stats(player_id, season),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved stats for player {player_id}: "
|
||||
|
||||
@ -4,6 +4,7 @@ Trade Builder Service
|
||||
Extends the TransactionBuilder to support multi-team trades and player exchanges.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime, timezone
|
||||
@ -524,14 +525,22 @@ class TradeBuilder:
|
||||
|
||||
# Validate each team's roster after the trade
|
||||
for participant in self.trade.participants:
|
||||
team_id = participant.team.id
|
||||
result.team_abbrevs[team_id] = participant.team.abbrev
|
||||
if team_id in self._team_builders:
|
||||
builder = self._team_builders[team_id]
|
||||
roster_validation = await builder.validate_transaction(next_week)
|
||||
result.team_abbrevs[participant.team.id] = participant.team.abbrev
|
||||
|
||||
team_ids_to_validate = [
|
||||
participant.team.id
|
||||
for participant in self.trade.participants
|
||||
if participant.team.id in self._team_builders
|
||||
]
|
||||
if team_ids_to_validate:
|
||||
validations = await asyncio.gather(
|
||||
*[
|
||||
self._team_builders[tid].validate_transaction(next_week)
|
||||
for tid in team_ids_to_validate
|
||||
]
|
||||
)
|
||||
for team_id, roster_validation in zip(team_ids_to_validate, validations):
|
||||
result.participant_validations[team_id] = roster_validation
|
||||
|
||||
if not roster_validation.is_legal:
|
||||
result.is_legal = False
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ class LiveScorebugTracker:
|
||||
# Don't return - still update voice channels
|
||||
else:
|
||||
# Get all published scorecards
|
||||
all_scorecards = self.scorecard_tracker.get_all_scorecards()
|
||||
all_scorecards = await self.scorecard_tracker.get_all_scorecards()
|
||||
|
||||
if not all_scorecards:
|
||||
# No active scorebugs - clear the channel and hide it
|
||||
@ -112,17 +112,16 @@ class LiveScorebugTracker:
|
||||
for text_channel_id, sheet_url in all_scorecards:
|
||||
try:
|
||||
scorebug_data = await self.scorebug_service.read_scorebug_data(
|
||||
sheet_url, full_length=False # Compact view for live channel
|
||||
sheet_url,
|
||||
full_length=False, # Compact view for live channel
|
||||
)
|
||||
|
||||
# Only include active (non-final) games
|
||||
if scorebug_data.is_active:
|
||||
# Get team data
|
||||
away_team = await team_service.get_team(
|
||||
scorebug_data.away_team_id
|
||||
)
|
||||
home_team = await team_service.get_team(
|
||||
scorebug_data.home_team_id
|
||||
away_team, home_team = await asyncio.gather(
|
||||
team_service.get_team(scorebug_data.away_team_id),
|
||||
team_service.get_team(scorebug_data.home_team_id),
|
||||
)
|
||||
|
||||
if away_team is None or home_team is None:
|
||||
@ -188,9 +187,8 @@ class LiveScorebugTracker:
|
||||
embeds: List of scorebug embeds
|
||||
"""
|
||||
try:
|
||||
# Clear old messages
|
||||
async for message in channel.history(limit=25):
|
||||
await message.delete()
|
||||
# Clear old messages using bulk delete
|
||||
await channel.purge(limit=25)
|
||||
|
||||
# Post new scorebugs (Discord allows up to 10 embeds per message)
|
||||
if len(embeds) <= 10:
|
||||
@ -216,9 +214,8 @@ class LiveScorebugTracker:
|
||||
channel: Discord text channel
|
||||
"""
|
||||
try:
|
||||
# Clear all messages
|
||||
async for message in channel.history(limit=25):
|
||||
await message.delete()
|
||||
# Clear all messages using bulk delete
|
||||
await channel.purge(limit=25)
|
||||
|
||||
self.logger.info("Cleared live-sba-scores channel (no active games)")
|
||||
|
||||
|
||||
284
tests/test_services_schedule.py
Normal file
284
tests/test_services_schedule.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""
|
||||
Tests for schedule service functionality.
|
||||
|
||||
Covers get_week_schedule, get_team_schedule, get_recent_games,
|
||||
get_upcoming_games, and group_games_by_series — verifying the
|
||||
asyncio.gather parallelization and post-fetch filtering logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from services.schedule_service import ScheduleService
|
||||
from tests.factories import GameFactory, TeamFactory
|
||||
|
||||
|
||||
def _game(game_id, week, away_abbrev, home_abbrev, **kwargs):
|
||||
"""Create a Game with distinct team IDs per matchup."""
|
||||
return GameFactory.create(
|
||||
id=game_id,
|
||||
week=week,
|
||||
away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev),
|
||||
home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TestGetWeekSchedule:
|
||||
"""Tests for ScheduleService.get_week_schedule — the HTTP layer."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
svc = ScheduleService()
|
||||
svc.get_client = AsyncMock()
|
||||
return svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success(self, service):
|
||||
"""get_week_schedule returns parsed Game objects on a normal response."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {
|
||||
"games": [
|
||||
{
|
||||
"id": 1,
|
||||
"season": 12,
|
||||
"week": 5,
|
||||
"game_num": 1,
|
||||
"season_type": "regular",
|
||||
"away_team": {
|
||||
"id": 10,
|
||||
"abbrev": "NYY",
|
||||
"sname": "NYY",
|
||||
"lname": "New York",
|
||||
"season": 12,
|
||||
},
|
||||
"home_team": {
|
||||
"id": 11,
|
||||
"abbrev": "BOS",
|
||||
"sname": "BOS",
|
||||
"lname": "Boston",
|
||||
"season": 12,
|
||||
},
|
||||
"away_score": 4,
|
||||
"home_score": 2,
|
||||
}
|
||||
]
|
||||
}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 5)
|
||||
|
||||
assert len(games) == 1
|
||||
assert games[0].away_team.abbrev == "NYY"
|
||||
assert games[0].home_team.abbrev == "BOS"
|
||||
assert games[0].is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response(self, service):
|
||||
"""get_week_schedule returns [] when the API has no games."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {"games": []}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 99)
|
||||
assert games == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_returns_empty(self, service):
|
||||
"""get_week_schedule returns [] on API error (no exception raised)."""
|
||||
service.get_client.side_effect = Exception("connection refused")
|
||||
|
||||
games = await service.get_week_schedule(12, 1)
|
||||
assert games == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_games_key(self, service):
|
||||
"""get_week_schedule returns [] when response lacks 'games' key."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = {"status": "ok"}
|
||||
service.get_client.return_value = mock_client
|
||||
|
||||
games = await service.get_week_schedule(12, 1)
|
||||
assert games == []
|
||||
|
||||
|
||||
class TestGetTeamSchedule:
|
||||
"""Tests for get_team_schedule — gather + team-abbrev filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_team_case_insensitive(self, service):
|
||||
"""get_team_schedule returns only games involving the requested team,
|
||||
regardless of abbreviation casing."""
|
||||
week1 = [
|
||||
_game(1, 1, "NYY", "BOS", away_score=3, home_score=1),
|
||||
_game(2, 1, "LAD", "CHC", away_score=5, home_score=2),
|
||||
]
|
||||
week2 = [
|
||||
_game(3, 2, "BOS", "NYY", away_score=2, home_score=4),
|
||||
]
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.side_effect = [week1, week2]
|
||||
result = await service.get_team_schedule(12, "nyy", weeks=2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(
|
||||
g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_season_fetches_18_weeks(self, service):
|
||||
"""When weeks is None, all 18 weeks are fetched via gather."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_team_schedule(12, "NYY")
|
||||
|
||||
assert mock.call_count == 18
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limited_weeks(self, service):
|
||||
"""When weeks=5, only 5 weeks are fetched."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_team_schedule(12, "NYY", weeks=5)
|
||||
|
||||
assert mock.call_count == 5
|
||||
|
||||
|
||||
class TestGetRecentGames:
|
||||
"""Tests for get_recent_games — gather + completed-only filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_completed_games(self, service):
|
||||
"""get_recent_games filters out games without scores."""
|
||||
completed = GameFactory.completed(id=1, week=10)
|
||||
incomplete = GameFactory.upcoming(id=2, week=10)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = [completed, incomplete]
|
||||
result = await service.get_recent_games(12, weeks_back=1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sorted_descending_by_week_and_game_num(self, service):
|
||||
"""Recent games are sorted most-recent first."""
|
||||
game_w10 = GameFactory.completed(id=1, week=10, game_num=2)
|
||||
game_w9 = GameFactory.completed(id=2, week=9, game_num=1)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.side_effect = [[game_w10], [game_w9]]
|
||||
result = await service.get_recent_games(12, weeks_back=2)
|
||||
|
||||
assert result[0].week == 10
|
||||
assert result[1].week == 9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_negative_weeks(self, service):
|
||||
"""Weeks that would be <= 0 are excluded from fetch."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_recent_games(12, weeks_back=15)
|
||||
|
||||
# weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks
|
||||
assert mock.call_count == 10
|
||||
|
||||
|
||||
class TestGetUpcomingGames:
|
||||
"""Tests for get_upcoming_games — gather all 18 weeks + incomplete filter."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_only_incomplete_games(self, service):
|
||||
"""get_upcoming_games filters out completed games."""
|
||||
completed = GameFactory.completed(id=1, week=5)
|
||||
upcoming = GameFactory.upcoming(id=2, week=5)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = [completed, upcoming]
|
||||
result = await service.get_upcoming_games(12)
|
||||
|
||||
assert len(result) == 18 # 1 incomplete game per week × 18 weeks
|
||||
assert all(not g.is_completed for g in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sorted_ascending_by_week_and_game_num(self, service):
|
||||
"""Upcoming games are sorted earliest first."""
|
||||
game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1)
|
||||
game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2)
|
||||
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
|
||||
def side_effect(season, week):
|
||||
if week == 1:
|
||||
return [game_w1]
|
||||
if week == 3:
|
||||
return [game_w3]
|
||||
return []
|
||||
|
||||
mock.side_effect = side_effect
|
||||
result = await service.get_upcoming_games(12)
|
||||
|
||||
assert result[0].week == 1
|
||||
assert result[1].week == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_all_18_weeks(self, service):
|
||||
"""All 18 weeks are fetched in parallel (no early exit)."""
|
||||
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = []
|
||||
await service.get_upcoming_games(12)
|
||||
|
||||
assert mock.call_count == 18
|
||||
|
||||
|
||||
class TestGroupGamesBySeries:
|
||||
"""Tests for group_games_by_series — synchronous grouping logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return ScheduleService()
|
||||
|
||||
def test_groups_by_alphabetical_pairing(self, service):
|
||||
"""Games between the same two teams are grouped under one key,
|
||||
with the alphabetically-first team first in the tuple."""
|
||||
games = [
|
||||
_game(1, 1, "NYY", "BOS", game_num=1),
|
||||
_game(2, 1, "BOS", "NYY", game_num=2),
|
||||
_game(3, 1, "LAD", "CHC", game_num=1),
|
||||
]
|
||||
|
||||
result = service.group_games_by_series(games)
|
||||
|
||||
assert ("BOS", "NYY") in result
|
||||
assert len(result[("BOS", "NYY")]) == 2
|
||||
assert ("CHC", "LAD") in result
|
||||
assert len(result[("CHC", "LAD")]) == 1
|
||||
|
||||
def test_sorted_by_game_num_within_series(self, service):
|
||||
"""Games within each series are sorted by game_num."""
|
||||
games = [
|
||||
_game(1, 1, "NYY", "BOS", game_num=3),
|
||||
_game(2, 1, "NYY", "BOS", game_num=1),
|
||||
_game(3, 1, "NYY", "BOS", game_num=2),
|
||||
]
|
||||
|
||||
result = service.group_games_by_series(games)
|
||||
series = result[("BOS", "NYY")]
|
||||
assert [g.game_num for g in series] == [1, 2, 3]
|
||||
|
||||
def test_empty_input(self, service):
|
||||
"""Empty games list returns empty dict."""
|
||||
assert service.group_games_by_series([]) == {}
|
||||
111
tests/test_services_stats.py
Normal file
111
tests/test_services_stats.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""
|
||||
Tests for StatsService
|
||||
|
||||
Validates stats service functionality including concurrent stat retrieval
|
||||
and error handling in get_player_stats().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from services.stats_service import StatsService
|
||||
|
||||
|
||||
class TestStatsServiceGetPlayerStats:
|
||||
"""Test StatsService.get_player_stats() concurrent retrieval."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh StatsService instance for testing."""
|
||||
return StatsService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_batting_stats(self):
|
||||
"""Create a mock BattingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.avg = 0.300
|
||||
return stats
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pitching_stats(self):
|
||||
"""Create a mock PitchingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.era = 3.50
|
||||
return stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_stats_returned(
|
||||
self, service, mock_batting_stats, mock_pitching_stats
|
||||
):
|
||||
"""When both batting and pitching stats exist, both are returned.
|
||||
|
||||
Verifies that get_player_stats returns a tuple of (batting, pitching)
|
||||
when both stat types are available for the player.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=100, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is mock_pitching_stats
|
||||
service.get_batting_stats.assert_called_once_with(100, 12)
|
||||
service.get_pitching_stats.assert_called_once_with(100, 12)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batting_only(self, service, mock_batting_stats):
|
||||
"""When only batting stats exist, pitching is None.
|
||||
|
||||
Covers the case of a position player with no pitching record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=200, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pitching_only(self, service, mock_pitching_stats):
|
||||
"""When only pitching stats exist, batting is None.
|
||||
|
||||
Covers the case of a pitcher with no batting record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=300, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is mock_pitching_stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stats_found(self, service):
|
||||
"""When no stats exist for the player, both are None.
|
||||
|
||||
Covers the case where a player has no stats for the given season
|
||||
(e.g., didn't play).
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=400, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_none_tuple(self, service):
|
||||
"""When an exception occurs, (None, None) is returned.
|
||||
|
||||
The get_player_stats method wraps both calls in a try/except and
|
||||
returns (None, None) on any error, ensuring callers always get a tuple.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down"))
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=500, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
@ -3,10 +3,16 @@ Tests for shared autocomplete utility functions.
|
||||
|
||||
Validates the shared autocomplete functions used across multiple command modules.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from utils.autocomplete import player_autocomplete, team_autocomplete, major_league_team_autocomplete
|
||||
import utils.autocomplete
|
||||
from utils.autocomplete import (
|
||||
player_autocomplete,
|
||||
team_autocomplete,
|
||||
major_league_team_autocomplete,
|
||||
)
|
||||
from tests.factories import PlayerFactory, TeamFactory
|
||||
from models.team import RosterType
|
||||
|
||||
@ -14,6 +20,13 @@ from models.team import RosterType
|
||||
class TestPlayerAutocomplete:
|
||||
"""Test player autocomplete functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_user_team_cache(self):
|
||||
"""Clear the module-level user team cache before each test to prevent interference."""
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
yield
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interaction(self):
|
||||
"""Create a mock Discord interaction."""
|
||||
@ -26,41 +39,43 @@ class TestPlayerAutocomplete:
|
||||
"""Test successful player autocomplete."""
|
||||
mock_players = [
|
||||
PlayerFactory.mike_trout(id=1),
|
||||
PlayerFactory.ronald_acuna(id=2)
|
||||
PlayerFactory.ronald_acuna(id=2),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players = AsyncMock(return_value=mock_players)
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
|
||||
assert len(choices) == 2
|
||||
assert choices[0].name == 'Mike Trout (CF)'
|
||||
assert choices[0].value == 'Mike Trout'
|
||||
assert choices[1].name == 'Ronald Acuna Jr. (OF)'
|
||||
assert choices[1].value == 'Ronald Acuna Jr.'
|
||||
assert choices[0].name == "Mike Trout (CF)"
|
||||
assert choices[0].value == "Mike Trout"
|
||||
assert choices[1].name == "Ronald Acuna Jr. (OF)"
|
||||
assert choices[1].value == "Ronald Acuna Jr."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_with_team_info(self, mock_interaction):
|
||||
"""Test player autocomplete with team information."""
|
||||
mock_team = TeamFactory.create(id=499, abbrev='LAA', sname='Angels', lname='Los Angeles Angels')
|
||||
mock_team = TeamFactory.create(
|
||||
id=499, abbrev="LAA", sname="Angels", lname="Los Angeles Angels"
|
||||
)
|
||||
mock_player = PlayerFactory.mike_trout(id=1)
|
||||
mock_player.team = mock_team
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players = AsyncMock(return_value=[mock_player])
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
|
||||
assert len(choices) == 1
|
||||
assert choices[0].name == 'Mike Trout (CF - LAA)'
|
||||
assert choices[0].value == 'Mike Trout'
|
||||
assert choices[0].name == "Mike Trout (CF - LAA)"
|
||||
assert choices[0].value == "Mike Trout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_prioritizes_user_team(self, mock_interaction):
|
||||
"""Test that user's team players are prioritized in autocomplete."""
|
||||
user_team = TeamFactory.create(id=1, abbrev='POR', sname='Loggers')
|
||||
other_team = TeamFactory.create(id=2, abbrev='LAA', sname='Angels')
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
other_team = TeamFactory.create(id=2, abbrev="LAA", sname="Angels")
|
||||
|
||||
# Create players - one from user's team, one from other team
|
||||
user_player = PlayerFactory.mike_trout(id=1)
|
||||
@ -71,32 +86,35 @@ class TestPlayerAutocomplete:
|
||||
other_player.team = other_team
|
||||
other_player.team_id = other_team.id
|
||||
|
||||
with patch('utils.autocomplete.player_service') as mock_service, \
|
||||
patch('utils.autocomplete.get_user_major_league_team') as mock_get_team:
|
||||
|
||||
mock_service.search_players = AsyncMock(return_value=[other_player, user_player])
|
||||
with (
|
||||
patch("utils.autocomplete.player_service") as mock_service,
|
||||
patch("utils.autocomplete.get_user_major_league_team") as mock_get_team,
|
||||
):
|
||||
mock_service.search_players = AsyncMock(
|
||||
return_value=[other_player, user_player]
|
||||
)
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'player')
|
||||
choices = await player_autocomplete(mock_interaction, "player")
|
||||
|
||||
assert len(choices) == 2
|
||||
# User's team player should be first
|
||||
assert choices[0].name == 'Mike Trout (CF - POR)'
|
||||
assert choices[1].name == 'Ronald Acuna Jr. (OF - LAA)'
|
||||
assert choices[0].name == "Mike Trout (CF - POR)"
|
||||
assert choices[1].name == "Ronald Acuna Jr. (OF - LAA)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test player autocomplete with short input returns empty."""
|
||||
choices = await player_autocomplete(mock_interaction, 'T')
|
||||
choices = await player_autocomplete(mock_interaction, "T")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_player_autocomplete_error_handling(self, mock_interaction):
|
||||
"""Test player autocomplete error handling."""
|
||||
with patch('utils.autocomplete.player_service') as mock_service:
|
||||
with patch("utils.autocomplete.player_service") as mock_service:
|
||||
mock_service.search_players.side_effect = Exception("API Error")
|
||||
|
||||
choices = await player_autocomplete(mock_interaction, 'Trout')
|
||||
choices = await player_autocomplete(mock_interaction, "Trout")
|
||||
assert len(choices) == 0
|
||||
|
||||
|
||||
@ -114,35 +132,35 @@ class TestTeamAutocomplete:
|
||||
async def test_team_autocomplete_success(self, mock_interaction):
|
||||
"""Test successful team autocomplete."""
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
|
||||
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'),
|
||||
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'),
|
||||
TeamFactory.create(id=4, abbrev='POR', sname='Loggers')
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
|
||||
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"),
|
||||
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"),
|
||||
TeamFactory.create(id=4, abbrev="POR", sname="Loggers"),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await team_autocomplete(mock_interaction, 'la')
|
||||
choices = await team_autocomplete(mock_interaction, "la")
|
||||
|
||||
assert len(choices) == 3 # All teams with 'la' in abbrev or sname
|
||||
assert any('LAA' in choice.name for choice in choices)
|
||||
assert any('LAAMIL' in choice.name for choice in choices)
|
||||
assert any('LAAAIL' in choice.name for choice in choices)
|
||||
assert any("LAA" in choice.name for choice in choices)
|
||||
assert any("LAAMIL" in choice.name for choice in choices)
|
||||
assert any("LAAAIL" in choice.name for choice in choices)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test team autocomplete with very short input."""
|
||||
choices = await team_autocomplete(mock_interaction, '')
|
||||
choices = await team_autocomplete(mock_interaction, "")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_autocomplete_error_handling(self, mock_interaction):
|
||||
"""Test team autocomplete error handling."""
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season.side_effect = Exception("API Error")
|
||||
|
||||
choices = await team_autocomplete(mock_interaction, 'LAA')
|
||||
choices = await team_autocomplete(mock_interaction, "LAA")
|
||||
assert len(choices) == 0
|
||||
|
||||
|
||||
@ -157,101 +175,197 @@ class TestMajorLeagueTeamAutocomplete:
|
||||
return interaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_filters_correctly(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_filters_correctly(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test that only major league teams are returned."""
|
||||
# Create teams with different roster types
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), # ML
|
||||
TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'), # MiL
|
||||
TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'), # IL
|
||||
TeamFactory.create(id=4, abbrev='FA', sname='Free Agents'), # FA
|
||||
TeamFactory.create(id=5, abbrev='POR', sname='Loggers'), # ML
|
||||
TeamFactory.create(id=6, abbrev='PORMIL', sname='Portland MiL'), # MiL
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), # ML
|
||||
TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"), # MiL
|
||||
TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"), # IL
|
||||
TeamFactory.create(id=4, abbrev="FA", sname="Free Agents"), # FA
|
||||
TeamFactory.create(id=5, abbrev="POR", sname="Loggers"), # ML
|
||||
TeamFactory.create(id=6, abbrev="PORMIL", sname="Portland MiL"), # MiL
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'l')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "l")
|
||||
|
||||
# Should only return major league teams that match 'l' (LAA, POR)
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'LAA' in choice_values
|
||||
assert 'POR' in choice_values
|
||||
assert "LAA" in choice_values
|
||||
assert "POR" in choice_values
|
||||
assert len(choice_values) == 2
|
||||
# Should NOT include MiL, IL, or FA teams
|
||||
assert 'LAAMIL' not in choice_values
|
||||
assert 'LAAAIL' not in choice_values
|
||||
assert 'FA' not in choice_values
|
||||
assert 'PORMIL' not in choice_values
|
||||
assert "LAAMIL" not in choice_values
|
||||
assert "LAAAIL" not in choice_values
|
||||
assert "FA" not in choice_values
|
||||
assert "PORMIL" not in choice_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_matching(self, mock_interaction):
|
||||
"""Test search matching on abbreviation and short name."""
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='LAA', sname='Angels'),
|
||||
TeamFactory.create(id=2, abbrev='LAD', sname='Dodgers'),
|
||||
TeamFactory.create(id=3, abbrev='POR', sname='Loggers'),
|
||||
TeamFactory.create(id=4, abbrev='BOS', sname='Red Sox'),
|
||||
TeamFactory.create(id=1, abbrev="LAA", sname="Angels"),
|
||||
TeamFactory.create(id=2, abbrev="LAD", sname="Dodgers"),
|
||||
TeamFactory.create(id=3, abbrev="POR", sname="Loggers"),
|
||||
TeamFactory.create(id=4, abbrev="BOS", sname="Red Sox"),
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
# Test abbreviation matching
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'la')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "la")
|
||||
assert len(choices) == 2 # LAA and LAD
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'LAA' in choice_values
|
||||
assert 'LAD' in choice_values
|
||||
assert "LAA" in choice_values
|
||||
assert "LAD" in choice_values
|
||||
|
||||
# Test short name matching
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'red')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "red")
|
||||
assert len(choices) == 1
|
||||
assert choices[0].value == 'BOS'
|
||||
assert choices[0].value == "BOS"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_short_input(self, mock_interaction):
|
||||
"""Test major league team autocomplete with very short input."""
|
||||
choices = await major_league_team_autocomplete(mock_interaction, '')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_error_handling(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_error_handling(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test major league team autocomplete error handling."""
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season.side_effect = Exception("API Error")
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'LAA')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "LAA")
|
||||
assert len(choices) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_major_league_team_autocomplete_roster_type_detection(self, mock_interaction):
|
||||
async def test_major_league_team_autocomplete_roster_type_detection(
|
||||
self, mock_interaction
|
||||
):
|
||||
"""Test that roster type detection works correctly for edge cases."""
|
||||
# Test edge cases like teams whose abbreviation ends in 'M' + 'IL'
|
||||
mock_teams = [
|
||||
TeamFactory.create(id=1, abbrev='BHM', sname='Iron'), # ML team ending in 'M'
|
||||
TeamFactory.create(id=2, abbrev='BHMIL', sname='Iron IL'), # IL team (BHM + IL)
|
||||
TeamFactory.create(id=3, abbrev='NYYMIL', sname='Staten Island RailRiders'), # MiL team (NYY + MIL)
|
||||
TeamFactory.create(id=4, abbrev='NYY', sname='Yankees'), # ML team
|
||||
TeamFactory.create(
|
||||
id=1, abbrev="BHM", sname="Iron"
|
||||
), # ML team ending in 'M'
|
||||
TeamFactory.create(
|
||||
id=2, abbrev="BHMIL", sname="Iron IL"
|
||||
), # IL team (BHM + IL)
|
||||
TeamFactory.create(
|
||||
id=3, abbrev="NYYMIL", sname="Staten Island RailRiders"
|
||||
), # MiL team (NYY + MIL)
|
||||
TeamFactory.create(id=4, abbrev="NYY", sname="Yankees"), # ML team
|
||||
]
|
||||
|
||||
with patch('utils.autocomplete.team_service') as mock_service:
|
||||
with patch("utils.autocomplete.team_service") as mock_service:
|
||||
mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams)
|
||||
|
||||
choices = await major_league_team_autocomplete(mock_interaction, 'b')
|
||||
choices = await major_league_team_autocomplete(mock_interaction, "b")
|
||||
|
||||
# Should only return major league teams
|
||||
choice_values = [choice.value for choice in choices]
|
||||
assert 'BHM' in choice_values # Major league team
|
||||
assert 'BHMIL' not in choice_values # Should be detected as IL, not MiL
|
||||
assert 'NYYMIL' not in choice_values # Minor league team
|
||||
assert "BHM" in choice_values # Major league team
|
||||
assert "BHMIL" not in choice_values # Should be detected as IL, not MiL
|
||||
assert "NYYMIL" not in choice_values # Minor league team
|
||||
|
||||
# Verify the roster type detection is working
|
||||
bhm_team = next(t for t in mock_teams if t.abbrev == 'BHM')
|
||||
bhmil_team = next(t for t in mock_teams if t.abbrev == 'BHMIL')
|
||||
nyymil_team = next(t for t in mock_teams if t.abbrev == 'NYYMIL')
|
||||
bhm_team = next(t for t in mock_teams if t.abbrev == "BHM")
|
||||
bhmil_team = next(t for t in mock_teams if t.abbrev == "BHMIL")
|
||||
nyymil_team = next(t for t in mock_teams if t.abbrev == "NYYMIL")
|
||||
|
||||
assert bhm_team.roster_type() == RosterType.MAJOR_LEAGUE
|
||||
assert bhmil_team.roster_type() == RosterType.INJURED_LIST
|
||||
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE
|
||||
|
||||
|
||||
class TestGetCachedUserTeam:
|
||||
"""Test the _get_cached_user_team caching helper.
|
||||
|
||||
Verifies that the cache avoids redundant get_user_major_league_team calls
|
||||
on repeated invocations within the TTL window, and that expired entries are
|
||||
re-fetched.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache(self):
|
||||
"""Isolate each test from cache state left by other tests."""
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
yield
|
||||
utils.autocomplete._user_team_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interaction(self):
|
||||
interaction = MagicMock()
|
||||
interaction.user.id = 99999
|
||||
return interaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_result_on_first_call(self, mock_interaction):
|
||||
"""First call populates the cache; API function called exactly once."""
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team
|
||||
|
||||
result1 = await _get_cached_user_team(mock_interaction)
|
||||
result2 = await _get_cached_user_team(mock_interaction)
|
||||
|
||||
assert result1 is user_team
|
||||
assert result2 is user_team
|
||||
# API called only once despite two invocations
|
||||
mock_get_team.assert_called_once_with(99999)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_fetches_after_ttl_expires(self, mock_interaction):
|
||||
"""Expired cache entries cause a fresh API call."""
|
||||
import time
|
||||
|
||||
user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers")
|
||||
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = user_team
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team, _USER_TEAM_CACHE_TTL
|
||||
|
||||
# Seed the cache with a timestamp that is already expired
|
||||
utils.autocomplete._user_team_cache[99999] = (
|
||||
user_team,
|
||||
time.time() - _USER_TEAM_CACHE_TTL - 1,
|
||||
)
|
||||
|
||||
await _get_cached_user_team(mock_interaction)
|
||||
|
||||
# Should have called the API to refresh the stale entry
|
||||
mock_get_team.assert_called_once_with(99999)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_none_result(self, mock_interaction):
|
||||
"""None (user has no team) is cached to avoid repeated API calls."""
|
||||
with patch(
|
||||
"utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = None
|
||||
|
||||
from utils.autocomplete import _get_cached_user_team
|
||||
|
||||
result1 = await _get_cached_user_team(mock_interaction)
|
||||
result2 = await _get_cached_user_team(mock_interaction)
|
||||
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
mock_get_team.assert_called_once()
|
||||
|
||||
@ -4,16 +4,33 @@ Autocomplete Utilities
|
||||
Shared autocomplete functions for Discord slash commands.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from config import get_config
|
||||
from models.team import RosterType
|
||||
from models.team import RosterType, Team
|
||||
from services.player_service import player_service
|
||||
from services.team_service import team_service
|
||||
from utils.team_utils import get_user_major_league_team
|
||||
|
||||
# Cache for user team lookups: user_id -> (team, cached_at)
|
||||
_user_team_cache: Dict[int, Tuple[Optional[Team], float]] = {}
|
||||
_USER_TEAM_CACHE_TTL = 60 # seconds
|
||||
|
||||
|
||||
async def _get_cached_user_team(interaction: discord.Interaction) -> Optional[Team]:
|
||||
"""Return the user's major league team, cached for 60 seconds per user."""
|
||||
user_id = interaction.user.id
|
||||
if user_id in _user_team_cache:
|
||||
team, cached_at = _user_team_cache[user_id]
|
||||
if time.time() - cached_at < _USER_TEAM_CACHE_TTL:
|
||||
return team
|
||||
team = await get_user_major_league_team(user_id)
|
||||
_user_team_cache[user_id] = (team, time.time())
|
||||
return team
|
||||
|
||||
|
||||
async def player_autocomplete(
|
||||
interaction: discord.Interaction, current: str
|
||||
@ -34,12 +51,12 @@ async def player_autocomplete(
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get user's team for prioritization
|
||||
user_team = await get_user_major_league_team(interaction.user.id)
|
||||
# Get user's team for prioritization (cached per user, 60s TTL)
|
||||
user_team = await _get_cached_user_team(interaction)
|
||||
|
||||
# Search for players using the search endpoint
|
||||
players = await player_service.search_players(
|
||||
current, limit=50, season=get_config().sba_season
|
||||
current, limit=25, season=get_config().sba_season
|
||||
)
|
||||
|
||||
# Separate players by team (user's team vs others)
|
||||
|
||||
@ -188,9 +188,11 @@ class CacheManager:
|
||||
|
||||
try:
|
||||
pattern = f"{prefix}:*"
|
||||
keys = await client.keys(pattern)
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
keys_to_delete = []
|
||||
async for key in client.scan_iter(match=pattern):
|
||||
keys_to_delete.append(key)
|
||||
if keys_to_delete:
|
||||
deleted = await client.delete(*keys_to_delete)
|
||||
logger.info(f"Cleared {deleted} cache keys with prefix '{prefix}'")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
|
||||
@ -11,29 +11,29 @@ from functools import wraps
|
||||
from typing import List, Optional, Callable, Any
|
||||
from utils.logging import set_discord_context, get_contextual_logger
|
||||
|
||||
cache_logger = logging.getLogger(f'{__name__}.CacheDecorators')
|
||||
period_check_logger = logging.getLogger(f'{__name__}.PeriodCheckDecorators')
|
||||
cache_logger = logging.getLogger(f"{__name__}.CacheDecorators")
|
||||
period_check_logger = logging.getLogger(f"{__name__}.PeriodCheckDecorators")
|
||||
|
||||
|
||||
def logged_command(
|
||||
command_name: Optional[str] = None,
|
||||
command_name: Optional[str] = None,
|
||||
log_params: bool = True,
|
||||
exclude_params: Optional[List[str]] = None
|
||||
exclude_params: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Decorator for Discord commands that adds comprehensive logging.
|
||||
|
||||
|
||||
This decorator automatically handles:
|
||||
- Setting Discord context with interaction details
|
||||
- Starting/ending operation timing
|
||||
- Logging command start/completion/failure
|
||||
- Preserving function metadata and signature
|
||||
|
||||
|
||||
Args:
|
||||
command_name: Override command name (defaults to function name with slashes)
|
||||
log_params: Whether to log command parameters (default: True)
|
||||
exclude_params: List of parameter names to exclude from logging
|
||||
|
||||
|
||||
Example:
|
||||
@logged_command("/roster", exclude_params=["sensitive_data"])
|
||||
async def team_roster(self, interaction, team_name: str, season: int = None):
|
||||
@ -42,57 +42,65 @@ def logged_command(
|
||||
players = await team_service.get_roster(team.id, season)
|
||||
embed = create_roster_embed(team, players)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
Side Effects:
|
||||
- Automatically sets Discord context for all subsequent log entries
|
||||
- Creates trace_id for request correlation
|
||||
- Logs command execution timing and results
|
||||
- Re-raises all exceptions after logging (preserves original behavior)
|
||||
|
||||
|
||||
Requirements:
|
||||
- The decorated class must have a 'logger' attribute, or one will be created
|
||||
- Function must be an async method with (self, interaction, ...) signature
|
||||
- Preserves Discord.py command registration compatibility
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
|
||||
exclude_set = set(exclude_params or [])
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, interaction, *args, **kwargs):
|
||||
# Auto-detect command name if not provided
|
||||
cmd_name = command_name or f"/{func.__name__.replace('_', '-')}"
|
||||
|
||||
|
||||
# Build context with safe parameter logging
|
||||
context = {"command": cmd_name}
|
||||
if log_params:
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())[2:] # Skip self, interaction
|
||||
exclude_set = set(exclude_params or [])
|
||||
|
||||
for i, (name, value) in enumerate(zip(param_names, args)):
|
||||
if name not in exclude_set:
|
||||
context[f"param_{name}"] = value
|
||||
|
||||
|
||||
set_discord_context(interaction=interaction, **context)
|
||||
|
||||
|
||||
# Get logger from the class instance or create one
|
||||
logger = getattr(self, 'logger', get_contextual_logger(f'{self.__class__.__module__}.{self.__class__.__name__}'))
|
||||
logger = getattr(
|
||||
self,
|
||||
"logger",
|
||||
get_contextual_logger(
|
||||
f"{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
),
|
||||
)
|
||||
trace_id = logger.start_operation(f"{func.__name__}_command")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"{cmd_name} command started")
|
||||
result = await func(self, interaction, *args, **kwargs)
|
||||
logger.info(f"{cmd_name} command completed successfully")
|
||||
logger.end_operation(trace_id, "completed")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{cmd_name} command failed", error=e)
|
||||
logger.end_operation(trace_id, "failed")
|
||||
# Re-raise to maintain original exception handling behavior
|
||||
raise
|
||||
|
||||
|
||||
# Preserve signature for Discord.py command registration
|
||||
wrapper.__signature__ = inspect.signature(func) # type: ignore
|
||||
wrapper.__signature__ = sig # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -122,6 +130,7 @@ def requires_draft_period(func):
|
||||
- Should be placed before @logged_command decorator
|
||||
- league_service must be available via import
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, interaction, *args, **kwargs):
|
||||
# Import here to avoid circular imports
|
||||
@ -133,10 +142,12 @@ def requires_draft_period(func):
|
||||
current = await league_service.get_current_state()
|
||||
|
||||
if not current:
|
||||
period_check_logger.error("Could not retrieve league state for draft period check")
|
||||
period_check_logger.error(
|
||||
"Could not retrieve league state for draft period check"
|
||||
)
|
||||
embed = EmbedTemplate.error(
|
||||
"System Error",
|
||||
"Could not verify draft period status. Please try again later."
|
||||
"Could not verify draft period status. Please try again later.",
|
||||
)
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
return
|
||||
@ -148,12 +159,12 @@ def requires_draft_period(func):
|
||||
extra={
|
||||
"user_id": interaction.user.id,
|
||||
"command": func.__name__,
|
||||
"current_week": current.week
|
||||
}
|
||||
"current_week": current.week,
|
||||
},
|
||||
)
|
||||
embed = EmbedTemplate.error(
|
||||
"Not Available",
|
||||
"Draft commands are only available in the offseason."
|
||||
"Draft commands are only available in the offseason.",
|
||||
)
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
return
|
||||
@ -161,7 +172,7 @@ def requires_draft_period(func):
|
||||
# Week <= 0, allow command to proceed
|
||||
period_check_logger.debug(
|
||||
f"Draft period check passed - week {current.week}",
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__}
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__},
|
||||
)
|
||||
return await func(self, interaction, *args, **kwargs)
|
||||
|
||||
@ -169,7 +180,7 @@ def requires_draft_period(func):
|
||||
period_check_logger.error(
|
||||
f"Error in draft period check: {e}",
|
||||
exc_info=True,
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__}
|
||||
extra={"user_id": interaction.user.id, "command": func.__name__},
|
||||
)
|
||||
# Re-raise to let error handling in logged_command handle it
|
||||
raise
|
||||
@ -182,110 +193,115 @@ def requires_draft_period(func):
|
||||
def cached_api_call(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
"""
|
||||
Decorator to add Redis caching to service methods that return List[T].
|
||||
|
||||
|
||||
This decorator will:
|
||||
1. Check cache for existing data using generated key
|
||||
2. Return cached data if found
|
||||
3. Execute original method if cache miss
|
||||
4. Cache the result for future calls
|
||||
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live override in seconds (uses service default if None)
|
||||
cache_key_suffix: Additional suffix for cache key differentiation
|
||||
|
||||
|
||||
Usage:
|
||||
@cached_api_call(ttl=600, cache_key_suffix="by_season")
|
||||
async def get_teams_by_season(self, season: int) -> List[Team]:
|
||||
# Original method implementation
|
||||
|
||||
|
||||
Requirements:
|
||||
- Method must be async
|
||||
- Method must return List[T] where T is a model
|
||||
- Class must have self.cache (CacheManager instance)
|
||||
- Class must have self._generate_cache_key, self._get_cached_items, self._cache_items methods
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs) -> List[Any]:
|
||||
# Check if caching is available (service has cache manager)
|
||||
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
|
||||
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
|
||||
# No caching available, execute original method
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Generate cache key from method name, args, and kwargs
|
||||
method_name = f"{func.__name__}{cache_key_suffix}"
|
||||
|
||||
|
||||
# Convert args and kwargs to params list for consistent cache key
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
|
||||
# Skip 'self' and convert to params format
|
||||
params = []
|
||||
for param_name, param_value in bound_args.arguments.items():
|
||||
if param_name != 'self' and param_value is not None:
|
||||
if param_name != "self" and param_value is not None:
|
||||
params.append((param_name, param_value))
|
||||
|
||||
|
||||
cache_key = self._generate_cache_key(method_name, params)
|
||||
|
||||
|
||||
# Try to get from cache
|
||||
if hasattr(self, '_get_cached_items'):
|
||||
if hasattr(self, "_get_cached_items"):
|
||||
cached_result = await self._get_cached_items(cache_key)
|
||||
if cached_result is not None:
|
||||
cache_logger.debug(f"Cache hit: {method_name}")
|
||||
return cached_result
|
||||
|
||||
|
||||
# Cache miss - execute original method
|
||||
cache_logger.debug(f"Cache miss: {method_name}")
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Cache the result if we have items and caching methods
|
||||
if result and hasattr(self, '_cache_items'):
|
||||
if result and hasattr(self, "_cache_items"):
|
||||
await self._cache_items(cache_key, result, ttl)
|
||||
cache_logger.debug(f"Cached {len(result)} items for {method_name}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
"""
|
||||
Decorator to add Redis caching to service methods that return Optional[T].
|
||||
|
||||
|
||||
Similar to cached_api_call but for methods returning a single model instance.
|
||||
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live override in seconds
|
||||
cache_key_suffix: Additional suffix for cache key differentiation
|
||||
|
||||
|
||||
Usage:
|
||||
@cached_single_item(ttl=300, cache_key_suffix="by_id")
|
||||
async def get_player(self, player_id: int) -> Optional[Player]:
|
||||
# Original method implementation
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs) -> Optional[Any]:
|
||||
# Check if caching is available
|
||||
if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'):
|
||||
if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Generate cache key
|
||||
method_name = f"{func.__name__}{cache_key_suffix}"
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
|
||||
params = []
|
||||
for param_name, param_value in bound_args.arguments.items():
|
||||
if param_name != 'self' and param_value is not None:
|
||||
if param_name != "self" and param_value is not None:
|
||||
params.append((param_name, param_value))
|
||||
|
||||
|
||||
cache_key = self._generate_cache_key(method_name, params)
|
||||
|
||||
|
||||
# Try cache first
|
||||
try:
|
||||
cached_data = await self.cache.get(cache_key)
|
||||
@ -293,12 +309,14 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
cache_logger.debug(f"Cache hit: {method_name}")
|
||||
return self.model_class.from_api_data(cached_data)
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error reading single item cache for {cache_key}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error reading single item cache for {cache_key}: {e}"
|
||||
)
|
||||
|
||||
# Cache miss - execute original method
|
||||
cache_logger.debug(f"Cache miss: {method_name}")
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Cache the single result
|
||||
if result:
|
||||
try:
|
||||
@ -306,43 +324,54 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""):
|
||||
await self.cache.set(cache_key, cache_data, ttl)
|
||||
cache_logger.debug(f"Cached single item for {method_name}")
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error caching single item for {cache_key}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error caching single item for {cache_key}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_invalidate(*cache_patterns: str):
|
||||
"""
|
||||
Decorator to invalidate cache entries when data is modified.
|
||||
|
||||
|
||||
Args:
|
||||
cache_patterns: Cache key patterns to invalidate (supports prefix matching)
|
||||
|
||||
|
||||
Usage:
|
||||
@cache_invalidate("players_by_team", "teams_by_season")
|
||||
async def update_player(self, player_id: int, updates: dict) -> Optional[Player]:
|
||||
# Original method implementation
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
# Execute original method first
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Invalidate specified cache patterns
|
||||
if hasattr(self, 'cache'):
|
||||
if hasattr(self, "cache"):
|
||||
for pattern in cache_patterns:
|
||||
try:
|
||||
cleared = await self.cache.clear_prefix(f"sba:{self.endpoint}_{pattern}")
|
||||
cleared = await self.cache.clear_prefix(
|
||||
f"sba:{self.endpoint}_{pattern}"
|
||||
)
|
||||
if cleared > 0:
|
||||
cache_logger.info(f"Invalidated {cleared} cache entries for pattern: {pattern}")
|
||||
cache_logger.info(
|
||||
f"Invalidated {cleared} cache entries for pattern: {pattern}"
|
||||
)
|
||||
except Exception as e:
|
||||
cache_logger.warning(f"Error invalidating cache pattern {pattern}: {e}")
|
||||
|
||||
cache_logger.warning(
|
||||
f"Error invalidating cache pattern {pattern}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
||||
@ -24,6 +24,8 @@ JSONValue = Union[
|
||||
str, int, float, bool, None, dict[str, Any], list[Any] # nested object # arrays
|
||||
]
|
||||
|
||||
_SERIALIZABLE_TYPES = (str, int, float, bool, type(None))
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured file logging."""
|
||||
@ -93,11 +95,11 @@ class JSONFormatter(logging.Formatter):
|
||||
extra_data = {}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in excluded_keys:
|
||||
# Ensure JSON serializable
|
||||
try:
|
||||
json.dumps(value)
|
||||
if isinstance(value, _SERIALIZABLE_TYPES) or isinstance(
|
||||
value, (list, dict)
|
||||
):
|
||||
extra_data[key] = value
|
||||
except (TypeError, ValueError):
|
||||
else:
|
||||
extra_data[key] = str(value)
|
||||
|
||||
if extra_data:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user