Standardize formatting with black and apply ruff auto-fixes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
323 lines
9.8 KiB
Python
323 lines
9.8 KiB
Python
"""
|
|
Retrosheet CSV Format Transformer
|
|
|
|
This module transforms newer Retrosheet CSV formats into the legacy format
|
|
expected by retrosheet_data.py. Includes smart caching to avoid redundant
|
|
transformations.
|
|
|
|
Author: Claude Code
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
|
|
# Set up logging
|
|
logger = logging.getLogger(f"{__name__}")
|
|
|
|
|
|
def get_normalized_csv_path(source_path: str) -> str:
|
|
"""
|
|
Generate the cached/normalized CSV path from source path.
|
|
|
|
Args:
|
|
source_path: Path to the source CSV file
|
|
|
|
Returns:
|
|
Path to the normalized cache file
|
|
"""
|
|
source = Path(source_path)
|
|
cache_name = f"{source.stem}_normalized{source.suffix}"
|
|
return str(source.parent / cache_name)
|
|
|
|
|
|
def needs_transformation(source_path: str, cache_path: str) -> bool:
|
|
"""
|
|
Check if transformation is needed based on file modification times.
|
|
|
|
Args:
|
|
source_path: Path to source CSV
|
|
cache_path: Path to cached normalized CSV
|
|
|
|
Returns:
|
|
True if transformation needed, False if cache is valid
|
|
"""
|
|
if not os.path.exists(cache_path):
|
|
logger.info(f"Cache file not found: {cache_path}")
|
|
return True
|
|
|
|
source_mtime = os.path.getmtime(source_path)
|
|
cache_mtime = os.path.getmtime(cache_path)
|
|
|
|
if source_mtime > cache_mtime:
|
|
logger.info("Source file is newer than cache, transformation needed")
|
|
return True
|
|
|
|
logger.info(f"Using cached normalized file: {cache_path}")
|
|
return False
|
|
|
|
|
|
def transform_event_type(row: pd.Series) -> str:
|
|
"""
|
|
Derive event_type from boolean columns in new format.
|
|
|
|
Priority order matches baseball scoring conventions.
|
|
"""
|
|
if row["hr"] == 1:
|
|
return "home run"
|
|
elif row["triple"] == 1:
|
|
return "triple"
|
|
elif row["double"] == 1:
|
|
return "double"
|
|
elif row["single"] == 1:
|
|
return "single"
|
|
elif row["walk"] == 1 or row["iw"] == 1:
|
|
return "walk"
|
|
elif row["k"] == 1:
|
|
return "strikeout"
|
|
elif row["hbp"] == 1:
|
|
return "hit by pitch"
|
|
else:
|
|
return "generic out"
|
|
|
|
|
|
def transform_batted_ball_type(row: pd.Series) -> str:
|
|
"""
|
|
Derive batted_ball_type from boolean columns.
|
|
|
|
Returns 'f' (fly), 'G' (ground), 'l' (line), or empty string.
|
|
"""
|
|
if row["fly"] == 1:
|
|
return "f"
|
|
elif row["ground"] == 1:
|
|
return "G"
|
|
elif row["line"] == 1:
|
|
return "l"
|
|
else:
|
|
return ""
|
|
|
|
|
|
def transform_hit_val(row: pd.Series) -> str:
|
|
"""
|
|
Derive hit_val from hit type columns.
|
|
|
|
Returns '1', '2', '3', '4' for singles through home runs.
|
|
"""
|
|
if row["hr"] == 1:
|
|
return "4"
|
|
elif row["triple"] == 1:
|
|
return "3"
|
|
elif row["double"] == 1:
|
|
return "2"
|
|
elif row["single"] == 1:
|
|
return "1"
|
|
else:
|
|
return ""
|
|
|
|
|
|
def bool_to_tf(val) -> str:
|
|
"""Convert 1/0 or True/False to 't'/'f' strings."""
|
|
if pd.isna(val):
|
|
return "f"
|
|
return "t" if val == 1 or val is True else "f"
|
|
|
|
|
|
def transform_retrosheet_csv(source_path: str) -> pd.DataFrame:
|
|
"""
|
|
Transform new Retrosheet CSV format to legacy format.
|
|
|
|
Args:
|
|
source_path: Path to source CSV file
|
|
|
|
Returns:
|
|
Transformed DataFrame in legacy format
|
|
"""
|
|
logger.info(f"Reading source CSV: {source_path}")
|
|
df = pd.read_csv(source_path, low_memory=False)
|
|
|
|
logger.info(f"Transforming {len(df)} rows to legacy format")
|
|
|
|
# Create new dataframe with legacy column names
|
|
transformed = pd.DataFrame()
|
|
|
|
# Simple renames (with case conversion for handedness)
|
|
transformed["game_id"] = df["gid"]
|
|
transformed["batter_id"] = df["batter"]
|
|
transformed["pitcher_id"] = df["pitcher"]
|
|
transformed["batter_hand"] = df["bathand"].str.lower() # Convert R/L/B to r/l/b
|
|
transformed["pitcher_hand"] = df["pithand"].str.lower() # Convert R/L to r/l
|
|
transformed["hit_location"] = df["loc"].astype(
|
|
str
|
|
) # Ensure string type for .str operations
|
|
|
|
# Derive event_type from multiple columns
|
|
logger.info("Deriving event_type from hit/walk/strikeout columns")
|
|
transformed["event_type"] = df.apply(transform_event_type, axis=1)
|
|
|
|
# Derive batted_ball_type
|
|
logger.info("Deriving batted_ball_type from fly/ground/line columns")
|
|
transformed["batted_ball_type"] = df.apply(
|
|
transform_batted_ball_type, axis=1
|
|
).astype(str)
|
|
|
|
# Derive hit_val
|
|
logger.info("Deriving hit_val from hit type columns")
|
|
transformed["hit_val"] = df.apply(transform_hit_val, axis=1).astype(str)
|
|
|
|
# Boolean conversions to 't'/'f' format
|
|
logger.info("Converting boolean columns to 't'/'f' format")
|
|
transformed["batter_event"] = df["pa"].apply(bool_to_tf)
|
|
transformed["ab"] = df["ab"].apply(bool_to_tf)
|
|
transformed["bunt"] = df["bunt"].apply(bool_to_tf)
|
|
transformed["tp"] = df["tp"].apply(bool_to_tf)
|
|
|
|
# Combine gdp + othdp for double play indicator
|
|
transformed["dp"] = (df["gdp"].fillna(0) + df["othdp"].fillna(0)).apply(
|
|
lambda x: "t" if x > 0 else "f"
|
|
)
|
|
|
|
# Convert batter handedness to actual batting side for each PA
|
|
# Switch hitters (B) bat left vs RHP and right vs LHP
|
|
def get_result_batter_hand(row):
|
|
bathand = row["bathand"].upper()
|
|
pithand = row["pithand"].upper()
|
|
|
|
if bathand == "B": # Switch hitter
|
|
# Switch hitters bat from opposite side of pitcher
|
|
return "l" if pithand == "R" else "r"
|
|
else:
|
|
# Regular batters always bat from same side
|
|
return bathand.lower()
|
|
|
|
logger.info("Converting switch hitter handedness based on pitcher matchups")
|
|
transformed["result_batter_hand"] = df.apply(get_result_batter_hand, axis=1)
|
|
|
|
# Add placeholder columns that may be referenced but aren't critical for stats
|
|
# These can be populated if needed in the future
|
|
transformed["event_id"] = range(1, len(df) + 1)
|
|
transformed["batting_team"] = ""
|
|
transformed["inning"] = df["inning"] if "inning" in df.columns else ""
|
|
transformed["outs"] = ""
|
|
transformed["balls"] = ""
|
|
transformed["strikes"] = ""
|
|
transformed["pitch_seq"] = ""
|
|
transformed["vis_score"] = ""
|
|
transformed["home_score"] = ""
|
|
transformed["result_batter_id"] = df["batter"]
|
|
transformed["result_pitcher_id"] = df["pitcher"]
|
|
transformed["result_pitcher_hand"] = df["pithand"]
|
|
transformed["def_c"] = ""
|
|
transformed["def_1b"] = ""
|
|
transformed["def_2b"] = ""
|
|
transformed["def_3b"] = ""
|
|
transformed["def_ss"] = ""
|
|
transformed["def_lf"] = ""
|
|
transformed["def_cf"] = ""
|
|
transformed["def_rf"] = ""
|
|
transformed["run_1b"] = ""
|
|
transformed["run_2b"] = ""
|
|
transformed["run_3b"] = ""
|
|
transformed["event_scoring"] = ""
|
|
transformed["leadoff"] = ""
|
|
transformed["pinch_hit"] = ""
|
|
transformed["batt_def_pos"] = ""
|
|
transformed["batt_lineup_pos"] = ""
|
|
transformed["sac_hit"] = df["sh"].apply(bool_to_tf) if "sh" in df.columns else "f"
|
|
transformed["sac_fly"] = df["sf"].apply(bool_to_tf) if "sf" in df.columns else "f"
|
|
transformed["event_outs"] = ""
|
|
transformed["rbi"] = ""
|
|
transformed["wild_pitch"] = (
|
|
df["wp"].apply(bool_to_tf) if "wp" in df.columns else "f"
|
|
)
|
|
transformed["passed_ball"] = (
|
|
df["pb"].apply(bool_to_tf) if "pb" in df.columns else "f"
|
|
)
|
|
transformed["fielded_by"] = ""
|
|
transformed["foul_ground"] = ""
|
|
|
|
logger.info(f"Transformation complete: {len(transformed)} rows")
|
|
return transformed
|
|
|
|
|
|
def load_retrosheet_csv(
|
|
source_path: str, force_transform: bool = False
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load Retrosheet CSV, using cached normalized version if available.
|
|
|
|
This is the main entry point for loading Retrosheet data. It handles:
|
|
- Checking for cached normalized data
|
|
- Transforming if needed
|
|
- Saving transformed data for future use
|
|
|
|
Args:
|
|
source_path: Path to source Retrosheet CSV
|
|
force_transform: If True, ignore cache and force transformation
|
|
|
|
Returns:
|
|
DataFrame in legacy format ready for retrosheet_data.py
|
|
"""
|
|
logger.info(f"Loading Retrosheet CSV: {source_path}")
|
|
|
|
if not os.path.exists(source_path):
|
|
raise FileNotFoundError(f"Source file not found: {source_path}")
|
|
|
|
cache_path = get_normalized_csv_path(source_path)
|
|
|
|
# Check if we need to transform
|
|
if force_transform or needs_transformation(source_path, cache_path):
|
|
# Transform the data
|
|
df = transform_retrosheet_csv(source_path)
|
|
|
|
# Save to cache
|
|
logger.info(f"Saving normalized data to cache: {cache_path}")
|
|
df.to_csv(cache_path, index=False)
|
|
logger.info("Cache saved successfully")
|
|
|
|
return df
|
|
else:
|
|
# Load from cache
|
|
logger.info(f"Loading from cache: {cache_path}")
|
|
# Explicitly set dtypes for string columns to ensure .str accessor works
|
|
dtype_dict = {
|
|
"game_id": "str",
|
|
"hit_val": "str",
|
|
"hit_location": "str",
|
|
"batted_ball_type": "str",
|
|
}
|
|
return pd.read_csv(cache_path, dtype=dtype_dict, low_memory=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the transformer
|
|
import sys
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
if len(sys.argv) > 1:
|
|
test_file = sys.argv[1]
|
|
else:
|
|
test_file = "data-input/retrosheet/retrosheets_events_2005.csv"
|
|
|
|
print(f"\n{'='*60}")
|
|
print("Testing Retrosheet Transformer")
|
|
print(f"{'='*60}\n")
|
|
|
|
df = load_retrosheet_csv(test_file)
|
|
|
|
print("\nTransformed DataFrame Info:")
|
|
print(f"Shape: {df.shape}")
|
|
print(f"\nColumns: {list(df.columns)}")
|
|
print("\nSample rows:")
|
|
print(df.head(3))
|
|
|
|
print("\nEvent type distribution:")
|
|
print(df["event_type"].value_counts())
|
|
|
|
print("\nBatted ball type distribution:")
|
|
print(df["batted_ball_type"].value_counts())
|