fix: remove commented-out dead code blocks (#31)
All checks were successful
Build Docker Image / build (pull_request) Successful in 2m12s

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-05 13:34:43 -06:00
parent ddf5f77da4
commit 06794c27a1
5 changed files with 1249 additions and 849 deletions

File diff suppressed because it is too large Load Diff

View File

@ -11,22 +11,14 @@ from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis import Redis from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
# logging.basicConfig(
# filename=f'logs/database/{date}.log',
# format='%(asctime)s - sba-database - %(levelname)s - %(message)s',
# level=log_level
# )
# Redis configuration # Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
REDIS_DB = int(os.environ.get('REDIS_DB', '0')) REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
# Initialize Redis client with connection error handling # Initialize Redis client with connection error handling
if not CACHE_ENABLED: if not CACHE_ENABLED:
@ -40,7 +32,7 @@ else:
db=REDIS_DB, db=REDIS_DB,
decode_responses=True, decode_responses=True,
socket_connect_timeout=5, socket_connect_timeout=5,
socket_timeout=5 socket_timeout=5,
) )
# Test connection # Test connection
redis_client.ping() redis_client.ping()
@ -50,12 +42,16 @@ else:
redis_client = None redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() priv_help = (
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False False
if not os.environ.get("PRIVATE_IN_SCHEMA")
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
def valid_token(token): def valid_token(token):
return token == os.environ.get('API_TOKEN') return token == os.environ.get("API_TOKEN")
def update_season_batting_stats(player_ids, season, db_connection): def update_season_batting_stats(player_ids, season, db_connection):
@ -63,17 +59,19 @@ def update_season_batting_stats(player_ids, season, db_connection):
Update season batting stats for specific players in a given season. Update season batting stats for specific players in a given season.
Recalculates stats from stratplay data and upserts into seasonbattingstats table. Recalculates stats from stratplay data and upserts into seasonbattingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_batting_stats called with empty player_ids list") logger.warning("update_season_batting_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season batting stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert batting stats # SQL query to recalculate and upsert batting stats
query = """ query = """
@ -217,12 +215,14 @@ def update_season_batting_stats(player_ids, season, db_connection):
sb = EXCLUDED.sb, sb = EXCLUDED.sb,
cs = EXCLUDED.cs; cs = EXCLUDED.cs;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season batting stats: {e}") logger.error(f"Error updating season batting stats: {e}")
raise raise
@ -233,17 +233,19 @@ def update_season_pitching_stats(player_ids, season, db_connection):
Update season pitching stats for specific players in a given season. Update season pitching stats for specific players in a given season.
Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table. Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_pitching_stats called with empty player_ids list") logger.warning("update_season_pitching_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert pitching stats # SQL query to recalculate and upsert pitching stats
query = """ query = """
@ -460,12 +462,14 @@ def update_season_pitching_stats(player_ids, season, db_connection):
rbipercent = EXCLUDED.rbipercent, rbipercent = EXCLUDED.rbipercent,
re24 = EXCLUDED.re24; re24 = EXCLUDED.re24;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season pitching stats: {e}") logger.error(f"Error updating season pitching stats: {e}")
raise raise
@ -474,26 +478,24 @@ def update_season_pitching_stats(player_ids, season, db_connection):
def send_webhook_message(message: str) -> bool: def send_webhook_message(message: str) -> bool:
""" """
Send a message to Discord via webhook. Send a message to Discord via webhook.
Args: Args:
message: The message content to send message: The message content to send
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW"
try: try:
payload = { payload = {"content": message}
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10) response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status() response.raise_for_status()
logger.info(f"Webhook message sent successfully: {message[:100]}...") logger.info(f"Webhook message sent successfully: {message[:100]}...")
return True return True
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to send webhook message: {e}") logger.error(f"Failed to send webhook message: {e}")
return False return False
@ -502,99 +504,106 @@ def send_webhook_message(message: str) -> bool:
return False return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): def cache_result(
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
):
""" """
Decorator to cache function results in Redis with parameter normalization. Decorator to cache function results in Redis with parameter normalization.
Args: Args:
ttl: Time to live in seconds (default: 5 minutes) ttl: Time to live in seconds (default: 5 minutes)
key_prefix: Prefix for cache keys (default: "api") key_prefix: Prefix for cache keys (default: "api")
normalize_params: Remove None/empty values to reduce cache variations (default: True) normalize_params: Remove None/empty values to reduce cache variations (default: True)
Usage: Usage:
@cache_result(ttl=600, key_prefix="stats") @cache_result(ttl=600, key_prefix="stats")
async def get_player_stats(player_id: int, season: Optional[int] = None): async def get_player_stats(player_id: int, season: Optional[int] = None):
# expensive operation # expensive operation
return stats return stats
# These will use the same cache entry when normalize_params=True: # These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123) # get_player_stats(123, None) and get_player_stats(123)
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# Skip caching if Redis is not available # Skip caching if Redis is not available
if redis_client is None: if redis_client is None:
return await func(*args, **kwargs) return await func(*args, **kwargs)
try: try:
# Normalize parameters to reduce cache variations # Normalize parameters to reduce cache variations
normalized_kwargs = kwargs.copy() normalized_kwargs = kwargs.copy()
if normalize_params: if normalize_params:
# Remove None values and empty collections # Remove None values and empty collections
normalized_kwargs = { normalized_kwargs = {
k: v for k, v in kwargs.items() k: v
for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {} if v is not None and v != [] and v != "" and v != {}
} }
# Generate more readable cache key # Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None) args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([ kwargs_str = "_".join(
f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
]) )
# Combine args and kwargs for cache key # Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__] key_parts = [key_prefix, func.__name__]
if args_str: if args_str:
key_parts.append(args_str) key_parts.append(args_str)
if kwargs_str: if kwargs_str:
key_parts.append(kwargs_str) key_parts.append(kwargs_str)
cache_key = ":".join(key_parts) cache_key = ":".join(key_parts)
# Truncate very long cache keys to prevent Redis key size limits # Truncate very long cache keys to prevent Redis key size limits
if len(cache_key) > 200: if len(cache_key) > 200:
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}" cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
# Try to get from cache # Try to get from cache
cached_result = redis_client.get(cache_key) cached_result = redis_client.get(cache_key)
if cached_result is not None: if cached_result is not None:
logger.debug(f"Cache hit for key: {cache_key}") logger.debug(f"Cache hit for key: {cache_key}")
return json.loads(cached_result) return json.loads(cached_result)
# Cache miss - execute function # Cache miss - execute function
logger.debug(f"Cache miss for key: {cache_key}") logger.debug(f"Cache miss for key: {cache_key}")
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized # Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
if not isinstance(result, Response): if not isinstance(result, Response):
# Store in cache with TTL # Store in cache with TTL
redis_client.setex( redis_client.setex(
cache_key, cache_key,
ttl, ttl,
json.dumps(result, default=str, ensure_ascii=False) json.dumps(result, default=str, ensure_ascii=False),
) )
else: else:
logger.debug(f"Skipping cache for Response object from {func.__name__}") logger.debug(
f"Skipping cache for Response object from {func.__name__}"
)
return result return result
except Exception as e: except Exception as e:
# If caching fails, log error and continue without caching # If caching fails, log error and continue without caching
logger.error(f"Cache error for {func.__name__}: {e}") logger.error(f"Cache error for {func.__name__}: {e}")
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
def invalidate_cache(pattern: str = "*"): def invalidate_cache(pattern: str = "*"):
""" """
Invalidate cache entries matching a pattern. Invalidate cache entries matching a pattern.
Args: Args:
pattern: Redis pattern to match keys (default: "*" for all) pattern: Redis pattern to match keys (default: "*" for all)
Usage: Usage:
invalidate_cache("stats:*") # Clear all stats cache invalidate_cache("stats:*") # Clear all stats cache
invalidate_cache("api:get_player_*") # Clear specific player cache invalidate_cache("api:get_player_*") # Clear specific player cache
@ -602,12 +611,14 @@ def invalidate_cache(pattern: str = "*"):
if redis_client is None: if redis_client is None:
logger.warning("Cannot invalidate cache: Redis not available") logger.warning("Cannot invalidate cache: Redis not available")
return 0 return 0
try: try:
keys = redis_client.keys(pattern) keys = redis_client.keys(pattern)
if keys: if keys:
deleted = redis_client.delete(*keys) deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") logger.info(
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
)
return deleted return deleted
else: else:
logger.debug(f"No cache entries found matching pattern: {pattern}") logger.debug(f"No cache entries found matching pattern: {pattern}")
@ -620,13 +631,13 @@ def invalidate_cache(pattern: str = "*"):
def get_cache_stats() -> dict: def get_cache_stats() -> dict:
""" """
Get Redis cache statistics. Get Redis cache statistics.
Returns: Returns:
dict: Cache statistics including memory usage, key count, etc. dict: Cache statistics including memory usage, key count, etc.
""" """
if redis_client is None: if redis_client is None:
return {"status": "unavailable", "message": "Redis not connected"} return {"status": "unavailable", "message": "Redis not connected"}
try: try:
info = redis_client.info() info = redis_client.info()
return { return {
@ -634,7 +645,7 @@ def get_cache_stats() -> dict:
"memory_used": info.get("used_memory_human", "unknown"), "memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(), "total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0) "uptime_seconds": info.get("uptime_in_seconds", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting cache stats: {e}") logger.error(f"Error getting cache stats: {e}")
@ -642,34 +653,35 @@ def get_cache_stats() -> dict:
def add_cache_headers( def add_cache_headers(
max_age: int = 300, max_age: int = 300,
cache_type: str = "public", cache_type: str = "public",
vary: Optional[str] = None, vary: Optional[str] = None,
etag: bool = False etag: bool = False,
): ):
""" """
Decorator to add HTTP cache headers to FastAPI responses. Decorator to add HTTP cache headers to FastAPI responses.
Args: Args:
max_age: Cache duration in seconds (default: 5 minutes) max_age: Cache duration in seconds (default: 5 minutes)
cache_type: "public", "private", or "no-cache" (default: "public") cache_type: "public", "private", or "no-cache" (default: "public")
vary: Vary header value (e.g., "Accept-Encoding, Authorization") vary: Vary header value (e.g., "Accept-Encoding, Authorization")
etag: Whether to generate ETag based on response content etag: Whether to generate ETag based on response content
Usage: Usage:
@add_cache_headers(max_age=1800, cache_type="public") @add_cache_headers(max_age=1800, cache_type="public")
async def get_static_data(): async def get_static_data():
return {"data": "static content"} return {"data": "static content"}
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization") @add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
async def get_user_data(): async def get_user_data():
return {"data": "user specific"} return {"data": "user specific"}
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Handle different response types # Handle different response types
if isinstance(result, Response): if isinstance(result, Response):
response = result response = result
@ -677,38 +689,41 @@ def add_cache_headers(
# Convert to Response with JSON content # Convert to Response with JSON content
response = Response( response = Response(
content=json.dumps(result, default=str, ensure_ascii=False), content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json" media_type="application/json",
) )
else: else:
# Handle other response types # Handle other response types
response = Response(content=str(result)) response = Response(content=str(result))
# Build Cache-Control header # Build Cache-Control header
cache_control_parts = [cache_type] cache_control_parts = [cache_type]
if cache_type != "no-cache" and max_age > 0: if cache_type != "no-cache" and max_age > 0:
cache_control_parts.append(f"max-age={max_age}") cache_control_parts.append(f"max-age={max_age}")
response.headers["Cache-Control"] = ", ".join(cache_control_parts) response.headers["Cache-Control"] = ", ".join(cache_control_parts)
# Add Vary header if specified # Add Vary header if specified
if vary: if vary:
response.headers["Vary"] = vary response.headers["Vary"] = vary
# Add ETag if requested # Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): if etag and (
hasattr(result, "__dict__") or isinstance(result, (dict, list))
):
content_hash = hashlib.md5( content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode() json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest() ).hexdigest()
response.headers["ETag"] = f'"{content_hash}"' response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content # Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( response.headers["Last-Modified"] = datetime.datetime.now(
"%a, %d %b %Y %H:%M:%S GMT" datetime.timezone.utc
) ).strftime("%a, %d %b %Y %H:%M:%S GMT")
return response return response
return wrapper return wrapper
return decorator return decorator
@ -718,52 +733,59 @@ def handle_db_errors(func):
Ensures proper cleanup of database connections and provides consistent error handling. Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces. Includes comprehensive logging with function context, timing, and stack traces.
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
import time import time
import traceback import traceback
from .db_engine import db # Import here to avoid circular imports from .db_engine import db # Import here to avoid circular imports
start_time = time.time() start_time = time.time()
func_name = f"{func.__module__}.{func.__name__}" func_name = f"{func.__module__}.{func.__name__}"
# Sanitize arguments for logging (exclude sensitive data) # Sanitize arguments for logging (exclude sensitive data)
safe_args = [] safe_args = []
safe_kwargs = {} safe_kwargs = {}
try: try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.) # Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args: for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object if hasattr(arg, "__dict__") and hasattr(
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") arg, "url"
): # FastAPI Request object
safe_args.append(
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
)
else: else:
safe_args.append(str(arg)[:100]) # Truncate long values safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']: if key.lower() in ["token", "password", "secret", "key"]:
safe_kwargs[key] = '[REDACTED]' safe_kwargs[key] = "[REDACTED]"
else: else:
safe_kwargs[key] = str(value)[:100] # Truncate long values safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") logger.info(
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
)
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s") logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s")
return result return result
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
error_trace = traceback.format_exc() error_trace = traceback.format_exc()
logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s") logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s")
logger.error(f"Function args: {safe_args}") logger.error(f"Function args: {safe_args}")
logger.error(f"Function kwargs: {safe_kwargs}") logger.error(f"Function kwargs: {safe_kwargs}")
logger.error(f"Exception: {str(e)}") logger.error(f"Exception: {str(e)}")
logger.error(f"Full traceback:\n{error_trace}") logger.error(f"Full traceback:\n{error_trace}")
try: try:
logger.info(f"Attempting database rollback for {func_name}") logger.info(f"Attempting database rollback for {func_name}")
db.rollback() db.rollback()
@ -775,8 +797,12 @@ def handle_db_errors(func):
db.close() db.close()
logger.info(f"Database connection closed for {func_name}") logger.info(f"Database connection closed for {func_name}")
except Exception as close_error: except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}") logger.error(
f"Error closing database connection in {func_name}: {close_error}"
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') )
raise HTTPException(
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
)
return wrapper return wrapper

