From 512efe98c4bb0bb9a9363928e47ec534caeec000 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sat, 12 Oct 2024 18:22:13 -0500 Subject: [PATCH] db_calls -> api_calls gameplay_db -> gameplay_models new-game campaign in progress added Player model --- db_calls.py => api_calls.py | 352 +++++++++--------- cogs/admins.py | 4 +- cogs/economy.py | 2 +- cogs/gameplay.py | 143 +++++-- cogs/gameplay_legacy.py | 2 +- cogs/players.py | 2 +- db_calls_gameplay.py | 2 +- docker-compose.yml | 2 - gauntlets.py | 2 +- helpers.py | 10 +- in_game/ai_manager.py | 2 +- in_game/data_cache.py | 2 +- in_game/game_helpers.py | 2 +- .../{gameplay_db.py => gameplay_models.py} | 58 ++- in_game/gameplay_queries.py | 21 ++ migrations/env.py | 2 +- tests/test_gameplay_db_game.py | 57 ++- tests/test_gameplay_db_lineup.py | 2 +- 18 files changed, 430 insertions(+), 237 deletions(-) rename db_calls.py => api_calls.py (97%) rename in_game/{gameplay_db.py => gameplay_models.py} (81%) create mode 100644 in_game/gameplay_queries.py diff --git a/db_calls.py b/api_calls.py similarity index 97% rename from db_calls.py rename to api_calls.py index 3c4a4db..0065586 100644 --- a/db_calls.py +++ b/api_calls.py @@ -1,176 +1,176 @@ -import datetime -from dataclasses import dataclass -from typing import Optional - -import logging -import aiohttp -import os - -AUTH_TOKEN = {'Authorization': f'Bearer {os.environ.get("API_TOKEN")}'} -DB_URL = 'https://pd.manticorum.com/api' -master_debug = True -alt_database = 'dev' -PLAYER_CACHE = {} - -if alt_database == 'dev': - DB_URL = 'https://pddev.manticorum.com/api' - - -def param_char(other_params): - if other_params: - return '&' - else: - return '?' - - -def get_req_url(endpoint: str, api_ver: int = 2, object_id: int = None, params: list = None): - req_url = f'{DB_URL}/v{api_ver}/{endpoint}{"/" if object_id is not None else ""}{object_id if object_id is not None else ""}' - - if params: - other_params = False - for x in params: - req_url += f'{param_char(other_params)}{x[0]}={x[1]}' - other_params = True - - return req_url - - -def log_return_value(log_string: str): - if master_debug: - logging.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n') - else: - logging.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n') - - -async def db_get(endpoint: str, api_ver: int = 2, object_id: int = None, params: list = None, none_okay: bool = True, - timeout: int = 3): - req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params) - log_string = f'db_get - get: {endpoint} id: {object_id} params: {params}' - logging.info(log_string) if master_debug else logging.debug(log_string) - - async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: - async with session.get(req_url) as r: - if r.status == 200: - js = await r.json() - log_return_value(f'{js}') - return js - elif none_okay: - e = await r.text() - logging.error(e) - return None - else: - e = await r.text() - logging.error(e) - raise ValueError(f'DB: {e}') - - -async def db_patch(endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 3): - req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params) - log_string = f'db_patch - patch: {endpoint} {params}' - logging.info(log_string) if master_debug else logging.debug(log_string) - - async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: - async with session.patch(req_url) as r: - if r.status == 200: - js = await r.json() - log_return_value(f'{js}') - return js - else: - e = await r.text() - logging.error(e) - raise ValueError(f'DB: {e}') - - -async def db_post(endpoint: str, api_ver: int = 2, payload: dict = None, timeout: int = 3): - req_url = get_req_url(endpoint, api_ver=api_ver) - log_string = f'db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}' - logging.info(log_string) if master_debug else logging.debug(log_string) - - async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: - async with session.post(req_url, json=payload) as r: - if r.status == 200: - js = await r.json() - log_return_value(f'{js}') - return js - else: - e = await r.text() - logging.error(e) - raise ValueError(f'DB: {e}') - - -async def db_put(endpoint: str, api_ver: int = 2, payload: dict = None, timeout: int = 3): - req_url = get_req_url(endpoint, api_ver=api_ver) - log_string = f'post:\n{endpoint} payload: {payload}\ntype: {type(payload)}' - logging.info(log_string) if master_debug else logging.debug(log_string) - - async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: - async with session.put(req_url, json=payload) as r: - if r.status == 200: - js = await r.json() - log_return_value(f'{js}') - return js - else: - e = await r.text() - logging.error(e) - raise ValueError(f'DB: {e}') - - # retries = 0 - # while True: - # try: - # resp = requests.put(req_url, json=payload, headers=AUTH_TOKEN, timeout=timeout) - # break - # except requests.Timeout as e: - # logging.error(f'Post Timeout: {req_url} / retries: {retries} / timeout: {timeout}') - # if retries > 1: - # raise ConnectionError(f'DB: The internet was a bit too slow for me to grab the data I needed. Please ' - # f'hang on a few extra seconds and try again.') - # timeout += [min(3, timeout), min(5, timeout)][retries] - # retries += 1 - # - # if resp.status_code == 200: - # data = resp.json() - # log_string = f'{data}' - # if master_debug: - # logging.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}') - # else: - # logging.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}') - # return data - # else: - # logging.warning(resp.text) - # raise ValueError(f'DB: {resp.text}') - - -async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout=3): - req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id) - log_string = f'db_delete - delete: {endpoint} {object_id}' - logging.info(log_string) if master_debug else logging.debug(log_string) - - async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: - async with session.delete(req_url) as r: - if r.status == 200: - js = await r.json() - log_return_value(f'{js}') - return js - else: - e = await r.text() - logging.error(e) - raise ValueError(f'DB: {e}') - - -async def get_team_by_abbrev(abbrev: str): - all_teams = await db_get('teams', params=[('abbrev', abbrev)]) - - if not all_teams or not all_teams['count']: - return None - - return all_teams['teams'][0] - - -async def post_to_dex(player, team): - return await db_post('paperdex', payload={'team_id': team['id'], 'player_id': player['id']}) - - -def team_hash(team): - hash_string = f'{team["sname"][-1]}{team["gmid"] / 6950123:.0f}{team["sname"][-2]}{team["gmid"] / 42069123:.0f}' - return hash_string - +import datetime +from dataclasses import dataclass +from typing import Optional + +import logging +import aiohttp +import os + +AUTH_TOKEN = {'Authorization': f'Bearer {os.environ.get("API_TOKEN")}'} +DB_URL = 'https://pd.manticorum.com/api' +master_debug = True +alt_database = 'dev' +PLAYER_CACHE = {} + +if alt_database == 'dev': + DB_URL = 'https://pddev.manticorum.com/api' + + +def param_char(other_params): + if other_params: + return '&' + else: + return '?' + + +def get_req_url(endpoint: str, api_ver: int = 2, object_id: int = None, params: list = None): + req_url = f'{DB_URL}/v{api_ver}/{endpoint}{"/" if object_id is not None else ""}{object_id if object_id is not None else ""}' + + if params: + other_params = False + for x in params: + req_url += f'{param_char(other_params)}{x[0]}={x[1]}' + other_params = True + + return req_url + + +def log_return_value(log_string: str): + if master_debug: + logging.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n') + else: + logging.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}\n') + + +async def db_get(endpoint: str, api_ver: int = 2, object_id: int = None, params: list = None, none_okay: bool = True, + timeout: int = 3): + req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params) + log_string = f'db_get - get: {endpoint} id: {object_id} params: {params}' + logging.info(log_string) if master_debug else logging.debug(log_string) + + async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: + async with session.get(req_url) as r: + if r.status == 200: + js = await r.json() + log_return_value(f'{js}') + return js + elif none_okay: + e = await r.text() + logging.error(e) + return None + else: + e = await r.text() + logging.error(e) + raise ValueError(f'DB: {e}') + + +async def db_patch(endpoint: str, object_id: int, params: list, api_ver: int = 2, timeout: int = 3): + req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id, params=params) + log_string = f'db_patch - patch: {endpoint} {params}' + logging.info(log_string) if master_debug else logging.debug(log_string) + + async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: + async with session.patch(req_url) as r: + if r.status == 200: + js = await r.json() + log_return_value(f'{js}') + return js + else: + e = await r.text() + logging.error(e) + raise ValueError(f'DB: {e}') + + +async def db_post(endpoint: str, api_ver: int = 2, payload: dict = None, timeout: int = 3): + req_url = get_req_url(endpoint, api_ver=api_ver) + log_string = f'db_post - post: {endpoint} payload: {payload}\ntype: {type(payload)}' + logging.info(log_string) if master_debug else logging.debug(log_string) + + async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: + async with session.post(req_url, json=payload) as r: + if r.status == 200: + js = await r.json() + log_return_value(f'{js}') + return js + else: + e = await r.text() + logging.error(e) + raise ValueError(f'DB: {e}') + + +async def db_put(endpoint: str, api_ver: int = 2, payload: dict = None, timeout: int = 3): + req_url = get_req_url(endpoint, api_ver=api_ver) + log_string = f'post:\n{endpoint} payload: {payload}\ntype: {type(payload)}' + logging.info(log_string) if master_debug else logging.debug(log_string) + + async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: + async with session.put(req_url, json=payload) as r: + if r.status == 200: + js = await r.json() + log_return_value(f'{js}') + return js + else: + e = await r.text() + logging.error(e) + raise ValueError(f'DB: {e}') + + # retries = 0 + # while True: + # try: + # resp = requests.put(req_url, json=payload, headers=AUTH_TOKEN, timeout=timeout) + # break + # except requests.Timeout as e: + # logging.error(f'Post Timeout: {req_url} / retries: {retries} / timeout: {timeout}') + # if retries > 1: + # raise ConnectionError(f'DB: The internet was a bit too slow for me to grab the data I needed. Please ' + # f'hang on a few extra seconds and try again.') + # timeout += [min(3, timeout), min(5, timeout)][retries] + # retries += 1 + # + # if resp.status_code == 200: + # data = resp.json() + # log_string = f'{data}' + # if master_debug: + # logging.info(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}') + # else: + # logging.debug(f'return: {log_string[:1200]}{" [ S N I P P E D ]" if len(log_string) > 1200 else ""}') + # return data + # else: + # logging.warning(resp.text) + # raise ValueError(f'DB: {resp.text}') + + +async def db_delete(endpoint: str, object_id: int, api_ver: int = 2, timeout=3): + req_url = get_req_url(endpoint, api_ver=api_ver, object_id=object_id) + log_string = f'db_delete - delete: {endpoint} {object_id}' + logging.info(log_string) if master_debug else logging.debug(log_string) + + async with aiohttp.ClientSession(headers=AUTH_TOKEN) as session: + async with session.delete(req_url) as r: + if r.status == 200: + js = await r.json() + log_return_value(f'{js}') + return js + else: + e = await r.text() + logging.error(e) + raise ValueError(f'DB: {e}') + + +async def get_team_by_abbrev(abbrev: str): + all_teams = await db_get('teams', params=[('abbrev', abbrev)]) + + if not all_teams or not all_teams['count']: + return None + + return all_teams['teams'][0] + + +async def post_to_dex(player, team): + return await db_post('paperdex', payload={'team_id': team['id'], 'player_id': player['id']}) + + +def team_hash(team): + hash_string = f'{team["sname"][-1]}{team["gmid"] / 6950123:.0f}{team["sname"][-2]}{team["gmid"] / 42069123:.0f}' + return hash_string + diff --git a/cogs/admins.py b/cogs/admins.py index d6c8830..efd7d9b 100644 --- a/cogs/admins.py +++ b/cogs/admins.py @@ -1,10 +1,10 @@ import csv import json -import db_calls +import api_calls import db_calls_gameplay from helpers import * -from db_calls import * +from api_calls import * from discord import Member from discord.ext import commands, tasks from discord import app_commands diff --git a/cogs/economy.py b/cogs/economy.py index 4684da7..9d7a5ca 100644 --- a/cogs/economy.py +++ b/cogs/economy.py @@ -17,7 +17,7 @@ from discord.app_commands import Choice import datetime import pygsheets -from db_calls import db_get, db_post, db_patch, db_delete, get_team_by_abbrev +from api_calls import db_get, db_post, db_patch, db_delete, get_team_by_abbrev from help_text import * # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' diff --git a/cogs/gameplay.py b/cogs/gameplay.py index 46a0a76..ec81387 100644 --- a/cogs/gameplay.py +++ b/cogs/gameplay.py @@ -1,29 +1,18 @@ +import enum import logging from typing import Literal import discord from discord import app_commands +from discord.app_commands import Choice from discord.ext import commands -from helpers import PD_PLAYERS_ROLE_NAME +from api_calls import db_get +from helpers import PD_PLAYERS_ROLE_NAME, user_has_role from in_game.game_helpers import PUBLIC_FIELDS_CATEGORY_NAME from in_game.data_cache import get_pd_team -from in_game.gameplay_db import Session, engine, create_db_and_tables, select, Game, get_team - - -def get_games_by_channel(session: Session, channel_id: int) -> list[Game]: - return session.exec(select(Game).where(Game.channel_id == channel_id)).all() - - -def get_channel_game_or_none(session: Session, channel_id: int) -> Game | None: - all_games = get_games_by_channel(session, channel_id) - if len(all_games) > 1: - err = 'Too many games found in get_channel_game_or_none' - logging.error(f'cogs.gameplay - get_channel_game_or_none - channel_id: {channel_id} / {err}') - raise LookupError(err) - elif len(all_games) == 0: - return None - return all_games[0] +from in_game.gameplay_models import Session, engine, create_db_and_tables, select, Game, get_team +from in_game.gameplay_queries import get_channel_game_or_none, get_games_by_team_id class Gameplay(commands.Cog): @@ -38,25 +27,26 @@ class Gameplay(commands.Cog): await ctx.send(f'{error[:1600]}') group_new_game = app_commands.Group(name='new-game', description='Start a new baseball game') - + @group_new_game.command(name='mlb-campaign', description='Start a new MLB campaign game against an AI') @app_commands.describe( sp_card_id='Light gray number to the left of the pitcher\'s name on your depth chart' ) - # @app_commands.rename( - # league='campaign', - # away_team_abbrev='away team abbrev', - # home_team_abbrev='home team abbrev', - # sp_card_id='sp card id', - # num_innings='number of innings' - # ) + @app_commands.choices(league=[ + # Choice(name='Minor League', value='minor-league'), + # Choice(name='Flashback', value='flashback'), + # Choice(name='Major League', value='major-league'), + # Choice(name='Hall of Fame', value='hall-of-fame') + Choice(name='minor-league', value='Minor League'), + Choice(name='flashback', value='Flashback'), + Choice(name='major-league', value='Major League'), + Choice(name='hall-of-fame', value='Hall of Fame') + ]) @app_commands.checks.has_any_role(PD_PLAYERS_ROLE_NAME) async def new_game_mlb_campaign_command( - self, interaction: discord.Interaction, - league: Literal['Minor League', 'Flashback', 'Major League', 'Hall of Fame'], - away_team_abbrev: str, home_team_abbrev: str, sp_card_id: int, num_innings: Literal[9, 3] = 9 + self, interaction: discord.Interaction, league: Choice[str], away_team_abbrev: str, home_team_abbrev: str, sp_card_id: int ): - await interaction.response.send_message(content=f'Let\'s get your game set up. First, I\'ll check for conflicts...') + await interaction.response.defer() with Session(engine) as session: conflict = get_channel_game_or_none(session, interaction.channel_id) @@ -67,7 +57,7 @@ class Gameplay(commands.Cog): ) return - await interaction.edit_original_response(content=f'Now to check that you\'re in the right channel category...') + # await interaction.edit_original_response(content=f'Now to check that you\'re in the right channel category...') if interaction.channel.category is None or interaction.channel.category.name != PUBLIC_FIELDS_CATEGORY_NAME: await interaction.edit_original_response( content=f'Why don\'t you head down to one of the Public Fields that way other humans can help if anything ' @@ -75,12 +65,93 @@ class Gameplay(commands.Cog): ) return - await interaction.edit_original_response(content=f'Now to find this away team **{away_team_abbrev.upper()}**') - away_team = await get_team(session, team_abbrev=away_team_abbrev) - - await interaction.edit_original_response( - content=f'Hey {away_team.gmname}! {interaction.channel.name} is available so the {away_team.sname} are welcome to play!' - ) + # await interaction.edit_original_response(content=f'Now to find this away team **{away_team_abbrev.upper()}**') + try: + away_team = await get_team(session, team_abbrev=away_team_abbrev) + except LookupError as e: + await interaction.edit_original_response( + content=f'Hm. I\'m not sure who **{away_team_abbrev}** is - check on that and try again!' + ) + return + try: + home_team = await get_team(session, team_abbrev=home_team_abbrev) + except LookupError as e: + await interaction.edit_original_response( + content=f'Hm. I\'m not sure who **{home_team_abbrev}** is - check on that and try again!' + ) + return + + if not away_team.is_ai ^ home_team.is_ai: + await interaction.edit_original_response( + content=f'I don\'t see an AI team in this MLB Campaign game. Run `/new-game mlb-campaign` again with an AI for a campaign game or `/new-game ` for a PvP game.' + ) + return + + ai_team = away_team if away_team.is_ai else home_team + human_team = away_team if home_team.is_ai else away_team + + conflict_games = get_games_by_team_id(session, team_id=human_team.id) + if len(conflict_games) > 0: + await interaction.edit_original_response( + content=f'Ope. The {human_team.sname} are already playing over in {interaction.guild.get_channel(conflict_games[0].channel_id).mention}' + ) + return + + current = await db_get('current') + week_num = current['week'] + logging.info(f'gameplay - new_game_mlb_campaign - Season: {current["season"]} / Week: {week_num} / Away Team: {away_team.description} / Home Team: {home_team.description}') + + def role_error(required_role: str, league_name: str, lower_league: str): + return f'Ope. Looks like you haven\'t received the **{required_role}** role, yet!\n\nTo play **{league_name}** games, you need to defeat all 30 MLB teams in the {lower_league} campaign. You can see your progress with `/record`.\n\nIf you have completed the {lower_league} campaign, go ping Cal to get your new role!' + + if league.name == 'flashback': + if not user_has_role(interaction.user, 'PD - Major League'): + await interaction.edit_original_response( + content=role_error('PD - Major League', league_name='Flashback', lower_league='Minor League') + ) + return + elif league.name == 'major-league': + if not user_has_role(interaction.user, 'PD - Major League'): + await interaction.edit_original_response( + content=role_error('PD - Major League', league_name='Major League', lower_league='Minor League') + ) + return + elif league.name == 'hall-of-fame': + if not user_has_role(interaction.user, 'PD - Hall of Fame'): + await interaction.edit_original_response( + content=role_error('PD - Hall of Fame', league_name='Hall of Fame', lower_league='Major League') + ) + return + + this_game = Game( + away_team_id=away_team.id, + home_team_id=home_team.id, + channel_id=interaction.channel_id, + season=current['season'], + week_num=week_num, + first_message=None if interaction.message is None else interaction.message.channel.id, + ai_team='away' if away_team.is_ai else 'home', + game_type=league.name + ) + # session.add(this_game) + # session.commit() + # session.refresh(this_game) + + game_info_log = f'Game {this_game.id} ({league.value}) between {away_team.description} and {home_team.description} / first message: {this_game.first_message}' + logging.info(game_info_log) + await interaction.channel.send(content=game_info_log) + + # Get Human SP card + + # Get AI SP + + # Get AI Lineup + + session.delete(this_game) + session.commit() + + await interaction.channel.send(content='I also deleted that game for ~~science~~ testing.') + diff --git a/cogs/gameplay_legacy.py b/cogs/gameplay_legacy.py index 462e275..4d95461 100644 --- a/cogs/gameplay_legacy.py +++ b/cogs/gameplay_legacy.py @@ -25,7 +25,7 @@ from in_game.game_helpers import single_onestar, single_wellhit, double_twostar, runner_on_first, runner_on_second, runner_on_third, gb_result_1, gb_result_2, gb_result_3, gb_result_4, \ gb_result_5, gb_result_6, gb_result_7, gb_result_8, gb_result_9, gb_result_10, gb_result_11, gb_result_12, \ gb_result_13, gb_decide, show_outfield_cards, legal_check, get_pitcher -from db_calls import db_get, db_patch, db_post, db_delete, get_team_by_abbrev +from api_calls import db_get, db_patch, db_post, db_delete, get_team_by_abbrev from db_calls_gameplay import StratGame, StratPlay, post_game, patch_game, get_game_team, post_lineups, make_sub, get_player, player_link, get_team_lineups, \ get_current_play, post_play, get_one_lineup, advance_runners, patch_play, complete_play, get_batting_stats, \ get_pitching_stats, undo_play, get_latest_play, advance_one_runner, count_team_games, \ diff --git a/cogs/players.py b/cogs/players.py index ef04961..d6be12d 100644 --- a/cogs/players.py +++ b/cogs/players.py @@ -24,7 +24,7 @@ import helpers # # from in_game import data_cache, simulations from in_game.data_cache import get_pd_pitchingcard, get_pd_battingcard, get_pd_player from in_game.simulations import get_pos_embeds, get_result -from db_calls import db_get, db_post, db_patch, get_team_by_abbrev +from api_calls import db_get, db_post, db_patch, get_team_by_abbrev from helpers import PD_PLAYERS_ROLE_NAME, IMAGES, PD_SEASON, random_conf_gif, fuzzy_player_search, ALL_MLB_TEAMS, \ fuzzy_search, get_channel, display_cards, get_card_embeds, get_team_embed, cardset_search, get_blank_team_card, \ get_team_by_owner, get_rosters, get_roster_sheet, legal_channel, random_conf_word, embed_pagination, get_cal_user, \ diff --git a/db_calls_gameplay.py b/db_calls_gameplay.py index 23777a9..07f8a19 100644 --- a/db_calls_gameplay.py +++ b/db_calls_gameplay.py @@ -12,7 +12,7 @@ from playhouse.shortcuts import model_to_dict from dataclasses import dataclass from helpers import SBA_SEASON, PD_SEASON, get_player_url, get_sheets -from db_calls import db_get +from api_calls import db_get from in_game.data_cache import get_pd_player, CardPosition, BattingCard, get_pd_team db = SqliteDatabase( diff --git a/docker-compose.yml b/docker-compose.yml index b170b61..3c9c2c5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3' - services: discord-app: image: paper-dynasty-discordapp:sqlmodel-rebuild diff --git a/gauntlets.py b/gauntlets.py index 1efb710..a955b3e 100644 --- a/gauntlets.py +++ b/gauntlets.py @@ -8,7 +8,7 @@ import discord from in_game import ai_manager import helpers from helpers import RARITY, get_or_create_role, send_to_channel, get_channel -from db_calls import db_get, db_post, db_delete, db_patch +from api_calls import db_get, db_post, db_delete, db_patch async def wipe_team(this_team, interaction: discord.Interaction, delete_team: bool = False, delete_runs: bool = False): diff --git a/helpers.py b/helpers.py index 30049a0..adca95a 100644 --- a/helpers.py +++ b/helpers.py @@ -10,7 +10,7 @@ import discord import pygsheets import requests from discord.ext import commands -from db_calls import * +from api_calls import * from bs4 import BeautifulSoup from difflib import get_close_matches @@ -3270,3 +3270,11 @@ def random_from_list(data_list: list): item = data_list[random.randint(0, len(data_list) - 1)] logging.info(f'random_from_list: {item}') return item + + +def user_has_role(user: discord.User | discord.Member, role_name: str) -> bool: + for x in user.roles: + if x.name == role_name: + return True + + return False diff --git a/in_game/ai_manager.py b/in_game/ai_manager.py index 928b8be..f0bab4e 100644 --- a/in_game/ai_manager.py +++ b/in_game/ai_manager.py @@ -7,7 +7,7 @@ import random from db_calls_gameplay import StratPlay, StratGame, get_one_lineup, get_manager, get_team_lineups, \ get_last_inning_end_play, make_sub, get_player, StratLineup, get_pitching_stats, patch_play, patch_lineup, \ get_one_game -from db_calls import db_get, db_post +from api_calls import db_get, db_post from peewee import * from typing import Optional, Literal diff --git a/in_game/data_cache.py b/in_game/data_cache.py index b4a831f..88e04dc 100644 --- a/in_game/data_cache.py +++ b/in_game/data_cache.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import logging from typing import Optional, Literal -from db_calls import db_get +from api_calls import db_get from helpers import PD_SEASON PLAYER_CACHE = {} diff --git a/in_game/game_helpers.py b/in_game/game_helpers.py index c53c9b8..fa5db1f 100644 --- a/in_game/game_helpers.py +++ b/in_game/game_helpers.py @@ -7,7 +7,7 @@ import discord from db_calls_gameplay import StratGame, StratPlay, StratLineup, StratManagerAi, patch_play, advance_runners, \ complete_play, get_team_lineups, get_or_create_bullpen, get_player, get_sheets, make_sub, get_one_lineup, \ advance_one_runner, get_one_lineup, ai_batting, get_manager -from db_calls import db_get, db_post +from api_calls import db_get, db_post from helpers import Pagination, get_team_embed, image_embed, Confirm from typing import Literal, Optional diff --git a/in_game/gameplay_db.py b/in_game/gameplay_models.py similarity index 81% rename from in_game/gameplay_db.py rename to in_game/gameplay_models.py index 0c284de..45c6814 100644 --- a/in_game/gameplay_db.py +++ b/in_game/gameplay_models.py @@ -1,9 +1,9 @@ import datetime import logging -from sqlmodel import Session, SQLModel, create_engine, select, Field, Relationship +from sqlmodel import Session, SQLModel, create_engine, select, or_, Field, Relationship from sqlalchemy import func -from db_calls import db_get +from api_calls import db_get from helpers import PD_SEASON @@ -47,11 +47,13 @@ class Game(SQLModel, table=True): # return f'Game {self.id} / Week {self.week_num} / Type {self.game_type}' -class Cardset(SQLModel, table=True): +class CardsetBase(SQLModel): id: int | None = Field(default=None, primary_key=True) name: str ranked_legal: bool | None = Field(default=False) + +class Cardset(CardsetBase, table=True): game_links: list[GameCardsetLink] = Relationship(back_populates='cardset') @@ -91,6 +93,10 @@ class TeamBase(SQLModel): is_ai: bool created: datetime.datetime | None = Field(default=datetime.datetime.now()) + @property + def description(self) -> str: + return f'{self.id}. {self.abbrev} {self.lname}, {"is_ai" if self.is_ai else "human"}' + class Team(TeamBase, table=True): pass @@ -114,9 +120,9 @@ async def get_team( this_team = session.exec(statement).one_or_none() if this_team is not None: - logging.info(f'we found a team: {this_team} / created: {this_team.created}') + logging.debug(f'we found a team: {this_team} / created: {this_team.created}') tdelta = datetime.datetime.now() - this_team.created - logging.info(f'tdelta: {tdelta}') + logging.debug(f'tdelta: {tdelta}') if tdelta.total_seconds() < 1209600: return this_team else: @@ -124,11 +130,11 @@ async def get_team( session.commit() def cache_team(json_data: dict) -> Team: - logging.info(f'gameplay_db - get_team - cache_team - writing a team to cache: {json_data}') + # logging.info(f'gameplay_db - get_team - cache_team - writing a team to cache: {json_data}') valid_team = TeamBase.model_validate(json_data, from_attributes=True) - logging.info(f'gameplay_db - get_team - cache_team - valid_team: {valid_team}') + # logging.info(f'gameplay_db - get_team - cache_team - valid_team: {valid_team}') db_team = Team.model_validate(valid_team) - logging.info(f'gameplay_db - get_team - cache_team - db_team: {db_team}') + # logging.info(f'gameplay_db - get_team - cache_team - db_team: {db_team}') session.add(db_team) session.commit() session.refresh(db_team) @@ -154,7 +160,41 @@ async def get_team( err = 'Team not found' logging.error(f'gameplay_db - get_team - {err}') raise LookupError(err) - + + +class PlayerBase(SQLModel): + id: int | None = Field(primary_key=True) + name: str + cost: int + image: str + mlbclub: str + franchise: str + cardset: dict + set_num: int + rarity: dict + pos_1: str + description: str + quantity: int | None = Field(default=999) + image2: str | None = Field(default=None) + pos_2: str | None = Field(default=None) + pos_3: str | None = Field(default=None) + pos_4: str | None = Field(default=None) + pos_5: str | None = Field(default=None) + pos_6: str | None = Field(default=None) + pos_7: str | None = Field(default=None) + pos_8: str | None = Field(default=None) + headshot: str | None = Field(default=None) + vanity_card: str | None = Field(default=None) + strat_code: str | None = Field(default=None) + bbref_id: str | None = Field(default=None) + fangr_id: str | None = Field(default=None) + mlbplayer_id: int | None = Field(default=None) + created: datetime.datetime | None = Field(default=datetime.datetime.now()) + + +class Player(PlayerBase, table=True): + pass + """ diff --git a/in_game/gameplay_queries.py b/in_game/gameplay_queries.py new file mode 100644 index 0000000..52732b3 --- /dev/null +++ b/in_game/gameplay_queries.py @@ -0,0 +1,21 @@ +import logging +from in_game.gameplay_models import Session, select, or_, Game + + +def get_games_by_channel(session: Session, channel_id: int) -> list[Game]: + return session.exec(select(Game).where(Game.channel_id == channel_id, Game.active)).all() + + +def get_channel_game_or_none(session: Session, channel_id: int) -> Game | None: + all_games = get_games_by_channel(session, channel_id) + if len(all_games) > 1: + err = 'Too many games found in get_channel_game_or_none' + logging.error(f'cogs.gameplay - get_channel_game_or_none - channel_id: {channel_id} / {err}') + raise LookupError(err) + elif len(all_games) == 0: + return None + return all_games[0] + + +def get_games_by_team_id(session: Session, team_id: int) -> list[Game]: + return session.exec(select(Game).where(Game.active, or_(Game.away_team_id == team_id, Game.home_team_id == team_id))).all() \ No newline at end of file diff --git a/migrations/env.py b/migrations/env.py index b33e098..3f64547 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -5,7 +5,7 @@ from sqlalchemy import pool from alembic import context -from in_game.gameplay_db import sqlite_url, Game +from in_game.gameplay_models import sqlite_url, Game from sqlmodel import SQLModel # this is the Alembic Config object, which provides diff --git a/tests/test_gameplay_db_game.py b/tests/test_gameplay_db_game.py index 3c3e0b8..e99bcb0 100644 --- a/tests/test_gameplay_db_game.py +++ b/tests/test_gameplay_db_game.py @@ -1,6 +1,8 @@ +import pytest from sqlmodel import Session -from in_game.gameplay_db import Game, select +from in_game.gameplay_models import Game, select +from in_game.gameplay_queries import get_games_by_channel, get_channel_game_or_none, get_games_by_team_id from factory import session_fixture, new_games_fixture @@ -43,3 +45,56 @@ def test_select_all_empty(session: Session): assert len(games) == 0 +def test_games_by_channel(session: Session, new_games: list[Game]): + game_1 = new_games[0] + game_2 = new_games[1] + game_3 = new_games[2] + session.add(game_1) + session.add(game_2) + session.add(game_3) + session.commit() + + assert get_channel_game_or_none(session, 1234) is not None + assert get_channel_game_or_none(session, 5678) is not None + assert get_channel_game_or_none(session, 69) is None + + game_2.active = True + session.add(game_2) + session.commit() + + with pytest.raises(LookupError) as exc_info: + get_channel_game_or_none(session, 5678) + + assert str(exc_info) == "" + + game_2.active = False + game_3.active = False + session.add(game_2) + session.add(game_3) + session.commit() + + assert get_channel_game_or_none(session, 5678) is None + + +def test_games_by_team(session: Session, new_games: list[Game]): + game_1 = new_games[1] + game_2 = new_games[2] + session.add(game_1) + session.add(game_2) + session.commit() + + assert len(get_games_by_team_id(session, team_id=3)) == 1 + + game_1.active = True + session.add(game_1) + session.commit() + + assert len(get_games_by_team_id(session, team_id=3)) == 2 + + game_1.active = False + game_2.active = False + session.add(game_1) + session.add(game_2) + session.commit() + + assert len(get_games_by_team_id(session, team_id=3)) == 0 diff --git a/tests/test_gameplay_db_lineup.py b/tests/test_gameplay_db_lineup.py index e3ea134..01c7043 100644 --- a/tests/test_gameplay_db_lineup.py +++ b/tests/test_gameplay_db_lineup.py @@ -1,6 +1,6 @@ from sqlmodel import Session, select -from in_game.gameplay_db import Game, Lineup +from in_game.gameplay_models import Game, Lineup from factory import session_fixture, new_games_with_lineups_fixture, new_games_fixture