Create ManagerAi model
This commit is contained in:
parent
1b87bfdb92
commit
10c68c02b1
@ -1,11 +1,15 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import discord
|
||||
import pydantic
|
||||
|
||||
from sqlmodel import Session, SQLModel, create_engine, select, or_, Field, Relationship
|
||||
from sqlalchemy import func
|
||||
|
||||
from api_calls import db_get, db_post
|
||||
from in_game.managerai_responses import JumpResponse
|
||||
|
||||
|
||||
sqlite_url = 'sqlite:///storage/gameplay.db'
|
||||
@ -14,6 +18,119 @@ engine = create_engine(sqlite_url, echo=False, connect_args=connect_args)
|
||||
CACHE_LIMIT = 1209600 # in seconds
|
||||
|
||||
|
||||
class ManagerAiBase(SQLModel):
|
||||
id: int | None = Field(primary_key=True)
|
||||
name: str = Field(index=True)
|
||||
steal: int | None = Field(default=5)
|
||||
running: int | None = Field(default=5)
|
||||
hold: int | None = Field(default=5)
|
||||
catcher_throw: int | None = Field(default=5)
|
||||
uncapped_home: int | None = Field(default=5)
|
||||
uncapped_third: int | None = Field(default=5)
|
||||
uncapped_trail: int | None = Field(default=5)
|
||||
bullpen_matchup: int | None = Field(default=5)
|
||||
behind_aggression: int | None = Field(default=5)
|
||||
ahead_aggression: int | None = Field(default=5)
|
||||
decide_throw: int | None = Field(default=5)
|
||||
|
||||
|
||||
class ManagerAi(ManagerAiBase, table=True):
|
||||
def create_ai(session: Session = None):
|
||||
def get_new_ai(this_session: Session):
|
||||
all_ai = session.exec(select(ManagerAi.id)).all()
|
||||
if len(all_ai) == 0:
|
||||
logging.info(f'Creating ManagerAI records')
|
||||
new_ai = [
|
||||
ManagerAi(
|
||||
name='Balanced'
|
||||
),
|
||||
ManagerAi(
|
||||
name='Yolo',
|
||||
steal=10,
|
||||
running=10,
|
||||
hold=5,
|
||||
catcher_throw=10,
|
||||
uncapped_home=10,
|
||||
uncapped_third=10,
|
||||
uncapped_trail=10,
|
||||
bullpen_matchup=3,
|
||||
behind_aggression=10,
|
||||
ahead_aggression=10,
|
||||
decide_throw=10
|
||||
),
|
||||
ManagerAi(
|
||||
name='Safe',
|
||||
steal=3,
|
||||
running=3,
|
||||
hold=8,
|
||||
catcher_throw=5,
|
||||
uncapped_home=5,
|
||||
uncapped_third=3,
|
||||
uncapped_trail=5,
|
||||
bullpen_matchup=8,
|
||||
behind_aggression=5,
|
||||
ahead_aggression=1,
|
||||
decide_throw=1
|
||||
)
|
||||
]
|
||||
for x in new_ai:
|
||||
session.add(x)
|
||||
session.commit()
|
||||
|
||||
if session is None:
|
||||
with Session(engine) as session:
|
||||
get_new_ai(session)
|
||||
else:
|
||||
get_new_ai(session)
|
||||
|
||||
return True
|
||||
|
||||
def check_jump(self, to_base: Literal[2, 3, 4], num_outs: Literal[0, 1, 2], run_diff: int) -> JumpResponse | None:
|
||||
this_resp = JumpResponse()
|
||||
if to_base == 2:
|
||||
match self.steal:
|
||||
case 10:
|
||||
this_resp.min_safe = 12 + num_outs
|
||||
case self.steal if self.steal > 8 and run_diff <= 5:
|
||||
this_resp.min_safe = 13 + num_outs
|
||||
case self.steal if self.steal > 6 and run_diff <= 5:
|
||||
this_resp.min_safe = 14 + num_outs
|
||||
case self.steal if self.steal > 4 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 15 + num_outs
|
||||
case self.steal if self.steal > 2 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 16 + num_outs
|
||||
case _:
|
||||
this_resp = 17 + num_outs
|
||||
|
||||
if self.steal > 7 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.run_if_auto_jump = True
|
||||
elif self.steal < 5:
|
||||
this_resp.must_auto_jump = True
|
||||
|
||||
elif to_base == 3:
|
||||
match self.steal:
|
||||
case 10:
|
||||
this_resp.min_safe = 12 + num_outs
|
||||
case self.steal if self.steal > 6 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.min_safe = 15 + num_outs
|
||||
case _:
|
||||
this_resp.min_safe = None
|
||||
|
||||
if self.steal == 10 and num_outs < 2 and run_diff <= 5:
|
||||
this_resp.run_if_auto_jump = True
|
||||
elif self.steal <= 5:
|
||||
this_resp.must_auto_jump = True
|
||||
|
||||
elif run_diff == -1:
|
||||
match self.steal:
|
||||
case self.steal if self.steal == 10:
|
||||
this_resp.min_safe = 5
|
||||
case self.steal if self.steal > 5:
|
||||
this_resp.min_safe = 7
|
||||
|
||||
return this_resp
|
||||
|
||||
|
||||
class GameCardsetLink(SQLModel, table=True):
|
||||
game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True)
|
||||
cardset_id: int | None = Field(default=None, foreign_key='cardset.id', primary_key=True)
|
||||
@ -494,6 +611,7 @@ BEGIN DEVELOPMENT HELPERS
|
||||
|
||||
def create_db_and_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
ManagerAi.create_ai()
|
||||
|
||||
|
||||
def create_test_games():
|
||||
|
||||
6
in_game/managerai_responses.py
Normal file
6
in_game/managerai_responses.py
Normal file
@ -0,0 +1,6 @@
|
||||
import pydantic
|
||||
|
||||
class JumpResponse(pydantic.BaseModel):
|
||||
min_safe: int | None = None
|
||||
must_auto_jump: bool = False
|
||||
run_if_auto_jump: bool = False
|
||||
@ -4,7 +4,7 @@ from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from typing import Literal
|
||||
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, Play, Team, Player
|
||||
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, ManagerAi, Play, Team, Player
|
||||
|
||||
|
||||
@pytest.fixture(name='session')
|
||||
@ -175,4 +175,6 @@ def session_fixture():
|
||||
|
||||
session.commit()
|
||||
|
||||
all_ai = ManagerAi.create_ai(session)
|
||||
|
||||
yield session
|
||||
|
||||
30
tests/gameplay_models/test_managerai_model.py
Normal file
30
tests/gameplay_models/test_managerai_model.py
Normal file
@ -0,0 +1,30 @@
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from in_game.gameplay_models import ManagerAi
|
||||
from factory import session_fixture
|
||||
from in_game.managerai_responses import JumpResponse
|
||||
|
||||
|
||||
def test_create_ai(session: Session):
|
||||
all_ai = session.exec(select(ManagerAi)).all()
|
||||
|
||||
assert len(all_ai) == 3
|
||||
assert ManagerAi.create_ai(session) == True
|
||||
|
||||
all_ai = session.exec(select(ManagerAi)).all()
|
||||
|
||||
assert len(all_ai) == 3
|
||||
|
||||
|
||||
def test_check_jump(session: Session):
|
||||
balanced_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Balanced')).one()
|
||||
aggressive_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Yolo')).one()
|
||||
|
||||
bal_second_22 = balanced_ai.check_jump(to_base=2, num_outs=0)
|
||||
agg_second_20 = aggressive_ai.check_jump(to_base=2, num_outs=0)
|
||||
agg_second_22 = aggressive_ai.check_jump(to_base=2, num_outs=2)
|
||||
|
||||
assert bal_second_22 == JumpResponse(min_safe=15)
|
||||
assert balanced_ai.check_jump(to_base=4, num_outs=2) is None
|
||||
assert agg_second_20 == JumpResponse(min_safe=12, run_if_auto_jump=True)
|
||||
assert agg_second_22.run_if_auto_jump == False
|
||||
Loading…
Reference in New Issue
Block a user