View File

@ -7,41 +7,58 @@ from fastapi import Depends, FastAPI, Request
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
# from fastapi.openapi.docs import get_swagger_ui_html from .routers_v3 import (
# from fastapi.openapi.utils import get_openapi current,
players,
results,
schedules,
standings,
teams,
transactions,
battingstats,
pitchingstats,
fieldingstats,
draftpicks,
draftlist,
managers,
awards,
draftdata,
keepers,
stratgame,
stratplay,
injuries,
decisions,
divisions,
sbaplayers,
custom_commands,
help_commands,
views,
)
from .routers_v3 import current, players, results, schedules, standings, teams, transactions, battingstats, pitchingstats, fieldingstats, draftpicks, draftlist, managers, awards, draftdata, keepers, stratgame, stratplay, injuries, decisions, divisions, sbaplayers, custom_commands, help_commands, views log_level = logging.INFO if os.environ.get("LOG_LEVEL") == "INFO" else logging.WARNING
logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
log_level = logging.INFO if os.environ.get('LOG_LEVEL') == 'INFO' else logging.WARNING
# logging.basicConfig(
# filename=f'logs/database/{date}.log',
# format='%(asctime)s - sba-database - %(levelname)s - %(message)s',
# level=log_level
# )
logger = logging.getLogger('discord_app')
logger.setLevel(log_level) logger.setLevel(log_level)
handler = RotatingFileHandler( handler = RotatingFileHandler(
filename='./logs/sba-database.log', filename="./logs/sba-database.log",
# encoding='utf-8', # encoding='utf-8',
maxBytes=8 * 1024 * 1024, # 8 MiB maxBytes=8 * 1024 * 1024, # 8 MiB
backupCount=5, # Rotate through 5 files backupCount=5, # Rotate through 5 files
) )
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
app = FastAPI( app = FastAPI(
# root_path='/api', # root_path='/api',
responses={404: {'description': 'Not found'}}, responses={404: {"description": "Not found"}},
docs_url='/api/docs', docs_url="/api/docs",
redoc_url='/api/redoc' redoc_url="/api/redoc",
) )
logger.info(f'Starting up now...') logger.info(f"Starting up now...")
app.include_router(current.router) app.include_router(current.router)
@ -70,20 +87,17 @@ app.include_router(custom_commands.router)
app.include_router(help_commands.router) app.include_router(help_commands.router)
app.include_router(views.router) app.include_router(views.router)
logger.info(f'Loaded all routers.') logger.info(f"Loaded all routers.")
@app.get("/api/docs", include_in_schema=False) @app.get("/api/docs", include_in_schema=False)
async def get_docs(req: Request): async def get_docs(req: Request):
print(req.scope) print(req.scope)
return get_swagger_ui_html(openapi_url=req.scope.get('root_path')+'/openapi.json', title='Swagger') return get_swagger_ui_html(
openapi_url=req.scope.get("root_path") + "/openapi.json", title="Swagger"
)
@app.get("/api/openapi.json", include_in_schema=False) @app.get("/api/openapi.json", include_in_schema=False)
async def openapi(): async def openapi():
return get_openapi(title='SBa API Docs', version=f'0.1.1', routes=app.routes) return get_openapi(title="SBa API Docs", version=f"0.1.1", routes=app.routes)
# @app.get("/api")
# async def root():
# return {"message": "Hello Bigger Applications!"}

