Create ManagerAi model
This commit is contained in:
parent
1b87bfdb92
commit
10c68c02b1
@ -1,11 +1,15 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from sqlmodel import Session, SQLModel, create_engine, select, or_, Field, Relationship
|
from sqlmodel import Session, SQLModel, create_engine, select, or_, Field, Relationship
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
from api_calls import db_get, db_post
|
from api_calls import db_get, db_post
|
||||||
|
from in_game.managerai_responses import JumpResponse
|
||||||
|
|
||||||
|
|
||||||
sqlite_url = 'sqlite:///storage/gameplay.db'
|
sqlite_url = 'sqlite:///storage/gameplay.db'
|
||||||
@ -14,6 +18,119 @@ engine = create_engine(sqlite_url, echo=False, connect_args=connect_args)
|
|||||||
CACHE_LIMIT = 1209600 # in seconds
|
CACHE_LIMIT = 1209600 # in seconds
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerAiBase(SQLModel):
|
||||||
|
id: int | None = Field(primary_key=True)
|
||||||
|
name: str = Field(index=True)
|
||||||
|
steal: int | None = Field(default=5)
|
||||||
|
running: int | None = Field(default=5)
|
||||||
|
hold: int | None = Field(default=5)
|
||||||
|
catcher_throw: int | None = Field(default=5)
|
||||||
|
uncapped_home: int | None = Field(default=5)
|
||||||
|
uncapped_third: int | None = Field(default=5)
|
||||||
|
uncapped_trail: int | None = Field(default=5)
|
||||||
|
bullpen_matchup: int | None = Field(default=5)
|
||||||
|
behind_aggression: int | None = Field(default=5)
|
||||||
|
ahead_aggression: int | None = Field(default=5)
|
||||||
|
decide_throw: int | None = Field(default=5)
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerAi(ManagerAiBase, table=True):
|
||||||
|
def create_ai(session: Session = None):
|
||||||
|
def get_new_ai(this_session: Session):
|
||||||
|
all_ai = session.exec(select(ManagerAi.id)).all()
|
||||||
|
if len(all_ai) == 0:
|
||||||
|
logging.info(f'Creating ManagerAI records')
|
||||||
|
new_ai = [
|
||||||
|
ManagerAi(
|
||||||
|
name='Balanced'
|
||||||
|
),
|
||||||
|
ManagerAi(
|
||||||
|
name='Yolo',
|
||||||
|
steal=10,
|
||||||
|
running=10,
|
||||||
|
hold=5,
|
||||||
|
catcher_throw=10,
|
||||||
|
uncapped_home=10,
|
||||||
|
uncapped_third=10,
|
||||||
|
uncapped_trail=10,
|
||||||
|
bullpen_matchup=3,
|
||||||
|
behind_aggression=10,
|
||||||
|
ahead_aggression=10,
|
||||||
|
decide_throw=10
|
||||||
|
),
|
||||||
|
ManagerAi(
|
||||||
|
name='Safe',
|
||||||
|
steal=3,
|
||||||
|
running=3,
|
||||||
|
hold=8,
|
||||||
|
catcher_throw=5,
|
||||||
|
uncapped_home=5,
|
||||||
|
uncapped_third=3,
|
||||||
|
uncapped_trail=5,
|
||||||
|
bullpen_matchup=8,
|
||||||
|
behind_aggression=5,
|
||||||
|
ahead_aggression=1,
|
||||||
|
decide_throw=1
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for x in new_ai:
|
||||||
|
session.add(x)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
with Session(engine) as session:
|
||||||
|
get_new_ai(session)
|
||||||
|
else:
|
||||||
|
get_new_ai(session)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_jump(self, to_base: Literal[2, 3, 4], num_outs: Literal[0, 1, 2], run_diff: int) -> JumpResponse | None:
|
||||||
|
this_resp = JumpResponse()
|
||||||
|
if to_base == 2:
|
||||||
|
match self.steal:
|
||||||
|
case 10:
|
||||||
|
this_resp.min_safe = 12 + num_outs
|
||||||
|
case self.steal if self.steal > 8 and run_diff <= 5:
|
||||||
|
this_resp.min_safe = 13 + num_outs
|
||||||
|
case self.steal if self.steal > 6 and run_diff <= 5:
|
||||||
|
this_resp.min_safe = 14 + num_outs
|
||||||
|
case self.steal if self.steal > 4 and num_outs < 2 and run_diff <= 5:
|
||||||
|
this_resp.min_safe = 15 + num_outs
|
||||||
|
case self.steal if self.steal > 2 and num_outs < 2 and run_diff <= 5:
|
||||||
|
this_resp.min_safe = 16 + num_outs
|
||||||
|
case _:
|
||||||
|
this_resp = 17 + num_outs
|
||||||
|
|
||||||
|
if self.steal > 7 and num_outs < 2 and run_diff <= 5:
|
||||||
|
this_resp.run_if_auto_jump = True
|
||||||
|
elif self.steal < 5:
|
||||||
|
this_resp.must_auto_jump = True
|
||||||
|
|
||||||
|
elif to_base == 3:
|
||||||
|
match self.steal:
|
||||||
|
case 10:
|
||||||
|
this_resp.min_safe = 12 + num_outs
|
||||||
|
case self.steal if self.steal > 6 and num_outs < 2 and run_diff <= 5:
|
||||||
|
this_resp.min_safe = 15 + num_outs
|
||||||
|
case _:
|
||||||
|
this_resp.min_safe = None
|
||||||
|
|
||||||
|
if self.steal == 10 and num_outs < 2 and run_diff <= 5:
|
||||||
|
this_resp.run_if_auto_jump = True
|
||||||
|
elif self.steal <= 5:
|
||||||
|
this_resp.must_auto_jump = True
|
||||||
|
|
||||||
|
elif run_diff == -1:
|
||||||
|
match self.steal:
|
||||||
|
case self.steal if self.steal == 10:
|
||||||
|
this_resp.min_safe = 5
|
||||||
|
case self.steal if self.steal > 5:
|
||||||
|
this_resp.min_safe = 7
|
||||||
|
|
||||||
|
return this_resp
|
||||||
|
|
||||||
|
|
||||||
class GameCardsetLink(SQLModel, table=True):
|
class GameCardsetLink(SQLModel, table=True):
|
||||||
game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True)
|
game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True)
|
||||||
cardset_id: int | None = Field(default=None, foreign_key='cardset.id', primary_key=True)
|
cardset_id: int | None = Field(default=None, foreign_key='cardset.id', primary_key=True)
|
||||||
@ -494,6 +611,7 @@ BEGIN DEVELOPMENT HELPERS
|
|||||||
|
|
||||||
def create_db_and_tables():
|
def create_db_and_tables():
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
ManagerAi.create_ai()
|
||||||
|
|
||||||
|
|
||||||
def create_test_games():
|
def create_test_games():
|
||||||
|
|||||||
6
in_game/managerai_responses.py
Normal file
6
in_game/managerai_responses.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import pydantic
|
||||||
|
|
||||||
|
class JumpResponse(pydantic.BaseModel):
|
||||||
|
min_safe: int | None = None
|
||||||
|
must_auto_jump: bool = False
|
||||||
|
run_if_auto_jump: bool = False
|
||||||
@ -4,7 +4,7 @@ from sqlmodel import Session, SQLModel, create_engine
|
|||||||
from sqlmodel.pool import StaticPool
|
from sqlmodel.pool import StaticPool
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, Play, Team, Player
|
from in_game.gameplay_models import Card, Cardset, Game, GameCardsetLink, Lineup, ManagerAi, Play, Team, Player
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name='session')
|
@pytest.fixture(name='session')
|
||||||
@ -175,4 +175,6 @@ def session_fixture():
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
all_ai = ManagerAi.create_ai(session)
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|||||||
30
tests/gameplay_models/test_managerai_model.py
Normal file
30
tests/gameplay_models/test_managerai_model.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from in_game.gameplay_models import ManagerAi
|
||||||
|
from factory import session_fixture
|
||||||
|
from in_game.managerai_responses import JumpResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_ai(session: Session):
|
||||||
|
all_ai = session.exec(select(ManagerAi)).all()
|
||||||
|
|
||||||
|
assert len(all_ai) == 3
|
||||||
|
assert ManagerAi.create_ai(session) == True
|
||||||
|
|
||||||
|
all_ai = session.exec(select(ManagerAi)).all()
|
||||||
|
|
||||||
|
assert len(all_ai) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_jump(session: Session):
|
||||||
|
balanced_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Balanced')).one()
|
||||||
|
aggressive_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Yolo')).one()
|
||||||
|
|
||||||
|
bal_second_22 = balanced_ai.check_jump(to_base=2, num_outs=0)
|
||||||
|
agg_second_20 = aggressive_ai.check_jump(to_base=2, num_outs=0)
|
||||||
|
agg_second_22 = aggressive_ai.check_jump(to_base=2, num_outs=2)
|
||||||
|
|
||||||
|
assert bal_second_22 == JumpResponse(min_safe=15)
|
||||||
|
assert balanced_ai.check_jump(to_base=4, num_outs=2) is None
|
||||||
|
assert agg_second_20 == JumpResponse(min_safe=12, run_if_auto_jump=True)
|
||||||
|
assert agg_second_22.run_if_auto_jump == False
|
||||||
Loading…
Reference in New Issue
Block a user