Team cache validating properly and tests green

This commit is contained in:
Cal Corum 2024-10-12 11:36:09 -05:00
parent 57610fe8a7
commit 5fe91c0729
5 changed files with 46 additions and 20 deletions

View File

@ -33,3 +33,4 @@ README.md
**/tests
**/storage
*_legacy.py
pytest.ini

View File

@ -8,18 +8,19 @@ from discord.ext import commands
from helpers import PD_PLAYERS_ROLE_NAME
from in_game.game_helpers import PUBLIC_FIELDS_CATEGORY_NAME
from in_game.data_cache import get_pd_team
from in_game.gameplay_db import Session, engine, create_db_and_tables, select, Game
from in_game.gameplay_db import Session, engine, create_db_and_tables, select, Game, get_team
def get_games_by_channel(session: Session, channel_id: int) -> list[Game]:
# TODO: test .all() on empty return
return session.exec(select(Game).where(Game.channel_id == channel_id)).all()
def get_channel_game_or_none(session: Session, channel_id: int) -> Game | None:
all_games = get_games_by_channel(session, channel_id)
if len(all_games) > 1:
pass # TODO: raise an exception
err = 'Too many games found in get_channel_game_or_none'
logging.error(f'cogs.gameplay - get_channel_game_or_none - channel_id: {channel_id} / {err}')
raise LookupError(err)
elif len(all_games) == 0:
return None
return all_games[0]
@ -55,17 +56,18 @@ class Gameplay(commands.Cog):
league: Literal['Minor League', 'Flashback', 'Major League', 'Hall of Fame'],
away_team_abbrev: str, home_team_abbrev: str, sp_card_id: int, num_innings: Literal[9, 3] = 9
):
await interaction.response.defer()
await interaction.response.send_message(content=f'Let\'s get your game set up. First, I\'ll check for conflicts...')
with Session(engine) as session:
conflict = get_games_by_channel(session, channel_id=interaction.channel_id)
if len(conflict) > 0:
conflict = get_channel_game_or_none(session, interaction.channel_id)
if conflict is not None:
await interaction.edit_original_response(
content=f'Ope. There is already a game going on in this channel. Please wait for it to complete '
f'before starting a new one.'
)
return
await interaction.edit_original_response(content=f'Now to check that you\'re in the right channel category...')
if interaction.channel.category is None or interaction.channel.category.name != PUBLIC_FIELDS_CATEGORY_NAME:
await interaction.edit_original_response(
content=f'Why don\'t you head down to one of the Public Fields that way other humans can help if anything '
@ -73,10 +75,11 @@ class Gameplay(commands.Cog):
)
return
away_team = await get_pd_team(away_team_abbrev)
await interaction.edit_original_response(content=f'Now to find this away team **{away_team_abbrev.upper()}**')
away_team = await get_team(session, team_abbrev=away_team_abbrev)
await interaction.edit_original_response(
content=f'This channel is ripe for the picking!'
content=f'Hey {away_team.gmname}! {interaction.channel.name} is available so the {away_team.sname} are welcome to play!'
)

View File

