Compare commits
10 Commits
06ff92df6c
...
38a411fd3e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38a411fd3e | ||
|
|
4be6afb541 | ||
|
|
0e70e94644 | ||
|
|
22d15490dd | ||
|
|
329658ce8d | ||
|
|
541c5bbc1e | ||
|
|
08fd7bec75 | ||
|
|
565afd0183 | ||
|
|
5aa88e4e3d | ||
|
|
c3054971eb |
328
api_calls.py
328
api_calls.py
@ -1,35 +1,46 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
import os
|
||||
|
||||
from exceptions import DatabaseError
|
||||
from exceptions import DatabaseError, APITimeoutError
|
||||
|
||||
AUTH_TOKEN = {'Authorization': f'Bearer {os.environ.get("API_TOKEN")}'}
|
||||
AUTH_TOKEN = {"Authorization": f"Bearer {os.environ.get('API_TOKEN')}"}
|
||||
ENV_DATABASE = os.getenv("DATABASE", "dev").lower()
|
||||
DB_URL = 'https://pd.manticorum.com/api' if 'prod' in ENV_DATABASE else 'https://pddev.manticorum.com/api'
|
||||
DB_URL = (
|
||||
"https://pd.manticorum.com/api"
|
||||
if "prod" in ENV_DATABASE
|
||||
else "https://pddev.manticorum.com/api"
|
||||
)
|
||||
master_debug = True
|
||||
PLAYER_CACHE = {}
|
||||
logger = logging.getLogger('discord_app')
|
||||
logger = logging.getLogger("discord_app")
|
||||
|
||||
|
||||
def param_char(other_params):
|
||||
if other_params:
|
||||
return '&'
|
||||
return "&"
|
||||
else:
|
||||
return '?'
|
||||
return "?"
|
||||
|
||||
|
||||
def get_req_url(endpoint: str, api_ver: int = 2, object_id: Optional[int] = None, params: Optional[list] = None):
|
||||
req_url = f'{DB_URL}/v{api_ver}/{endpoint}{"/" if object_id is not None else ""}{object_id if object_id is not None else ""}'
|
||||
def get_req_url(
|
||||
endpoint: str,
|
||||
api_ver: int = 2,
|
||||
object_id: Optional[int] = None,
|
||||
params: Optional[list] = None,
|
||||
):
|
||||
req_url = f"{DB_URL}/v{api_ver}/{endpoint}{'/' if object_id is not None else ''}{object_id if object_id is not None else ''}"
|
||||
|
||||
if params:
|
||||
other_params = False
|
||||
for x in params:
|
||||
req_url += f'{param_char(other_params)}{x[0]}={x[1]}'
|
||||
req_url += f"{param_char(other_params)}{x[0]}={x[1]}"
|
||||
other_params = True
|
||||
|
||||
return req_url
|
||||
@ -42,144 +53,251 @@ def log_return_value(log_string: str):
|
||||
line = log_string[start:end]
|
||||
if len(line) == 0:
|
||||
return
|
||||
logger.info(f'{"\n\nreturn: " if start == 0 else ""}{log_string[start:end]}')
|
||||
logger.info(f"{'\n\nreturn: ' if start == 0 else ''}{log_string[start:end]}")
|
||||
start += 3000
|
||||
end += 3000
|
||||
logger.warning('[ S N I P P E D ]')
|
||||
logger.warning("[ S N I P P E D ]")
|
||||
# if master_debug:
|
||||
# logger.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n')
|
||||
# else:
|
||||
# logger.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n')
|
||||
|
||||
|
||||
async def db_get(endpoint: str, api_ver: int = 2, object_id: Optional[int] = None, params: Optional[list] = None, none_okay: bool = True, timeout: int = 3):
|
||||
async def db_get(
|
||||
endpoint: str,
|
||||
api_ver: int = 2,
|
||||
object_id: Optional[int] = None,
|
||||
params: Optional[list] = None,
|
||||
none_okay: bool = True,
|
||||
timeout: int = 5,
|
||||
retries: int = 3,
|
||||
):
|
||||
"""
|
||||
GET request to the API with timeout and retry logic.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
object_id: Optional object ID to append to URL
|
||||
params: Optional list of (key, value) tuples for query params
|
||||
none_okay: If True, return None on non-200 response; if False, raise DatabaseError
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
retries: Number of retry attempts on timeout (default 3)
|
||||
|
||||
Returns:
|
||||
JSON response or None if none_okay and request failed
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If all retry attempts fail due to timeout
|
||||
DatabaseError: If response is non-200 and none_okay is False
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params)
|
||||
log_string = f'db_get - get: {endpoint} id: {object_id} params: {params}'
|
||||
log_string = f"db_get - get: {endpoint} id: {object_id} params: {params}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.get(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
elif none_okay:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
return None
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.get(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
elif none_okay:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
return None
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
if attempt < retries - 1:
|
||||
wait_time = 2**attempt # 1s, 2s, 4s
|
||||
logger.warning(
|
||||
f"Timeout on GET {endpoint}, retry {attempt + 1}/{retries} in {wait_time}s"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
logger.error(
|
||||
f"Connection timeout to host {req_url} after {retries} attempts"
|
||||
)
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_patch(endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 3):
|
||||
async def db_patch(
|
||||
endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
PATCH request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
object_id: Object ID to patch
|
||||
params: List of (key, value) tuples for query params
|
||||
api_ver: API version (default 2)
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params)
|
||||
log_string = f'db_patch - patch: {endpoint} {params}'
|
||||
log_string = f"db_patch - patch: {endpoint} {params}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.patch(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.patch(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_post(endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 3):
|
||||
async def db_post(
|
||||
endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
POST request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
payload: Optional JSON payload
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver)
|
||||
log_string = f'db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}'
|
||||
log_string = f"db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.post(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.post(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_put(endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 3):
|
||||
async def db_put(
|
||||
endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
PUT request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
payload: Optional JSON payload
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver)
|
||||
log_string = f'post:\n{endpoint} payload: {payload}\ntype: {type(payload)}'
|
||||
log_string = f"db_put - put: {endpoint} payload: {payload}\ntype: {type(payload)}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.put(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
|
||||
# retries = 0
|
||||
# while True:
|
||||
# try:
|
||||
# resp = requests.put(req_url, json=payload, headers=AUTH_TOKEN, timeout=timeout)
|
||||
# break
|
||||
# except requests.Timeout as e:
|
||||
# logger.error(f'Post Timeout: {req_url} / retries: {retries} / timeout: {timeout}')
|
||||
# if retries > 1:
|
||||
# raise ConnectionError(f'DB: The internet was a bit too slow for me to grab the data I needed. Please '
|
||||
# f'hang on a few extra seconds and try again.')
|
||||
# timeout += [min(3, timeout), min(5, timeout)][retries]
|
||||
# retries += 1
|
||||
#
|
||||
# if resp.status_code == 200:
|
||||
# data = resp.json()
|
||||
# log_string = f'{data}'
|
||||
# if master_debug:
|
||||
# logger.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}')
|
||||
# else:
|
||||
# logger.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}')
|
||||
# return data
|
||||
# else:
|
||||
# logger.warning(resp.text)
|
||||
# raise ValueError(f'DB: {resp.text}')
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.put(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout=3):
|
||||
async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout: int = 5):
|
||||
"""
|
||||
DELETE request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
object_id: Object ID to delete
|
||||
api_ver: API version (default 2)
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id)
|
||||
log_string = f'db_delete - delete: {endpoint} {object_id}'
|
||||
log_string = f"db_delete - delete: {endpoint} {object_id}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.delete(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.delete(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def get_team_by_abbrev(abbrev: str):
|
||||
all_teams = await db_get('teams', params=[('abbrev', abbrev)])
|
||||
all_teams = await db_get("teams", params=[("abbrev", abbrev)])
|
||||
|
||||
if not all_teams or not all_teams['count']:
|
||||
if not all_teams or not all_teams["count"]:
|
||||
return None
|
||||
|
||||
return all_teams['teams'][0]
|
||||
return all_teams["teams"][0]
|
||||
|
||||
|
||||
async def post_to_dex(player, team):
|
||||
return await db_post('paperdex', payload={'team_id': team['id'], 'player_id': player['id']})
|
||||
return await db_post(
|
||||
"paperdex", payload={"team_id": team["id"], "player_id": player["id"]}
|
||||
)
|
||||
|
||||
|
||||
def team_hash(team):
|
||||
hash_string = f'{team["sname"][-1]}{team["gmid"] / 6950123:.0f}{team["sname"][-2]}{team["gmid"] / 42069123:.0f}'
|
||||
hash_string = f"{team['sname'][-1]}{team['gmid'] / 6950123:.0f}{team['sname'][-2]}{team['gmid'] / 42069123:.0f}"
|
||||
return hash_string
|
||||
|
||||
|
||||
@ -682,8 +682,10 @@ class Economy(commands.Cog):
|
||||
logger.debug(f'pack: {pack}')
|
||||
logger.debug(f'pack cardset: {pack["pack_cardset"]}')
|
||||
if pack['pack_team'] is None and pack['pack_cardset'] is None:
|
||||
if pack['pack_type']['name'] in p_data:
|
||||
p_group = pack['pack_type']['name']
|
||||
p_group = pack['pack_type']['name']
|
||||
# Add to p_data if this is a new pack type
|
||||
if p_group not in p_data:
|
||||
p_data[p_group] = []
|
||||
|
||||
elif pack['pack_team'] is not None:
|
||||
if pack['pack_type']['name'] == 'Standard':
|
||||
@ -1250,14 +1252,14 @@ class Economy(commands.Cog):
|
||||
anchor_all_stars = await db_get(
|
||||
'players/random',
|
||||
params=[
|
||||
('min_rarity', 3), ('max_rarity', 3), ('franchise', team_choice), ('pos_exclude', 'RP'), ('limit', 1),
|
||||
('min_rarity', 3), ('max_rarity', 3), ('franchise', normalize_franchise(team_choice)), ('pos_exclude', 'RP'), ('limit', 1),
|
||||
('in_packs', True)
|
||||
]
|
||||
)
|
||||
anchor_starters = await db_get(
|
||||
'players/random',
|
||||
params=[
|
||||
('min_rarity', 2), ('max_rarity', 2), ('franchise', team_choice), ('pos_exclude', 'RP'), ('limit', 2),
|
||||
('min_rarity', 2), ('max_rarity', 2), ('franchise', normalize_franchise(team_choice)), ('pos_exclude', 'RP'), ('limit', 2),
|
||||
('in_packs', True)
|
||||
]
|
||||
)
|
||||
@ -1476,7 +1478,7 @@ class Economy(commands.Cog):
|
||||
'is_ai': True
|
||||
})
|
||||
|
||||
p_query = await db_get('players', params=[('franchise', lname)])
|
||||
p_query = await db_get('players', params=[('franchise', sname)])
|
||||
|
||||
this_pack = await db_post(
|
||||
'packs/one',
|
||||
@ -1503,7 +1505,7 @@ class Economy(commands.Cog):
|
||||
total_cards = 0
|
||||
total_teams = 0
|
||||
for team in ai_teams['teams']:
|
||||
all_players = await db_get('players', params=[('franchise', team['lname'])])
|
||||
all_players = await db_get('players', params=[('franchise', team['sname'])])
|
||||
|
||||
new_players = []
|
||||
if all_players:
|
||||
|
||||
@ -263,8 +263,10 @@ class Packs(commands.Cog):
|
||||
logger.debug(f'pack: {pack}')
|
||||
logger.debug(f'pack cardset: {pack["pack_cardset"]}')
|
||||
if pack['pack_team'] is None and pack['pack_cardset'] is None:
|
||||
if pack['pack_type']['name'] in p_data:
|
||||
p_group = pack['pack_type']['name']
|
||||
p_group = pack['pack_type']['name']
|
||||
# Add to p_data if this is a new pack type
|
||||
if p_group not in p_data:
|
||||
p_data[p_group] = []
|
||||
|
||||
elif pack['pack_team'] is not None:
|
||||
if pack['pack_type']['name'] == 'Standard':
|
||||
|
||||
@ -115,6 +115,7 @@ ALL_MLB_TEAMS = {
|
||||
'New York Mets': ['NYM', 'Mets'],
|
||||
'New York Yankees': ['NYY', 'Yankees'],
|
||||
'Oakland Athletics': ['OAK', 'Athletics'],
|
||||
'Athletics': ['OAK', 'Athletics'], # Alias for post-Oakland move
|
||||
'Philadelphia Phillies': ['PHI', 'Phillies'],
|
||||
'Pittsburgh Pirates': ['PIT', 'Pirates'],
|
||||
'San Diego Padres': ['SDP', 'Padres'],
|
||||
@ -157,6 +158,7 @@ IMAGES = {
|
||||
'New York Mets': f'{PD_IMAGE_BUCKET}/mvp/new-york-mets.gif',
|
||||
'New York Yankees': f'{PD_IMAGE_BUCKET}/mvp/new-york-yankees.gif',
|
||||
'Oakland Athletics': f'{PD_IMAGE_BUCKET}/mvp/oakland-athletics.gif',
|
||||
'Athletics': f'{PD_IMAGE_BUCKET}/mvp/oakland-athletics.gif', # Alias for post-Oakland move
|
||||
'Philadelphia Phillies': f'{PD_IMAGE_BUCKET}/mvp/philadelphia-phillies.gif',
|
||||
'Pittsburgh Pirates': f'{PD_IMAGE_BUCKET}/mvp/pittsburgh-pirates.gif',
|
||||
'San Diego Padres': f'{PD_IMAGE_BUCKET}/mvp/san-diego-padres.gif',
|
||||
|
||||
@ -6,7 +6,7 @@ Contains all Select classes for various team, cardset, and pack selections.
|
||||
import logging
|
||||
import discord
|
||||
from typing import Literal, Optional
|
||||
from helpers.constants import ALL_MLB_TEAMS, IMAGES
|
||||
from helpers.constants import ALL_MLB_TEAMS, IMAGES, normalize_franchise
|
||||
|
||||
logger = logging.getLogger('discord_app')
|
||||
|
||||
@ -23,6 +23,7 @@ AL_TEAM_IDS = {
|
||||
'Minnesota Twins': 17,
|
||||
'New York Yankees': 19,
|
||||
'Oakland Athletics': 20,
|
||||
'Athletics': 20, # Alias for post-Oakland move
|
||||
'Seattle Mariners': 24,
|
||||
'Tampa Bay Rays': 27,
|
||||
'Texas Rangers': 28,
|
||||
@ -464,7 +465,9 @@ class SelectUpdatePlayerTeam(discord.ui.Select):
|
||||
from discord_ui.confirmations import Confirm
|
||||
from helpers import player_desc, send_to_channel
|
||||
|
||||
if self.values[0] == self.player['franchise'] or self.values[0] == self.player['mlbclub']:
|
||||
# Check if already assigned - compare against both normalized franchise and full mlbclub
|
||||
normalized_selection = normalize_franchise(self.values[0])
|
||||
if normalized_selection == self.player['franchise'] or self.values[0] == self.player['mlbclub']:
|
||||
await interaction.response.send_message(
|
||||
content=f'Thank you for the help, but it looks like somebody beat you to it! '
|
||||
f'**{player_desc(self.player)}** is already assigned to the **{self.player["mlbclub"]}**.'
|
||||
@ -492,7 +495,7 @@ class SelectUpdatePlayerTeam(discord.ui.Select):
|
||||
await question.delete()
|
||||
|
||||
await db_patch('players', object_id=self.player['player_id'], params=[
|
||||
('mlbclub', self.values[0]), ('franchise', self.values[0])
|
||||
('mlbclub', self.values[0]), ('franchise', normalize_franchise(self.values[0]))
|
||||
])
|
||||
await db_post(f'teams/{self.reporting_team["id"]}/money/25')
|
||||
await send_to_channel(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger('discord_app')
|
||||
logger = logging.getLogger("discord_app")
|
||||
|
||||
|
||||
def log_errors(func):
|
||||
@ -15,26 +15,32 @@ def log_errors(func):
|
||||
except Exception as e:
|
||||
logger.error(func.__name__)
|
||||
log_exception(e)
|
||||
return result # type: ignore
|
||||
|
||||
return result # type: ignore
|
||||
|
||||
return wrap
|
||||
|
||||
def log_exception(e: Exception, msg: str = '', level: Literal['debug', 'error', 'info', 'warn'] = 'error'):
|
||||
if level == 'debug':
|
||||
|
||||
def log_exception(
|
||||
e: Exception,
|
||||
msg: str = "",
|
||||
level: Literal["debug", "error", "info", "warn"] = "error",
|
||||
):
|
||||
if level == "debug":
|
||||
logger.debug(msg, exc_info=True, stack_info=True)
|
||||
elif level == 'error':
|
||||
elif level == "error":
|
||||
logger.error(msg, exc_info=True, stack_info=True)
|
||||
elif level == 'info':
|
||||
elif level == "info":
|
||||
logger.info(msg, exc_info=True, stack_info=True)
|
||||
else:
|
||||
logger.warning(msg, exc_info=True, stack_info=True)
|
||||
|
||||
|
||||
# Check if 'e' is an exception class or instance
|
||||
if isinstance(e, Exception):
|
||||
raise e # If 'e' is already an instance of an exception
|
||||
else:
|
||||
raise e(msg) # If 'e' is an exception class
|
||||
|
||||
|
||||
class GameException(Exception):
|
||||
pass
|
||||
|
||||
@ -75,6 +81,12 @@ class DatabaseError(GameException):
|
||||
pass
|
||||
|
||||
|
||||
class APITimeoutError(DatabaseError):
|
||||
"""Raised when an API call times out after all retries."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PositionNotFoundException(GameException):
|
||||
pass
|
||||
|
||||
|
||||
2553
gauntlets.py
2553
gauntlets.py
File diff suppressed because it is too large
Load Diff
1943
helpers.py
1943
helpers.py
File diff suppressed because it is too large
Load Diff
@ -8,43 +8,42 @@ import discord
|
||||
from typing import Literal
|
||||
|
||||
# Season Configuration
|
||||
SBA_SEASON = 11
|
||||
PD_SEASON = 9
|
||||
ranked_cardsets = [20, 21, 22, 17, 18, 19]
|
||||
SBA_SEASON = 12
|
||||
PD_SEASON = 10
|
||||
ranked_cardsets = [24, 25, 26, 27, 28, 29]
|
||||
LIVE_CARDSET_ID = 27
|
||||
LIVE_PROMO_CARDSET_ID = 28
|
||||
MAX_CARDSET_ID = 30
|
||||
|
||||
# Cardset Configuration
|
||||
CARDSETS = {
|
||||
'Ranked': {
|
||||
'ranked': {
|
||||
'primary': ranked_cardsets,
|
||||
'human': ranked_cardsets
|
||||
},
|
||||
'Minor League': {
|
||||
'primary': [20, 8], # 1998, Mario
|
||||
'secondary': [6], # 2013
|
||||
'human': [x for x in range(1, MAX_CARDSET_ID)]
|
||||
'minor-league': {
|
||||
'primary': [27, 8], # 2005, Mario
|
||||
'secondary': [24], # 2025
|
||||
'human': [x for x in range(1, 30)]
|
||||
},
|
||||
'Major League': {
|
||||
'primary': [20, 21, 17, 18, 12, 6, 7, 8], # 1998, 1998 Promos, 2024, 24 Promos, 2008, 2013, 2012, Mario
|
||||
'major-league': {
|
||||
'primary': [27, 28, 24, 25, 13, 14, 6, 8], # 2005 + Promos, 2025 + Promos, 2018 + Promos, 2012, Mario
|
||||
'secondary': [5, 3], # 2019, 2022
|
||||
'human': ranked_cardsets
|
||||
},
|
||||
'Hall of Fame': {
|
||||
'primary': [x for x in range(1, MAX_CARDSET_ID)],
|
||||
'secondary': [],
|
||||
'hall-of-fame': {
|
||||
'primary': [x for x in range(1, 30)],
|
||||
'human': ranked_cardsets
|
||||
},
|
||||
'Flashback': {
|
||||
'primary': [5, 1, 3, 9, 8], # 2019, 2021, 2022, 2023, Mario
|
||||
'secondary': [13, 5], # 2018, 2019
|
||||
'human': [5, 1, 3, 9, 8] # 2019, 2021, 2022, 2023
|
||||
'flashback': {
|
||||
'primary': [13, 5, 1, 3, 8], # 2018, 2019, 2021, 2022, Mario
|
||||
'secondary': [24], # 2025
|
||||
'human': [13, 5, 1, 3, 8] # 2018, 2019, 2021, 2022
|
||||
},
|
||||
'gauntlet-3': {
|
||||
'primary': [13], # 2018
|
||||
'secondary': [5, 11, 9], # 2019, 2016, 2023
|
||||
'human': [x for x in range(1, MAX_CARDSET_ID)]
|
||||
'human': [x for x in range(1, 30)]
|
||||
},
|
||||
'gauntlet-4': {
|
||||
'primary': [3, 6, 16], # 2022, 2013, Backyard Baseball
|
||||
@ -54,17 +53,26 @@ CARDSETS = {
|
||||
'gauntlet-5': {
|
||||
'primary': [17, 8], # 2024, Mario
|
||||
'secondary': [13], # 2018
|
||||
'human': [x for x in range(1, MAX_CARDSET_ID)]
|
||||
'human': [x for x in range(1, 30)]
|
||||
},
|
||||
'gauntlet-6': {
|
||||
'primary': [20, 8], # 1998, Mario
|
||||
'secondary': [12], # 2008
|
||||
'human': [x for x in range(1, MAX_CARDSET_ID)]
|
||||
'human': [x for x in range(1, 30)]
|
||||
},
|
||||
'gauntlet-7': {
|
||||
'primary': [5, 23], # 2019, Brilliant Stars
|
||||
'secondary': [1], # 2021
|
||||
'human': [x for x in range(1, MAX_CARDSET_ID)]
|
||||
'human': [x for x in range(1, 30)]
|
||||
},
|
||||
'gauntlet-8': {
|
||||
'primary': [24], # 2025
|
||||
'secondary': [17],
|
||||
'human': [24, 25, 22, 23]
|
||||
},
|
||||
'gauntlet-9': {
|
||||
'primary': [27], # 2005
|
||||
'secondary': [24] # 2025
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +115,7 @@ ALL_MLB_TEAMS = {
|
||||
'New York Mets': ['NYM', 'Mets'],
|
||||
'New York Yankees': ['NYY', 'Yankees'],
|
||||
'Oakland Athletics': ['OAK', 'Athletics'],
|
||||
'Athletics': ['OAK', 'Athletics'], # Alias for post-Oakland move
|
||||
'Philadelphia Phillies': ['PHI', 'Phillies'],
|
||||
'Pittsburgh Pirates': ['PIT', 'Pirates'],
|
||||
'San Diego Padres': ['SDP', 'Padres'],
|
||||
@ -119,6 +128,48 @@ ALL_MLB_TEAMS = {
|
||||
'Washington Nationals': ['WSN', 'WAS', 'Nationals'],
|
||||
}
|
||||
|
||||
# Franchise normalization: Convert city+team names to city-agnostic team names
|
||||
# This enables cross-era player matching (e.g., 'Oakland Athletics' -> 'Athletics')
|
||||
FRANCHISE_NORMALIZE = {
|
||||
'Arizona Diamondbacks': 'Diamondbacks',
|
||||
'Atlanta Braves': 'Braves',
|
||||
'Baltimore Orioles': 'Orioles',
|
||||
'Boston Red Sox': 'Red Sox',
|
||||
'Chicago Cubs': 'Cubs',
|
||||
'Chicago White Sox': 'White Sox',
|
||||
'Cincinnati Reds': 'Reds',
|
||||
'Cleveland Guardians': 'Guardians',
|
||||
'Colorado Rockies': 'Rockies',
|
||||
'Detroit Tigers': 'Tigers',
|
||||
'Houston Astros': 'Astros',
|
||||
'Kansas City Royals': 'Royals',
|
||||
'Los Angeles Angels': 'Angels',
|
||||
'Los Angeles Dodgers': 'Dodgers',
|
||||
'Miami Marlins': 'Marlins',
|
||||
'Milwaukee Brewers': 'Brewers',
|
||||
'Minnesota Twins': 'Twins',
|
||||
'New York Mets': 'Mets',
|
||||
'New York Yankees': 'Yankees',
|
||||
'Oakland Athletics': 'Athletics',
|
||||
'Philadelphia Phillies': 'Phillies',
|
||||
'Pittsburgh Pirates': 'Pirates',
|
||||
'San Diego Padres': 'Padres',
|
||||
'San Francisco Giants': 'Giants',
|
||||
'Seattle Mariners': 'Mariners',
|
||||
'St Louis Cardinals': 'Cardinals',
|
||||
'St. Louis Cardinals': 'Cardinals',
|
||||
'Tampa Bay Rays': 'Rays',
|
||||
'Texas Rangers': 'Rangers',
|
||||
'Toronto Blue Jays': 'Blue Jays',
|
||||
'Washington Nationals': 'Nationals',
|
||||
}
|
||||
|
||||
|
||||
def normalize_franchise(franchise: str) -> str:
|
||||
"""Convert city+team name to team-only (e.g., 'Oakland Athletics' -> 'Athletics')"""
|
||||
return FRANCHISE_NORMALIZE.get(franchise, franchise)
|
||||
|
||||
|
||||
# Image URLs
|
||||
IMAGES = {
|
||||
'logo': f'{PD_IMAGE_BUCKET}/sba-logo.png',
|
||||
@ -149,6 +200,7 @@ IMAGES = {
|
||||
'New York Mets': f'{PD_IMAGE_BUCKET}/mvp/new-york-mets.gif',
|
||||
'New York Yankees': f'{PD_IMAGE_BUCKET}/mvp/new-york-yankees.gif',
|
||||
'Oakland Athletics': f'{PD_IMAGE_BUCKET}/mvp/oakland-athletics.gif',
|
||||
'Athletics': f'{PD_IMAGE_BUCKET}/mvp/oakland-athletics.gif', # Alias for post-Oakland move
|
||||
'Philadelphia Phillies': f'{PD_IMAGE_BUCKET}/mvp/philadelphia-phillies.gif',
|
||||
'Pittsburgh Pirates': f'{PD_IMAGE_BUCKET}/mvp/pittsburgh-pirates.gif',
|
||||
'San Diego Padres': f'{PD_IMAGE_BUCKET}/mvp/san-diego-padres.gif',
|
||||
@ -292,7 +344,7 @@ RARITY = {
|
||||
|
||||
# Discord UI Options
|
||||
SELECT_CARDSET_OPTIONS = [
|
||||
discord.SelectOption(label='2005 Live', value='27'),
|
||||
discord.SelectOption(label='2005 Season', value='27'),
|
||||
discord.SelectOption(label='2025 Season', value='24'),
|
||||
discord.SelectOption(label='2025 Promos', value='25'),
|
||||
discord.SelectOption(label='1998 Season', value='20'),
|
||||
|
||||
1944
helpers/main.py
1944
helpers/main.py
File diff suppressed because it is too large
Load Diff
@ -13,7 +13,8 @@ from typing import Optional, Literal
|
||||
|
||||
from in_game import data_cache
|
||||
from in_game.gameplay_models import Play, Session, Game, Team, Lineup
|
||||
from in_game.gameplay_queries import get_or_create_ai_card, get_player_id_from_dict, get_player_or_none
|
||||
from in_game.gameplay_queries import get_or_create_ai_card, get_player_id_from_dict, get_player_or_none, get_pitcher_scouting_or_none
|
||||
from exceptions import DatabaseError
|
||||
|
||||
db = SqliteDatabase(
|
||||
'storage/ai-database.db',
|
||||
@ -342,11 +343,48 @@ async def get_starting_pitcher(
|
||||
sp_rank = 5
|
||||
logger.info(f'chosen rank: {sp_rank}')
|
||||
|
||||
sp_query = await db_get(
|
||||
f'teams/{this_team.id}/sp/{league_name}?sp_rank={sp_rank}{this_game.cardset_param_string}'
|
||||
)
|
||||
this_player = await get_player_or_none(session, get_player_id_from_dict(sp_query))
|
||||
sp_card = await get_or_create_ai_card(session, this_player, this_team)
|
||||
# Try to get a pitcher with valid pitching data, retrying with different ranks if needed
|
||||
original_rank = sp_rank
|
||||
tried_ranks = set()
|
||||
direction = 1 # 1 = incrementing, -1 = decrementing
|
||||
|
||||
while len(tried_ranks) < 5:
|
||||
tried_ranks.add(sp_rank)
|
||||
logger.info(f'Trying sp_rank: {sp_rank}')
|
||||
|
||||
sp_query = await db_get(
|
||||
f'teams/{this_team.id}/sp/{league_name}?sp_rank={sp_rank}{this_game.cardset_param_string}'
|
||||
)
|
||||
this_player = await get_player_or_none(session, get_player_id_from_dict(sp_query))
|
||||
sp_card = await get_or_create_ai_card(session, this_player, this_team)
|
||||
|
||||
# Validate pitcher has pitching data
|
||||
try:
|
||||
pitcher_scouting = await get_pitcher_scouting_or_none(session, sp_card)
|
||||
if pitcher_scouting is not None:
|
||||
sp_card.pitcherscouting = pitcher_scouting
|
||||
session.add(sp_card)
|
||||
session.commit()
|
||||
session.refresh(sp_card)
|
||||
logger.info(f'Found valid pitcher at rank {sp_rank}: {this_player.name_with_desc}')
|
||||
break
|
||||
else:
|
||||
logger.warning(f'Pitcher at rank {sp_rank} ({this_player.name_with_desc}) returned None for pitcherscouting')
|
||||
except DatabaseError:
|
||||
logger.warning(f'Pitcher at rank {sp_rank} ({this_player.name_with_desc}) lacks pitching data, trying another')
|
||||
|
||||
# Adjust rank: increment first, if we hit 6, switch to decrementing from original
|
||||
sp_rank += direction
|
||||
if sp_rank > 5:
|
||||
direction = -1
|
||||
sp_rank = original_rank - 1
|
||||
if sp_rank < 1:
|
||||
# Find any untried rank
|
||||
untried = [r for r in range(1, 6) if r not in tried_ranks]
|
||||
if untried:
|
||||
sp_rank = untried[0]
|
||||
else:
|
||||
break
|
||||
|
||||
return Lineup(
|
||||
team=this_team,
|
||||
|
||||
@ -237,12 +237,18 @@ class Game(SQLModel, table=True):
|
||||
for line in all_lineups:
|
||||
logger.info(f'line in all_lineups: {line}')
|
||||
if with_links:
|
||||
name_string = {line.player.name_card_link("batting" if line.position != "P" else "pitching")}
|
||||
name_string = line.player.name_card_link("batting" if line.position != "P" else "pitching")
|
||||
else:
|
||||
name_string = f'{line.player.name_with_desc}'
|
||||
|
||||
if line.position == 'P':
|
||||
this_hand = line.card.pitcherscouting.pitchingcard.hand
|
||||
if line.card.pitcherscouting:
|
||||
this_hand = line.card.pitcherscouting.pitchingcard.hand
|
||||
elif line.card.batterscouting:
|
||||
# Fallback to batting hand if pitcherscouting is missing
|
||||
this_hand = line.card.batterscouting.battingcard.hand
|
||||
else:
|
||||
this_hand = '?'
|
||||
else:
|
||||
this_hand = line.card.batterscouting.battingcard.hand
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ sqlmodel
|
||||
alembic
|
||||
pytest
|
||||
pytest-asyncio
|
||||
numpy<2
|
||||
pandas
|
||||
psycopg2-binary
|
||||
aiohttp
|
||||
|
||||
@ -1,102 +1,102 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import aiohttp
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from exceptions import DatabaseError
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from exceptions import DatabaseError, APITimeoutError
|
||||
import api_calls
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions in api_calls."""
|
||||
|
||||
|
||||
def test_param_char_with_params(self):
|
||||
"""Test param_char returns & when other_params is truthy."""
|
||||
assert api_calls.param_char(True) == '&'
|
||||
assert api_calls.param_char(['param1']) == '&'
|
||||
assert api_calls.param_char({'key': 'value'}) == '&'
|
||||
assert api_calls.param_char('some_param') == '&'
|
||||
|
||||
assert api_calls.param_char(True) == "&"
|
||||
assert api_calls.param_char(["param1"]) == "&"
|
||||
assert api_calls.param_char({"key": "value"}) == "&"
|
||||
assert api_calls.param_char("some_param") == "&"
|
||||
|
||||
def test_param_char_without_params(self):
|
||||
"""Test param_char returns ? when other_params is falsy."""
|
||||
assert api_calls.param_char(False) == '?'
|
||||
assert api_calls.param_char(None) == '?'
|
||||
assert api_calls.param_char([]) == '?'
|
||||
assert api_calls.param_char({}) == '?'
|
||||
assert api_calls.param_char('') == '?'
|
||||
assert api_calls.param_char(0) == '?'
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
assert api_calls.param_char(False) == "?"
|
||||
assert api_calls.param_char(None) == "?"
|
||||
assert api_calls.param_char([]) == "?"
|
||||
assert api_calls.param_char({}) == "?"
|
||||
assert api_calls.param_char("") == "?"
|
||||
assert api_calls.param_char(0) == "?"
|
||||
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_basic(self):
|
||||
"""Test basic URL generation without object_id or params."""
|
||||
result = api_calls.get_req_url('teams')
|
||||
expected = 'https://test.example.com/api/v2/teams'
|
||||
result = api_calls.get_req_url("teams")
|
||||
expected = "https://test.example.com/api/v2/teams"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_version(self):
|
||||
"""Test URL generation with custom API version."""
|
||||
result = api_calls.get_req_url('teams', api_ver=1)
|
||||
expected = 'https://test.example.com/api/v1/teams'
|
||||
result = api_calls.get_req_url("teams", api_ver=1)
|
||||
expected = "https://test.example.com/api/v1/teams"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_object_id(self):
|
||||
"""Test URL generation with object_id."""
|
||||
result = api_calls.get_req_url('teams', object_id=123)
|
||||
expected = 'https://test.example.com/api/v2/teams/123'
|
||||
result = api_calls.get_req_url("teams", object_id=123)
|
||||
expected = "https://test.example.com/api/v2/teams/123"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_params(self):
|
||||
"""Test URL generation with parameters."""
|
||||
params = [('season', '7'), ('active', 'true')]
|
||||
result = api_calls.get_req_url('teams', params=params)
|
||||
expected = 'https://test.example.com/api/v2/teams?season=7&active=true'
|
||||
params = [("season", "7"), ("active", "true")]
|
||||
result = api_calls.get_req_url("teams", params=params)
|
||||
expected = "https://test.example.com/api/v2/teams?season=7&active=true"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_complete(self):
|
||||
"""Test URL generation with all parameters."""
|
||||
params = [('season', '7'), ('limit', '10')]
|
||||
result = api_calls.get_req_url('games', api_ver=1, object_id=456, params=params)
|
||||
expected = 'https://test.example.com/api/v1/games/456?season=7&limit=10'
|
||||
params = [("season", "7"), ("limit", "10")]
|
||||
result = api_calls.get_req_url("games", api_ver=1, object_id=456, params=params)
|
||||
expected = "https://test.example.com/api/v1/games/456?season=7&limit=10"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.logger')
|
||||
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_short_string(self, mock_logger):
|
||||
"""Test logging short return values."""
|
||||
api_calls.log_return_value('Short log message')
|
||||
mock_logger.info.assert_called_once_with('\n\nreturn: Short log message')
|
||||
|
||||
@patch('api_calls.logger')
|
||||
api_calls.log_return_value("Short log message")
|
||||
mock_logger.info.assert_called_once_with("\n\nreturn: Short log message")
|
||||
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_long_string(self, mock_logger):
|
||||
"""Test logging long return values that get chunked."""
|
||||
long_string = 'A' * 5000 # 5000 character string
|
||||
long_string = "A" * 5000 # 5000 character string
|
||||
api_calls.log_return_value(long_string)
|
||||
|
||||
|
||||
# Should have been called twice (first chunk + second chunk)
|
||||
assert mock_logger.info.call_count == 2
|
||||
# First call should include the "return:" prefix
|
||||
assert '\n\nreturn: ' in mock_logger.info.call_args_list[0][0][0]
|
||||
|
||||
@patch('api_calls.logger')
|
||||
assert "\n\nreturn: " in mock_logger.info.call_args_list[0][0][0]
|
||||
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_extremely_long_string(self, mock_logger):
|
||||
"""Test logging extremely long return values that get snipped."""
|
||||
extremely_long_string = 'B' * 400000 # 400k character string (exceeds 300k limit)
|
||||
extremely_long_string = (
|
||||
"B" * 400000
|
||||
) # 400k character string (exceeds 300k limit)
|
||||
api_calls.log_return_value(extremely_long_string)
|
||||
|
||||
|
||||
# Should warn about snipping
|
||||
mock_logger.warning.assert_called_with('[ S N I P P E D ]')
|
||||
|
||||
mock_logger.warning.assert_called_with("[ S N I P P E D ]")
|
||||
|
||||
def test_team_hash(self):
|
||||
"""Test team hash generation."""
|
||||
mock_team = {
|
||||
'sname': 'TestTeam',
|
||||
'gmid': 1234567
|
||||
}
|
||||
|
||||
mock_team = {"sname": "TestTeam", "gmid": 1234567}
|
||||
|
||||
result = api_calls.team_hash(mock_team)
|
||||
# Expected format: last char + gmid/6950123 + second-to-last char + gmid/42069123
|
||||
expected = f'm{1234567 / 6950123:.0f}a{1234567 / 42069123:.0f}'
|
||||
expected = f"m{1234567 / 6950123:.0f}a{1234567 / 42069123:.0f}"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@ -106,59 +106,116 @@ class TestUtilityFunctions:
|
||||
|
||||
class TestSpecificFunctions:
|
||||
"""Test specific API wrapper functions."""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_get')
|
||||
@patch("api_calls.db_get")
|
||||
async def test_get_team_by_abbrev_found(self, mock_db_get):
|
||||
"""Test get_team_by_abbrev function when team is found."""
|
||||
mock_db_get.return_value = {
|
||||
'count': 1,
|
||||
'teams': [{'id': 123, 'abbrev': 'TEST', 'name': 'Test Team'}]
|
||||
"count": 1,
|
||||
"teams": [{"id": 123, "abbrev": "TEST", "name": "Test Team"}],
|
||||
}
|
||||
|
||||
result = await api_calls.get_team_by_abbrev('TEST')
|
||||
|
||||
assert result == {'id': 123, 'abbrev': 'TEST', 'name': 'Test Team'}
|
||||
mock_db_get.assert_called_once_with('teams', params=[('abbrev', 'TEST')])
|
||||
|
||||
|
||||
result = await api_calls.get_team_by_abbrev("TEST")
|
||||
|
||||
assert result == {"id": 123, "abbrev": "TEST", "name": "Test Team"}
|
||||
mock_db_get.assert_called_once_with("teams", params=[("abbrev", "TEST")])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_get')
|
||||
@patch("api_calls.db_get")
|
||||
async def test_get_team_by_abbrev_not_found(self, mock_db_get):
|
||||
"""Test get_team_by_abbrev function when team is not found."""
|
||||
mock_db_get.return_value = {
|
||||
'count': 0,
|
||||
'teams': []
|
||||
}
|
||||
|
||||
result = await api_calls.get_team_by_abbrev('NONEXISTENT')
|
||||
|
||||
mock_db_get.return_value = {"count": 0, "teams": []}
|
||||
|
||||
result = await api_calls.get_team_by_abbrev("NONEXISTENT")
|
||||
|
||||
assert result is None
|
||||
mock_db_get.assert_called_once_with('teams', params=[('abbrev', 'NONEXISTENT')])
|
||||
|
||||
mock_db_get.assert_called_once_with("teams", params=[("abbrev", "NONEXISTENT")])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_post')
|
||||
@patch("api_calls.db_post")
|
||||
async def test_post_to_dex(self, mock_db_post):
|
||||
"""Test post_to_dex function."""
|
||||
mock_db_post.return_value = {'id': 456, 'posted': True}
|
||||
|
||||
mock_player = {'id': 123}
|
||||
mock_team = {'id': 456}
|
||||
|
||||
mock_db_post.return_value = {"id": 456, "posted": True}
|
||||
|
||||
mock_player = {"id": 123}
|
||||
mock_team = {"id": 456}
|
||||
|
||||
result = await api_calls.post_to_dex(mock_player, mock_team)
|
||||
|
||||
assert result == {'id': 456, 'posted': True}
|
||||
mock_db_post.assert_called_once_with('paperdex', payload={'player_id': 123, 'team_id': 456})
|
||||
|
||||
assert result == {"id": 456, "posted": True}
|
||||
mock_db_post.assert_called_once_with(
|
||||
"paperdex", payload={"player_id": 123, "team_id": 456}
|
||||
)
|
||||
|
||||
|
||||
class TestEnvironmentConfiguration:
|
||||
"""Test environment-based configuration."""
|
||||
|
||||
|
||||
def test_db_url_exists(self):
|
||||
"""Test that DB_URL is configured."""
|
||||
assert api_calls.DB_URL is not None
|
||||
assert 'manticorum.com' in api_calls.DB_URL
|
||||
|
||||
assert "manticorum.com" in api_calls.DB_URL
|
||||
|
||||
def test_auth_token_exists(self):
|
||||
"""Test that AUTH_TOKEN is configured."""
|
||||
assert api_calls.AUTH_TOKEN is not None
|
||||
assert 'Authorization' in api_calls.AUTH_TOKEN
|
||||
assert "Authorization" in api_calls.AUTH_TOKEN
|
||||
|
||||
|
||||
class TestTimeoutAndRetry:
|
||||
"""Test timeout and retry logic for API calls.
|
||||
|
||||
These tests verify that:
|
||||
1. Default timeout values are correctly set
|
||||
2. db_get has retry parameter, mutation methods do not
|
||||
3. APITimeoutError exception exists and is a subclass of DatabaseError
|
||||
"""
|
||||
|
||||
def test_default_timeout_values(self):
|
||||
"""Test that default timeout values are set correctly.
|
||||
|
||||
Default should be 5 seconds for all functions.
|
||||
db_get should have retries parameter, mutation methods should not.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
# Check db_get signature - should have both timeout and retries
|
||||
sig = inspect.signature(api_calls.db_get)
|
||||
assert sig.parameters["timeout"].default == 5
|
||||
assert sig.parameters["retries"].default == 3
|
||||
|
||||
# Check mutation functions - should have timeout but no retries param
|
||||
for func_name in ["db_post", "db_patch", "db_put", "db_delete"]:
|
||||
func = getattr(api_calls, func_name)
|
||||
sig = inspect.signature(func)
|
||||
assert sig.parameters["timeout"].default == 5, (
|
||||
f"{func_name} should have default timeout=5"
|
||||
)
|
||||
assert "retries" not in sig.parameters, (
|
||||
f"{func_name} should not have retries parameter"
|
||||
)
|
||||
|
||||
def test_api_timeout_error_exists(self):
|
||||
"""Test that APITimeoutError exception is properly defined.
|
||||
|
||||
APITimeoutError should be a subclass of DatabaseError so existing
|
||||
error handlers that catch DatabaseError will also catch timeouts.
|
||||
"""
|
||||
assert issubclass(APITimeoutError, DatabaseError)
|
||||
assert issubclass(APITimeoutError, Exception)
|
||||
|
||||
# Test that it can be instantiated with a message
|
||||
error = APITimeoutError("Test timeout message")
|
||||
assert "Test timeout message" in str(error)
|
||||
|
||||
def test_client_timeout_import(self):
|
||||
"""Test that ClientTimeout is properly imported from aiohttp.
|
||||
|
||||
This verifies the timeout functionality can be used.
|
||||
"""
|
||||
from aiohttp import ClientTimeout
|
||||
|
||||
# Create a timeout object to verify it works
|
||||
timeout = ClientTimeout(total=5)
|
||||
assert timeout.total == 5
|
||||
|
||||
Loading…
Reference in New Issue
Block a user