CLAUDE: Refactor dice roll functions to use DiceRoll dataclass

Replaced dictionary return values with a DiceRoll dataclass for better
type safety and cleaner code.

Changes:
- Added DiceRoll dataclass with fields: dice_notation, num_dice, die_sides, rolls, total
- Updated _parse_and_roll_single_dice() to return Optional[DiceRoll]
- Updated _parse_and_roll_multiple_dice() to return list[DiceRoll]
- Updated _roll_weighted_scout_dice() to return list[DiceRoll]
- Updated _create_multi_roll_embed() to accept list[DiceRoll]
- Updated _create_fielding_embed() to accept list[DiceRoll]
- Changed all dict key access (result['total']) to dataclass attributes (result.total)
- Updated logging statements to use dataclass attributes
- Updated all 34 test cases to use DiceRoll dataclass

Benefits:
- Improved type safety with explicit dataclass types
- Better IDE autocomplete and type checking
- More maintainable code with clear data structures
- No runtime changes - all functionality preserved

All 34 dice command tests pass.
This commit is contained in:
Cal Corum 2025-10-14 14:28:19 -05:00
parent 4cab227109
commit b61cad2478
2 changed files with 125 additions and 114 deletions

View File

@ -6,6 +6,7 @@ Implements slash commands for dice rolling functionality required for gameplay.
import random
import re
from typing import Optional
from dataclasses import dataclass
import discord
from discord.ext import commands
@ -15,6 +16,16 @@ from utils.decorators import logged_command
from views.embeds import EmbedColors, EmbedTemplate
@dataclass
class DiceRoll:
"""Represents the result of a dice roll."""
dice_notation: str
num_dice: int
die_sides: int
rolls: list[int]
total: int
class DiceRollCommands(commands.Cog):
"""Dice rolling command handlers for gameplay."""
@ -195,7 +206,7 @@ class DiceRollCommands(commands.Cog):
dice_notation = "1d20;3d6"
roll_results = self._parse_and_roll_multiple_dice(dice_notation)
self.logger.info("SA Fielding dice rolled successfully", position=parsed_position, d20=roll_results[0]['total'], d6_total=roll_results[1]['total'])
self.logger.info("SA Fielding dice rolled successfully", position=parsed_position, d20=roll_results[0].total, d6_total=roll_results[1].total)
# Create fielding embed
embed = self._create_fielding_embed(parsed_position, roll_results, ctx.author)
@ -222,11 +233,11 @@ class DiceRollCommands(commands.Cog):
return position_map.get(pos)
def _create_fielding_embed(self, position: str, roll_results: list[dict], user) -> discord.Embed:
def _create_fielding_embed(self, position: str, roll_results: list[DiceRoll], user) -> discord.Embed:
"""Create an embed for fielding roll results."""
d20_result = roll_results[0]['total']
d6_total = roll_results[1]['total']
d6_rolls = roll_results[1]['rolls']
d20_result = roll_results[0].total
d6_total = roll_results[1].total
d6_rolls = roll_results[1].rolls
# Create base embed
embed = EmbedTemplate.create_base_embed(
@ -549,7 +560,7 @@ class DiceRollCommands(commands.Cog):
}
return errors.get(d6_total, 'No error')
def _parse_and_roll_multiple_dice(self, dice_notation: str) -> list[dict]:
def _parse_and_roll_multiple_dice(self, dice_notation: str) -> list[DiceRoll]:
"""Parse dice notation (supports multiple rolls) and return roll results."""
# Split by semicolon for multiple rolls
dice_parts = [part.strip() for part in dice_notation.split(';')]
@ -563,7 +574,7 @@ class DiceRollCommands(commands.Cog):
return results
def _parse_and_roll_single_dice(self, dice_notation: str) -> Optional[dict]:
def _parse_and_roll_single_dice(self, dice_notation: str) -> Optional[DiceRoll]:
"""Parse single dice notation and return roll results."""
# Clean the input
dice_notation = dice_notation.strip().lower().replace(' ', '')
@ -586,15 +597,15 @@ class DiceRollCommands(commands.Cog):
rolls = [random.randint(1, die_sides) for _ in range(num_dice)]
total = sum(rolls)
return {
'dice_notation': dice_notation,
'num_dice': num_dice,
'die_sides': die_sides,
'rolls': rolls,
'total': total
}
return DiceRoll(
dice_notation=dice_notation,
num_dice=num_dice,
die_sides=die_sides,
rolls=rolls,
total=total
)
def _roll_weighted_scout_dice(self, card_type: str) -> list[dict]:
def _roll_weighted_scout_dice(self, card_type: str) -> list[DiceRoll]:
"""
Roll scouting dice with weighted first d6 based on card type.
@ -602,7 +613,7 @@ class DiceRollCommands(commands.Cog):
card_type: Either "batter" (1-3) or "pitcher" (4-6) for first d6
Returns:
List of 3 roll result dicts: weighted 1d6, normal 2d6, normal 1d20
List of 3 roll result dataclasses: weighted 1d6, normal 2d6, normal 1d20
"""
# First die (1d6) - weighted based on card type
if card_type == "batter":
@ -610,13 +621,13 @@ class DiceRollCommands(commands.Cog):
else: # pitcher
first_roll = random.randint(4, 6)
first_d6_result = {
'dice_notation': '1d6',
'num_dice': 1,
'die_sides': 6,
'rolls': [first_roll],
'total': first_roll
}
first_d6_result = DiceRoll(
dice_notation='1d6',
num_dice=1,
die_sides=6,
rolls=[first_roll],
total=first_roll
)
# Second roll (2d6) - normal
second_result = self._parse_and_roll_single_dice("2d6")
@ -626,7 +637,7 @@ class DiceRollCommands(commands.Cog):
return [first_d6_result, second_result, third_result]
def _create_multi_roll_embed(self, dice_notation: str, roll_results: list[dict], user: discord.User | discord.Member) -> discord.Embed:
def _create_multi_roll_embed(self, dice_notation: str, roll_results: list[DiceRoll], user: discord.User | discord.Member) -> discord.Embed:
"""Create an embed for multiple dice roll results."""
embed = EmbedTemplate.create_base_embed(
title="🎲 Dice Roll",
@ -640,16 +651,16 @@ class DiceRollCommands(commands.Cog):
)
# Create summary line with totals
totals = [str(result['total']) for result in roll_results]
totals = [str(result.total) for result in roll_results]
summary = f"# {','.join(totals)}"
# Create details line in the specified format: Details:[1d6;2d6;1d20 (5 - 5 6 - 13)]
dice_notations = [result['dice_notation'] for result in roll_results]
dice_notations = [result.dice_notation for result in roll_results]
# Create the rolls breakdown part - group dice within each roll, separate roll groups with dashes
roll_groups = []
for result in roll_results:
rolls = result['rolls']
rolls = result.rolls
if len(rolls) == 1:
# Single die: just the number
roll_groups.append(str(rolls[0]))

