diff --git a/commands/transactions/trade.py b/commands/transactions/trade.py index ec29c55..b603c0e 100644 --- a/commands/transactions/trade.py +++ b/commands/transactions/trade.py @@ -18,7 +18,9 @@ from utils.team_utils import validate_user_has_team, get_team_by_abbrev_with_val from services.trade_builder import ( TradeBuilder, get_trade_builder, - clear_trade_builder + get_trade_builder_by_team, + clear_trade_builder, + clear_trade_builder_by_team, ) from services.player_service import player_service from models.team import RosterType @@ -183,18 +185,20 @@ class TradeCommands(commands.Cog): """Add a team to an existing trade.""" await interaction.response.defer(ephemeral=True) - # Check if user has an active trade - trade_key = f"{interaction.user.id}:trade" - from services.trade_builder import _active_trade_builders - if trade_key not in _active_trade_builders: + # Get user's team first + user_team = await validate_user_has_team(interaction) + if not user_team: + return + + # Look up trade by user's team (allows any GM in the trade to participate) + trade_builder = get_trade_builder_by_team(user_team.id) + if not trade_builder: await interaction.followup.send( - "❌ You don't have an active trade. Use `/trade initiate` first.", + "❌ Your team is not part of an active trade. Use `/trade initiate` first.", ephemeral=True ) return - trade_builder = _active_trade_builders[trade_key] - # Get the team to add team_to_add = await get_team_by_abbrev_with_validation(other_team, interaction) if not team_to_add: @@ -264,23 +268,20 @@ class TradeCommands(commands.Cog): """Add a player move to the trade.""" await interaction.response.defer(ephemeral=True) - # Check if user has an active trade - trade_key = f"{interaction.user.id}:trade" - from services.trade_builder import _active_trade_builders - if trade_key not in _active_trade_builders: - await interaction.followup.send( - "❌ You don't have an active trade. Use `/trade initiate` first.", - ephemeral=True - ) - return - - trade_builder = _active_trade_builders[trade_key] - - # Get user's team + # Get user's team first user_team = await validate_user_has_team(interaction) if not user_team: return + # Look up trade by user's team (allows any GM in the trade to participate) + trade_builder = get_trade_builder_by_team(user_team.id) + if not trade_builder: + await interaction.followup.send( + "❌ Your team is not part of an active trade. Use `/trade initiate` or ask another GM to add your team.", + ephemeral=True + ) + return + # Find the player players = await player_service.search_players(player_name, limit=10, season=get_config().sba_season) if not players: @@ -374,23 +375,20 @@ class TradeCommands(commands.Cog): """Add a supplementary (internal organization) move for roster legality.""" await interaction.response.defer(ephemeral=True) - # Check if user has an active trade - trade_key = f"{interaction.user.id}:trade" - from services.trade_builder import _active_trade_builders - if trade_key not in _active_trade_builders: - await interaction.followup.send( - "❌ You don't have an active trade. Use `/trade initiate` first.", - ephemeral=True - ) - return - - trade_builder = _active_trade_builders[trade_key] - - # Get user's team + # Get user's team first user_team = await validate_user_has_team(interaction) if not user_team: return + # Look up trade by user's team (allows any GM in the trade to participate) + trade_builder = get_trade_builder_by_team(user_team.id) + if not trade_builder: + await interaction.followup.send( + "❌ Your team is not part of an active trade. Use `/trade initiate` or ask another GM to add your team.", + ephemeral=True + ) + return + # Find the player players = await player_service.search_players(player_name, limit=10, season=get_config().sba_season) if not players: @@ -468,17 +466,20 @@ class TradeCommands(commands.Cog): """View the current trade.""" await interaction.response.defer(ephemeral=True) - trade_key = f"{interaction.user.id}:trade" - from services.trade_builder import _active_trade_builders - if trade_key not in _active_trade_builders: + # Get user's team first + user_team = await validate_user_has_team(interaction) + if not user_team: + return + + # Look up trade by user's team (allows any GM in the trade to view) + trade_builder = get_trade_builder_by_team(user_team.id) + if not trade_builder: await interaction.followup.send( - "❌ You don't have an active trade.", + "❌ Your team is not part of an active trade.", ephemeral=True ) return - trade_builder = _active_trade_builders[trade_key] - # Show trade interface embed = await create_trade_embed(trade_builder) view = TradeEmbedView(trade_builder, interaction.user.id) @@ -507,25 +508,33 @@ class TradeCommands(commands.Cog): """Clear the current trade.""" await interaction.response.defer(ephemeral=True) - # Get trade_id before clearing (for channel deletion) - trade_key = f"{interaction.user.id}:trade" - from services.trade_builder import _active_trade_builders - trade_id = None - if trade_key in _active_trade_builders: - trade_id = _active_trade_builders[trade_key].trade_id + # Get user's team first + user_team = await validate_user_has_team(interaction) + if not user_team: + return + + # Look up trade by user's team (allows any GM in the trade to clear) + trade_builder = get_trade_builder_by_team(user_team.id) + if not trade_builder: + await interaction.followup.send( + "❌ Your team is not part of an active trade.", + ephemeral=True + ) + return + + trade_id = trade_builder.trade_id # Delete associated trade channel if it exists - if trade_id: - await self.channel_manager.delete_trade_channel( - guild=interaction.guild, - trade_id=trade_id - ) + await self.channel_manager.delete_trade_channel( + guild=interaction.guild, + trade_id=trade_id + ) - # Clear the trade builder - clear_trade_builder(interaction.user.id) + # Clear the trade builder using team-based function + clear_trade_builder_by_team(user_team.id) await interaction.followup.send( - "✅ Your trade has been cleared.", + "✅ The trade has been cleared.", ephemeral=True ) diff --git a/services/trade_builder.py b/services/trade_builder.py index 060b5e1..73d14f7 100644 --- a/services/trade_builder.py +++ b/services/trade_builder.py @@ -137,6 +137,10 @@ class TradeBuilder: # Create transaction builder for this team self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season) + # Register team in secondary index for multi-GM access + trade_key = f"{self.trade.initiated_by}:trade" + _team_to_trade_key[team.id] = trade_key + logger.info(f"Added team {team.abbrev} to trade {self.trade_id}") return True, "" @@ -160,10 +164,12 @@ class TradeBuilder: # Remove team removed = self.trade.remove_participant(team_id) - if removed and team_id in self._team_builders: - del self._team_builders[team_id] - if removed: + if team_id in self._team_builders: + del self._team_builders[team_id] + # Remove from secondary index + if team_id in _team_to_trade_key: + del _team_to_trade_key[team_id] logger.info(f"Removed team {team_id} from trade {self.trade_id}") return removed, "" if removed else "Failed to remove team" @@ -444,6 +450,9 @@ class TradeBuilder: # Global cache for active trade builders _active_trade_builders: Dict[str, TradeBuilder] = {} +# Secondary index: maps team_id -> trade_key for multi-GM access +_team_to_trade_key: Dict[int, str] = {} + def get_trade_builder(user_id: int, initiating_team: Team) -> TradeBuilder: """ @@ -456,23 +465,80 @@ def get_trade_builder(user_id: int, initiating_team: Team) -> TradeBuilder: Returns: TradeBuilder instance """ - # For now, use user_id as the key. In the future, could support multiple concurrent trades trade_key = f"{user_id}:trade" if trade_key not in _active_trade_builders: - _active_trade_builders[trade_key] = TradeBuilder(user_id, initiating_team) + builder = TradeBuilder(user_id, initiating_team) + _active_trade_builders[trade_key] = builder + # Register initiating team in secondary index for multi-GM access + _team_to_trade_key[initiating_team.id] = trade_key return _active_trade_builders[trade_key] +def get_trade_builder_by_team(team_id: int) -> Optional[TradeBuilder]: + """ + Get trade builder that includes a specific team. + + This allows any GM whose team is participating in a trade to access + the trade builder, not just the initiator. + + Args: + team_id: Team ID to look up + + Returns: + TradeBuilder if team is in an active trade, None otherwise + """ + trade_key = _team_to_trade_key.get(team_id) + if trade_key: + return _active_trade_builders.get(trade_key) + return None + + def clear_trade_builder(user_id: int) -> None: - """Clear trade builder for a user.""" + """Clear trade builder for a user and remove all team mappings.""" trade_key = f"{user_id}:trade" if trade_key in _active_trade_builders: + # Remove all team mappings for this trade + builder = _active_trade_builders[trade_key] + for team in builder.participating_teams: + if team.id in _team_to_trade_key: + del _team_to_trade_key[team.id] + del _active_trade_builders[trade_key] logger.info(f"Cleared trade builder for user {user_id}") +def clear_trade_builder_by_team(team_id: int) -> bool: + """ + Clear trade builder that includes a specific team. + + This allows any GM in a trade to clear it, not just the initiator. + + Args: + team_id: Team ID whose trade should be cleared + + Returns: + True if a trade was cleared, False if no trade found + """ + trade_key = _team_to_trade_key.get(team_id) + if not trade_key: + return False + + if trade_key in _active_trade_builders: + builder = _active_trade_builders[trade_key] + # Remove all team mappings + for team in builder.participating_teams: + if team.id in _team_to_trade_key: + del _team_to_trade_key[team.id] + + del _active_trade_builders[trade_key] + logger.info(f"Cleared trade builder via team {team_id}") + return True + + return False + + def get_active_trades() -> Dict[str, TradeBuilder]: """Get all active trade builders (for debugging/admin purposes).""" return _active_trade_builders.copy() \ No newline at end of file diff --git a/tests/test_services_trade_builder.py b/tests/test_services_trade_builder.py index 7d83fac..8c0eec0 100644 --- a/tests/test_services_trade_builder.py +++ b/tests/test_services_trade_builder.py @@ -12,8 +12,11 @@ from services.trade_builder import ( TradeBuilder, TradeValidationResult, get_trade_builder, + get_trade_builder_by_team, clear_trade_builder, - _active_trade_builders + clear_trade_builder_by_team, + _active_trade_builders, + _team_to_trade_key, ) from models.trade import TradeStatus from models.team import RosterType, Team @@ -502,6 +505,7 @@ class TestTradeBuilderCache: def setup_method(self): """Clear cache before each test.""" _active_trade_builders.clear() + _team_to_trade_key.clear() def test_get_trade_builder(self): """Test getting trade builder from cache.""" @@ -534,6 +538,185 @@ class TestTradeBuilderCache: new_builder = get_trade_builder(user_id, team) assert new_builder is not builder + def test_get_trade_builder_registers_initiating_team(self): + """ + Test that get_trade_builder registers the initiating team in the secondary index. + + The secondary index allows any GM in the trade to access the builder by team ID, + enabling multi-GM participation in trades. + """ + user_id = 12345 + team = TeamFactory.west_virginia() + + # Create builder + builder = get_trade_builder(user_id, team) + + # Secondary index should have initiating team mapped + assert team.id in _team_to_trade_key + assert _team_to_trade_key[team.id] == f"{user_id}:trade" + + def test_get_trade_builder_by_team_returns_builder(self): + """ + Test that get_trade_builder_by_team returns the correct builder for a team. + + This is the core function that enables any GM in a trade to access the builder. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team2 = TeamFactory.new_york() + + # Create builder and add second team + builder = get_trade_builder(user_id, team1) + builder.trade.add_participant(team2) + # Manually add to secondary index (simulating add_team) + _team_to_trade_key[team2.id] = f"{user_id}:trade" + + # Both teams should find the same builder + found_by_team1 = get_trade_builder_by_team(team1.id) + found_by_team2 = get_trade_builder_by_team(team2.id) + + assert found_by_team1 is builder + assert found_by_team2 is builder + + def test_get_trade_builder_by_team_returns_none_for_nonparticipant(self): + """ + Test that get_trade_builder_by_team returns None for a team not in any trade. + + This ensures proper error handling when a GM tries to access a trade they're not part of. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team3 = TeamFactory.create(id=999, abbrev="POR", name="Portland") # Non-participant + + # Create builder with team1 + get_trade_builder(user_id, team1) + + # team3 should not find any builder + found = get_trade_builder_by_team(team3.id) + assert found is None + + @pytest.mark.asyncio + async def test_add_team_registers_in_secondary_index(self): + """ + Test that add_team registers the new team in the secondary index. + + This ensures that when a new team joins a trade, their GM can immediately + access the trade builder. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team2 = TeamFactory.new_york() + + # Create builder + builder = get_trade_builder(user_id, team1) + + # Add second team + success, error = await builder.add_team(team2) + assert success + + # Both teams should be in secondary index + assert team1.id in _team_to_trade_key + assert team2.id in _team_to_trade_key + assert _team_to_trade_key[team1.id] == _team_to_trade_key[team2.id] + + # Both teams should find the same builder + assert get_trade_builder_by_team(team1.id) is builder + assert get_trade_builder_by_team(team2.id) is builder + + @pytest.mark.asyncio + async def test_remove_team_clears_from_secondary_index(self): + """ + Test that remove_team clears the team from the secondary index. + + This ensures that when a team is removed from a trade, their GM can no + longer access the trade builder. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team2 = TeamFactory.new_york() + + # Create builder and add team2 + builder = get_trade_builder(user_id, team1) + await builder.add_team(team2) + + # Both teams should be in index + assert team1.id in _team_to_trade_key + assert team2.id in _team_to_trade_key + + # Remove team2 + success, error = await builder.remove_team(team2.id) + assert success + + # team2 should be removed from index, team1 should remain + assert team1.id in _team_to_trade_key + assert team2.id not in _team_to_trade_key + + def test_clear_trade_builder_clears_secondary_index(self): + """ + Test that clear_trade_builder removes all teams from secondary index. + + This ensures that when a trade is cleared, all participating GMs lose access. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team2 = TeamFactory.new_york() + + # Create builder and manually add team2 to secondary index + builder = get_trade_builder(user_id, team1) + builder.trade.add_participant(team2) + _team_to_trade_key[team2.id] = f"{user_id}:trade" + + # Both teams in index + assert team1.id in _team_to_trade_key + assert team2.id in _team_to_trade_key + + # Clear trade builder + clear_trade_builder(user_id) + + # Both teams should be removed from index + assert team1.id not in _team_to_trade_key + assert team2.id not in _team_to_trade_key + + def test_clear_trade_builder_by_team_clears_all_participants(self): + """ + Test that clear_trade_builder_by_team removes all teams from secondary index. + + This allows any GM in the trade to clear it, and ensures all participants + lose access simultaneously. + """ + user_id = 12345 + team1 = TeamFactory.west_virginia() + team2 = TeamFactory.new_york() + + # Create builder and manually add team2 to secondary index + builder = get_trade_builder(user_id, team1) + builder.trade.add_participant(team2) + _team_to_trade_key[team2.id] = f"{user_id}:trade" + + # Both teams in index + assert team1.id in _team_to_trade_key + assert team2.id in _team_to_trade_key + + # Clear using team2's ID (non-initiator) + result = clear_trade_builder_by_team(team2.id) + assert result is True + + # Both teams should be removed from index + assert team1.id not in _team_to_trade_key + assert team2.id not in _team_to_trade_key + assert len(_active_trade_builders) == 0 + + def test_clear_trade_builder_by_team_returns_false_for_nonparticipant(self): + """ + Test that clear_trade_builder_by_team returns False for non-participating team. + + This ensures proper error handling when a GM not in the trade tries to clear it. + """ + team3 = TeamFactory.create(id=999, abbrev="POR", name="Portland") # Non-participant + + result = clear_trade_builder_by_team(team3.id) + assert result is False + class TestTradeValidationResult: """Test TradeValidationResult functionality."""