View File

@ -3,15 +3,27 @@ from typing import List, Optional, Literal
import logging import logging
import pydantic import pydantic
from ..db_engine import db, BattingStat, Team, Player, Current, model_to_dict, chunked, fn, per_season_weeks from ..db_engine import (
from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors db,
BattingStat,
logger = logging.getLogger('discord_app') Team,
Player,
router = APIRouter( Current,
prefix='/api/v3/battingstats', model_to_dict,
tags=['battingstats'] chunked,
fn,
per_season_weeks,
) )
from ..dependencies import (
oauth2_scheme,
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
)
logger = logging.getLogger("discord_app")
router = APIRouter(prefix="/api/v3/battingstats", tags=["battingstats"])
class BatStatModel(pydantic.BaseModel): class BatStatModel(pydantic.BaseModel):
@ -60,29 +72,37 @@ class BatStatList(pydantic.BaseModel):
stats: List[BatStatModel] stats: List[BatStatModel]
@router.get('') @router.get("")
@handle_db_errors @handle_db_errors
async def get_batstats( async def get_batstats(
season: int, s_type: Optional[str] = 'regular', team_abbrev: list = Query(default=None), season: int,
player_name: list = Query(default=None), player_id: list = Query(default=None), s_type: Optional[str] = "regular",
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), team_abbrev: list = Query(default=None),
position: list = Query(default=None), limit: Optional[int] = None, sort: Optional[str] = None, player_name: list = Query(default=None),
short_output: Optional[bool] = True): player_id: list = Query(default=None),
if 'post' in s_type.lower(): week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
limit: Optional[int] = None,
sort: Optional[str] = None,
short_output: Optional[bool] = True,
):
if "post" in s_type.lower():
all_stats = BattingStat.post_season(season) all_stats = BattingStat.post_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() db.close()
return {'count': 0, 'stats': []} return {"count": 0, "stats": []}
elif s_type.lower() in ['combined', 'total', 'all']: elif s_type.lower() in ["combined", "total", "all"]:
all_stats = BattingStat.combined_season(season) all_stats = BattingStat.combined_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() db.close()
return {'count': 0, 'stats': []} return {"count": 0, "stats": []}
else: else:
all_stats = BattingStat.regular_season(season) all_stats = BattingStat.regular_season(season)
if all_stats.count() == 0: if all_stats.count() == 0:
db.close() db.close()
return {'count': 0, 'stats': []} return {"count": 0, "stats": []}
if position is not None: if position is not None:
all_stats = all_stats.where(BattingStat.pos << [x.upper() for x in position]) all_stats = all_stats.where(BattingStat.pos << [x.upper() for x in position])
@ -93,7 +113,9 @@ async def get_batstats(
if player_id: if player_id:
all_stats = all_stats.where(BattingStat.player_id << player_id) all_stats = all_stats.where(BattingStat.player_id << player_id)
else: else:
p_query = Player.select_season(season).where(fn.Lower(Player.name) << [x.lower() for x in player_name]) p_query = Player.select_season(season).where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
all_stats = all_stats.where(BattingStat.player << p_query) all_stats = all_stats.where(BattingStat.player << p_query)
if game_num: if game_num:
all_stats = all_stats.where(BattingStat.game == game_num) all_stats = all_stats.where(BattingStat.game == game_num)
@ -108,21 +130,19 @@ async def get_batstats(
db.close() db.close()
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f'Start week {start} is after end week {end} - cannot pull stats' detail=f"Start week {start} is after end week {end} - cannot pull stats",
) )
all_stats = all_stats.where( all_stats = all_stats.where((BattingStat.week >= start) & (BattingStat.week <= end))
(BattingStat.week >= start) & (BattingStat.week <= end)
)
if limit: if limit:
all_stats = all_stats.limit(limit) all_stats = all_stats.limit(limit)
if sort: if sort:
if sort == 'newest': if sort == "newest":
all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game) all_stats = all_stats.order_by(-BattingStat.week, -BattingStat.game)
return_stats = { return_stats = {
'count': all_stats.count(), "count": all_stats.count(),
'stats': [model_to_dict(x, recurse=not short_output) for x in all_stats], "stats": [model_to_dict(x, recurse=not short_output) for x in all_stats],
# 'stats': [{'id': x.id} for x in all_stats] # 'stats': [{'id': x.id} for x in all_stats]
} }
@ -130,52 +150,82 @@ async def get_batstats(
return return_stats return return_stats
@router.get('/totals') @router.get("/totals")
@handle_db_errors @handle_db_errors
async def get_totalstats( async def get_totalstats(
season: int, s_type: Literal['regular', 'post', 'total', None] = None, team_abbrev: list = Query(default=None), season: int,
team_id: list = Query(default=None), player_name: list = Query(default=None), s_type: Literal["regular", "post", "total", None] = None,
week_start: Optional[int] = None, week_end: Optional[int] = None, game_num: list = Query(default=None), team_abbrev: list = Query(default=None),
position: list = Query(default=None), sort: Optional[str] = None, player_id: list = Query(default=None), team_id: list = Query(default=None),
group_by: Literal['team', 'player', 'playerteam'] = 'player', short_output: Optional[bool] = False, player_name: list = Query(default=None),
min_pa: Optional[int] = 1, week: list = Query(default=None)): week_start: Optional[int] = None,
week_end: Optional[int] = None,
game_num: list = Query(default=None),
position: list = Query(default=None),
sort: Optional[str] = None,
player_id: list = Query(default=None),
group_by: Literal["team", "player", "playerteam"] = "player",
short_output: Optional[bool] = False,
min_pa: Optional[int] = 1,
week: list = Query(default=None),
):
if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1: if sum(1 for x in [s_type, (week_start or week_end), week] if x is not None) > 1:
raise HTTPException(status_code=400, detail=f'Only one of s_type, week_start/week_end, or week may be used.') raise HTTPException(
status_code=400,
detail=f"Only one of s_type, week_start/week_end, or week may be used.",
)
# Build SELECT fields conditionally based on group_by to match GROUP BY exactly # Build SELECT fields conditionally based on group_by to match GROUP BY exactly
select_fields = [] select_fields = []
if group_by == 'player': if group_by == "player":
select_fields = [BattingStat.player] select_fields = [BattingStat.player]
elif group_by == 'team': elif group_by == "team":
select_fields = [BattingStat.team] select_fields = [BattingStat.team]
elif group_by == 'playerteam': elif group_by == "playerteam":
select_fields = [BattingStat.player, BattingStat.team] select_fields = [BattingStat.player, BattingStat.team]
else: else:
# Default case # Default case
select_fields = [BattingStat.player] select_fields = [BattingStat.player]
all_stats = ( all_stats = (
BattingStat BattingStat.select(
.select(*select_fields, *select_fields,
fn.SUM(BattingStat.pa).alias('sum_pa'), fn.SUM(BattingStat.ab).alias('sum_ab'), fn.SUM(BattingStat.pa).alias("sum_pa"),
fn.SUM(BattingStat.run).alias('sum_run'), fn.SUM(BattingStat.hit).alias('sum_hit'), fn.SUM(BattingStat.ab).alias("sum_ab"),
fn.SUM(BattingStat.rbi).alias('sum_rbi'), fn.SUM(BattingStat.double).alias('sum_double'), fn.SUM(BattingStat.run).alias("sum_run"),
fn.SUM(BattingStat.triple).alias('sum_triple'), fn.SUM(BattingStat.hr).alias('sum_hr'), fn.SUM(BattingStat.hit).alias("sum_hit"),
fn.SUM(BattingStat.bb).alias('sum_bb'), fn.SUM(BattingStat.so).alias('sum_so'), fn.SUM(BattingStat.rbi).alias("sum_rbi"),
fn.SUM(BattingStat.hbp).alias('sum_hbp'), fn.SUM(BattingStat.sac).alias('sum_sac'), fn.SUM(BattingStat.double).alias("sum_double"),
fn.SUM(BattingStat.ibb).alias('sum_ibb'), fn.SUM(BattingStat.gidp).alias('sum_gidp'), fn.SUM(BattingStat.triple).alias("sum_triple"),
fn.SUM(BattingStat.sb).alias('sum_sb'), fn.SUM(BattingStat.cs).alias('sum_cs'), fn.SUM(BattingStat.hr).alias("sum_hr"),
fn.SUM(BattingStat.bphr).alias('sum_bphr'), fn.SUM(BattingStat.bpfo).alias('sum_bpfo'), fn.SUM(BattingStat.bb).alias("sum_bb"),
fn.SUM(BattingStat.bp1b).alias('sum_bp1b'), fn.SUM(BattingStat.bplo).alias('sum_bplo'), fn.SUM(BattingStat.so).alias("sum_so"),
fn.SUM(BattingStat.xba).alias('sum_xba'), fn.SUM(BattingStat.xbt).alias('sum_xbt'), fn.SUM(BattingStat.hbp).alias("sum_hbp"),
fn.SUM(BattingStat.xch).alias('sum_xch'), fn.SUM(BattingStat.xhit).alias('sum_xhit'), fn.SUM(BattingStat.sac).alias("sum_sac"),
fn.SUM(BattingStat.error).alias('sum_error'), fn.SUM(BattingStat.pb).alias('sum_pb'), fn.SUM(BattingStat.ibb).alias("sum_ibb"),
fn.SUM(BattingStat.sbc).alias('sum_sbc'), fn.SUM(BattingStat.csc).alias('sum_csc'), fn.SUM(BattingStat.gidp).alias("sum_gidp"),
fn.SUM(BattingStat.roba).alias('sum_roba'), fn.SUM(BattingStat.robs).alias('sum_robs'), fn.SUM(BattingStat.sb).alias("sum_sb"),
fn.SUM(BattingStat.raa).alias('sum_raa'), fn.SUM(BattingStat.rto).alias('sum_rto')) fn.SUM(BattingStat.cs).alias("sum_cs"),
.where(BattingStat.season == season) fn.SUM(BattingStat.bphr).alias("sum_bphr"),
.having(fn.SUM(BattingStat.pa) >= min_pa) fn.SUM(BattingStat.bpfo).alias("sum_bpfo"),
fn.SUM(BattingStat.bp1b).alias("sum_bp1b"),
fn.SUM(BattingStat.bplo).alias("sum_bplo"),
fn.SUM(BattingStat.xba).alias("sum_xba"),
fn.SUM(BattingStat.xbt).alias("sum_xbt"),
fn.SUM(BattingStat.xch).alias("sum_xch"),
fn.SUM(BattingStat.xhit).alias("sum_xhit"),
fn.SUM(BattingStat.error).alias("sum_error"),
fn.SUM(BattingStat.pb).alias("sum_pb"),
fn.SUM(BattingStat.sbc).alias("sum_sbc"),
fn.SUM(BattingStat.csc).alias("sum_csc"),
fn.SUM(BattingStat.roba).alias("sum_roba"),
fn.SUM(BattingStat.robs).alias("sum_robs"),
fn.SUM(BattingStat.raa).alias("sum_raa"),
fn.SUM(BattingStat.rto).alias("sum_rto"),
)
.where(BattingStat.season == season)
.having(fn.SUM(BattingStat.pa) >= min_pa)
) )
if True in [s_type is not None, week_start is not None, week_end is not None]: if True in [s_type is not None, week_start is not None, week_end is not None]:
@ -185,16 +235,20 @@ async def get_totalstats(
elif week_start is not None or week_end is not None: elif week_start is not None or week_end is not None:
if week_start is None or week_end is None: if week_start is None or week_end is None:
raise HTTPException( raise HTTPException(
status_code=400, detail='Both week_start and week_end must be included if either is used.' status_code=400,
detail="Both week_start and week_end must be included if either is used.",
)
weeks["start"] = week_start
if week_end < weeks["start"]:
raise HTTPException(
status_code=400,
detail="week_end must be greater than or equal to week_start",
) )
weeks['start'] = week_start
if week_end < weeks['start']:
raise HTTPException(status_code=400, detail='week_end must be greater than or equal to week_start')
else: else:
weeks['end'] = week_end weeks["end"] = week_end
all_stats = all_stats.where( all_stats = all_stats.where(
(BattingStat.week >= weeks['start']) & (BattingStat.week <= weeks['end']) (BattingStat.week >= weeks["start"]) & (BattingStat.week <= weeks["end"])
) )
elif week is not None: elif week is not None:
all_stats = all_stats.where(BattingStat.week << week) all_stats = all_stats.where(BattingStat.week << week)
@ -204,14 +258,20 @@ async def get_totalstats(
if position is not None: if position is not None:
p_list = [x.upper() for x in position] p_list = [x.upper() for x in position]
all_players = Player.select().where( all_players = Player.select().where(
(Player.pos_1 << p_list) | (Player.pos_2 << p_list) | (Player.pos_3 << p_list) | ( Player.pos_4 << p_list) | (Player.pos_1 << p_list)
(Player.pos_5 << p_list) | (Player.pos_6 << p_list) | (Player.pos_7 << p_list) | ( Player.pos_8 << p_list) | (Player.pos_2 << p_list)
| (Player.pos_3 << p_list)
| (Player.pos_4 << p_list)
| (Player.pos_5 << p_list)
| (Player.pos_6 << p_list)
| (Player.pos_7 << p_list)
| (Player.pos_8 << p_list)
) )
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
if sort is not None: if sort is not None:
if sort == 'player': if sort == "player":
all_stats = all_stats.order_by(BattingStat.player) all_stats = all_stats.order_by(BattingStat.player)
elif sort == 'team': elif sort == "team":
all_stats = all_stats.order_by(BattingStat.team) all_stats = all_stats.order_by(BattingStat.team)
if group_by is not None: if group_by is not None:
# Use the same fields for GROUP BY as we used for SELECT # Use the same fields for GROUP BY as we used for SELECT
@ -227,75 +287,78 @@ async def get_totalstats(
all_teams = Team.select().where(Team.id << team_id) all_teams = Team.select().where(Team.id << team_id)
all_stats = all_stats.where(BattingStat.team << all_teams) all_stats = all_stats.where(BattingStat.team << all_teams)
elif team_abbrev is not None: elif team_abbrev is not None:
all_teams = Team.select().where(fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]) all_teams = Team.select().where(
fn.Lower(Team.abbrev) << [x.lower() for x in team_abbrev]
)
all_stats = all_stats.where(BattingStat.team << all_teams) all_stats = all_stats.where(BattingStat.team << all_teams)
if player_name is not None: if player_name is not None:
all_players = Player.select().where(fn.Lower(Player.name) << [x.lower() for x in player_name]) all_players = Player.select().where(
fn.Lower(Player.name) << [x.lower() for x in player_name]
)
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
elif player_id is not None: elif player_id is not None:
all_players = Player.select().where(Player.id << player_id) all_players = Player.select().where(Player.id << player_id)
all_stats = all_stats.where(BattingStat.player << all_players) all_stats = all_stats.where(BattingStat.player << all_players)
return_stats = { return_stats = {"count": all_stats.count(), "stats": []}
'count': all_stats.count(),
'stats': []
}
for x in all_stats: for x in all_stats:
# Handle player field based on grouping with safe access # Handle player field based on grouping with safe access
this_player = 'TOT' this_player = "TOT"
if 'player' in group_by and hasattr(x, 'player'): if "player" in group_by and hasattr(x, "player"):
this_player = x.player_id if short_output else model_to_dict(x.player, recurse=False) this_player = (
x.player_id if short_output else model_to_dict(x.player, recurse=False)
)
# Handle team field based on grouping with safe access # Handle team field based on grouping with safe access
this_team = 'TOT' this_team = "TOT"
if 'team' in group_by and hasattr(x, 'team'): if "team" in group_by and hasattr(x, "team"):
this_team = x.team_id if short_output else model_to_dict(x.team, recurse=False) this_team = (
x.team_id if short_output else model_to_dict(x.team, recurse=False)
return_stats['stats'].append({ )
'player': this_player,
'team': this_team, return_stats["stats"].append(
'pa': x.sum_pa, {
'ab': x.sum_ab, "player": this_player,
'run': x.sum_run, "team": this_team,
'hit': x.sum_hit, "pa": x.sum_pa,
'rbi': x.sum_rbi, "ab": x.sum_ab,
'double': x.sum_double, "run": x.sum_run,
'triple': x.sum_triple, "hit": x.sum_hit,
'hr': x.sum_hr, "rbi": x.sum_rbi,
'bb': x.sum_bb, "double": x.sum_double,
'so': x.sum_so, "triple": x.sum_triple,
'hbp': x.sum_hbp, "hr": x.sum_hr,
'sac': x.sum_sac, "bb": x.sum_bb,
'ibb': x.sum_ibb, "so": x.sum_so,
'gidp': x.sum_gidp, "hbp": x.sum_hbp,
'sb': x.sum_sb, "sac": x.sum_sac,
'cs': x.sum_cs, "ibb": x.sum_ibb,
'bphr': x.sum_bphr, "gidp": x.sum_gidp,
'bpfo': x.sum_bpfo, "sb": x.sum_sb,
'bp1b': x.sum_bp1b, "cs": x.sum_cs,
'bplo': x.sum_bplo "bphr": x.sum_bphr,
}) "bpfo": x.sum_bpfo,
"bp1b": x.sum_bp1b,
"bplo": x.sum_bplo,
}
)
db.close() db.close()
return return_stats return return_stats
# @router.get('/career/{player_name}') @router.patch("/{stat_id}", include_in_schema=PRIVATE_IN_SCHEMA)
# async def get_careerstats(
# s_type: Literal['regular', 'post', 'total'] = 'regular', player_name: list = Query(default=None)):
# pass # Keep Career Stats table and recalculate after posting stats
@router.patch('/{stat_id}', include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def patch_batstats(stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)): async def patch_batstats(
stat_id: int, new_stats: BatStatModel, token: str = Depends(oauth2_scheme)
):
if not valid_token(token): if not valid_token(token):
logger.warning(f'patch_batstats - Bad Token: {token}') logger.warning(f"patch_batstats - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
if BattingStat.get_or_none(BattingStat.id == stat_id) is None: if BattingStat.get_or_none(BattingStat.id == stat_id) is None:
raise HTTPException(status_code=404, detail=f'Stat ID {stat_id} not found') raise HTTPException(status_code=404, detail=f"Stat ID {stat_id} not found")
BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute() BattingStat.update(**new_stats.dict()).where(BattingStat.id == stat_id).execute()
r_stat = model_to_dict(BattingStat.get_by_id(stat_id)) r_stat = model_to_dict(BattingStat.get_by_id(stat_id))
@ -303,12 +366,12 @@ async def patch_batstats(stat_id: int, new_stats: BatStatModel, token: str = Dep
return r_stat return r_stat
@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)): async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'post_batstats - Bad Token: {token}') logger.warning(f"post_batstats - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
all_stats = [] all_stats = []
@ -316,9 +379,13 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)
team = Team.get_or_none(Team.id == x.team_id) team = Team.get_or_none(Team.id == x.team_id)
this_player = Player.get_or_none(Player.id == x.player_id) this_player = Player.get_or_none(Player.id == x.player_id)
if team is None: if team is None:
raise HTTPException(status_code=404, detail=f'Team ID {x.team_id} not found') raise HTTPException(
status_code=404, detail=f"Team ID {x.team_id} not found"
)
if this_player is None: if this_player is None:
raise HTTPException(status_code=404, detail=f'Player ID {x.player_id} not found') raise HTTPException(
status_code=404, detail=f"Player ID {x.player_id} not found"
)
all_stats.append(BattingStat(**x.dict())) all_stats.append(BattingStat(**x.dict()))
@ -329,4 +396,4 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)
# Update career stats # Update career stats
db.close() db.close()
return f'Added {len(all_stats)} batting lines' return f"Added {len(all_stats)} batting lines"

View File

@ -4,15 +4,26 @@ import copy
import logging import logging
import pydantic import pydantic
from ..db_engine import db, Decision, StratGame, Player, model_to_dict, chunked, fn, Team from ..db_engine import (
from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors db,
Decision,
logger = logging.getLogger('discord_app') StratGame,
Player,
router = APIRouter( model_to_dict,
prefix='/api/v3/decisions', chunked,
tags=['decisions'] fn,
Team,
) )
from ..dependencies import (
oauth2_scheme,
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
)
logger = logging.getLogger("discord_app")
router = APIRouter(prefix="/api/v3/decisions", tags=["decisions"])
class DecisionModel(pydantic.BaseModel): class DecisionModel(pydantic.BaseModel):
@ -43,17 +54,31 @@ class DecisionReturnList(pydantic.BaseModel):
decisions: list[DecisionModel] decisions: list[DecisionModel]
@router.get('') @router.get("")
@handle_db_errors @handle_db_errors
async def get_decisions( async def get_decisions(
season: list = Query(default=None), week: list = Query(default=None), game_num: list = Query(default=None), season: list = Query(default=None),
s_type: Literal['regular', 'post', 'all', None] = None, team_id: list = Query(default=None), week: list = Query(default=None),
week_start: Optional[int] = None, week_end: Optional[int] = None, win: Optional[int] = None, game_num: list = Query(default=None),
loss: Optional[int] = None, hold: Optional[int] = None, save: Optional[int] = None, s_type: Literal["regular", "post", "all", None] = None,
b_save: Optional[int] = None, irunners: list = Query(default=None), irunners_scored: list = Query(default=None), team_id: list = Query(default=None),
game_id: list = Query(default=None), player_id: list = Query(default=None), week_start: Optional[int] = None,
limit: Optional[int] = None, short_output: Optional[bool] = False): week_end: Optional[int] = None,
all_dec = Decision.select().order_by(-Decision.season, -Decision.week, -Decision.game_num) win: Optional[int] = None,
loss: Optional[int] = None,
hold: Optional[int] = None,
save: Optional[int] = None,
b_save: Optional[int] = None,
irunners: list = Query(default=None),
irunners_scored: list = Query(default=None),
game_id: list = Query(default=None),
player_id: list = Query(default=None),
limit: Optional[int] = None,
short_output: Optional[bool] = False,
):
all_dec = Decision.select().order_by(
-Decision.season, -Decision.week, -Decision.game_num
)
if season is not None: if season is not None:
all_dec = all_dec.where(Decision.season << season) all_dec = all_dec.where(Decision.season << season)
@ -65,21 +90,13 @@ async def get_decisions(
all_dec = all_dec.where(Decision.game_id << game_id) all_dec = all_dec.where(Decision.game_id << game_id)
if player_id is not None: if player_id is not None:
all_dec = all_dec.where(Decision.pitcher << player_id) all_dec = all_dec.where(Decision.pitcher << player_id)
# # Need to allow for split-season stats
# if team_id is not None:
# all_teams = Team.select().where(Team.id << team_id)
# all_games = StratGame.select().where(
# (StratGame.away_team << all_teams) | (StratGame.home_team << all_teams))
# all_dec = all_dec.where(Decision.game << all_games)
# if team_id is not None:
# all_players = Player.select().where(Player.team_id << team_id)
# all_dec = all_dec.where(Decision.pitcher << all_players)
if team_id is not None: if team_id is not None:
s8_teams = [int(x) for x in team_id if int(x) <= 350] s8_teams = [int(x) for x in team_id if int(x) <= 350]
if season is not None and 8 in season or s8_teams: if season is not None and 8 in season or s8_teams:
all_teams = Team.select().where(Team.id << team_id) all_teams = Team.select().where(Team.id << team_id)
all_games = StratGame.select().where( all_games = StratGame.select().where(
(StratGame.away_team << all_teams) | (StratGame.home_team << all_teams)) (StratGame.away_team << all_teams) | (StratGame.home_team << all_teams)
)
all_dec = all_dec.where(Decision.game << all_games) all_dec = all_dec.where(Decision.game << all_games)
else: else:
all_teams = Team.select().where(Team.id << team_id) all_teams = Team.select().where(Team.id << team_id)
@ -87,9 +104,6 @@ async def get_decisions(
if s_type is not None: if s_type is not None:
all_games = StratGame.select().where(StratGame.season_type == s_type) all_games = StratGame.select().where(StratGame.season_type == s_type)
all_dec = all_dec.where(Decision.game << all_games) all_dec = all_dec.where(Decision.game << all_games)
# if team_id is not None:
# all_players = Player.select().where(Player.team_id << team_id)
# all_dec = all_dec.where(Decision.pitcher << all_players)
if week_start is not None: if week_start is not None:
all_dec = all_dec.where(Decision.week >= week_start) all_dec = all_dec.where(Decision.week >= week_start)
if week_end is not None: if week_end is not None:
@ -115,28 +129,38 @@ async def get_decisions(
all_dec = all_dec.limit(limit) all_dec = all_dec.limit(limit)
return_dec = { return_dec = {
'count': all_dec.count(), "count": all_dec.count(),
'decisions': [model_to_dict(x, recurse=not short_output) for x in all_dec] "decisions": [model_to_dict(x, recurse=not short_output) for x in all_dec],
} }
db.close() db.close()
return return_dec return return_dec
@router.patch('/{decision_id}', include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/{decision_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def patch_decision( async def patch_decision(
decision_id: int, win: Optional[int] = None, loss: Optional[int] = None, hold: Optional[int] = None, decision_id: int,
save: Optional[int] = None, b_save: Optional[int] = None, irunners: Optional[int] = None, win: Optional[int] = None,
irunners_scored: Optional[int] = None, rest_ip: Optional[int] = None, rest_required: Optional[int] = None, loss: Optional[int] = None,
token: str = Depends(oauth2_scheme)): hold: Optional[int] = None,
save: Optional[int] = None,
b_save: Optional[int] = None,
irunners: Optional[int] = None,
irunners_scored: Optional[int] = None,
rest_ip: Optional[int] = None,
rest_required: Optional[int] = None,
token: str = Depends(oauth2_scheme),
):
if not valid_token(token): if not valid_token(token):
logger.warning(f'patch_decision - Bad Token: {token}') logger.warning(f"patch_decision - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
this_dec = Decision.get_or_none(Decision.id == decision_id) this_dec = Decision.get_or_none(Decision.id == decision_id)
if this_dec is None: if this_dec is None:
db.close() db.close()
raise HTTPException(status_code=404, detail=f'Decision ID {decision_id} not found') raise HTTPException(
status_code=404, detail=f"Decision ID {decision_id} not found"
)
if win is not None: if win is not None:
this_dec.win = win this_dec.win = win
@ -163,22 +187,28 @@ async def patch_decision(
return d_result return d_result
else: else:
db.close() db.close()
raise HTTPException(status_code=500, detail=f'Unable to patch decision {decision_id}') raise HTTPException(
status_code=500, detail=f"Unable to patch decision {decision_id}"
)
@router.post('', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_scheme)): async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'post_decisions - Bad Token: {token}') logger.warning(f"post_decisions - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
new_dec = [] new_dec = []
for x in dec_list.decisions: for x in dec_list.decisions:
if StratGame.get_or_none(StratGame.id == x.game_id) is None: if StratGame.get_or_none(StratGame.id == x.game_id) is None:
raise HTTPException(status_code=404, detail=f'Game ID {x.game_id} not found') raise HTTPException(
status_code=404, detail=f"Game ID {x.game_id} not found"
)
if Player.get_or_none(Player.id == x.pitcher_id) is None: if Player.get_or_none(Player.id == x.pitcher_id) is None:
raise HTTPException(status_code=404, detail=f'Player ID {x.pitcher_id} not found') raise HTTPException(
status_code=404, detail=f"Player ID {x.pitcher_id} not found"
)
new_dec.append(x.dict()) new_dec.append(x.dict())
@ -187,49 +217,53 @@ async def post_decisions(dec_list: DecisionList, token: str = Depends(oauth2_sch
Decision.insert_many(batch).on_conflict_ignore().execute() Decision.insert_many(batch).on_conflict_ignore().execute()
db.close() db.close()
return f'Inserted {len(new_dec)} decisions' return f"Inserted {len(new_dec)} decisions"
@router.delete('/{decision_id}', include_in_schema=PRIVATE_IN_SCHEMA) @router.delete("/{decision_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def delete_decision(decision_id: int, token: str = Depends(oauth2_scheme)): async def delete_decision(decision_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'delete_decision - Bad Token: {token}') logger.warning(f"delete_decision - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
this_dec = Decision.get_or_none(Decision.id == decision_id) this_dec = Decision.get_or_none(Decision.id == decision_id)
if this_dec is None: if this_dec is None:
db.close() db.close()
raise HTTPException(status_code=404, detail=f'Decision ID {decision_id} not found') raise HTTPException(
status_code=404, detail=f"Decision ID {decision_id} not found"
)
count = this_dec.delete_instance() count = this_dec.delete_instance()
db.close() db.close()
if count == 1: if count == 1:
return f'Decision {decision_id} has been deleted' return f"Decision {decision_id} has been deleted"
else: else:
raise HTTPException(status_code=500, detail=f'Decision {decision_id} could not be deleted') raise HTTPException(
status_code=500, detail=f"Decision {decision_id} could not be deleted"
)
@router.delete('/game/{game_id}', include_in_schema=PRIVATE_IN_SCHEMA) @router.delete("/game/{game_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def delete_decisions_game(game_id: int, token: str = Depends(oauth2_scheme)): async def delete_decisions_game(game_id: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'delete_decisions_game - Bad Token: {token}') logger.warning(f"delete_decisions_game - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
this_game = StratGame.get_or_none(StratGame.id == game_id) this_game = StratGame.get_or_none(StratGame.id == game_id)
if not this_game: if not this_game:
db.close() db.close()
raise HTTPException(status_code=404, detail=f'Game ID {game_id} not found') raise HTTPException(status_code=404, detail=f"Game ID {game_id} not found")
count = Decision.delete().where(Decision.game == this_game).execute() count = Decision.delete().where(Decision.game == this_game).execute()
db.close() db.close()
if count > 0: if count > 0:
return f'Deleted {count} decisions matching Game ID {game_id}' return f"Deleted {count} decisions matching Game ID {game_id}"
else: else:
raise HTTPException(status_code=500, detail=f'No decisions matching Game ID {game_id} were deleted') raise HTTPException(
status_code=500,
detail=f"No decisions matching Game ID {game_id} were deleted",
)