CLAUDE: Refactor dice roll functions to use DiceRoll dataclass
Replaced dictionary return values with a DiceRoll dataclass for better type safety and cleaner code. Changes: - Added DiceRoll dataclass with fields: dice_notation, num_dice, die_sides, rolls, total - Updated _parse_and_roll_single_dice() to return Optional[DiceRoll] - Updated _parse_and_roll_multiple_dice() to return list[DiceRoll] - Updated _roll_weighted_scout_dice() to return list[DiceRoll] - Updated _create_multi_roll_embed() to accept list[DiceRoll] - Updated _create_fielding_embed() to accept list[DiceRoll] - Changed all dict key access (result['total']) to dataclass attributes (result.total) - Updated logging statements to use dataclass attributes - Updated all 34 test cases to use DiceRoll dataclass Benefits: - Improved type safety with explicit dataclass types - Better IDE autocomplete and type checking - More maintainable code with clear data structures - No runtime changes - all functionality preserved All 34 dice command tests pass.
This commit is contained in:
parent
4cab227109
commit
b61cad2478
@ -6,6 +6,7 @@ Implements slash commands for dice rolling functionality required for gameplay.
|
||||
import random
|
||||
import re
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
@ -15,6 +16,16 @@ from utils.decorators import logged_command
|
||||
from views.embeds import EmbedColors, EmbedTemplate
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiceRoll:
|
||||
"""Represents the result of a dice roll."""
|
||||
dice_notation: str
|
||||
num_dice: int
|
||||
die_sides: int
|
||||
rolls: list[int]
|
||||
total: int
|
||||
|
||||
|
||||
class DiceRollCommands(commands.Cog):
|
||||
"""Dice rolling command handlers for gameplay."""
|
||||
|
||||
@ -195,7 +206,7 @@ class DiceRollCommands(commands.Cog):
|
||||
dice_notation = "1d20;3d6"
|
||||
roll_results = self._parse_and_roll_multiple_dice(dice_notation)
|
||||
|
||||
self.logger.info("SA Fielding dice rolled successfully", position=parsed_position, d20=roll_results[0]['total'], d6_total=roll_results[1]['total'])
|
||||
self.logger.info("SA Fielding dice rolled successfully", position=parsed_position, d20=roll_results[0].total, d6_total=roll_results[1].total)
|
||||
|
||||
# Create fielding embed
|
||||
embed = self._create_fielding_embed(parsed_position, roll_results, ctx.author)
|
||||
@ -222,11 +233,11 @@ class DiceRollCommands(commands.Cog):
|
||||
|
||||
return position_map.get(pos)
|
||||
|
||||
def _create_fielding_embed(self, position: str, roll_results: list[dict], user) -> discord.Embed:
|
||||
def _create_fielding_embed(self, position: str, roll_results: list[DiceRoll], user) -> discord.Embed:
|
||||
"""Create an embed for fielding roll results."""
|
||||
d20_result = roll_results[0]['total']
|
||||
d6_total = roll_results[1]['total']
|
||||
d6_rolls = roll_results[1]['rolls']
|
||||
d20_result = roll_results[0].total
|
||||
d6_total = roll_results[1].total
|
||||
d6_rolls = roll_results[1].rolls
|
||||
|
||||
# Create base embed
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
@ -549,7 +560,7 @@ class DiceRollCommands(commands.Cog):
|
||||
}
|
||||
return errors.get(d6_total, 'No error')
|
||||
|
||||
def _parse_and_roll_multiple_dice(self, dice_notation: str) -> list[dict]:
|
||||
def _parse_and_roll_multiple_dice(self, dice_notation: str) -> list[DiceRoll]:
|
||||
"""Parse dice notation (supports multiple rolls) and return roll results."""
|
||||
# Split by semicolon for multiple rolls
|
||||
dice_parts = [part.strip() for part in dice_notation.split(';')]
|
||||
@ -563,7 +574,7 @@ class DiceRollCommands(commands.Cog):
|
||||
|
||||
return results
|
||||
|
||||
def _parse_and_roll_single_dice(self, dice_notation: str) -> Optional[dict]:
|
||||
def _parse_and_roll_single_dice(self, dice_notation: str) -> Optional[DiceRoll]:
|
||||
"""Parse single dice notation and return roll results."""
|
||||
# Clean the input
|
||||
dice_notation = dice_notation.strip().lower().replace(' ', '')
|
||||
@ -586,15 +597,15 @@ class DiceRollCommands(commands.Cog):
|
||||
rolls = [random.randint(1, die_sides) for _ in range(num_dice)]
|
||||
total = sum(rolls)
|
||||
|
||||
return {
|
||||
'dice_notation': dice_notation,
|
||||
'num_dice': num_dice,
|
||||
'die_sides': die_sides,
|
||||
'rolls': rolls,
|
||||
'total': total
|
||||
}
|
||||
return DiceRoll(
|
||||
dice_notation=dice_notation,
|
||||
num_dice=num_dice,
|
||||
die_sides=die_sides,
|
||||
rolls=rolls,
|
||||
total=total
|
||||
)
|
||||
|
||||
def _roll_weighted_scout_dice(self, card_type: str) -> list[dict]:
|
||||
def _roll_weighted_scout_dice(self, card_type: str) -> list[DiceRoll]:
|
||||
"""
|
||||
Roll scouting dice with weighted first d6 based on card type.
|
||||
|
||||
@ -602,7 +613,7 @@ class DiceRollCommands(commands.Cog):
|
||||
card_type: Either "batter" (1-3) or "pitcher" (4-6) for first d6
|
||||
|
||||
Returns:
|
||||
List of 3 roll result dicts: weighted 1d6, normal 2d6, normal 1d20
|
||||
List of 3 roll result dataclasses: weighted 1d6, normal 2d6, normal 1d20
|
||||
"""
|
||||
# First die (1d6) - weighted based on card type
|
||||
if card_type == "batter":
|
||||
@ -610,13 +621,13 @@ class DiceRollCommands(commands.Cog):
|
||||
else: # pitcher
|
||||
first_roll = random.randint(4, 6)
|
||||
|
||||
first_d6_result = {
|
||||
'dice_notation': '1d6',
|
||||
'num_dice': 1,
|
||||
'die_sides': 6,
|
||||
'rolls': [first_roll],
|
||||
'total': first_roll
|
||||
}
|
||||
first_d6_result = DiceRoll(
|
||||
dice_notation='1d6',
|
||||
num_dice=1,
|
||||
die_sides=6,
|
||||
rolls=[first_roll],
|
||||
total=first_roll
|
||||
)
|
||||
|
||||
# Second roll (2d6) - normal
|
||||
second_result = self._parse_and_roll_single_dice("2d6")
|
||||
@ -626,7 +637,7 @@ class DiceRollCommands(commands.Cog):
|
||||
|
||||
return [first_d6_result, second_result, third_result]
|
||||
|
||||
def _create_multi_roll_embed(self, dice_notation: str, roll_results: list[dict], user: discord.User | discord.Member) -> discord.Embed:
|
||||
def _create_multi_roll_embed(self, dice_notation: str, roll_results: list[DiceRoll], user: discord.User | discord.Member) -> discord.Embed:
|
||||
"""Create an embed for multiple dice roll results."""
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title="🎲 Dice Roll",
|
||||
@ -640,16 +651,16 @@ class DiceRollCommands(commands.Cog):
|
||||
)
|
||||
|
||||
# Create summary line with totals
|
||||
totals = [str(result['total']) for result in roll_results]
|
||||
totals = [str(result.total) for result in roll_results]
|
||||
summary = f"# {','.join(totals)}"
|
||||
|
||||
# Create details line in the specified format: Details:[1d6;2d6;1d20 (5 - 5 6 - 13)]
|
||||
dice_notations = [result['dice_notation'] for result in roll_results]
|
||||
dice_notations = [result.dice_notation for result in roll_results]
|
||||
|
||||
# Create the rolls breakdown part - group dice within each roll, separate roll groups with dashes
|
||||
roll_groups = []
|
||||
for result in roll_results:
|
||||
rolls = result['rolls']
|
||||
rolls = result.rolls
|
||||
if len(rolls) == 1:
|
||||
# Single die: just the number
|
||||
roll_groups.append(str(rolls[0]))
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from commands.dice.rolls import DiceRollCommands
|
||||
from commands.dice.rolls import DiceRollCommands, DiceRoll
|
||||
|
||||
|
||||
class TestDiceRollCommands:
|
||||
@ -65,20 +65,20 @@ class TestDiceRollCommands:
|
||||
results = dice_cog._parse_and_roll_multiple_dice("2d6")
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result['num_dice'] == 2
|
||||
assert result['die_sides'] == 6
|
||||
assert len(result['rolls']) == 2
|
||||
assert all(1 <= roll <= 6 for roll in result['rolls'])
|
||||
assert result['total'] == sum(result['rolls'])
|
||||
assert result.num_dice == 2
|
||||
assert result.die_sides == 6
|
||||
assert len(result.rolls) == 2
|
||||
assert all(1 <= roll <= 6 for roll in result.rolls)
|
||||
assert result.total == sum(result.rolls)
|
||||
|
||||
# Test single die
|
||||
results = dice_cog._parse_and_roll_multiple_dice("1d20")
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result['num_dice'] == 1
|
||||
assert result['die_sides'] == 20
|
||||
assert len(result['rolls']) == 1
|
||||
assert 1 <= result['rolls'][0] <= 20
|
||||
assert result.num_dice == 1
|
||||
assert result.die_sides == 20
|
||||
assert len(result.rolls) == 1
|
||||
assert 1 <= result.rolls[0] <= 20
|
||||
|
||||
def test_parse_invalid_dice_notation(self, dice_cog):
|
||||
"""Test parsing invalid dice notation."""
|
||||
@ -101,17 +101,17 @@ class TestDiceRollCommands:
|
||||
results = dice_cog._parse_and_roll_multiple_dice("1d6;2d8;1d20")
|
||||
assert len(results) == 3
|
||||
|
||||
assert results[0]['dice_notation'] == '1d6'
|
||||
assert results[0]['num_dice'] == 1
|
||||
assert results[0]['die_sides'] == 6
|
||||
assert results[0].dice_notation == '1d6'
|
||||
assert results[0].num_dice == 1
|
||||
assert results[0].die_sides == 6
|
||||
|
||||
assert results[1]['dice_notation'] == '2d8'
|
||||
assert results[1]['num_dice'] == 2
|
||||
assert results[1]['die_sides'] == 8
|
||||
assert results[1].dice_notation == '2d8'
|
||||
assert results[1].num_dice == 2
|
||||
assert results[1].die_sides == 8
|
||||
|
||||
assert results[2]['dice_notation'] == '1d20'
|
||||
assert results[2]['num_dice'] == 1
|
||||
assert results[2]['die_sides'] == 20
|
||||
assert results[2].dice_notation == '1d20'
|
||||
assert results[2].num_dice == 1
|
||||
assert results[2].die_sides == 20
|
||||
|
||||
def test_parse_case_insensitive(self, dice_cog):
|
||||
"""Test that dice notation parsing is case insensitive."""
|
||||
@ -120,20 +120,20 @@ class TestDiceRollCommands:
|
||||
|
||||
assert len(result_lower) == 1
|
||||
assert len(result_upper) == 1
|
||||
assert result_lower[0]['num_dice'] == result_upper[0]['num_dice']
|
||||
assert result_lower[0]['die_sides'] == result_upper[0]['die_sides']
|
||||
assert result_lower[0].num_dice == result_upper[0].num_dice
|
||||
assert result_lower[0].die_sides == result_upper[0].die_sides
|
||||
|
||||
def test_parse_whitespace_handling(self, dice_cog):
|
||||
"""Test that whitespace is handled properly."""
|
||||
results = dice_cog._parse_and_roll_multiple_dice(" 2d6 ")
|
||||
assert len(results) == 1
|
||||
assert results[0]['num_dice'] == 2
|
||||
assert results[0]['die_sides'] == 6
|
||||
assert results[0].num_dice == 2
|
||||
assert results[0].die_sides == 6
|
||||
|
||||
results = dice_cog._parse_and_roll_multiple_dice("2 d 6")
|
||||
assert len(results) == 1
|
||||
assert results[0]['num_dice'] == 2
|
||||
assert results[0]['die_sides'] == 6
|
||||
assert results[0].num_dice == 2
|
||||
assert results[0].die_sides == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_roll_dice_valid_input(self, dice_cog, mock_interaction):
|
||||
@ -170,13 +170,13 @@ class TestDiceRollCommands:
|
||||
def test_create_multi_roll_embed_single_die(self, dice_cog, mock_interaction):
|
||||
"""Test embed creation for single die roll."""
|
||||
roll_results = [
|
||||
{
|
||||
'dice_notation': '1d20',
|
||||
'num_dice': 1,
|
||||
'die_sides': 20,
|
||||
'rolls': [15],
|
||||
'total': 15
|
||||
}
|
||||
DiceRoll(
|
||||
dice_notation='1d20',
|
||||
num_dice=1,
|
||||
die_sides=20,
|
||||
rolls=[15],
|
||||
total=15
|
||||
)
|
||||
]
|
||||
|
||||
embed = dice_cog._create_multi_roll_embed("1d20", roll_results, mock_interaction.user)
|
||||
@ -194,27 +194,27 @@ class TestDiceRollCommands:
|
||||
def test_create_multi_roll_embed_multiple_dice(self, dice_cog, mock_interaction):
|
||||
"""Test embed creation for multiple dice rolls."""
|
||||
roll_results = [
|
||||
{
|
||||
'dice_notation': '1d6',
|
||||
'num_dice': 1,
|
||||
'die_sides': 6,
|
||||
'rolls': [5],
|
||||
'total': 5
|
||||
},
|
||||
{
|
||||
'dice_notation': '2d6',
|
||||
'num_dice': 2,
|
||||
'die_sides': 6,
|
||||
'rolls': [5, 6],
|
||||
'total': 11
|
||||
},
|
||||
{
|
||||
'dice_notation': '1d20',
|
||||
'num_dice': 1,
|
||||
'die_sides': 20,
|
||||
'rolls': [13],
|
||||
'total': 13
|
||||
}
|
||||
DiceRoll(
|
||||
dice_notation='1d6',
|
||||
num_dice=1,
|
||||
die_sides=6,
|
||||
rolls=[5],
|
||||
total=5
|
||||
),
|
||||
DiceRoll(
|
||||
dice_notation='2d6',
|
||||
num_dice=2,
|
||||
die_sides=6,
|
||||
rolls=[5, 6],
|
||||
total=11
|
||||
),
|
||||
DiceRoll(
|
||||
dice_notation='1d20',
|
||||
num_dice=1,
|
||||
die_sides=20,
|
||||
rolls=[13],
|
||||
total=13
|
||||
)
|
||||
]
|
||||
|
||||
embed = dice_cog._create_multi_roll_embed("1d6;2d6;1d20", roll_results, mock_interaction.user)
|
||||
@ -233,7 +233,7 @@ class TestDiceRollCommands:
|
||||
results = []
|
||||
for _ in range(20): # Roll 20 times
|
||||
result = dice_cog._parse_and_roll_multiple_dice("1d20")
|
||||
results.append(result[0]['rolls'][0])
|
||||
results.append(result[0].rolls[0])
|
||||
|
||||
# Should have some variation in results (very unlikely all 20 rolls are the same)
|
||||
unique_results = set(results)
|
||||
@ -245,20 +245,20 @@ class TestDiceRollCommands:
|
||||
results = dice_cog._parse_and_roll_multiple_dice("100d2")
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert len(result['rolls']) == 100
|
||||
assert all(roll in [1, 2] for roll in result['rolls'])
|
||||
assert len(result.rolls) == 100
|
||||
assert all(roll in [1, 2] for roll in result.rolls)
|
||||
|
||||
# Test maximum die size
|
||||
results = dice_cog._parse_and_roll_multiple_dice("1d1000")
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert 1 <= result['rolls'][0] <= 1000
|
||||
assert 1 <= result.rolls[0] <= 1000
|
||||
|
||||
# Test minimum valid values
|
||||
results = dice_cog._parse_and_roll_multiple_dice("1d2")
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result['rolls'][0] in [1, 2]
|
||||
assert result.rolls[0] in [1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_command_valid_input(self, dice_cog, mock_context):
|
||||
@ -385,17 +385,17 @@ class TestDiceRollCommands:
|
||||
assert len(results) == 3
|
||||
|
||||
# Check each dice type
|
||||
assert results[0]['dice_notation'] == '1d6'
|
||||
assert results[0]['num_dice'] == 1
|
||||
assert results[0]['die_sides'] == 6
|
||||
assert results[0].dice_notation == '1d6'
|
||||
assert results[0].num_dice == 1
|
||||
assert results[0].die_sides == 6
|
||||
|
||||
assert results[1]['dice_notation'] == '2d6'
|
||||
assert results[1]['num_dice'] == 2
|
||||
assert results[1]['die_sides'] == 6
|
||||
assert results[1].dice_notation == '2d6'
|
||||
assert results[1].num_dice == 2
|
||||
assert results[1].die_sides == 6
|
||||
|
||||
assert results[2]['dice_notation'] == '1d20'
|
||||
assert results[2]['num_dice'] == 1
|
||||
assert results[2]['die_sides'] == 20
|
||||
assert results[2].dice_notation == '1d20'
|
||||
assert results[2].num_dice == 1
|
||||
assert results[2].die_sides == 20
|
||||
|
||||
# Fielding command tests
|
||||
@pytest.mark.asyncio
|
||||
@ -544,14 +544,14 @@ class TestDiceRollCommands:
|
||||
assert len(results) == 2
|
||||
|
||||
# Check 1d20
|
||||
assert results[0]['dice_notation'] == '1d20'
|
||||
assert results[0]['num_dice'] == 1
|
||||
assert results[0]['die_sides'] == 20
|
||||
assert results[0].dice_notation == '1d20'
|
||||
assert results[0].num_dice == 1
|
||||
assert results[0].die_sides == 20
|
||||
|
||||
# Check 3d6
|
||||
assert results[1]['dice_notation'] == '3d6'
|
||||
assert results[1]['num_dice'] == 3
|
||||
assert results[1]['die_sides'] == 6
|
||||
assert results[1].dice_notation == '3d6'
|
||||
assert results[1].num_dice == 3
|
||||
assert results[1].die_sides == 6
|
||||
|
||||
def test_weighted_scout_dice_batter(self, dice_cog):
|
||||
"""Test that batter scout dice always rolls 1-3 for first d6."""
|
||||
@ -563,18 +563,18 @@ class TestDiceRollCommands:
|
||||
assert len(results) == 3
|
||||
|
||||
# First d6 should ALWAYS be 1-3 for batter
|
||||
first_d6 = results[0]['rolls'][0]
|
||||
first_d6 = results[0].rolls[0]
|
||||
assert 1 <= first_d6 <= 3, f"Batter first d6 was {first_d6}, expected 1-3"
|
||||
|
||||
# Second roll (2d6) should be normal
|
||||
assert results[1]['num_dice'] == 2
|
||||
assert results[1]['die_sides'] == 6
|
||||
assert all(1 <= roll <= 6 for roll in results[1]['rolls'])
|
||||
assert results[1].num_dice == 2
|
||||
assert results[1].die_sides == 6
|
||||
assert all(1 <= roll <= 6 for roll in results[1].rolls)
|
||||
|
||||
# Third roll (1d20) should be normal
|
||||
assert results[2]['num_dice'] == 1
|
||||
assert results[2]['die_sides'] == 20
|
||||
assert 1 <= results[2]['rolls'][0] <= 20
|
||||
assert results[2].num_dice == 1
|
||||
assert results[2].die_sides == 20
|
||||
assert 1 <= results[2].rolls[0] <= 20
|
||||
|
||||
def test_weighted_scout_dice_pitcher(self, dice_cog):
|
||||
"""Test that pitcher scout dice always rolls 4-6 for first d6."""
|
||||
@ -586,18 +586,18 @@ class TestDiceRollCommands:
|
||||
assert len(results) == 3
|
||||
|
||||
# First d6 should ALWAYS be 4-6 for pitcher
|
||||
first_d6 = results[0]['rolls'][0]
|
||||
first_d6 = results[0].rolls[0]
|
||||
assert 4 <= first_d6 <= 6, f"Pitcher first d6 was {first_d6}, expected 4-6"
|
||||
|
||||
# Second roll (2d6) should be normal
|
||||
assert results[1]['num_dice'] == 2
|
||||
assert results[1]['die_sides'] == 6
|
||||
assert all(1 <= roll <= 6 for roll in results[1]['rolls'])
|
||||
assert results[1].num_dice == 2
|
||||
assert results[1].die_sides == 6
|
||||
assert all(1 <= roll <= 6 for roll in results[1].rolls)
|
||||
|
||||
# Third roll (1d20) should be normal
|
||||
assert results[2]['num_dice'] == 1
|
||||
assert results[2]['die_sides'] == 20
|
||||
assert 1 <= results[2]['rolls'][0] <= 20
|
||||
assert results[2].num_dice == 1
|
||||
assert results[2].die_sides == 20
|
||||
assert 1 <= results[2].rolls[0] <= 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scout_command_batter(self, dice_cog, mock_interaction):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user