@ -1,6 +1,7 @@
import datetime
import logging
from sqlmodel import Session, SQLModel, create_engine, select, Field, Relationship
from sqlalchemy import func
from db_calls import db_get
from helpers import PD_SEASON
@ -70,7 +71,7 @@ class Lineup(SQLModel, table=True):
game: Game = Relationship(back_populates='lineups')
class Team(SQLModel, table=True):
class TeamBase(SQLModel):
id: int = Field(primary_key=True)
abbrev: str = Field(index=True)
sname: str
@ -91,6 +92,10 @@ class Team(SQLModel, table=True):
created: datetime.datetime | None = Field(default=datetime.datetime.now())
class Team(TeamBase, table=True):
pass
async def get_team(
session: Session, team_id: int | None = None, gm_id: int | None = None, team_abbrev: str | None = None, skip_cache: bool = False) -> Team:
if team_id is None and gm_id is None and team_abbrev is None:
@ -105,11 +110,13 @@ async def get_team(
if gm_id is not None:
statement = select(Team).where(Team.gmid == gm_id)
else:
statement = select(Team).where(Team.abbrev.lower() == team_abbrev.lower())
this_team = session.exec(statement).one_or_none
statement = select(Team).where(func.lower(Team.abbrev) == team_abbrev.lower())
this_team = session.exec(statement).one_or_none()
if this_team is not None:
logging.info(f'we found a team: {this_team} / created: {this_team.created}')
tdelta = datetime.datetime.now() - this_team.created
logging.info(f'tdelta: {tdelta}')
if tdelta.total_seconds() < 1209600:
return this_team
else:
@ -117,7 +124,11 @@ async def get_team(
session.commit()
def cache_team(json_data: dict) -> Team:
db_team = Team.model_validate(t_query)
logging.info(f'gameplay_db - get_team - cache_team - writing a team to cache: {json_data}')
valid_team = TeamBase.model_validate(json_data, from_attributes=True)
logging.info(f'gameplay_db - get_team - cache_team - valid_team: {valid_team}')
db_team = Team.model_validate(valid_team)
logging.info(f'gameplay_db - get_team - cache_team - db_team: {db_team}')
session.add(db_team)
session.commit()
session.refresh(db_team)
@ -131,13 +142,13 @@ async def get_team(
elif gm_id is not None:
t_query = await db_get('teams', params=[('season', PD_SEASON), ('gm_id', gm_id)])
if t_query['count'] != 0:
for team in [x for x in t_query['teams'] if 'gauntlet' not in x.abbrev.lower()]:
for team in [x for x in t_query['teams'] if 'gauntlet' not in x['abbrev'].lower()]:
return cache_team(team)
elif team_abbrev is not None:
t_query = await db_get('teams', params=[('season', PD_SEASON), ('abbrev', team_abbrev)])
if t_query['count'] != 0:
for team in [x for x in t_query['teams'] if 'gauntlet' not in x.abbrev.lower()]:
for team in [x for x in t_query['teams'] if 'gauntlet' not in x['abbrev'].lower()]:
return cache_team(team)
err = 'Team not found'
@ -212,10 +223,18 @@ def select_speed_testing():
print(f'len(games): {len(games)}')
def select_all_testing():
with Session(engine) as session:
game_search = session.exec(select(Team)).all()
for game in game_search:
print(f'Game: {game}')
def main():
# create_db_and_tables()
# create_test_games()
select_speed_testing()
# select_all_testing()
if __name__ == "__main__":

View File

@ -1,6 +1,6 @@
from sqlmodel import Session
from in_game.gameplay_db import Game
from in_game.gameplay_db import Game, select
from factory import session_fixture, new_games_fixture
@ -37,3 +37,9 @@ def test_create_game(session: Session, new_games: list[Game]):
assert game_2.game_type == 'minor-league'
def test_select_all_empty(session: Session):
games = session.exec(select(Game)).all()
assert len(games) == 0

View File

@ -33,20 +33,17 @@ def test_create_incomplete_team(session: Session, new_teams: list[Team]):
async def test_team_cache(session: Session, new_teams: list[Team]):
team_1 = new_teams[0]
team_2 = new_teams[1]
team_3 = new_teams[3]
session.add(team_1)
session.add(team_2)
session.add(team_3)
session.commit()
new_team_1 = await get_team(session, team_id=team_1.id)
new_team_2 = await get_team(session, team_id=team_2.id)
new_team_3 = await get_team(session, team_id=team_3.id)
new_team_3 = await get_team(session, team_abbrev='BAL')
assert team_1.created == new_team_1.created
assert team_2.created == new_team_2.created
assert (datetime.datetime.now() - team_3.created).total_seconds() > 1209600
assert (datetime.datetime.now() - new_team_3.created).total_seconds() < 1209600