445 lines
18 KiB
Python
445 lines
18 KiB
Python
import pytest
|
|
import asyncio
|
|
import pandas as pd
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
import sys
|
|
|
|
# Create a proper mock for pybaseball
|
|
mock_pb = Mock()
|
|
mock_pb.cache = Mock()
|
|
mock_pb.cache.enable = Mock()
|
|
mock_pb.batting_stats_bref = Mock()
|
|
mock_pb.pitching_stats_bref = Mock()
|
|
mock_pb.batting_stats = Mock()
|
|
mock_pb.pitching_stats = Mock()
|
|
mock_pb.batting_stats_range = Mock()
|
|
mock_pb.pitching_stats_range = Mock()
|
|
mock_pb.get_splits = Mock()
|
|
|
|
# Mock the modules before importing
|
|
with patch.dict('sys.modules', {
|
|
'pybaseball': mock_pb,
|
|
'creation_helpers': Mock(),
|
|
'exceptions': Mock()
|
|
}):
|
|
from automated_data_fetcher import DataFetcher, LiveSeriesDataFetcher, fetch_season_data, fetch_live_series_data
|
|
|
|
|
|
class TestDataFetcher:
|
|
"""Test cases for the DataFetcher class"""
|
|
|
|
@pytest.fixture
|
|
def fetcher(self):
|
|
"""Create a DataFetcher instance for testing"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
# Override output directory to use temp directory
|
|
fetcher = DataFetcher(2023, "Season")
|
|
fetcher.output_dir = Path(tmp_dir) / "test_output"
|
|
yield fetcher
|
|
|
|
@pytest.fixture
|
|
def sample_batting_data(self):
|
|
"""Sample batting data for testing"""
|
|
return pd.DataFrame({
|
|
'Name': ['Player A', 'Player B', 'Player C'],
|
|
'Team': ['NYY', 'LAD', 'BOS'],
|
|
'G': [162, 140, 120],
|
|
'PA': [650, 580, 450],
|
|
'H': [180, 160, 120],
|
|
'HR': [30, 25, 15],
|
|
'RBI': [100, 85, 65],
|
|
'SB': [20, 5, 8],
|
|
'CS': [5, 2, 3],
|
|
'SB%': [0.8, 0.714, 0.727],
|
|
'GDP': [15, 12, 8],
|
|
'R': [95, 80, 55],
|
|
'BB': [65, 55, 40],
|
|
'SO': [150, 120, 90],
|
|
'IDfg': ['12345', '67890', '11111']
|
|
})
|
|
|
|
@pytest.fixture
|
|
def sample_pitching_data(self):
|
|
"""Sample pitching data for testing"""
|
|
return pd.DataFrame({
|
|
'Name': ['Pitcher A', 'Pitcher B'],
|
|
'Team': ['NYY', 'LAD'],
|
|
'W': [15, 12],
|
|
'L': [8, 10],
|
|
'ERA': [3.25, 4.15],
|
|
'G': [32, 30],
|
|
'GS': [32, 30],
|
|
'IP': [200.1, 180.2],
|
|
'H': [180, 190],
|
|
'HR': [25, 30],
|
|
'BB': [60, 70],
|
|
'SO': [220, 180]
|
|
})
|
|
|
|
@pytest.fixture
|
|
def sample_splits_data(self):
|
|
"""Sample splits data for testing"""
|
|
return pd.DataFrame({
|
|
'Split': ['vs LHP', 'vs RHP', 'Home', 'Away'],
|
|
'G': [80, 82, 81, 81],
|
|
'PA': [320, 330, 325, 325],
|
|
'H': [85, 95, 90, 90],
|
|
'AVG': [.280, .295, .285, .285],
|
|
'OBP': [.350, .365, .360, .355],
|
|
'SLG': [.450, .480, .465, .465]
|
|
})
|
|
|
|
def test_init(self, fetcher):
|
|
"""Test DataFetcher initialization"""
|
|
assert fetcher.season == 2023
|
|
assert fetcher.cardset_type == "Season"
|
|
assert fetcher.cache_enabled == True
|
|
# Note: fetcher.output_dir is overridden in the fixture to use temp directory
|
|
|
|
def test_ensure_output_dir(self, fetcher):
|
|
"""Test output directory creation"""
|
|
assert not fetcher.output_dir.exists()
|
|
fetcher.ensure_output_dir()
|
|
assert fetcher.output_dir.exists()
|
|
|
|
def test_get_csv_filename(self, fetcher):
|
|
"""Test CSV filename mapping"""
|
|
assert fetcher._get_csv_filename('pitching') == 'pitching.csv'
|
|
assert fetcher._get_csv_filename('running') == 'running.csv'
|
|
assert fetcher._get_csv_filename('batting_basic') == 'batter-stats.csv'
|
|
assert fetcher._get_csv_filename('pitching_basic') == 'pitcher-stats.csv'
|
|
assert fetcher._get_csv_filename('unknown_type') == 'unknown_type.csv'
|
|
|
|
def test_transform_for_card_creation_batting_splits(self, fetcher, sample_splits_data):
|
|
"""Test batting splits transformation"""
|
|
result = fetcher._transform_for_card_creation(sample_splits_data, 'batting_splits')
|
|
|
|
# Should filter to only handedness splits
|
|
expected_splits = ['vs LHP', 'vs RHP']
|
|
assert all(split in expected_splits for split in result['Split'].values)
|
|
assert len(result) == 2
|
|
|
|
def test_transform_for_card_creation_running(self, fetcher, sample_batting_data):
|
|
"""Test running stats transformation"""
|
|
result = fetcher._transform_for_card_creation(sample_batting_data, 'running')
|
|
|
|
# Should include only running-related columns
|
|
expected_cols = ['Name', 'SB', 'CS', 'SB%', 'GDP']
|
|
assert all(col in expected_cols for col in result.columns)
|
|
|
|
def test_save_data_to_csv(self, fetcher, sample_batting_data):
|
|
"""Test saving data to CSV"""
|
|
fetcher.ensure_output_dir()
|
|
|
|
data = {'batting_basic': sample_batting_data}
|
|
fetcher.save_data_to_csv(data)
|
|
|
|
# Check file was created
|
|
expected_file = fetcher.output_dir / 'batter-stats.csv'
|
|
assert expected_file.exists()
|
|
|
|
# Verify content
|
|
saved_data = pd.read_csv(expected_file)
|
|
assert len(saved_data) == len(sample_batting_data)
|
|
assert 'Name' in saved_data.columns
|
|
|
|
def test_save_data_to_csv_empty_dataframe(self, fetcher):
|
|
"""Test saving empty dataframe"""
|
|
fetcher.ensure_output_dir()
|
|
|
|
empty_data = {'empty_set': pd.DataFrame()}
|
|
fetcher.save_data_to_csv(empty_data)
|
|
|
|
# Should not create file for empty data
|
|
expected_file = fetcher.output_dir / 'empty_set.csv'
|
|
assert not expected_file.exists()
|
|
|
|
@patch('automated_data_fetcher.pb.batting_stats_bref')
|
|
@patch('automated_data_fetcher.pb.pitching_stats_bref')
|
|
async def test_fetch_baseball_reference_data(self, mock_pitching, mock_batting, fetcher,
|
|
sample_batting_data, sample_pitching_data):
|
|
"""Test fetching Baseball Reference data"""
|
|
# Mock pybaseball functions
|
|
mock_batting.return_value = sample_batting_data
|
|
mock_pitching.return_value = sample_pitching_data
|
|
|
|
# Mock player ID and splits functions
|
|
with patch.object(fetcher, '_get_active_players', return_value=['12345', '67890']):
|
|
with patch.object(fetcher, '_fetch_player_splits', return_value={
|
|
'batting': pd.DataFrame(), 'pitching': pd.DataFrame()
|
|
}):
|
|
result = await fetcher.fetch_baseball_reference_data()
|
|
|
|
# Verify data structure
|
|
assert 'pitching' in result
|
|
assert 'running' in result
|
|
assert 'batting_splits' in result
|
|
assert 'pitching_splits' in result
|
|
|
|
# Verify data content
|
|
assert len(result['pitching']) == 2
|
|
assert len(result['running']) == 3
|
|
|
|
@patch('automated_data_fetcher.pb.batting_stats')
|
|
@patch('automated_data_fetcher.pb.pitching_stats')
|
|
async def test_fetch_fangraphs_data(self, mock_pitching, mock_batting, fetcher,
|
|
sample_batting_data, sample_pitching_data):
|
|
"""Test fetching FanGraphs data"""
|
|
# Mock pybaseball functions
|
|
mock_batting.return_value = sample_batting_data
|
|
mock_pitching.return_value = sample_pitching_data
|
|
|
|
result = await fetcher.fetch_fangraphs_data()
|
|
|
|
# Verify data structure
|
|
assert 'batting_basic' in result
|
|
assert 'pitching_basic' in result
|
|
|
|
# Verify function calls
|
|
mock_batting.assert_called_once_with(2023, 2023)
|
|
mock_pitching.assert_called_once_with(2023, 2023)
|
|
|
|
@patch('automated_data_fetcher.pb.batting_stats_range')
|
|
@patch('automated_data_fetcher.pb.pitching_stats_range')
|
|
async def test_fetch_fangraphs_data_with_dates(self, mock_pitching, mock_batting, fetcher,
|
|
sample_batting_data, sample_pitching_data):
|
|
"""Test fetching FanGraphs data with date range"""
|
|
# Mock pybaseball functions
|
|
mock_batting.return_value = sample_batting_data
|
|
mock_pitching.return_value = sample_pitching_data
|
|
|
|
start_date = "2023-03-01"
|
|
end_date = "2023-09-01"
|
|
result = await fetcher.fetch_fangraphs_data(start_date, end_date)
|
|
|
|
# Verify function calls with date parameters
|
|
mock_batting.assert_called_once_with(start_date, end_date)
|
|
mock_pitching.assert_called_once_with(start_date, end_date)
|
|
|
|
@patch('automated_data_fetcher.get_all_pybaseball_ids')
|
|
async def test_get_active_players_existing_function(self, mock_get_ids, fetcher):
|
|
"""Test getting player IDs using existing function"""
|
|
mock_get_ids.return_value = ['12345', '67890', '11111']
|
|
|
|
result = await fetcher._get_active_players()
|
|
|
|
assert result == ['12345', '67890', '11111']
|
|
mock_get_ids.assert_called_once_with(2023)
|
|
|
|
@patch('automated_data_fetcher.get_all_pybaseball_ids')
|
|
@patch('automated_data_fetcher.pb.batting_stats')
|
|
async def test_get_active_players_fallback(self, mock_batting, mock_get_ids, fetcher, sample_batting_data):
|
|
"""Test getting player IDs with fallback to FanGraphs"""
|
|
# Mock existing function to fail
|
|
mock_get_ids.side_effect = Exception("Function not available")
|
|
mock_batting.return_value = sample_batting_data
|
|
|
|
result = await fetcher._get_active_players()
|
|
|
|
# Should fallback to FanGraphs data
|
|
expected_ids = ['12345', '67890', '11111']
|
|
assert result == expected_ids
|
|
|
|
@patch('automated_data_fetcher.pb.get_splits')
|
|
async def test_fetch_player_splits(self, mock_get_splits, fetcher, sample_splits_data):
|
|
"""Test fetching player splits"""
|
|
# Mock get_splits to return sample data
|
|
mock_get_splits.return_value = sample_splits_data
|
|
|
|
player_ids = ['12345', '67890']
|
|
result = await fetcher._fetch_player_splits(player_ids)
|
|
|
|
# Verify structure
|
|
assert 'batting' in result
|
|
assert 'pitching' in result
|
|
|
|
# Verify splits were called for each player
|
|
assert mock_get_splits.call_count == 4 # 2 players * 2 split types
|
|
|
|
|
|
class TestLiveSeriesDataFetcher:
|
|
"""Test cases for the LiveSeriesDataFetcher class"""
|
|
|
|
@pytest.fixture
|
|
def live_fetcher(self):
|
|
"""Create a LiveSeriesDataFetcher instance for testing"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
fetcher = LiveSeriesDataFetcher(2023, 81) # Half season
|
|
fetcher.output_dir = Path(tmp_dir) / "test_output"
|
|
yield fetcher
|
|
|
|
def test_init(self, live_fetcher):
|
|
"""Test LiveSeriesDataFetcher initialization"""
|
|
assert live_fetcher.season == 2023
|
|
assert live_fetcher.cardset_type == "Live"
|
|
assert live_fetcher.games_played == 81
|
|
assert live_fetcher.start_date == "2023-03-01"
|
|
|
|
def test_calculate_end_date(self, live_fetcher):
|
|
"""Test end date calculation"""
|
|
# 81 games should be roughly half season (90 days)
|
|
end_date = live_fetcher._calculate_end_date(81)
|
|
|
|
# Should be a valid date string
|
|
assert len(end_date) == 10 # YYYY-MM-DD format
|
|
assert end_date.startswith("2023")
|
|
|
|
# Should be after start date
|
|
assert end_date > "2023-03-01"
|
|
|
|
# Test full season
|
|
full_season_end = live_fetcher._calculate_end_date(162)
|
|
assert full_season_end > end_date
|
|
|
|
@patch.object(DataFetcher, 'fetch_baseball_reference_data')
|
|
@patch.object(DataFetcher, 'fetch_fangraphs_data')
|
|
async def test_fetch_live_data(self, mock_fg_data, mock_bref_data, live_fetcher):
|
|
"""Test fetching live series data"""
|
|
# Mock return values
|
|
mock_bref_data.return_value = {'pitching': pd.DataFrame(), 'running': pd.DataFrame()}
|
|
mock_fg_data.return_value = {'batting_basic': pd.DataFrame()}
|
|
|
|
result = await live_fetcher.fetch_live_data()
|
|
|
|
# Verify both data sources were called
|
|
mock_bref_data.assert_called_once()
|
|
mock_fg_data.assert_called_once_with(live_fetcher.start_date, live_fetcher.end_date)
|
|
|
|
# Verify combined result
|
|
assert 'pitching' in result
|
|
assert 'running' in result
|
|
assert 'batting_basic' in result
|
|
|
|
|
|
class TestUtilityFunctions:
|
|
"""Test cases for utility functions"""
|
|
|
|
@patch('automated_data_fetcher.DataFetcher')
|
|
async def test_fetch_season_data(self, mock_fetcher_class):
|
|
"""Test fetch_season_data function"""
|
|
# Create mock fetcher instance
|
|
mock_fetcher = Mock()
|
|
mock_fetcher.fetch_baseball_reference_data = AsyncMock(return_value={'pitching': pd.DataFrame()})
|
|
mock_fetcher.fetch_fangraphs_data = AsyncMock(return_value={'batting_basic': pd.DataFrame()})
|
|
mock_fetcher.save_data_to_csv = Mock()
|
|
mock_fetcher.output_dir = Path("test/output")
|
|
mock_fetcher_class.return_value = mock_fetcher
|
|
|
|
# Capture print output
|
|
with patch('builtins.print') as mock_print:
|
|
await fetch_season_data(2023)
|
|
|
|
# Verify fetcher was created and methods called
|
|
mock_fetcher_class.assert_called_once_with(2023, "Season")
|
|
mock_fetcher.fetch_baseball_reference_data.assert_called_once()
|
|
mock_fetcher.fetch_fangraphs_data.assert_called_once()
|
|
mock_fetcher.save_data_to_csv.assert_called_once()
|
|
|
|
# Verify print output includes completion message
|
|
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
|
assert any("AUTOMATED DOWNLOAD COMPLETE" in call for call in print_calls)
|
|
|
|
@patch('automated_data_fetcher.LiveSeriesDataFetcher')
|
|
async def test_fetch_live_series_data(self, mock_fetcher_class):
|
|
"""Test fetch_live_series_data function"""
|
|
# Create mock fetcher instance
|
|
mock_fetcher = Mock()
|
|
mock_fetcher.fetch_live_data = AsyncMock(return_value={'live_data': pd.DataFrame()})
|
|
mock_fetcher.save_data_to_csv = Mock()
|
|
mock_fetcher_class.return_value = mock_fetcher
|
|
|
|
await fetch_live_series_data(2023, 81)
|
|
|
|
# Verify fetcher was created and methods called
|
|
mock_fetcher_class.assert_called_once_with(2023, 81)
|
|
mock_fetcher.fetch_live_data.assert_called_once()
|
|
mock_fetcher.save_data_to_csv.assert_called_once()
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling scenarios"""
|
|
|
|
@pytest.fixture
|
|
def fetcher(self):
|
|
"""Create a DataFetcher instance for error testing"""
|
|
return DataFetcher(2023, "Season")
|
|
|
|
@patch('automated_data_fetcher.pb.pitching_stats_bref')
|
|
async def test_fetch_baseball_reference_data_error(self, mock_pitching, fetcher):
|
|
"""Test error handling in Baseball Reference data fetch"""
|
|
# Mock function to raise an exception
|
|
mock_pitching.side_effect = Exception("Network error")
|
|
|
|
with pytest.raises(Exception, match="Error fetching Baseball Reference data"):
|
|
await fetcher.fetch_baseball_reference_data()
|
|
|
|
@patch('automated_data_fetcher.pb.batting_stats')
|
|
async def test_fetch_fangraphs_data_error(self, mock_batting, fetcher):
|
|
"""Test error handling in FanGraphs data fetch"""
|
|
# Mock function to raise an exception
|
|
mock_batting.side_effect = Exception("API error")
|
|
|
|
with pytest.raises(Exception, match="Error fetching FanGraphs data"):
|
|
await fetcher.fetch_fangraphs_data()
|
|
|
|
@patch('automated_data_fetcher.get_all_pybaseball_ids')
|
|
@patch('automated_data_fetcher.pb.batting_stats')
|
|
async def test_get_active_players_complete_failure(self, mock_batting, mock_get_ids, fetcher):
|
|
"""Test complete failure in getting player IDs"""
|
|
# Mock both functions to fail
|
|
mock_get_ids.side_effect = Exception("Function error")
|
|
mock_batting.side_effect = Exception("API error")
|
|
|
|
result = await fetcher._get_active_players()
|
|
|
|
# Should return empty list when all methods fail
|
|
assert result == []
|
|
|
|
@patch('automated_data_fetcher.pb.get_splits')
|
|
async def test_fetch_player_splits_individual_errors(self, mock_get_splits, fetcher):
|
|
"""Test handling individual player split fetch errors"""
|
|
# Mock get_splits to fail for some players
|
|
def side_effect(player_id, **kwargs):
|
|
if player_id == 'bad_player':
|
|
raise Exception("Player not found")
|
|
return pd.DataFrame({'Split': ['vs LHP'], 'AVG': [.250]})
|
|
|
|
mock_get_splits.side_effect = side_effect
|
|
|
|
player_ids = ['good_player', 'bad_player', 'another_good_player']
|
|
result = await fetcher._fetch_player_splits(player_ids)
|
|
|
|
# Should handle errors gracefully and return data for successful players
|
|
assert 'batting' in result
|
|
assert 'pitching' in result
|
|
|
|
# Should have been called for all players despite errors
|
|
assert mock_get_splits.call_count == 6 # 3 players * 2 split types
|
|
|
|
|
|
# Integration test markers
|
|
@pytest.mark.integration
|
|
class TestIntegration:
|
|
"""Integration tests that require network access"""
|
|
|
|
@pytest.mark.skip(reason="Requires network access and may be slow")
|
|
async def test_real_data_fetch(self):
|
|
"""Test fetching real data from pybaseball (skip by default)"""
|
|
fetcher = DataFetcher(2022, "Season") # Use a complete season
|
|
|
|
# This would actually call pybaseball APIs
|
|
# Only run when specifically testing integration
|
|
try:
|
|
fg_data = await fetcher.fetch_fangraphs_data()
|
|
assert 'batting_basic' in fg_data
|
|
assert 'pitching_basic' in fg_data
|
|
except Exception as e:
|
|
pytest.skip(f"Network error during integration test: {e}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Run tests
|
|
pytest.main([__file__, '-v']) |