major-domo-database/app/dependencies.py
Cal Corum e6a325ac8f Add CACHE_ENABLED env var to toggle Redis caching (v2.2.1)
- Set CACHE_ENABLED=false to disable caching without stopping Redis
- Defaults to true (caching enabled)
- Useful for draft periods requiring real-time data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-10 07:59:54 -06:00

783 lines
30 KiB
Python

import datetime
import hashlib
import json
import logging
import os
import requests
from functools import wraps
from typing import Optional
from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer
from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
logger = logging.getLogger('discord_app')
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
# logging.basicConfig(
# filename=f'logs/database/{date}.log',
# format='%(asctime)s - sba-database - %(levelname)s - %(message)s',
# level=log_level
# )
# Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost')
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379'))
REDIS_DB = int(os.environ.get('REDIS_DB', '0'))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true'
# Initialize Redis client with connection error handling
if not CACHE_ENABLED:
logger.info("Caching disabled via CACHE_ENABLED=false")
redis_client = None
else:
try:
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5
)
# Test connection
redis_client.ping()
logger.info(f"Redis connected successfully at {REDIS_HOST}:{REDIS_PORT}")
except Exception as e:
logger.warning(f"Redis connection failed: {e}. Caching will be disabled.")
redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper()
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False
def valid_token(token):
return token == os.environ.get('API_TOKEN')
def update_season_batting_stats(player_ids, season, db_connection):
"""
Update season batting stats for specific players in a given season.
Recalculates stats from stratplay data and upserts into seasonbattingstats table.
"""
if not player_ids:
logger.warning("update_season_batting_stats called with empty player_ids list")
return
# Convert single player_id to list for consistency
if isinstance(player_ids, int):
player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}")
try:
# SQL query to recalculate and upsert batting stats
query = """
WITH batting_stats AS (
SELECT
p.id AS player_id,
p.name,
p.sbaplayer_id,
p.team_id AS player_team_id,
t.abbrev AS player_team_abbrev,
sg.season,
-- Counting statistics (summed from StratPlays)
SUM(sp.pa) AS pa,
SUM(sp.ab) AS ab,
SUM(sp.run) AS run,
SUM(sp.hit) AS hit,
SUM(sp.double) AS double,
SUM(sp.triple) AS triple,
SUM(sp.homerun) AS homerun,
SUM(sp.rbi) AS rbi,
SUM(sp.bb) AS bb,
SUM(sp.so) AS so,
SUM(sp.bphr) AS bphr,
SUM(sp.bpfo) AS bpfo,
SUM(sp.bp1b) AS bp1b,
SUM(sp.bplo) AS bplo,
SUM(sp.gidp) AS gidp,
SUM(sp.hbp) AS hbp,
SUM(sp.sac) AS sac,
SUM(sp.ibb) AS ibb,
-- Calculated statistics using formulas
CASE
WHEN SUM(sp.ab) > 0
THEN ROUND(SUM(sp.hit)::DECIMAL / SUM(sp.ab), 3)
ELSE 0.000
END AS avg,
CASE
WHEN SUM(sp.pa) > 0
THEN ROUND((SUM(sp.hit) + SUM(sp.bb) + SUM(sp.hbp) + SUM(sp.ibb))::DECIMAL / SUM(sp.pa), 3)
ELSE 0.000
END AS obp,
CASE
WHEN SUM(sp.ab) > 0
THEN ROUND((SUM(sp.hit) + SUM(sp.double) + 2 * SUM(sp.triple) + 3 *
SUM(sp.homerun))::DECIMAL / SUM(sp.ab), 3)
ELSE 0.000
END AS slg,
CASE
WHEN SUM(sp.pa) > 0 AND SUM(sp.ab) > 0
THEN ROUND(
((SUM(sp.hit) + SUM(sp.bb) + SUM(sp.hbp) + SUM(sp.ibb))::DECIMAL / SUM(sp.pa)) +
((SUM(sp.hit) + SUM(sp.double) + 2 * SUM(sp.triple) + 3 *
SUM(sp.homerun))::DECIMAL / SUM(sp.ab)), 3)
ELSE 0.000
END AS ops,
-- wOBA calculation (simplified version)
CASE
WHEN SUM(sp.pa) > 0
THEN ROUND((0.690 * SUM(sp.bb) + 0.722 * SUM(sp.hbp) + 0.888 * (SUM(sp.hit) -
SUM(sp.double) - SUM(sp.triple) - SUM(sp.homerun)) +
1.271 * SUM(sp.double) + 1.616 * SUM(sp.triple) + 2.101 *
SUM(sp.homerun))::DECIMAL / SUM(sp.pa), 3)
ELSE 0.000
END AS woba,
CASE
WHEN SUM(sp.pa) > 0
THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.pa) * 100, 1)
ELSE 0.0
END AS k_pct
FROM stratplay sp
JOIN stratgame sg ON sg.id = sp.game_id
JOIN player p ON p.id = sp.batter_id
JOIN team t ON t.id = p.team_id
WHERE sg.season = %s AND p.id = ANY(%s)
GROUP BY p.id, p.name, p.sbaplayer_id, p.team_id, t.abbrev, sg.season
),
running_stats AS (
SELECT
sp.runner_id AS player_id,
sg.season,
SUM(sp.sb) AS sb,
SUM(sp.cs) AS cs
FROM stratplay sp
JOIN stratgame sg ON sg.id = sp.game_id
WHERE sg.season = %s AND sp.runner_id IS NOT NULL AND sp.runner_id = ANY(%s)
GROUP BY sp.runner_id, sg.season
)
INSERT INTO seasonbattingstats (
player_id, sbaplayer_id, team_id, season, name, player_team_id, player_team_abbrev,
pa, ab, run, hit, double, triple, homerun, rbi, bb, so, bphr, bpfo, bp1b, bplo, gidp, hbp, sac, ibb,
avg, obp, slg, ops, woba, k_pct, sb, cs
)
SELECT
bs.player_id, bs.sbaplayer_id, bs.player_team_id, bs.season, bs.name, bs.player_team_id, bs.player_team_abbrev,
bs.pa, bs.ab, bs.run, bs.hit, bs.double, bs.triple, bs.homerun, bs.rbi, bs.bb, bs.so,
bs.bphr, bs.bpfo, bs.bp1b, bs.bplo, bs.gidp, bs.hbp, bs.sac, bs.ibb,
bs.avg, bs.obp, bs.slg, bs.ops, bs.woba, bs.k_pct,
COALESCE(rs.sb, 0) AS sb,
COALESCE(rs.cs, 0) AS cs
FROM batting_stats bs
LEFT JOIN running_stats rs ON bs.player_id = rs.player_id AND bs.season = rs.season
ON CONFLICT (player_id, season)
DO UPDATE SET
sbaplayer_id = EXCLUDED.sbaplayer_id,
team_id = EXCLUDED.team_id,
name = EXCLUDED.name,
player_team_id = EXCLUDED.player_team_id,
player_team_abbrev = EXCLUDED.player_team_abbrev,
pa = EXCLUDED.pa,
ab = EXCLUDED.ab,
run = EXCLUDED.run,
hit = EXCLUDED.hit,
double = EXCLUDED.double,
triple = EXCLUDED.triple,
homerun = EXCLUDED.homerun,
rbi = EXCLUDED.rbi,
bb = EXCLUDED.bb,
so = EXCLUDED.so,
bphr = EXCLUDED.bphr,
bpfo = EXCLUDED.bpfo,
bp1b = EXCLUDED.bp1b,
bplo = EXCLUDED.bplo,
gidp = EXCLUDED.gidp,
hbp = EXCLUDED.hbp,
sac = EXCLUDED.sac,
ibb = EXCLUDED.ibb,
avg = EXCLUDED.avg,
obp = EXCLUDED.obp,
slg = EXCLUDED.slg,
ops = EXCLUDED.ops,
woba = EXCLUDED.woba,
k_pct = EXCLUDED.k_pct,
sb = EXCLUDED.sb,
cs = EXCLUDED.cs;
"""
# Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}")
except Exception as e:
logger.error(f"Error updating season batting stats: {e}")
raise
def update_season_pitching_stats(player_ids, season, db_connection):
"""
Update season pitching stats for specific players in a given season.
Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table.
"""
if not player_ids:
logger.warning("update_season_pitching_stats called with empty player_ids list")
return
# Convert single player_id to list for consistency
if isinstance(player_ids, int):
player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}")
try:
# SQL query to recalculate and upsert pitching stats
query = """
WITH pitching_stats AS (
SELECT
p.id AS player_id,
p.name,
p.sbaplayer_id,
p.team_id AS player_team_id,
t.abbrev AS player_team_abbrev,
sg.season,
-- Counting statistics (summed from StratPlays)
SUM(sp.pa) AS tbf,
SUM(sp.ab) AS ab,
SUM(sp.run) AS run,
SUM(sp.e_run) AS e_run,
SUM(sp.hit) AS hits,
SUM(sp.double) AS double,
SUM(sp.triple) AS triple,
SUM(sp.homerun) AS homerun,
SUM(sp.bb) AS bb,
SUM(sp.so) AS so,
SUM(sp.hbp) AS hbp,
SUM(sp.sac) AS sac,
SUM(sp.ibb) AS ibb,
SUM(sp.gidp) AS gidp,
SUM(sp.sb) AS sb,
SUM(sp.cs) AS cs,
SUM(sp.bphr) AS bphr,
SUM(sp.bpfo) AS bpfo,
SUM(sp.bp1b) AS bp1b,
SUM(sp.bplo) AS bplo,
SUM(sp.wild_pitch) AS wp,
SUM(sp.balk) AS balk,
SUM(sp.outs) AS outs,
COALESCE(SUM(sp.wpa), 0) AS wpa,
COALESCE(SUM(sp.re24_primary), 0) AS re24,
-- Calculated statistics using formulas
CASE
WHEN SUM(sp.outs) > 0
THEN ROUND((SUM(sp.e_run) * 27)::DECIMAL / SUM(sp.outs), 2)
ELSE 0.00
END AS era,
CASE
WHEN SUM(sp.outs) > 0
THEN ROUND(((SUM(sp.bb) + SUM(sp.hit) + SUM(sp.ibb)) * 3)::DECIMAL / SUM(sp.outs), 2)
ELSE 0.00
END AS whip,
CASE
WHEN SUM(sp.ab) > 0
THEN ROUND(SUM(sp.hit)::DECIMAL / SUM(sp.ab), 3)
ELSE 0.000
END AS avg,
CASE
WHEN SUM(sp.pa) > 0
THEN ROUND((SUM(sp.hit) + SUM(sp.bb) + SUM(sp.hbp) + SUM(sp.ibb))::DECIMAL / SUM(sp.pa), 3)
ELSE 0.000
END AS obp,
CASE
WHEN SUM(sp.ab) > 0
THEN ROUND((SUM(sp.hit) + SUM(sp.double) + 2 * SUM(sp.triple) + 3 *
SUM(sp.homerun))::DECIMAL / SUM(sp.ab), 3)
ELSE 0.000
END AS slg,
CASE
WHEN SUM(sp.pa) > 0 AND SUM(sp.ab) > 0
THEN ROUND(
((SUM(sp.hit) + SUM(sp.bb) + SUM(sp.hbp) + SUM(sp.ibb))::DECIMAL / SUM(sp.pa)) +
((SUM(sp.hit) + SUM(sp.double) + 2 * SUM(sp.triple) + 3 *
SUM(sp.homerun))::DECIMAL / SUM(sp.ab)), 3)
ELSE 0.000
END AS ops,
-- wOBA calculation (same as batting)
CASE
WHEN SUM(sp.pa) > 0
THEN ROUND((0.690 * SUM(sp.bb) + 0.722 * SUM(sp.hbp) + 0.888 * (SUM(sp.hit) -
SUM(sp.double) - SUM(sp.triple) - SUM(sp.homerun)) +
1.271 * SUM(sp.double) + 1.616 * SUM(sp.triple) + 2.101 *
SUM(sp.homerun))::DECIMAL / SUM(sp.pa), 3)
ELSE 0.000
END AS woba,
-- Rate stats
CASE
WHEN SUM(sp.outs) > 0
THEN ROUND((SUM(sp.hit) * 9)::DECIMAL / (SUM(sp.outs) / 3.0), 1)
ELSE 0.0
END AS hper9,
CASE
WHEN SUM(sp.outs) > 0
THEN ROUND((SUM(sp.so) * 9)::DECIMAL / (SUM(sp.outs) / 3.0), 1)
ELSE 0.0
END AS kper9,
CASE
WHEN SUM(sp.outs) > 0
THEN ROUND((SUM(sp.bb) * 9)::DECIMAL / (SUM(sp.outs) / 3.0), 1)
ELSE 0.0
END AS bbper9,
CASE
WHEN SUM(sp.bb) > 0
THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2)
ELSE 0.0
END AS kperbb
FROM stratplay sp
JOIN stratgame sg ON sg.id = sp.game_id
JOIN player p ON p.id = sp.pitcher_id
JOIN team t ON t.id = p.team_id
WHERE sg.season = %s AND p.id = ANY(%s) AND sp.pitcher_id IS NOT NULL
GROUP BY p.id, p.name, p.sbaplayer_id, p.team_id, t.abbrev, sg.season
),
decision_stats AS (
SELECT
d.pitcher_id AS player_id,
sg.season,
SUM(d.win) AS win,
SUM(d.loss) AS loss,
SUM(d.hold) AS hold,
SUM(d.is_save) AS saves,
SUM(d.b_save) AS bsave,
SUM(d.irunners) AS ir,
SUM(d.irunners_scored) AS irs,
SUM(d.is_start::INTEGER) AS gs,
COUNT(d.game_id) AS games
FROM decision d
JOIN stratgame sg ON sg.id = d.game_id
WHERE sg.season = %s AND d.pitcher_id = ANY(%s)
GROUP BY d.pitcher_id, sg.season
)
INSERT INTO seasonpitchingstats (
player_id, sbaplayer_id, team_id, season, name, player_team_id, player_team_abbrev,
tbf, outs, games, gs, win, loss, hold, saves, bsave, ir, irs,
ab, run, e_run, hits, double, triple, homerun, bb, so, hbp, sac, ibb, gidp, sb, cs,
bphr, bpfo, bp1b, bplo, wp, balk,
wpa, era, whip, avg, obp, slg, ops, woba, hper9, kper9, bbper9, kperbb,
lob_2outs, rbipercent, re24
)
SELECT
ps.player_id, ps.sbaplayer_id, ps.player_team_id, ps.season, ps.name, ps.player_team_id, ps.player_team_abbrev,
ps.tbf, ps.outs, COALESCE(ds.games, 0), COALESCE(ds.gs, 0),
COALESCE(ds.win, 0), COALESCE(ds.loss, 0), COALESCE(ds.hold, 0),
COALESCE(ds.saves, 0), COALESCE(ds.bsave, 0), COALESCE(ds.ir, 0), COALESCE(ds.irs, 0),
ps.ab, ps.run, ps.e_run, ps.hits, ps.double, ps.triple, ps.homerun, ps.bb, ps.so,
ps.hbp, ps.sac, ps.ibb, ps.gidp, ps.sb, ps.cs,
ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk,
ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba,
ps.hper9, ps.kper9, ps.bbper9, ps.kperbb,
0.0, 0.0, COALESCE(ps.re24 * -1, 0.0)
FROM pitching_stats ps
LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season
ON CONFLICT (player_id, season)
DO UPDATE SET
sbaplayer_id = EXCLUDED.sbaplayer_id,
team_id = EXCLUDED.team_id,
name = EXCLUDED.name,
player_team_id = EXCLUDED.player_team_id,
player_team_abbrev = EXCLUDED.player_team_abbrev,
tbf = EXCLUDED.tbf,
outs = EXCLUDED.outs,
games = EXCLUDED.games,
gs = EXCLUDED.gs,
win = EXCLUDED.win,
loss = EXCLUDED.loss,
hold = EXCLUDED.hold,
saves = EXCLUDED.saves,
bsave = EXCLUDED.bsave,
ir = EXCLUDED.ir,
irs = EXCLUDED.irs,
ab = EXCLUDED.ab,
run = EXCLUDED.run,
e_run = EXCLUDED.e_run,
hits = EXCLUDED.hits,
double = EXCLUDED.double,
triple = EXCLUDED.triple,
homerun = EXCLUDED.homerun,
bb = EXCLUDED.bb,
so = EXCLUDED.so,
hbp = EXCLUDED.hbp,
sac = EXCLUDED.sac,
ibb = EXCLUDED.ibb,
gidp = EXCLUDED.gidp,
sb = EXCLUDED.sb,
cs = EXCLUDED.cs,
bphr = EXCLUDED.bphr,
bpfo = EXCLUDED.bpfo,
bp1b = EXCLUDED.bp1b,
bplo = EXCLUDED.bplo,
wp = EXCLUDED.wp,
balk = EXCLUDED.balk,
wpa = EXCLUDED.wpa,
era = EXCLUDED.era,
whip = EXCLUDED.whip,
avg = EXCLUDED.avg,
obp = EXCLUDED.obp,
slg = EXCLUDED.slg,
ops = EXCLUDED.ops,
woba = EXCLUDED.woba,
hper9 = EXCLUDED.hper9,
kper9 = EXCLUDED.kper9,
bbper9 = EXCLUDED.bbper9,
kperbb = EXCLUDED.kperbb,
lob_2outs = EXCLUDED.lob_2outs,
rbipercent = EXCLUDED.rbipercent,
re24 = EXCLUDED.re24;
"""
# Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}")
except Exception as e:
logger.error(f"Error updating season pitching stats: {e}")
raise
def send_webhook_message(message: str) -> bool:
"""
Send a message to Discord via webhook.
Args:
message: The message content to send
Returns:
bool: True if successful, False otherwise
"""
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW"
try:
payload = {
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status()
logger.info(f"Webhook message sent successfully: {message[:100]}...")
return True
except requests.exceptions.RequestException as e:
logger.error(f"Failed to send webhook message: {e}")
return False
except Exception as e:
logger.error(f"Unexpected error sending webhook message: {e}")
return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True):
"""
Decorator to cache function results in Redis with parameter normalization.
Args:
ttl: Time to live in seconds (default: 5 minutes)
key_prefix: Prefix for cache keys (default: "api")
normalize_params: Remove None/empty values to reduce cache variations (default: True)
Usage:
@cache_result(ttl=600, key_prefix="stats")
async def get_player_stats(player_id: int, season: Optional[int] = None):
# expensive operation
return stats
# These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123)
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Skip caching if Redis is not available
if redis_client is None:
return await func(*args, **kwargs)
try:
# Normalize parameters to reduce cache variations
normalized_kwargs = kwargs.copy()
if normalize_params:
# Remove None values and empty collections
normalized_kwargs = {
k: v for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {}
}
# Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([
f"{k}={v}" for k, v in sorted(normalized_kwargs.items())
])
# Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__]
if args_str:
key_parts.append(args_str)
if kwargs_str:
key_parts.append(kwargs_str)
cache_key = ":".join(key_parts)
# Truncate very long cache keys to prevent Redis key size limits
if len(cache_key) > 200:
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
# Try to get from cache
cached_result = redis_client.get(cache_key)
if cached_result is not None:
logger.debug(f"Cache hit for key: {cache_key}")
return json.loads(cached_result)
# Cache miss - execute function
logger.debug(f"Cache miss for key: {cache_key}")
result = await func(*args, **kwargs)
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
if not isinstance(result, Response):
# Store in cache with TTL
redis_client.setex(
cache_key,
ttl,
json.dumps(result, default=str, ensure_ascii=False)
)
else:
logger.debug(f"Skipping cache for Response object from {func.__name__}")
return result
except Exception as e:
# If caching fails, log error and continue without caching
logger.error(f"Cache error for {func.__name__}: {e}")
return await func(*args, **kwargs)
return wrapper
return decorator
def invalidate_cache(pattern: str = "*"):
"""
Invalidate cache entries matching a pattern.
Args:
pattern: Redis pattern to match keys (default: "*" for all)
Usage:
invalidate_cache("stats:*") # Clear all stats cache
invalidate_cache("api:get_player_*") # Clear specific player cache
"""
if redis_client is None:
logger.warning("Cannot invalidate cache: Redis not available")
return 0
try:
keys = redis_client.keys(pattern)
if keys:
deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}")
return deleted
else:
logger.debug(f"No cache entries found matching pattern: {pattern}")
return 0
except Exception as e:
logger.error(f"Error invalidating cache: {e}")
return 0
def get_cache_stats() -> dict:
"""
Get Redis cache statistics.
Returns:
dict: Cache statistics including memory usage, key count, etc.
"""
if redis_client is None:
return {"status": "unavailable", "message": "Redis not connected"}
try:
info = redis_client.info()
return {
"status": "connected",
"memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0)
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {"status": "error", "message": str(e)}
def add_cache_headers(
max_age: int = 300,
cache_type: str = "public",
vary: Optional[str] = None,
etag: bool = False
):
"""
Decorator to add HTTP cache headers to FastAPI responses.
Args:
max_age: Cache duration in seconds (default: 5 minutes)
cache_type: "public", "private", or "no-cache" (default: "public")
vary: Vary header value (e.g., "Accept-Encoding, Authorization")
etag: Whether to generate ETag based on response content
Usage:
@add_cache_headers(max_age=1800, cache_type="public")
async def get_static_data():
return {"data": "static content"}
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
async def get_user_data():
return {"data": "user specific"}
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
# Handle different response types
if isinstance(result, Response):
response = result
elif isinstance(result, dict) or isinstance(result, list):
# Convert to Response with JSON content
response = Response(
content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json"
)
else:
# Handle other response types
response = Response(content=str(result))
# Build Cache-Control header
cache_control_parts = [cache_type]
if cache_type != "no-cache" and max_age > 0:
cache_control_parts.append(f"max-age={max_age}")
response.headers["Cache-Control"] = ", ".join(cache_control_parts)
# Add Vary header if specified
if vary:
response.headers["Vary"] = vary
# Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))):
content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest()
response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime(
"%a, %d %b %Y %H:%M:%S GMT"
)
return response
return wrapper
return decorator
def handle_db_errors(func):
"""
Decorator to handle database connection errors and transaction rollbacks.
Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
import time
import traceback
from .db_engine import db # Import here to avoid circular imports
start_time = time.time()
func_name = f"{func.__module__}.{func.__name__}"
# Sanitize arguments for logging (exclude sensitive data)
safe_args = []
safe_kwargs = {}
try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})")
else:
safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']:
safe_kwargs[key] = '[REDACTED]'
else:
safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}")
result = await func(*args, **kwargs)
elapsed_time = time.time() - start_time
logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error_trace = traceback.format_exc()
logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s")
logger.error(f"Function args: {safe_args}")
logger.error(f"Function kwargs: {safe_kwargs}")
logger.error(f"Exception: {str(e)}")
logger.error(f"Full traceback:\n{error_trace}")
try:
logger.info(f"Attempting database rollback for {func_name}")
db.rollback()
logger.info(f"Database rollback successful for {func_name}")
except Exception as rollback_error:
logger.error(f"Rollback failed in {func_name}: {rollback_error}")
finally:
try:
db.close()
logger.info(f"Database connection closed for {func_name}")
except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}")
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}')
return wrapper