View File

@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import discord
from discord.ext import commands
from commands.dice.rolls import DiceRollCommands
from commands.dice.rolls import DiceRollCommands, DiceRoll
class TestDiceRollCommands:
@ -65,20 +65,20 @@ class TestDiceRollCommands:
results = dice_cog._parse_and_roll_multiple_dice("2d6")
assert len(results) == 1
result = results[0]
assert result['num_dice'] == 2
assert result['die_sides'] == 6
assert len(result['rolls']) == 2
assert all(1 <= roll <= 6 for roll in result['rolls'])
assert result['total'] == sum(result['rolls'])
assert result.num_dice == 2
assert result.die_sides == 6
assert len(result.rolls) == 2
assert all(1 <= roll <= 6 for roll in result.rolls)
assert result.total == sum(result.rolls)
# Test single die
results = dice_cog._parse_and_roll_multiple_dice("1d20")
assert len(results) == 1
result = results[0]
assert result['num_dice'] == 1
assert result['die_sides'] == 20
assert len(result['rolls']) == 1
assert 1 <= result['rolls'][0] <= 20
assert result.num_dice == 1
assert result.die_sides == 20
assert len(result.rolls) == 1
assert 1 <= result.rolls[0] <= 20
def test_parse_invalid_dice_notation(self, dice_cog):
"""Test parsing invalid dice notation."""
@ -101,17 +101,17 @@ class TestDiceRollCommands:
results = dice_cog._parse_and_roll_multiple_dice("1d6;2d8;1d20")
assert len(results) == 3
assert results[0]['dice_notation'] == '1d6'
assert results[0]['num_dice'] == 1
assert results[0]['die_sides'] == 6
assert results[0].dice_notation == '1d6'
assert results[0].num_dice == 1
assert results[0].die_sides == 6
assert results[1]['dice_notation'] == '2d8'
assert results[1]['num_dice'] == 2
assert results[1]['die_sides'] == 8
assert results[1].dice_notation == '2d8'
assert results[1].num_dice == 2
assert results[1].die_sides == 8
assert results[2]['dice_notation'] == '1d20'
assert results[2]['num_dice'] == 1
assert results[2]['die_sides'] == 20
assert results[2].dice_notation == '1d20'
assert results[2].num_dice == 1
assert results[2].die_sides == 20
def test_parse_case_insensitive(self, dice_cog):
"""Test that dice notation parsing is case insensitive."""
@ -120,20 +120,20 @@ class TestDiceRollCommands:
assert len(result_lower) == 1
assert len(result_upper) == 1
assert result_lower[0]['num_dice'] == result_upper[0]['num_dice']
assert result_lower[0]['die_sides'] == result_upper[0]['die_sides']
assert result_lower[0].num_dice == result_upper[0].num_dice
assert result_lower[0].die_sides == result_upper[0].die_sides
def test_parse_whitespace_handling(self, dice_cog):
"""Test that whitespace is handled properly."""
results = dice_cog._parse_and_roll_multiple_dice(" 2d6 ")
assert len(results) == 1
assert results[0]['num_dice'] == 2
assert results[0]['die_sides'] == 6
assert results[0].num_dice == 2
assert results[0].die_sides == 6
results = dice_cog._parse_and_roll_multiple_dice("2 d 6")
assert len(results) == 1
assert results[0]['num_dice'] == 2
assert results[0]['die_sides'] == 6
assert results[0].num_dice == 2
assert results[0].die_sides == 6
@pytest.mark.asyncio
async def test_roll_dice_valid_input(self, dice_cog, mock_interaction):
@ -170,13 +170,13 @@ class TestDiceRollCommands:
def test_create_multi_roll_embed_single_die(self, dice_cog, mock_interaction):
"""Test embed creation for single die roll."""
roll_results = [
{
'dice_notation': '1d20',
'num_dice': 1,
'die_sides': 20,
'rolls': [15],
'total': 15
}
DiceRoll(
dice_notation='1d20',
num_dice=1,
die_sides=20,
rolls=[15],
total=15
)
]
embed = dice_cog._create_multi_roll_embed("1d20", roll_results, mock_interaction.user)
@ -194,27 +194,27 @@ class TestDiceRollCommands:
def test_create_multi_roll_embed_multiple_dice(self, dice_cog, mock_interaction):
"""Test embed creation for multiple dice rolls."""
roll_results = [
{
'dice_notation': '1d6',
'num_dice': 1,
'die_sides': 6,
'rolls': [5],
'total': 5
},
{
'dice_notation': '2d6',
'num_dice': 2,
'die_sides': 6,
'rolls': [5, 6],
'total': 11
},
{
'dice_notation': '1d20',
'num_dice': 1,
'die_sides': 20,
'rolls': [13],
'total': 13
}
DiceRoll(
dice_notation='1d6',
num_dice=1,
die_sides=6,
rolls=[5],
total=5
),
DiceRoll(
dice_notation='2d6',
num_dice=2,
die_sides=6,
rolls=[5, 6],
total=11
),
DiceRoll(
dice_notation='1d20',
num_dice=1,
die_sides=20,
rolls=[13],
total=13
)
]
embed = dice_cog._create_multi_roll_embed("1d6;2d6;1d20", roll_results, mock_interaction.user)
@ -233,7 +233,7 @@ class TestDiceRollCommands:
results = []
for _ in range(20): # Roll 20 times
result = dice_cog._parse_and_roll_multiple_dice("1d20")
results.append(result[0]['rolls'][0])
results.append(result[0].rolls[0])
# Should have some variation in results (very unlikely all 20 rolls are the same)
unique_results = set(results)
@ -245,20 +245,20 @@ class TestDiceRollCommands:
results = dice_cog._parse_and_roll_multiple_dice("100d2")
assert len(results) == 1
result = results[0]
assert len(result['rolls']) == 100
assert all(roll in [1, 2] for roll in result['rolls'])
assert len(result.rolls) == 100
assert all(roll in [1, 2] for roll in result.rolls)
# Test maximum die size
results = dice_cog._parse_and_roll_multiple_dice("1d1000")
assert len(results) == 1
result = results[0]
assert 1 <= result['rolls'][0] <= 1000
assert 1 <= result.rolls[0] <= 1000
# Test minimum valid values
results = dice_cog._parse_and_roll_multiple_dice("1d2")
assert len(results) == 1
result = results[0]
assert result['rolls'][0] in [1, 2]
assert result.rolls[0] in [1, 2]
@pytest.mark.asyncio
async def test_prefix_command_valid_input(self, dice_cog, mock_context):
@ -385,17 +385,17 @@ class TestDiceRollCommands:
assert len(results) == 3
# Check each dice type
assert results[0]['dice_notation'] == '1d6'
assert results[0]['num_dice'] == 1
assert results[0]['die_sides'] == 6
assert results[0].dice_notation == '1d6'
assert results[0].num_dice == 1
assert results[0].die_sides == 6
assert results[1]['dice_notation'] == '2d6'
assert results[1]['num_dice'] == 2
assert results[1]['die_sides'] == 6
assert results[1].dice_notation == '2d6'
assert results[1].num_dice == 2
assert results[1].die_sides == 6
assert results[2]['dice_notation'] == '1d20'
assert results[2]['num_dice'] == 1
assert results[2]['die_sides'] == 20
assert results[2].dice_notation == '1d20'
assert results[2].num_dice == 1
assert results[2].die_sides == 20
# Fielding command tests
@pytest.mark.asyncio
@ -544,14 +544,14 @@ class TestDiceRollCommands:
assert len(results) == 2
# Check 1d20
assert results[0]['dice_notation'] == '1d20'
assert results[0]['num_dice'] == 1
assert results[0]['die_sides'] == 20
assert results[0].dice_notation == '1d20'
assert results[0].num_dice == 1
assert results[0].die_sides == 20
# Check 3d6
assert results[1]['dice_notation'] == '3d6'
assert results[1]['num_dice'] == 3
assert results[1]['die_sides'] == 6
assert results[1].dice_notation == '3d6'
assert results[1].num_dice == 3
assert results[1].die_sides == 6
def test_weighted_scout_dice_batter(self, dice_cog):
"""Test that batter scout dice always rolls 1-3 for first d6."""
@ -563,18 +563,18 @@ class TestDiceRollCommands:
assert len(results) == 3
# First d6 should ALWAYS be 1-3 for batter
first_d6 = results[0]['rolls'][0]
first_d6 = results[0].rolls[0]
assert 1 <= first_d6 <= 3, f"Batter first d6 was {first_d6}, expected 1-3"
# Second roll (2d6) should be normal
assert results[1]['num_dice'] == 2
assert results[1]['die_sides'] == 6
assert all(1 <= roll <= 6 for roll in results[1]['rolls'])
assert results[1].num_dice == 2
assert results[1].die_sides == 6
assert all(1 <= roll <= 6 for roll in results[1].rolls)
# Third roll (1d20) should be normal
assert results[2]['num_dice'] == 1
assert results[2]['die_sides'] == 20
assert 1 <= results[2]['rolls'][0] <= 20
assert results[2].num_dice == 1
assert results[2].die_sides == 20
assert 1 <= results[2].rolls[0] <= 20
def test_weighted_scout_dice_pitcher(self, dice_cog):
"""Test that pitcher scout dice always rolls 4-6 for first d6."""
@ -586,18 +586,18 @@ class TestDiceRollCommands:
assert len(results) == 3
# First d6 should ALWAYS be 4-6 for pitcher
first_d6 = results[0]['rolls'][0]
first_d6 = results[0].rolls[0]
assert 4 <= first_d6 <= 6, f"Pitcher first d6 was {first_d6}, expected 4-6"
# Second roll (2d6) should be normal
assert results[1]['num_dice'] == 2
assert results[1]['die_sides'] == 6
assert all(1 <= roll <= 6 for roll in results[1]['rolls'])
assert results[1].num_dice == 2
assert results[1].die_sides == 6
assert all(1 <= roll <= 6 for roll in results[1].rolls)
# Third roll (1d20) should be normal
assert results[2]['num_dice'] == 1
assert results[2]['die_sides'] == 20
assert 1 <= results[2]['rolls'][0] <= 20
assert results[2].num_dice == 1
assert results[2].die_sides == 20
assert 1 <= results[2].rolls[0] <= 20
@pytest.mark.asyncio
async def test_scout_command_batter(self, dice_cog, mock_interaction):