Add API timeout/retry logic and fix get_team_by_owner for PostgreSQL
- Add APITimeoutError exception and retry logic to db_get - Add timeout handling to db_post, db_put, db_patch, db_delete - Fix get_team_by_owner to prefer non-gauntlet team (PostgreSQL migration fix) - Code formatting cleanup (black)
This commit is contained in:
parent
0e70e94644
commit
4be6afb541
328
api_calls.py
328
api_calls.py
@ -1,35 +1,46 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
import os
|
||||
|
||||
from exceptions import DatabaseError
|
||||
from exceptions import DatabaseError, APITimeoutError
|
||||
|
||||
AUTH_TOKEN = {'Authorization': f'Bearer {os.environ.get("API_TOKEN")}'}
|
||||
AUTH_TOKEN = {"Authorization": f"Bearer {os.environ.get('API_TOKEN')}"}
|
||||
ENV_DATABASE = os.getenv("DATABASE", "dev").lower()
|
||||
DB_URL = 'https://pd.manticorum.com/api' if 'prod' in ENV_DATABASE else 'https://pddev.manticorum.com/api'
|
||||
DB_URL = (
|
||||
"https://pd.manticorum.com/api"
|
||||
if "prod" in ENV_DATABASE
|
||||
else "https://pddev.manticorum.com/api"
|
||||
)
|
||||
master_debug = True
|
||||
PLAYER_CACHE = {}
|
||||
logger = logging.getLogger('discord_app')
|
||||
logger = logging.getLogger("discord_app")
|
||||
|
||||
|
||||
def param_char(other_params):
|
||||
if other_params:
|
||||
return '&'
|
||||
return "&"
|
||||
else:
|
||||
return '?'
|
||||
return "?"
|
||||
|
||||
|
||||
def get_req_url(endpoint: str, api_ver: int = 2, object_id: Optional[int] = None, params: Optional[list] = None):
|
||||
req_url = f'{DB_URL}/v{api_ver}/{endpoint}{"/" if object_id is not None else ""}{object_id if object_id is not None else ""}'
|
||||
def get_req_url(
|
||||
endpoint: str,
|
||||
api_ver: int = 2,
|
||||
object_id: Optional[int] = None,
|
||||
params: Optional[list] = None,
|
||||
):
|
||||
req_url = f"{DB_URL}/v{api_ver}/{endpoint}{'/' if object_id is not None else ''}{object_id if object_id is not None else ''}"
|
||||
|
||||
if params:
|
||||
other_params = False
|
||||
for x in params:
|
||||
req_url += f'{param_char(other_params)}{x[0]}={x[1]}'
|
||||
req_url += f"{param_char(other_params)}{x[0]}={x[1]}"
|
||||
other_params = True
|
||||
|
||||
return req_url
|
||||
@ -42,144 +53,251 @@ def log_return_value(log_string: str):
|
||||
line = log_string[start:end]
|
||||
if len(line) == 0:
|
||||
return
|
||||
logger.info(f'{"\n\nreturn: " if start == 0 else ""}{log_string[start:end]}')
|
||||
logger.info(f"{'\n\nreturn: ' if start == 0 else ''}{log_string[start:end]}")
|
||||
start += 3000
|
||||
end += 3000
|
||||
logger.warning('[ S N I P P E D ]')
|
||||
logger.warning("[ S N I P P E D ]")
|
||||
# if master_debug:
|
||||
# logger.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n')
|
||||
# else:
|
||||
# logger.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n')
|
||||
|
||||
|
||||
async def db_get(endpoint: str, api_ver: int = 2, object_id: Optional[int] = None, params: Optional[list] = None, none_okay: bool = True, timeout: int = 3):
|
||||
async def db_get(
|
||||
endpoint: str,
|
||||
api_ver: int = 2,
|
||||
object_id: Optional[int] = None,
|
||||
params: Optional[list] = None,
|
||||
none_okay: bool = True,
|
||||
timeout: int = 5,
|
||||
retries: int = 3,
|
||||
):
|
||||
"""
|
||||
GET request to the API with timeout and retry logic.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
object_id: Optional object ID to append to URL
|
||||
params: Optional list of (key, value) tuples for query params
|
||||
none_okay: If True, return None on non-200 response; if False, raise DatabaseError
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
retries: Number of retry attempts on timeout (default 3)
|
||||
|
||||
Returns:
|
||||
JSON response or None if none_okay and request failed
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If all retry attempts fail due to timeout
|
||||
DatabaseError: If response is non-200 and none_okay is False
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params)
|
||||
log_string = f'db_get - get: {endpoint} id: {object_id} params: {params}'
|
||||
log_string = f"db_get - get: {endpoint} id: {object_id} params: {params}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.get(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
elif none_okay:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
return None
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.get(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
elif none_okay:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
return None
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
if attempt < retries - 1:
|
||||
wait_time = 2**attempt # 1s, 2s, 4s
|
||||
logger.warning(
|
||||
f"Timeout on GET {endpoint}, retry {attempt + 1}/{retries} in {wait_time}s"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
logger.error(
|
||||
f"Connection timeout to host {req_url} after {retries} attempts"
|
||||
)
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_patch(endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 3):
|
||||
async def db_patch(
|
||||
endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
PATCH request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
object_id: Object ID to patch
|
||||
params: List of (key, value) tuples for query params
|
||||
api_ver: API version (default 2)
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params)
|
||||
log_string = f'db_patch - patch: {endpoint} {params}'
|
||||
log_string = f"db_patch - patch: {endpoint} {params}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.patch(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.patch(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_post(endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 3):
|
||||
async def db_post(
|
||||
endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
POST request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
payload: Optional JSON payload
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver)
|
||||
log_string = f'db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}'
|
||||
log_string = f"db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.post(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.post(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_put(endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 3):
|
||||
async def db_put(
|
||||
endpoint: str, api_ver: int = 2, payload: Optional[dict] = None, timeout: int = 5
|
||||
):
|
||||
"""
|
||||
PUT request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
api_ver: API version (default 2)
|
||||
payload: Optional JSON payload
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver)
|
||||
log_string = f'post:\n{endpoint} payload: {payload}\ntype: {type(payload)}'
|
||||
log_string = f"db_put - put: {endpoint} payload: {payload}\ntype: {type(payload)}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.put(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
|
||||
# retries = 0
|
||||
# while True:
|
||||
# try:
|
||||
# resp = requests.put(req_url, json=payload, headers=AUTH_TOKEN, timeout=timeout)
|
||||
# break
|
||||
# except requests.Timeout as e:
|
||||
# logger.error(f'Post Timeout: {req_url} / retries: {retries} / timeout: {timeout}')
|
||||
# if retries > 1:
|
||||
# raise ConnectionError(f'DB: The internet was a bit too slow for me to grab the data I needed. Please '
|
||||
# f'hang on a few extra seconds and try again.')
|
||||
# timeout += [min(3, timeout), min(5, timeout)][retries]
|
||||
# retries += 1
|
||||
#
|
||||
# if resp.status_code == 200:
|
||||
# data = resp.json()
|
||||
# log_string = f'{data}'
|
||||
# if master_debug:
|
||||
# logger.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}')
|
||||
# else:
|
||||
# logger.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}')
|
||||
# return data
|
||||
# else:
|
||||
# logger.warning(resp.text)
|
||||
# raise ValueError(f'DB: {resp.text}')
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.put(req_url, json=payload) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout=3):
|
||||
async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout: int = 5):
|
||||
"""
|
||||
DELETE request to the API with timeout (no retry - not safe for mutations).
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint path
|
||||
object_id: Object ID to delete
|
||||
api_ver: API version (default 2)
|
||||
timeout: Request timeout in seconds (default 5)
|
||||
|
||||
Raises:
|
||||
APITimeoutError: If request times out
|
||||
DatabaseError: If response is non-200
|
||||
"""
|
||||
req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id)
|
||||
log_string = f'db_delete - delete: {endpoint} {object_id}'
|
||||
log_string = f"db_delete - delete: {endpoint} {object_id}"
|
||||
logger.info(log_string) if master_debug else logger.debug(log_string)
|
||||
|
||||
async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session:
|
||||
async with session.delete(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f'{js}')
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(
|
||||
headers=AUTH_TOKEN, timeout=client_timeout
|
||||
) as session:
|
||||
async with session.delete(req_url) as r:
|
||||
if r.status == 200:
|
||||
js = await r.json()
|
||||
log_return_value(f"{js}")
|
||||
return js
|
||||
else:
|
||||
e = await r.text()
|
||||
logger.error(e)
|
||||
raise DatabaseError(e)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Connection timeout to host {req_url}")
|
||||
raise APITimeoutError(f"Connection timeout to host {req_url}")
|
||||
|
||||
|
||||
async def get_team_by_abbrev(abbrev: str):
|
||||
all_teams = await db_get('teams', params=[('abbrev', abbrev)])
|
||||
all_teams = await db_get("teams", params=[("abbrev", abbrev)])
|
||||
|
||||
if not all_teams or not all_teams['count']:
|
||||
if not all_teams or not all_teams["count"]:
|
||||
return None
|
||||
|
||||
return all_teams['teams'][0]
|
||||
return all_teams["teams"][0]
|
||||
|
||||
|
||||
async def post_to_dex(player, team):
|
||||
return await db_post('paperdex', payload={'team_id': team['id'], 'player_id': player['id']})
|
||||
return await db_post(
|
||||
"paperdex", payload={"team_id": team["id"], "player_id": player["id"]}
|
||||
)
|
||||
|
||||
|
||||
def team_hash(team):
|
||||
hash_string = f'{team["sname"][-1]}{team["gmid"] / 6950123:.0f}{team["sname"][-2]}{team["gmid"] / 42069123:.0f}'
|
||||
hash_string = f"{team['sname'][-1]}{team['gmid'] / 6950123:.0f}{team['sname'][-2]}{team['gmid'] / 42069123:.0f}"
|
||||
return hash_string
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger('discord_app')
|
||||
logger = logging.getLogger("discord_app")
|
||||
|
||||
|
||||
def log_errors(func):
|
||||
@ -15,16 +15,21 @@ def log_errors(func):
|
||||
except Exception as e:
|
||||
logger.error(func.__name__)
|
||||
log_exception(e)
|
||||
return result # type: ignore
|
||||
return result # type: ignore
|
||||
|
||||
return wrap
|
||||
|
||||
def log_exception(e: Exception, msg: str = '', level: Literal['debug', 'error', 'info', 'warn'] = 'error'):
|
||||
if level == 'debug':
|
||||
|
||||
def log_exception(
|
||||
e: Exception,
|
||||
msg: str = "",
|
||||
level: Literal["debug", "error", "info", "warn"] = "error",
|
||||
):
|
||||
if level == "debug":
|
||||
logger.debug(msg, exc_info=True, stack_info=True)
|
||||
elif level == 'error':
|
||||
elif level == "error":
|
||||
logger.error(msg, exc_info=True, stack_info=True)
|
||||
elif level == 'info':
|
||||
elif level == "info":
|
||||
logger.info(msg, exc_info=True, stack_info=True)
|
||||
else:
|
||||
logger.warning(msg, exc_info=True, stack_info=True)
|
||||
@ -35,6 +40,7 @@ def log_exception(e: Exception, msg: str = '', level: Literal['debug', 'error',
|
||||
else:
|
||||
raise e(msg) # If 'e' is an exception class
|
||||
|
||||
|
||||
class GameException(Exception):
|
||||
pass
|
||||
|
||||
@ -75,6 +81,12 @@ class DatabaseError(GameException):
|
||||
pass
|
||||
|
||||
|
||||
class APITimeoutError(DatabaseError):
|
||||
"""Raised when an API call times out after all retries."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PositionNotFoundException(GameException):
|
||||
pass
|
||||
|
||||
|
||||
2543
gauntlets.py
2543
gauntlets.py
File diff suppressed because it is too large
Load Diff
1885
helpers.py
1885
helpers.py
File diff suppressed because it is too large
Load Diff
1886
helpers/main.py
1886
helpers/main.py
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import aiohttp
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from exceptions import DatabaseError
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from exceptions import DatabaseError, APITimeoutError
|
||||
import api_calls
|
||||
|
||||
|
||||
@ -10,93 +11,92 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_param_char_with_params(self):
|
||||
"""Test param_char returns & when other_params is truthy."""
|
||||
assert api_calls.param_char(True) == '&'
|
||||
assert api_calls.param_char(['param1']) == '&'
|
||||
assert api_calls.param_char({'key': 'value'}) == '&'
|
||||
assert api_calls.param_char('some_param') == '&'
|
||||
assert api_calls.param_char(True) == "&"
|
||||
assert api_calls.param_char(["param1"]) == "&"
|
||||
assert api_calls.param_char({"key": "value"}) == "&"
|
||||
assert api_calls.param_char("some_param") == "&"
|
||||
|
||||
def test_param_char_without_params(self):
|
||||
"""Test param_char returns ? when other_params is falsy."""
|
||||
assert api_calls.param_char(False) == '?'
|
||||
assert api_calls.param_char(None) == '?'
|
||||
assert api_calls.param_char([]) == '?'
|
||||
assert api_calls.param_char({}) == '?'
|
||||
assert api_calls.param_char('') == '?'
|
||||
assert api_calls.param_char(0) == '?'
|
||||
assert api_calls.param_char(False) == "?"
|
||||
assert api_calls.param_char(None) == "?"
|
||||
assert api_calls.param_char([]) == "?"
|
||||
assert api_calls.param_char({}) == "?"
|
||||
assert api_calls.param_char("") == "?"
|
||||
assert api_calls.param_char(0) == "?"
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_basic(self):
|
||||
"""Test basic URL generation without object_id or params."""
|
||||
result = api_calls.get_req_url('teams')
|
||||
expected = 'https://test.example.com/api/v2/teams'
|
||||
result = api_calls.get_req_url("teams")
|
||||
expected = "https://test.example.com/api/v2/teams"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_version(self):
|
||||
"""Test URL generation with custom API version."""
|
||||
result = api_calls.get_req_url('teams', api_ver=1)
|
||||
expected = 'https://test.example.com/api/v1/teams'
|
||||
result = api_calls.get_req_url("teams", api_ver=1)
|
||||
expected = "https://test.example.com/api/v1/teams"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_object_id(self):
|
||||
"""Test URL generation with object_id."""
|
||||
result = api_calls.get_req_url('teams', object_id=123)
|
||||
expected = 'https://test.example.com/api/v2/teams/123'
|
||||
result = api_calls.get_req_url("teams", object_id=123)
|
||||
expected = "https://test.example.com/api/v2/teams/123"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_with_params(self):
|
||||
"""Test URL generation with parameters."""
|
||||
params = [('season', '7'), ('active', 'true')]
|
||||
result = api_calls.get_req_url('teams', params=params)
|
||||
expected = 'https://test.example.com/api/v2/teams?season=7&active=true'
|
||||
params = [("season", "7"), ("active", "true")]
|
||||
result = api_calls.get_req_url("teams", params=params)
|
||||
expected = "https://test.example.com/api/v2/teams?season=7&active=true"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.DB_URL', 'https://test.example.com/api')
|
||||
@patch("api_calls.DB_URL", "https://test.example.com/api")
|
||||
def test_get_req_url_complete(self):
|
||||
"""Test URL generation with all parameters."""
|
||||
params = [('season', '7'), ('limit', '10')]
|
||||
result = api_calls.get_req_url('games', api_ver=1, object_id=456, params=params)
|
||||
expected = 'https://test.example.com/api/v1/games/456?season=7&limit=10'
|
||||
params = [("season", "7"), ("limit", "10")]
|
||||
result = api_calls.get_req_url("games", api_ver=1, object_id=456, params=params)
|
||||
expected = "https://test.example.com/api/v1/games/456?season=7&limit=10"
|
||||
assert result == expected
|
||||
|
||||
@patch('api_calls.logger')
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_short_string(self, mock_logger):
|
||||
"""Test logging short return values."""
|
||||
api_calls.log_return_value('Short log message')
|
||||
mock_logger.info.assert_called_once_with('\n\nreturn: Short log message')
|
||||
api_calls.log_return_value("Short log message")
|
||||
mock_logger.info.assert_called_once_with("\n\nreturn: Short log message")
|
||||
|
||||
@patch('api_calls.logger')
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_long_string(self, mock_logger):
|
||||
"""Test logging long return values that get chunked."""
|
||||
long_string = 'A' * 5000 # 5000 character string
|
||||
long_string = "A" * 5000 # 5000 character string
|
||||
api_calls.log_return_value(long_string)
|
||||
|
||||
# Should have been called twice (first chunk + second chunk)
|
||||
assert mock_logger.info.call_count == 2
|
||||
# First call should include the "return:" prefix
|
||||
assert '\n\nreturn: ' in mock_logger.info.call_args_list[0][0][0]
|
||||
assert "\n\nreturn: " in mock_logger.info.call_args_list[0][0][0]
|
||||
|
||||
@patch('api_calls.logger')
|
||||
@patch("api_calls.logger")
|
||||
def test_log_return_value_extremely_long_string(self, mock_logger):
|
||||
"""Test logging extremely long return values that get snipped."""
|
||||
extremely_long_string = 'B' * 400000 # 400k character string (exceeds 300k limit)
|
||||
extremely_long_string = (
|
||||
"B" * 400000
|
||||
) # 400k character string (exceeds 300k limit)
|
||||
api_calls.log_return_value(extremely_long_string)
|
||||
|
||||
# Should warn about snipping
|
||||
mock_logger.warning.assert_called_with('[ S N I P P E D ]')
|
||||
mock_logger.warning.assert_called_with("[ S N I P P E D ]")
|
||||
|
||||
def test_team_hash(self):
|
||||
"""Test team hash generation."""
|
||||
mock_team = {
|
||||
'sname': 'TestTeam',
|
||||
'gmid': 1234567
|
||||
}
|
||||
mock_team = {"sname": "TestTeam", "gmid": 1234567}
|
||||
|
||||
result = api_calls.team_hash(mock_team)
|
||||
# Expected format: last char + gmid/6950123 + second-to-last char + gmid/42069123
|
||||
expected = f'm{1234567 / 6950123:.0f}a{1234567 / 42069123:.0f}'
|
||||
expected = f"m{1234567 / 6950123:.0f}a{1234567 / 42069123:.0f}"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@ -108,46 +108,45 @@ class TestSpecificFunctions:
|
||||
"""Test specific API wrapper functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_get')
|
||||
@patch("api_calls.db_get")
|
||||
async def test_get_team_by_abbrev_found(self, mock_db_get):
|
||||
"""Test get_team_by_abbrev function when team is found."""
|
||||
mock_db_get.return_value = {
|
||||
'count': 1,
|
||||
'teams': [{'id': 123, 'abbrev': 'TEST', 'name': 'Test Team'}]
|
||||
"count": 1,
|
||||
"teams": [{"id": 123, "abbrev": "TEST", "name": "Test Team"}],
|
||||
}
|
||||
|
||||
result = await api_calls.get_team_by_abbrev('TEST')
|
||||
result = await api_calls.get_team_by_abbrev("TEST")
|
||||
|
||||
assert result == {'id': 123, 'abbrev': 'TEST', 'name': 'Test Team'}
|
||||
mock_db_get.assert_called_once_with('teams', params=[('abbrev', 'TEST')])
|
||||
assert result == {"id": 123, "abbrev": "TEST", "name": "Test Team"}
|
||||
mock_db_get.assert_called_once_with("teams", params=[("abbrev", "TEST")])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_get')
|
||||
@patch("api_calls.db_get")
|
||||
async def test_get_team_by_abbrev_not_found(self, mock_db_get):
|
||||
"""Test get_team_by_abbrev function when team is not found."""
|
||||
mock_db_get.return_value = {
|
||||
'count': 0,
|
||||
'teams': []
|
||||
}
|
||||
mock_db_get.return_value = {"count": 0, "teams": []}
|
||||
|
||||
result = await api_calls.get_team_by_abbrev('NONEXISTENT')
|
||||
result = await api_calls.get_team_by_abbrev("NONEXISTENT")
|
||||
|
||||
assert result is None
|
||||
mock_db_get.assert_called_once_with('teams', params=[('abbrev', 'NONEXISTENT')])
|
||||
mock_db_get.assert_called_once_with("teams", params=[("abbrev", "NONEXISTENT")])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api_calls.db_post')
|
||||
@patch("api_calls.db_post")
|
||||
async def test_post_to_dex(self, mock_db_post):
|
||||
"""Test post_to_dex function."""
|
||||
mock_db_post.return_value = {'id': 456, 'posted': True}
|
||||
mock_db_post.return_value = {"id": 456, "posted": True}
|
||||
|
||||
mock_player = {'id': 123}
|
||||
mock_team = {'id': 456}
|
||||
mock_player = {"id": 123}
|
||||
mock_team = {"id": 456}
|
||||
|
||||
result = await api_calls.post_to_dex(mock_player, mock_team)
|
||||
|
||||
assert result == {'id': 456, 'posted': True}
|
||||
mock_db_post.assert_called_once_with('paperdex', payload={'player_id': 123, 'team_id': 456})
|
||||
assert result == {"id": 456, "posted": True}
|
||||
mock_db_post.assert_called_once_with(
|
||||
"paperdex", payload={"player_id": 123, "team_id": 456}
|
||||
)
|
||||
|
||||
|
||||
class TestEnvironmentConfiguration:
|
||||
@ -156,9 +155,67 @@ class TestEnvironmentConfiguration:
|
||||
def test_db_url_exists(self):
|
||||
"""Test that DB_URL is configured."""
|
||||
assert api_calls.DB_URL is not None
|
||||
assert 'manticorum.com' in api_calls.DB_URL
|
||||
assert "manticorum.com" in api_calls.DB_URL
|
||||
|
||||
def test_auth_token_exists(self):
|
||||
"""Test that AUTH_TOKEN is configured."""
|
||||
assert api_calls.AUTH_TOKEN is not None
|
||||
assert 'Authorization' in api_calls.AUTH_TOKEN
|
||||
assert "Authorization" in api_calls.AUTH_TOKEN
|
||||
|
||||
|
||||
class TestTimeoutAndRetry:
|
||||
"""Test timeout and retry logic for API calls.
|
||||
|
||||
These tests verify that:
|
||||
1. Default timeout values are correctly set
|
||||
2. db_get has retry parameter, mutation methods do not
|
||||
3. APITimeoutError exception exists and is a subclass of DatabaseError
|
||||
"""
|
||||
|
||||
def test_default_timeout_values(self):
|
||||
"""Test that default timeout values are set correctly.
|
||||
|
||||
Default should be 5 seconds for all functions.
|
||||
db_get should have retries parameter, mutation methods should not.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
# Check db_get signature - should have both timeout and retries
|
||||
sig = inspect.signature(api_calls.db_get)
|
||||
assert sig.parameters["timeout"].default == 5
|
||||
assert sig.parameters["retries"].default == 3
|
||||
|
||||
# Check mutation functions - should have timeout but no retries param
|
||||
for func_name in ["db_post", "db_patch", "db_put", "db_delete"]:
|
||||
func = getattr(api_calls, func_name)
|
||||
sig = inspect.signature(func)
|
||||
assert sig.parameters["timeout"].default == 5, (
|
||||
f"{func_name} should have default timeout=5"
|
||||
)
|
||||
assert "retries" not in sig.parameters, (
|
||||
f"{func_name} should not have retries parameter"
|
||||
)
|
||||
|
||||
def test_api_timeout_error_exists(self):
|
||||
"""Test that APITimeoutError exception is properly defined.
|
||||
|
||||
APITimeoutError should be a subclass of DatabaseError so existing
|
||||
error handlers that catch DatabaseError will also catch timeouts.
|
||||
"""
|
||||
assert issubclass(APITimeoutError, DatabaseError)
|
||||
assert issubclass(APITimeoutError, Exception)
|
||||
|
||||
# Test that it can be instantiated with a message
|
||||
error = APITimeoutError("Test timeout message")
|
||||
assert "Test timeout message" in str(error)
|
||||
|
||||
def test_client_timeout_import(self):
|
||||
"""Test that ClientTimeout is properly imported from aiohttp.
|
||||
|
||||
This verifies the timeout functionality can be used.
|
||||
"""
|
||||
from aiohttp import ClientTimeout
|
||||
|
||||
# Create a timeout object to verify it works
|
||||
timeout = ClientTimeout(total=5)
|
||||
assert timeout.total == 5
|
||||
|
||||
Loading…
Reference in New Issue
Block a user