Create ManagerAi model

This commit is contained in:
Cal Corum 2024-10-16 14:09:11 -05:00
parent 1b87bfdb92
commit 10c68c02b1
4 changed files with 157 additions and 1 deletions

View File

@ -1,11 +1,15 @@
import datetime
import logging
from typing import Literal
import discord
import pydantic
from sqlmodel import Session, SQLModel, create_engine, select, or_, Field, Relationship
from sqlalchemy import func
from api_calls import db_get, db_post
from in_game.managerai_responses import JumpResponse
sqlite_url = 'sqlite:///storage/gameplay.db'
@ -14,6 +18,119 @@ engine = create_engine(sqlite_url, echo=False, connect_args=connect_args)
CACHE_LIMIT = 1209600 # in seconds
class ManagerAiBase(SQLModel):
id: int | None = Field(primary_key=True)
name: str = Field(index=True)
steal: int | None = Field(default=5)
running: int | None = Field(default=5)
hold: int | None = Field(default=5)
catcher_throw: int | None = Field(default=5)
uncapped_home: int | None = Field(default=5)
uncapped_third: int | None = Field(default=5)
uncapped_trail: int | None = Field(default=5)
bullpen_matchup: int | None = Field(default=5)
behind_aggression: int | None = Field(default=5)
ahead_aggression: int | None = Field(default=5)
decide_throw: int | None = Field(default=5)
class ManagerAi(ManagerAiBase, table=True):
def create_ai(session: Session = None):
def get_new_ai(this_session: Session):
all_ai = session.exec(select(ManagerAi.id)).all()
if len(all_ai) == 0:
logging.info(f'Creating ManagerAI records')
new_ai = [
ManagerAi(
name='Balanced'
),
ManagerAi(
name='Yolo',
steal=10,
running=10,
hold=5,
catcher_throw=10,
uncapped_home=10,
uncapped_third=10,
uncapped_trail=10,
bullpen_matchup=3,
behind_aggression=10,
ahead_aggression=10,
decide_throw=10
),
ManagerAi(
name='Safe',
steal=3,
running=3,
hold=8,
catcher_throw=5,
uncapped_home=5,
uncapped_third=3,
uncapped_trail=5,
bullpen_matchup=8,
behind_aggression=5,
ahead_aggression=1,
decide_throw=1
)
]
for x in new_ai:
session.add(x)
session.commit()
if session is None:
with Session(engine) as session:
get_new_ai(session)
else:
get_new_ai(session)
return True
def check_jump(self, to_base: Literal[2, 3, 4], num_outs: Literal[0, 1, 2], run_diff: int) -> JumpResponse | None:
this_resp = JumpResponse()
if to_base == 2:
match self.steal:
case 10:
this_resp.min_safe = 12 + num_outs
case self.steal if self.steal > 8 and run_diff <= 5:
this_resp.min_safe = 13 + num_outs
case self.steal if self.steal > 6 and run_diff <= 5:
this_resp.min_safe = 14 + num_outs
case self.steal if self.steal > 4 and num_outs < 2 and run_diff <= 5:
this_resp.min_safe = 15 + num_outs
case self.steal if self.steal > 2 and num_outs < 2 and run_diff <= 5:
this_resp.min_safe = 16 + num_outs
case _:
this_resp = 17 + num_outs
if self.steal > 7 and num_outs < 2 and run_diff <= 5:
this_resp.run_if_auto_jump = True
elif self.steal < 5:
this_resp.must_auto_jump = True
elif to_base == 3:
match self.steal:
case 10:
this_resp.min_safe = 12 + num_outs
case self.steal if self.steal > 6 and num_outs < 2 and run_diff <= 5:
this_resp.min_safe = 15 + num_outs
case _:
this_resp.min_safe = None
if self.steal == 10 and num_outs < 2 and run_diff <= 5:
this_resp.run_if_auto_jump = True
elif self.steal <= 5:
this_resp.must_auto_jump = True
elif run_diff == -1:
match self.steal:
case self.steal if self.steal == 10:
this_resp.min_safe = 5
case self.steal if self.steal > 5:
this_resp.min_safe = 7
return this_resp
class GameCardsetLink(SQLModel, table=True):
game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True)
cardset_id: int | None = Field(default=None, foreign_key='cardset.id', primary_key=True)
@ -494,6 +611,7 @@ BEGIN DEVELOPMENT HELPERS
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
ManagerAi.create_ai()
def create_test_games():

View File

@ -0,0 +1,6 @@
import pydantic
class JumpResponse(pydantic.BaseModel):
min_safe: int | None = None
must_auto_jump: bool = False
run_if_auto_jump: bool = False

View File

@ -4,7 +4,7 @@ from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
from typing import Literal
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, Play, Team, Player
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, ManagerAi, Play, Team, Player
@pytest.fixture(name='session')
@ -175,4 +175,6 @@ def session_fixture():
session.commit()
all_ai = ManagerAi.create_ai(session)
yield session

View File

@ -0,0 +1,30 @@
from sqlmodel import Session, select
from in_game.gameplay_models import ManagerAi
from factory import session_fixture
from in_game.managerai_responses import JumpResponse
def test_create_ai(session: Session):
all_ai = session.exec(select(ManagerAi)).all()
assert len(all_ai) == 3
assert ManagerAi.create_ai(session) == True
all_ai = session.exec(select(ManagerAi)).all()
assert len(all_ai) == 3
def test_check_jump(session: Session):
balanced_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Balanced')).one()
aggressive_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Yolo')).one()
bal_second_22 = balanced_ai.check_jump(to_base=2, num_outs=0)
agg_second_20 = aggressive_ai.check_jump(to_base=2, num_outs=0)
agg_second_22 = aggressive_ai.check_jump(to_base=2, num_outs=2)
assert bal_second_22 == JumpResponse(min_safe=15)
assert balanced_ai.check_jump(to_base=4, num_outs=2) is None
assert agg_second_20 == JumpResponse(min_safe=12, run_if_auto_jump=True)
assert agg_second_22.run_if_auto_jump == False